diff --git a/src/api/_api_internal/_api_internal.cc b/src/api/_api_internal/_api_internal.cc index ae14f1fefeeb..dc0dac811037 100644 --- a/src/api/_api_internal/_api_internal.cc +++ b/src/api/_api_internal/_api_internal.cc @@ -35,80 +35,72 @@ namespace mxnet { MXNET_REGISTER_GLOBAL("_Integer") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - if (args[0].type_code() == kDLInt) { - *ret = Integer(args[0].operator int64_t()); - } else { - LOG(FATAL) << "only accept int"; - } -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + if (args[0].type_code() == kDLInt) { + *ret = Integer(args[0].operator int64_t()); + } else { + LOG(FATAL) << "only accept int"; + } + }); -MXNET_REGISTER_GLOBAL("_Float") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - if (args[0].type_code() == kDLFloat) { - *ret = Float(args[0].operator double()); - } else { - LOG(FATAL) << "only accept float"; - } +MXNET_REGISTER_GLOBAL("_Float").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + if (args[0].type_code() == kDLFloat) { + *ret = Float(args[0].operator double()); + } else { + LOG(FATAL) << "only accept float"; + } }); -MXNET_REGISTER_GLOBAL("_ADT") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - std::vector data; - for (int i = 0; i < args.size(); ++i) { - if (args[i].type_code() == kNDArrayHandle) { - mxnet::NDArray *array = args[i].operator mxnet::NDArray*(); - ObjectRef input = NDArrayHandle(array); - data.push_back(input); - } else if (args[i].type_code() != kNull) { - ObjectRef input = String::CanConvertFrom(args[i]) ? args[i].operator String() - : args[i].operator ObjectRef(); - data.push_back(input); - } else { - data.emplace_back(nullptr); - } +MXNET_REGISTER_GLOBAL("_ADT").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + std::vector data; + for (int i = 0; i < args.size(); ++i) { + if (args[i].type_code() == kNDArrayHandle) { + mxnet::NDArray* array = args[i].operator mxnet::NDArray*(); + ObjectRef input = NDArrayHandle(array); + data.push_back(input); + } else if (args[i].type_code() != kNull) { + ObjectRef input = String::CanConvertFrom(args[i]) ? args[i].operator String() + : args[i].operator ObjectRef(); + data.push_back(input); + } else { + data.emplace_back(nullptr); } - *ret = ADT(0, data.begin(), data.end()); + } + *ret = ADT(0, data.begin(), data.end()); }); -MXNET_REGISTER_GLOBAL("_Map") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - CHECK_EQ(args.size() % 2, 0); - std::unordered_map data; - for (int i = 0; i < args.num_args; i += 2) { - ObjectRef k = - String::CanConvertFrom(args[i]) ? args[i].operator String() - : args[i].operator ObjectRef(); - ObjectRef v; - if (args[i + 1].type_code() == kNDArrayHandle) { - mxnet::NDArray *array = args[i + 1].operator mxnet::NDArray*(); - v = NDArrayHandle(array); - } else { - v = args[i + 1]; - } - data.emplace(std::move(k), std::move(v)); +MXNET_REGISTER_GLOBAL("_Map").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + CHECK_EQ(args.size() % 2, 0); + std::unordered_map data; + for (int i = 0; i < args.num_args; i += 2) { + ObjectRef k = + String::CanConvertFrom(args[i]) ? args[i].operator String() : args[i].operator ObjectRef(); + ObjectRef v; + if (args[i + 1].type_code() == kNDArrayHandle) { + mxnet::NDArray* array = args[i + 1].operator mxnet::NDArray*(); + v = NDArrayHandle(array); + } else { + v = args[i + 1]; } - *ret = Map(data); + data.emplace(std::move(k), std::move(v)); + } + *ret = Map(data); }); -MXNET_REGISTER_GLOBAL("_String") -.set_body([] (runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { +MXNET_REGISTER_GLOBAL("_String").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { using namespace runtime; std::string str = args[0].operator std::string(); - *ret = String(std::move(str)); + *ret = String(std::move(str)); }); -MXNET_REGISTER_GLOBAL("_echo") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { +MXNET_REGISTER_GLOBAL("_echo").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { *ret = args[0]; }); -MXNET_REGISTER_API("_nop") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { -}); +MXNET_REGISTER_API("_nop").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {}); } // namespace mxnet diff --git a/src/api/cached_op_api.cc b/src/api/cached_op_api.cc index 1c325d229da3..79494ea80bcf 100644 --- a/src/api/cached_op_api.cc +++ b/src/api/cached_op_api.cc @@ -30,119 +30,118 @@ namespace mxnet { MXNET_REGISTER_GLOBAL("cached_op.invoke") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - CachedOpPtr op_shared = *static_cast(args[0].value().v_handle); - // CachedOp* points to CachedOpThreadSafe object if CreateCachedOpEX - // was called with thread_safe=true - CachedOp* op = dynamic_cast(op_shared.get()); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + CachedOpPtr op_shared = *static_cast(args[0].value().v_handle); + // CachedOp* points to CachedOpThreadSafe object if CreateCachedOpEX + // was called with thread_safe=true + CachedOp* op = dynamic_cast(op_shared.get()); - int num_inputs = args[1]; - int args_size = args.size(); - std::vector ndinputs; - ndinputs.reserve(num_inputs); - for (int i = 2; i < num_inputs + 2; ++i) { - ndinputs.push_back(static_cast(args[i])); - } + int num_inputs = args[1]; + int args_size = args.size(); + std::vector ndinputs; + ndinputs.reserve(num_inputs); + for (int i = 2; i < num_inputs + 2; ++i) { + ndinputs.push_back(static_cast(args[i])); + } - std::vector ndoutputs; - ndoutputs.reserve(op->num_outputs()); - if (args[num_inputs + 4].type_code() == kNull) { - for (int i = 0; i < op->num_outputs(); ++i) ndoutputs.push_back(new NDArray()); - } else { - int array_size = args_size - num_inputs - 4; - CHECK_EQ(array_size, op->num_outputs()) - << "CachedOp expects " << op->num_outputs() << " outputs, but " - << array_size << " was given."; - for (int i = num_inputs + 4; i < array_size; ++i) { - ndoutputs.push_back(args[i].operator mxnet::NDArray*()); - } - } + std::vector ndoutputs; + ndoutputs.reserve(op->num_outputs()); + if (args[num_inputs + 4].type_code() == kNull) { + for (int i = 0; i < op->num_outputs(); ++i) + ndoutputs.push_back(new NDArray()); + } else { + int array_size = args_size - num_inputs - 4; + CHECK_EQ(array_size, op->num_outputs()) << "CachedOp expects " << op->num_outputs() + << " outputs, but " << array_size << " was given."; + for (int i = num_inputs + 4; i < array_size; ++i) { + ndoutputs.push_back(args[i].operator mxnet::NDArray*()); + } + } - int default_dev_type; - int default_dev_id; - if (args[num_inputs + 2].type_code() != kNull) { - default_dev_type = args[num_inputs + 2]; - default_dev_id = args[num_inputs + 3]; - } else { - const Context &ctx = ndinputs[0]->ctx(); - default_dev_type = ctx.dev_type; - default_dev_id = ctx.dev_id; - } + int default_dev_type; + int default_dev_id; + if (args[num_inputs + 2].type_code() != kNull) { + default_dev_type = args[num_inputs + 2]; + default_dev_id = args[num_inputs + 3]; + } else { + const Context& ctx = ndinputs[0]->ctx(); + default_dev_type = ctx.dev_type; + default_dev_id = ctx.dev_id; + } - // construct default context - Context ctx = Context::Create(static_cast(default_dev_type), - default_dev_id); - op->Forward(op_shared, ndinputs, ndoutputs, ctx); + // construct default context + Context ctx = + Context::Create(static_cast(default_dev_type), default_dev_id); + op->Forward(op_shared, ndinputs, ndoutputs, ctx); - if (op->num_outputs() == 1) { - *ret = ndoutputs[0]; - } else { - std::vector outputs; - outputs.reserve(op->num_outputs()); - for (int i = 0; i < op->num_outputs(); ++i) { - ObjectRef out = NDArrayHandle(ndoutputs[i]); - outputs.push_back(out); - delete ndoutputs[i]; - } - *ret = runtime::ADT(0, outputs.begin(), outputs.end()); - } -}); + if (op->num_outputs() == 1) { + *ret = ndoutputs[0]; + } else { + std::vector outputs; + outputs.reserve(op->num_outputs()); + for (int i = 0; i < op->num_outputs(); ++i) { + ObjectRef out = NDArrayHandle(ndoutputs[i]); + outputs.push_back(out); + delete ndoutputs[i]; + } + *ret = runtime::ADT(0, outputs.begin(), outputs.end()); + } + }); MXNET_REGISTER_GLOBAL("cached_op.create") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - nnvm::Symbol* sym = static_cast(args[0].value().v_handle); - Object* flags_ptr = static_cast(args[1].value().v_handle); - auto* n = static_cast(flags_ptr); - int num_flags = static_cast(n->size()); - bool thread_safe = args[2]; - std::vector > flags; - flags.reserve(num_flags); - for (const auto& kv : *n) { - flags.emplace_back(std::string(runtime::Downcast(kv.first)), - std::string(runtime::Downcast(kv.second))); - } - mxnet::CachedOpPtr *out; - if (!thread_safe) { - out = new CachedOpPtr(new CachedOp(*sym, flags)); - } else { - out = new CachedOpPtr(new CachedOpThreadSafe(*sym, flags)); - } - *ret = static_cast(out); -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + nnvm::Symbol* sym = static_cast(args[0].value().v_handle); + Object* flags_ptr = static_cast(args[1].value().v_handle); + auto* n = static_cast(flags_ptr); + int num_flags = static_cast(n->size()); + bool thread_safe = args[2]; + std::vector > flags; + flags.reserve(num_flags); + for (const auto& kv : *n) { + flags.emplace_back(std::string(runtime::Downcast(kv.first)), + std::string(runtime::Downcast(kv.second))); + } + mxnet::CachedOpPtr* out; + if (!thread_safe) { + out = new CachedOpPtr(new CachedOp(*sym, flags)); + } else { + out = new CachedOpPtr(new CachedOpThreadSafe(*sym, flags)); + } + *ret = static_cast(out); + }); MXNET_REGISTER_GLOBAL("cached_op.free") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - CachedOpPtr* g = static_cast(args[0].value().v_handle); - delete g; -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + CachedOpPtr* g = static_cast(args[0].value().v_handle); + delete g; + }); MXNET_REGISTER_GLOBAL("cached_op.get_optimized_symbol") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - auto s = new nnvm::Symbol(); - CachedOpPtr op = *static_cast(args[0].value().v_handle); - *s = op->GetOptimizedSymbol(); - *ret = static_cast(static_cast(s)); -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + auto s = new nnvm::Symbol(); + CachedOpPtr op = *static_cast(args[0].value().v_handle); + *s = op->GetOptimizedSymbol(); + *ret = static_cast(static_cast(s)); + }); MXNET_REGISTER_GLOBAL("cached_op.register_op_hook") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - CachedOpHandle handle = static_cast(args[0].value().v_handle); - CachedOpMonitorCallback callback = reinterpret_cast( - reinterpret_cast(args[1].value().v_handle)); - bool monitor_all = args[2]; - CachedOpMonitorCallback callback_temp = nullptr; - std::function clbk; - if (callback) { - callback_temp = callback; - clbk = [callback_temp](const char *name, const char *opr_name, - void *handle) { - callback_temp(name, opr_name, handle); - }; - } else { - clbk = nullptr; - } - CachedOpPtr op = *static_cast(handle); - op->RegisterOpHook(clbk, monitor_all); -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + CachedOpHandle handle = static_cast(args[0].value().v_handle); + CachedOpMonitorCallback callback = reinterpret_cast( + reinterpret_cast(args[1].value().v_handle)); + bool monitor_all = args[2]; + CachedOpMonitorCallback callback_temp = nullptr; + std::function clbk; + if (callback) { + callback_temp = callback; + clbk = [callback_temp](const char* name, const char* opr_name, void* handle) { + callback_temp(name, opr_name, handle); + }; + } else { + clbk = nullptr; + } + CachedOpPtr op = *static_cast(handle); + op->RegisterOpHook(clbk, monitor_all); + }); } // namespace mxnet diff --git a/src/api/operator/numpy/linalg/np_det.cc b/src/api/operator/numpy/linalg/np_det.cc index 2a415d1d56b5..62fcf28b4b04 100644 --- a/src/api/operator/numpy/linalg/np_det.cc +++ b/src/api/operator/numpy/linalg/np_det.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -27,17 +27,16 @@ namespace mxnet { -MXNET_REGISTER_API("_npi.det") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { +MXNET_REGISTER_API("_npi.det").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { using namespace runtime; const nnvm::Op* op = Op::Get("_npi_det"); nnvm::NodeAttrs attrs; - attrs.op = op; - int num_inputs = 1; - int num_outputs = 0; + attrs.op = op; + int num_inputs = 1; + int num_outputs = 0; NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); - *ret = NDArrayHandle(ndoutputs[0]); + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = NDArrayHandle(ndoutputs[0]); }); } // namespace mxnet diff --git a/src/api/operator/numpy/linalg/np_eig.cc b/src/api/operator/numpy/linalg/np_eig.cc index 69f92a4762a1..05cfa6c71a9d 100644 --- a/src/api/operator/numpy/linalg/np_eig.cc +++ b/src/api/operator/numpy/linalg/np_eig.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -29,36 +29,32 @@ namespace mxnet { -MXNET_REGISTER_API("_npi.eig") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { +MXNET_REGISTER_API("_npi.eig").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { using namespace runtime; const nnvm::Op* op = Op::Get("_npi_eig"); nnvm::NodeAttrs attrs; - attrs.op = op; - int num_inputs = 1; + attrs.op = op; + int num_inputs = 1; NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; - int num_outputs = 0; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); - *ret = ADT(0, {NDArrayHandle(ndoutputs[0]), - NDArrayHandle(ndoutputs[1])}); + int num_outputs = 0; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = ADT(0, {NDArrayHandle(ndoutputs[0]), NDArrayHandle(ndoutputs[1])}); }); -MXNET_REGISTER_API("_npi.eigh") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { +MXNET_REGISTER_API("_npi.eigh").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { using namespace runtime; const nnvm::Op* op = Op::Get("_npi_eigh"); nnvm::NodeAttrs attrs; op::EighParam param; - param.UPLO = *((args[1].operator std::string()).c_str()); + param.UPLO = *((args[1].operator std::string()).c_str()); attrs.parsed = param; - attrs.op = op; + attrs.op = op; SetAttrDict(&attrs); - int num_inputs = 1; - int num_outputs = 0; + int num_inputs = 1; + int num_outputs = 0; NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); - *ret = ADT(0, {NDArrayHandle(ndoutputs[0]), - NDArrayHandle(ndoutputs[1])}); + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = ADT(0, {NDArrayHandle(ndoutputs[0]), NDArrayHandle(ndoutputs[1])}); }); } // namespace mxnet diff --git a/src/api/operator/numpy/linalg/np_eigvals.cc b/src/api/operator/numpy/linalg/np_eigvals.cc index acde49f87b74..04982ded7d06 100644 --- a/src/api/operator/numpy/linalg/np_eigvals.cc +++ b/src/api/operator/numpy/linalg/np_eigvals.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -29,33 +29,33 @@ namespace mxnet { MXNET_REGISTER_API("_npi.eigvals") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_eigvals"); - nnvm::NodeAttrs attrs; - attrs.op = op; - int num_inputs = 1; - int num_outputs = 0; - NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); - *ret = reinterpret_cast(ndoutputs[0]); -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_eigvals"); + nnvm::NodeAttrs attrs; + attrs.op = op; + int num_inputs = 1; + int num_outputs = 0; + NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = reinterpret_cast(ndoutputs[0]); + }); MXNET_REGISTER_API("_npi.eigvalsh") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_eigvalsh"); - nnvm::NodeAttrs attrs; - op::EigvalshParam param; - param.UPLO = *((args[1].operator std::string()).c_str()); - attrs.parsed = param; - attrs.op = op; - SetAttrDict(&attrs); - int num_inputs = 1; - int num_outputs = 0; - NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); - *ret = reinterpret_cast(ndoutputs[0]); -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_eigvalsh"); + nnvm::NodeAttrs attrs; + op::EigvalshParam param; + param.UPLO = *((args[1].operator std::string()).c_str()); + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + int num_inputs = 1; + int num_outputs = 0; + NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = reinterpret_cast(ndoutputs[0]); + }); } // namespace mxnet diff --git a/src/api/operator/numpy/linalg/np_gesvd.cc b/src/api/operator/numpy/linalg/np_gesvd.cc index a4517849cbaf..5feb5ae8c8d1 100644 --- a/src/api/operator/numpy/linalg/np_gesvd.cc +++ b/src/api/operator/numpy/linalg/np_gesvd.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -27,21 +27,19 @@ namespace mxnet { -MXNET_REGISTER_API("_npi.svd") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { +MXNET_REGISTER_API("_npi.svd").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { using namespace runtime; nnvm::NodeAttrs attrs; const nnvm::Op* op = Op::Get("_npi_svd"); - attrs.op = op; + attrs.op = op; // inputs NDArray* inputs[] = {args[0].operator NDArray*()}; - int num_inputs = 1; + int num_inputs = 1; // outputs int num_outputs = 0; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); - *ret = ADT(0, {NDArrayHandle(ndoutputs[0]), - NDArrayHandle(ndoutputs[1]), - NDArrayHandle(ndoutputs[2])}); + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = ADT( + 0, {NDArrayHandle(ndoutputs[0]), NDArrayHandle(ndoutputs[1]), NDArrayHandle(ndoutputs[2])}); }); } // namespace mxnet diff --git a/src/api/operator/numpy/linalg/np_inv.cc b/src/api/operator/numpy/linalg/np_inv.cc index 238f666f29bd..bc9853e96616 100644 --- a/src/api/operator/numpy/linalg/np_inv.cc +++ b/src/api/operator/numpy/linalg/np_inv.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -27,17 +27,16 @@ namespace mxnet { -MXNET_REGISTER_API("_npi.inv") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { +MXNET_REGISTER_API("_npi.inv").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { using namespace runtime; const nnvm::Op* op = Op::Get("_npi_inv"); nnvm::NodeAttrs attrs; - attrs.op = op; - int num_inputs = 1; - int num_outputs = 0; + attrs.op = op; + int num_inputs = 1; + int num_outputs = 0; NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); - *ret = reinterpret_cast(ndoutputs[0]); + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = reinterpret_cast(ndoutputs[0]); }); } // namespace mxnet diff --git a/src/api/operator/numpy/linalg/np_lstsq.cc b/src/api/operator/numpy/linalg/np_lstsq.cc index fbeafbee6054..e2ac7673c38b 100644 --- a/src/api/operator/numpy/linalg/np_lstsq.cc +++ b/src/api/operator/numpy/linalg/np_lstsq.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -28,8 +28,7 @@ namespace mxnet { -MXNET_REGISTER_API("_npi.lstsq") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { +MXNET_REGISTER_API("_npi.lstsq").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { using namespace runtime; const nnvm::Op* op = Op::Get("_npi_lstsq"); nnvm::NodeAttrs attrs; @@ -46,20 +45,21 @@ MXNET_REGISTER_API("_npi.lstsq") } else { param.rcond = args[2].operator double(); } - param.finfoEps32 = args[3].operator double(); - param.finfoEps64 = args[4].operator double(); + param.finfoEps32 = args[3].operator double(); + param.finfoEps64 = args[4].operator double(); param.new_default = args[2].type_code() == kNull ? true : false; - attrs.parsed = param; - attrs.op = op; + attrs.parsed = param; + attrs.op = op; SetAttrDict(&attrs); - int num_inputs = 2; - int num_outputs = 0; + int num_inputs = 2; + int num_outputs = 0; NDArray* inputs[] = {args[0].operator mxnet::NDArray*(), args[1].operator mxnet::NDArray*()}; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); - *ret = ADT(0, {NDArrayHandle(ndoutputs[0]), - NDArrayHandle(ndoutputs[1]), - NDArrayHandle(ndoutputs[2]), - NDArrayHandle(ndoutputs[3])}); + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = ADT(0, + {NDArrayHandle(ndoutputs[0]), + NDArrayHandle(ndoutputs[1]), + NDArrayHandle(ndoutputs[2]), + NDArrayHandle(ndoutputs[3])}); }); } // namespace mxnet diff --git a/src/api/operator/numpy/linalg/np_matrix_rank.cc b/src/api/operator/numpy/linalg/np_matrix_rank.cc index 4bfe66664ef8..5849973c5333 100644 --- a/src/api/operator/numpy/linalg/np_matrix_rank.cc +++ b/src/api/operator/numpy/linalg/np_matrix_rank.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -28,49 +28,47 @@ namespace mxnet { -inline static void _npi_matrix_rank_none_tol(runtime::MXNetArgs args, - runtime::MXNetRetValue* ret) { +inline static void _npi_matrix_rank_none_tol(runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { using namespace runtime; const nnvm::Op* op = Op::Get("_npi_matrix_rank_none_tol"); op::MatrixRankNoneTolParam param; nnvm::NodeAttrs attrs; - param.hermitian = args[2].operator bool(); + param.hermitian = args[2].operator bool(); param.finfoEps32 = args[3].operator double(); param.finfoEps64 = args[4].operator double(); - attrs.parsed = param; - attrs.op = op; + attrs.parsed = param; + attrs.op = op; SetAttrDict(&attrs); - int num_inputs = 1; - int num_outputs = 0; + int num_inputs = 1; + int num_outputs = 0; NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); - *ret = reinterpret_cast(ndoutputs[0]); + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = reinterpret_cast(ndoutputs[0]); } -inline static void _npi_matrix_rank(runtime::MXNetArgs args, - runtime::MXNetRetValue* ret) { +inline static void _npi_matrix_rank(runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { using namespace runtime; const nnvm::Op* op = Op::Get("_npi_matrix_rank"); op::MatrixRankParam param; nnvm::NodeAttrs attrs; param.hermitian = args[2].operator bool(); - attrs.parsed = param; - attrs.op = op; + attrs.parsed = param; + attrs.op = op; SetAttrDict(&attrs); - int num_inputs = 2; - int num_outputs = 0; + int num_inputs = 2; + int num_outputs = 0; NDArray* inputs[] = {args[0].operator mxnet::NDArray*(), args[1].operator mxnet::NDArray*()}; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); - *ret = reinterpret_cast(ndoutputs[0]); + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = reinterpret_cast(ndoutputs[0]); } MXNET_REGISTER_API("_npi.matrix_rank") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - if (args[1].type_code() == kNull) { - _npi_matrix_rank_none_tol(args, ret); - } else { - _npi_matrix_rank(args, ret); - } -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + if (args[1].type_code() == kNull) { + _npi_matrix_rank_none_tol(args, ret); + } else { + _npi_matrix_rank(args, ret); + } + }); } // namespace mxnet diff --git a/src/api/operator/numpy/linalg/np_norm.cc b/src/api/operator/numpy/linalg/np_norm.cc index 1928321ad206..b3a45701fd68 100644 --- a/src/api/operator/numpy/linalg/np_norm.cc +++ b/src/api/operator/numpy/linalg/np_norm.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -28,8 +28,7 @@ namespace mxnet { -MXNET_REGISTER_API("_npi.norm") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { +MXNET_REGISTER_API("_npi.norm").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { using namespace runtime; nnvm::NodeAttrs attrs; const nnvm::Op* op = Op::Get("_npi_norm"); @@ -41,19 +40,19 @@ MXNET_REGISTER_API("_npi.norm") param.axis = mxnet::TShape(args[2].operator ObjectRef()); } param.keepdims = args[3].operator bool(); - param.flag = args[4].operator int(); + param.flag = args[4].operator int(); - attrs.op = op; + attrs.op = op; attrs.parsed = param; SetAttrDict(&attrs); // inputs NDArray* inputs[] = {args[0].operator NDArray*()}; - int num_inputs = 1; + int num_inputs = 1; // outputs int num_outputs = 0; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); - *ret = reinterpret_cast(ndoutputs[0]); + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = reinterpret_cast(ndoutputs[0]); }); } // namespace mxnet diff --git a/src/api/operator/numpy/linalg/np_pinv.cc b/src/api/operator/numpy/linalg/np_pinv.cc index b14407c7b69f..531d7c0f8d44 100644 --- a/src/api/operator/numpy/linalg/np_pinv.cc +++ b/src/api/operator/numpy/linalg/np_pinv.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -34,14 +34,14 @@ inline static void _npi_pinv(runtime::MXNetArgs args, runtime::MXNetRetValue* re op::PinvParam param; nnvm::NodeAttrs attrs; param.hermitian = args[2].operator bool(); - attrs.parsed = param; - attrs.op = op; + attrs.parsed = param; + attrs.op = op; SetAttrDict(&attrs); - int num_inputs = 2; - int num_outputs = 0; + int num_inputs = 2; + int num_outputs = 0; NDArray* inputs[] = {args[0].operator mxnet::NDArray*(), args[1].operator mxnet::NDArray*()}; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); - *ret = reinterpret_cast(ndoutputs[0]); + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = reinterpret_cast(ndoutputs[0]); } inline static void _npi_pinv_scalar_rcond(runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { @@ -49,20 +49,19 @@ inline static void _npi_pinv_scalar_rcond(runtime::MXNetArgs args, runtime::MXNe const nnvm::Op* op = Op::Get("_npi_pinv_scalar_rcond"); op::PinvScalarRcondParam param; nnvm::NodeAttrs attrs; - param.rcond = args[1].operator double(); + param.rcond = args[1].operator double(); param.hermitian = args[2].operator bool(); - attrs.parsed = param; - attrs.op = op; + attrs.parsed = param; + attrs.op = op; SetAttrDict(&attrs); - int num_inputs = 1; - int num_outputs = 0; + int num_inputs = 1; + int num_outputs = 0; NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); - *ret = reinterpret_cast(ndoutputs[0]); + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = reinterpret_cast(ndoutputs[0]); } -MXNET_REGISTER_API("_npi.pinv") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { +MXNET_REGISTER_API("_npi.pinv").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { if (args[1].type_code() == kDLFloat || args[1].type_code() == kDLInt) { _npi_pinv_scalar_rcond(args, ret); } else { diff --git a/src/api/operator/numpy/linalg/np_potrf.cc b/src/api/operator/numpy/linalg/np_potrf.cc index 811ce74f8692..bd11a56d4796 100644 --- a/src/api/operator/numpy/linalg/np_potrf.cc +++ b/src/api/operator/numpy/linalg/np_potrf.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -29,20 +29,20 @@ namespace mxnet { MXNET_REGISTER_API("_npi.cholesky") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_cholesky"); - nnvm::NodeAttrs attrs; - op::LaCholeskyParam param; - param.lower = args[1].operator bool(); - attrs.parsed = param; - attrs.op = op; - SetAttrDict(&attrs); - int num_inputs = 1; - int num_outputs = 0; - NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); - *ret = reinterpret_cast(ndoutputs[0]); -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_cholesky"); + nnvm::NodeAttrs attrs; + op::LaCholeskyParam param; + param.lower = args[1].operator bool(); + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + int num_inputs = 1; + int num_outputs = 0; + NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = reinterpret_cast(ndoutputs[0]); + }); } // namespace mxnet diff --git a/src/api/operator/numpy/linalg/np_qr.cc b/src/api/operator/numpy/linalg/np_qr.cc index e9c0ec5d66d3..359b5bca1394 100644 --- a/src/api/operator/numpy/linalg/np_qr.cc +++ b/src/api/operator/numpy/linalg/np_qr.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -27,18 +27,16 @@ namespace mxnet { -MXNET_REGISTER_API("_npi.qr") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { +MXNET_REGISTER_API("_npi.qr").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { using namespace runtime; const nnvm::Op* op = Op::Get("_npi_qr"); nnvm::NodeAttrs attrs; - attrs.op = op; - int num_inputs = 1; + attrs.op = op; + int num_inputs = 1; NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; - int num_outputs = 0; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); - *ret = ADT(0, {NDArrayHandle(ndoutputs[0]), - NDArrayHandle(ndoutputs[1])}); + int num_outputs = 0; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = ADT(0, {NDArrayHandle(ndoutputs[0]), NDArrayHandle(ndoutputs[1])}); }); } // namespace mxnet diff --git a/src/api/operator/numpy/linalg/np_slogdet.cc b/src/api/operator/numpy/linalg/np_slogdet.cc index 28c90265cdc7..8b2c36ddb6b0 100644 --- a/src/api/operator/numpy/linalg/np_slogdet.cc +++ b/src/api/operator/numpy/linalg/np_slogdet.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -28,17 +28,16 @@ namespace mxnet { MXNET_REGISTER_API("_npi.slogdet") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_slogdet"); - nnvm::NodeAttrs attrs; - attrs.op = op; - int num_inputs = 1; - int num_outputs = 0; - NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); - *ret = ADT(0, {NDArrayHandle(ndoutputs[0]), - NDArrayHandle(ndoutputs[1])}); -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_slogdet"); + nnvm::NodeAttrs attrs; + attrs.op = op; + int num_inputs = 1; + int num_outputs = 0; + NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = ADT(0, {NDArrayHandle(ndoutputs[0]), NDArrayHandle(ndoutputs[1])}); + }); } // namespace mxnet diff --git a/src/api/operator/numpy/linalg/np_solve.cc b/src/api/operator/numpy/linalg/np_solve.cc index d0d263881701..e3e13b06a3a9 100644 --- a/src/api/operator/numpy/linalg/np_solve.cc +++ b/src/api/operator/numpy/linalg/np_solve.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -27,17 +27,16 @@ namespace mxnet { -MXNET_REGISTER_API("_npi.solve") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { +MXNET_REGISTER_API("_npi.solve").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { using namespace runtime; const nnvm::Op* op = Op::Get("_npi_solve"); nnvm::NodeAttrs attrs; - attrs.op = op; - int num_inputs = 2; - int num_outputs = 0; + attrs.op = op; + int num_inputs = 2; + int num_outputs = 0; NDArray* inputs[] = {args[0].operator mxnet::NDArray*(), args[1].operator mxnet::NDArray*()}; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); - *ret = reinterpret_cast(ndoutputs[0]); + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = reinterpret_cast(ndoutputs[0]); }); } // namespace mxnet diff --git a/src/api/operator/numpy/linalg/np_tensorinv.cc b/src/api/operator/numpy/linalg/np_tensorinv.cc index c3062eee637f..9392f2e8c9bc 100644 --- a/src/api/operator/numpy/linalg/np_tensorinv.cc +++ b/src/api/operator/numpy/linalg/np_tensorinv.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -29,20 +29,20 @@ namespace mxnet { MXNET_REGISTER_API("_npi.tensorinv") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_tensorinv"); - nnvm::NodeAttrs attrs; - op::TensorinvParam param; - param.ind = args[1].operator int(); - attrs.parsed = param; - attrs.op = op; - SetAttrDict(&attrs); - int num_inputs = 1; - int num_outputs = 0; - NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); - *ret = reinterpret_cast(ndoutputs[0]); -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_tensorinv"); + nnvm::NodeAttrs attrs; + op::TensorinvParam param; + param.ind = args[1].operator int(); + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + int num_inputs = 1; + int num_outputs = 0; + NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = reinterpret_cast(ndoutputs[0]); + }); } // namespace mxnet diff --git a/src/api/operator/numpy/linalg/np_tensorsolve.cc b/src/api/operator/numpy/linalg/np_tensorsolve.cc index 5a50c22ea94e..9d1224063ee4 100644 --- a/src/api/operator/numpy/linalg/np_tensorsolve.cc +++ b/src/api/operator/numpy/linalg/np_tensorsolve.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -29,28 +29,28 @@ namespace mxnet { MXNET_REGISTER_API("_npi.tensorsolve") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_tensorsolve"); - nnvm::NodeAttrs attrs; - op::TensorsolveParam param; - if (args[2].type_code() == kNull) { - param.a_axes = Tuple(); - } else { - if (args[2].type_code() == kDLInt) { - param.a_axes = Tuple(1, args[2].operator int64_t()); - } else { - param.a_axes = Tuple(args[2].operator ObjectRef()); - } - } - attrs.parsed = param; - attrs.op = op; - SetAttrDict(&attrs); - int num_inputs = 2; - int num_outputs = 0; - NDArray* inputs[] = {args[0].operator mxnet::NDArray*(), args[1].operator mxnet::NDArray*()}; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); - *ret = reinterpret_cast(ndoutputs[0]); -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_tensorsolve"); + nnvm::NodeAttrs attrs; + op::TensorsolveParam param; + if (args[2].type_code() == kNull) { + param.a_axes = Tuple(); + } else { + if (args[2].type_code() == kDLInt) { + param.a_axes = Tuple(1, args[2].operator int64_t()); + } else { + param.a_axes = Tuple(args[2].operator ObjectRef()); + } + } + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + int num_inputs = 2; + int num_outputs = 0; + NDArray* inputs[] = {args[0].operator mxnet::NDArray*(), args[1].operator mxnet::NDArray*()}; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = reinterpret_cast(ndoutputs[0]); + }); } // namespace mxnet diff --git a/src/api/operator/numpy/np_bincount_op.cc b/src/api/operator/numpy/np_bincount_op.cc index 7be884aefb1a..27495e98182d 100644 --- a/src/api/operator/numpy/np_bincount_op.cc +++ b/src/api/operator/numpy/np_bincount_op.cc @@ -28,34 +28,35 @@ namespace mxnet { MXNET_REGISTER_API("_npi.bincount") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_bincount"); - nnvm::NodeAttrs attrs; - op::NumpyBincountParam param; + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_bincount"); + nnvm::NodeAttrs attrs; + op::NumpyBincountParam param; - int num_outputs = 0; - if (args[1].type_code() == kNull) { - param.minlength = args[2].operator int64_t(); - param.has_weights = false; - NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; - int num_inputs = 1; - attrs.parsed = param; - attrs.op = op; - SetAttrDict(&attrs); - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); - *ret = reinterpret_cast(ndoutputs[0]); - } else { - param.minlength = args[2].operator int64_t(); - param.has_weights = true; - NDArray* inputs[] = {args[0].operator mxnet::NDArray*(), args[1].operator mxnet::NDArray*()}; - int num_inputs = 2; - attrs.parsed = param; - attrs.op = op; - SetAttrDict(&attrs); - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); - *ret = reinterpret_cast(ndoutputs[0]); - } -}); + int num_outputs = 0; + if (args[1].type_code() == kNull) { + param.minlength = args[2].operator int64_t(); + param.has_weights = false; + NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; + int num_inputs = 1; + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = reinterpret_cast(ndoutputs[0]); + } else { + param.minlength = args[2].operator int64_t(); + param.has_weights = true; + NDArray* inputs[] = {args[0].operator mxnet::NDArray*(), + args[1].operator mxnet::NDArray*()}; + int num_inputs = 2; + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = reinterpret_cast(ndoutputs[0]); + } + }); } // namespace mxnet diff --git a/src/api/operator/numpy/np_broadcast_reduce_op_boolean.cc b/src/api/operator/numpy/np_broadcast_reduce_op_boolean.cc index c3e186195dca..f2494f0d5672 100644 --- a/src/api/operator/numpy/np_broadcast_reduce_op_boolean.cc +++ b/src/api/operator/numpy/np_broadcast_reduce_op_boolean.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -19,7 +19,8 @@ /*! * \file np_broadcast_reduce_op_boolean.cc - * \brief Implementation of the API of functions in src/operator/numpy/np_broadcast_reduce_op_boolean.cc + * \brief Implementation of the API of functions in + * src/operator/numpy/np_broadcast_reduce_op_boolean.cc */ #include #include @@ -28,16 +29,15 @@ namespace mxnet { -MXNET_REGISTER_API("_npi.all") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { +MXNET_REGISTER_API("_npi.all").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { using namespace runtime; const nnvm::Op* op = Op::Get("_npi_all"); nnvm::NodeAttrs attrs; op::NumpyReduceAxesBoolParam param; - NDArray* out = args[3].operator mxnet::NDArray*(); + NDArray* out = args[3].operator mxnet::NDArray*(); NDArray** outputs = out == nullptr ? nullptr : &out; - int num_outputs = out != nullptr; + int num_outputs = out != nullptr; if (args[1].type_code() == kNull) { param.axis = dmlc::nullopt; } else if (args[1].type_code() == kDLInt) { @@ -45,11 +45,11 @@ MXNET_REGISTER_API("_npi.all") } else { param.axis = Tuple(args[1].operator ObjectRef()); } - param.keepdims = args[2].operator bool(); + param.keepdims = args[2].operator bool(); NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; - int num_inputs = 1; - attrs.parsed = param; - attrs.op = op; + int num_inputs = 1; + attrs.parsed = param; + attrs.op = op; SetAttrDict(&attrs); auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs); if (out) { @@ -59,16 +59,15 @@ MXNET_REGISTER_API("_npi.all") } }); -MXNET_REGISTER_API("_npi.any") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { +MXNET_REGISTER_API("_npi.any").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { using namespace runtime; const nnvm::Op* op = Op::Get("_npi_any"); nnvm::NodeAttrs attrs; op::NumpyReduceAxesBoolParam param; - NDArray* out = args[3].operator mxnet::NDArray*(); + NDArray* out = args[3].operator mxnet::NDArray*(); NDArray** outputs = out == nullptr ? nullptr : &out; - int num_outputs = out != nullptr; + int num_outputs = out != nullptr; if (args[1].type_code() == kNull) { param.axis = dmlc::nullopt; } else if (args[1].type_code() == kDLInt) { @@ -76,11 +75,11 @@ MXNET_REGISTER_API("_npi.any") } else { param.axis = Tuple(args[1].operator ObjectRef()); } - param.keepdims = args[2].operator bool(); + param.keepdims = args[2].operator bool(); NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; - int num_inputs = 1; - attrs.parsed = param; - attrs.op = op; + int num_inputs = 1; + attrs.parsed = param; + attrs.op = op; SetAttrDict(&attrs); auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs); if (out) { diff --git a/src/api/operator/numpy/np_broadcast_reduce_op_index.cc b/src/api/operator/numpy/np_broadcast_reduce_op_index.cc index 83e16999417b..1d46ec037aef 100644 --- a/src/api/operator/numpy/np_broadcast_reduce_op_index.cc +++ b/src/api/operator/numpy/np_broadcast_reduce_op_index.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -30,69 +30,69 @@ namespace mxnet { MXNET_REGISTER_API("_npi.argmax") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_argmax"); - nnvm::NodeAttrs attrs; - op::ReduceAxisParam param; - // param.axis - if (args[1].type_code() == kNull) { - param.axis = dmlc::nullopt; - } else { - param.axis = args[1].operator int(); - } - // param.keepdims - param.keepdims = args[2].operator bool(); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_argmax"); + nnvm::NodeAttrs attrs; + op::ReduceAxisParam param; + // param.axis + if (args[1].type_code() == kNull) { + param.axis = dmlc::nullopt; + } else { + param.axis = args[1].operator int(); + } + // param.keepdims + param.keepdims = args[2].operator bool(); - attrs.parsed = param; - attrs.op = op; - SetAttrDict(&attrs); - // inputs - NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; - int num_inputs = 1; - // outputs - NDArray* out = args[3].operator mxnet::NDArray*(); - NDArray** outputs = out == nullptr ? nullptr : &out; - int num_outputs = out != nullptr; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs); - if (out) { - *ret = PythonArg(3); - } else { - *ret = reinterpret_cast(ndoutputs[0]); - } -}); + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + // inputs + NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; + int num_inputs = 1; + // outputs + NDArray* out = args[3].operator mxnet::NDArray*(); + NDArray** outputs = out == nullptr ? nullptr : &out; + int num_outputs = out != nullptr; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs); + if (out) { + *ret = PythonArg(3); + } else { + *ret = reinterpret_cast(ndoutputs[0]); + } + }); MXNET_REGISTER_API("_npi.argmin") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_argmin"); - nnvm::NodeAttrs attrs; - op::ReduceAxisParam param; - // param.axis - if (args[1].type_code() == kNull) { - param.axis = dmlc::nullopt; - } else { - param.axis = args[1].operator int(); - } - // param.keepdims - param.keepdims = args[2].operator bool(); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_argmin"); + nnvm::NodeAttrs attrs; + op::ReduceAxisParam param; + // param.axis + if (args[1].type_code() == kNull) { + param.axis = dmlc::nullopt; + } else { + param.axis = args[1].operator int(); + } + // param.keepdims + param.keepdims = args[2].operator bool(); - attrs.parsed = param; - attrs.op = op; - SetAttrDict(&attrs); - // inputs - NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; - int num_inputs = 1; - // outputs - NDArray* out = args[3].operator mxnet::NDArray*(); - NDArray** outputs = out == nullptr ? nullptr : &out; - int num_outputs = out != nullptr; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs); - if (out) { - *ret = PythonArg(3); - } else { - *ret = reinterpret_cast(ndoutputs[0]); - } -}); + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + // inputs + NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; + int num_inputs = 1; + // outputs + NDArray* out = args[3].operator mxnet::NDArray*(); + NDArray** outputs = out == nullptr ? nullptr : &out; + int num_outputs = out != nullptr; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs); + if (out) { + *ret = PythonArg(3); + } else { + *ret = reinterpret_cast(ndoutputs[0]); + } + }); } // namespace mxnet diff --git a/src/api/operator/numpy/np_broadcast_reduce_op_value.cc b/src/api/operator/numpy/np_broadcast_reduce_op_value.cc index 277bf4a65b42..f7238e8b24d2 100644 --- a/src/api/operator/numpy/np_broadcast_reduce_op_value.cc +++ b/src/api/operator/numpy/np_broadcast_reduce_op_value.cc @@ -30,29 +30,28 @@ namespace mxnet { MXNET_REGISTER_API("_npi.broadcast_to") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_broadcast_to"); - nnvm::NodeAttrs attrs; - op::BroadcastToParam param; - if (args[1].type_code() == kDLInt) { - param.shape = TShape(1, args[1].operator int64_t()); - } else { - param.shape = TShape(args[1].operator ObjectRef()); - } - attrs.parsed = param; - attrs.op = op; - SetAttrDict(&attrs); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_broadcast_to"); + nnvm::NodeAttrs attrs; + op::BroadcastToParam param; + if (args[1].type_code() == kDLInt) { + param.shape = TShape(1, args[1].operator int64_t()); + } else { + param.shape = TShape(args[1].operator ObjectRef()); + } + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); - int num_outputs = 0; - NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; - int num_inputs = 1; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); - *ret = ndoutputs[0]; -}); + int num_outputs = 0; + NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; + int num_inputs = 1; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = ndoutputs[0]; + }); -MXNET_REGISTER_API("_npi.sum") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { +MXNET_REGISTER_API("_npi.sum").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { using namespace runtime; const nnvm::Op* op = Op::Get("_npi_sum"); op::NumpyReduceAxesParam param; @@ -96,12 +95,12 @@ MXNET_REGISTER_API("_npi.sum") SetAttrDict(&attrs); NDArray* inputs[] = {args[0].operator NDArray*()}; - int num_inputs = 1; + int num_inputs = 1; NDArray* outputs[] = {args[5].operator NDArray*()}; - NDArray** out = (outputs[0] == nullptr) ? nullptr : outputs; - int num_outputs = (outputs[0] != nullptr); - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, out); + NDArray** out = (outputs[0] == nullptr) ? nullptr : outputs; + int num_outputs = (outputs[0] != nullptr); + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, out); if (out) { *ret = PythonArg(5); @@ -110,8 +109,7 @@ MXNET_REGISTER_API("_npi.sum") } }); -MXNET_REGISTER_API("_npi.mean") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { +MXNET_REGISTER_API("_npi.mean").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { using namespace runtime; const nnvm::Op* op = Op::Get("_npi_mean"); nnvm::NodeAttrs attrs; @@ -135,15 +133,15 @@ MXNET_REGISTER_API("_npi.mean") param.keepdims = args[3].operator bool(); } param.initial = dmlc::optional(); - attrs.parsed = param; - attrs.op = op; + attrs.parsed = param; + attrs.op = op; SetAttrDict(&attrs); - int num_inputs = 1; + int num_inputs = 1; NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; - NDArray* out = args[4].operator mxnet::NDArray*(); + NDArray* out = args[4].operator mxnet::NDArray*(); NDArray** outputs = out == nullptr ? nullptr : &out; - int num_outputs = out != nullptr; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs); + int num_outputs = out != nullptr; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs); if (out) { *ret = PythonArg(4); } else { @@ -151,8 +149,7 @@ MXNET_REGISTER_API("_npi.mean") } }); -MXNET_REGISTER_API("_npi.prod") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { +MXNET_REGISTER_API("_npi.prod").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { using namespace runtime; const nnvm::Op* op = Op::Get("_npi_prod"); nnvm::NodeAttrs attrs; @@ -180,14 +177,14 @@ MXNET_REGISTER_API("_npi.prod") param.initial = args[4].operator double(); } attrs.parsed = param; - attrs.op = op; + attrs.op = op; SetAttrDict(&attrs); - int num_inputs = 1; + int num_inputs = 1; NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; - NDArray* out = args[5].operator mxnet::NDArray*(); + NDArray* out = args[5].operator mxnet::NDArray*(); NDArray** outputs = out == nullptr ? nullptr : &out; - int num_outputs = out != nullptr; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs); + int num_outputs = out != nullptr; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs); if (out) { *ret = PythonArg(5); } else { @@ -195,16 +192,15 @@ MXNET_REGISTER_API("_npi.prod") } }); -MXNET_REGISTER_API("_npi.max") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { +MXNET_REGISTER_API("_npi.max").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { using namespace runtime; static const nnvm::Op* op = Op::Get("_npi_max"); nnvm::NodeAttrs attrs; op::NumpyReduceAxesNoDTypeParam param; - NDArray* out = args[3].operator mxnet::NDArray*(); + NDArray* out = args[3].operator mxnet::NDArray*(); NDArray** outputs = out == nullptr ? nullptr : &out; - int num_outputs = out != nullptr; + int num_outputs = out != nullptr; if (args[1].type_code() == kNull) { param.axis = dmlc::nullopt; } else if (args[1].type_code() == kDLInt) { @@ -212,11 +208,11 @@ MXNET_REGISTER_API("_npi.max") } else { param.axis = Tuple(args[1].operator ObjectRef()); } - param.keepdims = args[2].operator bool(); + param.keepdims = args[2].operator bool(); NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; - int num_inputs = 1; - attrs.parsed = param; - attrs.op = op; + int num_inputs = 1; + attrs.parsed = param; + attrs.op = op; SetAttrDict(&attrs); auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs); if (out) { @@ -226,16 +222,15 @@ MXNET_REGISTER_API("_npi.max") } }); -MXNET_REGISTER_API("_npi.min") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { +MXNET_REGISTER_API("_npi.min").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { using namespace runtime; static const nnvm::Op* op = Op::Get("_npi_min"); nnvm::NodeAttrs attrs; op::NumpyReduceAxesNoDTypeParam param; - NDArray* out = args[3].operator mxnet::NDArray*(); + NDArray* out = args[3].operator mxnet::NDArray*(); NDArray** outputs = out == nullptr ? nullptr : &out; - int num_outputs = out != nullptr; + int num_outputs = out != nullptr; if (args[1].type_code() == kNull) { param.axis = dmlc::nullopt; } else if (args[1].type_code() == kDLInt) { @@ -243,11 +238,11 @@ MXNET_REGISTER_API("_npi.min") } else { param.axis = Tuple(args[1].operator ObjectRef()); } - param.keepdims = args[2].operator bool(); + param.keepdims = args[2].operator bool(); NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; - int num_inputs = 1; - attrs.parsed = param; - attrs.op = op; + int num_inputs = 1; + attrs.parsed = param; + attrs.op = op; SetAttrDict(&attrs); auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs); if (out) { @@ -257,16 +252,15 @@ MXNET_REGISTER_API("_npi.min") } }); -MXNET_REGISTER_API("_npi.amax") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { +MXNET_REGISTER_API("_npi.amax").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { using namespace runtime; static const nnvm::Op* op = Op::Get("_npi_amax"); nnvm::NodeAttrs attrs; op::NumpyReduceAxesNoDTypeParam param; - NDArray* out = args[3].operator mxnet::NDArray*(); + NDArray* out = args[3].operator mxnet::NDArray*(); NDArray** outputs = out == nullptr ? nullptr : &out; - int num_outputs = out != nullptr; + int num_outputs = out != nullptr; if (args[1].type_code() == kNull) { param.axis = dmlc::nullopt; } else if (args[1].type_code() == kDLInt) { @@ -274,11 +268,11 @@ MXNET_REGISTER_API("_npi.amax") } else { param.axis = Tuple(args[1].operator ObjectRef()); } - param.keepdims = args[2].operator bool(); + param.keepdims = args[2].operator bool(); NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; - int num_inputs = 1; - attrs.parsed = param; - attrs.op = op; + int num_inputs = 1; + attrs.parsed = param; + attrs.op = op; SetAttrDict(&attrs); auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs); if (out) { @@ -288,16 +282,15 @@ MXNET_REGISTER_API("_npi.amax") } }); -MXNET_REGISTER_API("_npi.amin") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { +MXNET_REGISTER_API("_npi.amin").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { using namespace runtime; static const nnvm::Op* op = Op::Get("_npi_amin"); nnvm::NodeAttrs attrs; op::NumpyReduceAxesNoDTypeParam param; - NDArray* out = args[3].operator mxnet::NDArray*(); + NDArray* out = args[3].operator mxnet::NDArray*(); NDArray** outputs = out == nullptr ? nullptr : &out; - int num_outputs = out != nullptr; + int num_outputs = out != nullptr; if (args[1].type_code() == kNull) { param.axis = dmlc::nullopt; } else if (args[1].type_code() == kDLInt) { @@ -305,11 +298,11 @@ MXNET_REGISTER_API("_npi.amin") } else { param.axis = Tuple(args[1].operator ObjectRef()); } - param.keepdims = args[2].operator bool(); + param.keepdims = args[2].operator bool(); NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; - int num_inputs = 1; - attrs.parsed = param; - attrs.op = op; + int num_inputs = 1; + attrs.parsed = param; + attrs.op = op; SetAttrDict(&attrs); auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs); if (out) { diff --git a/src/api/operator/numpy/np_cross.cc b/src/api/operator/numpy/np_cross.cc index 0dd4644cad59..2bf9675148ca 100644 --- a/src/api/operator/numpy/np_cross.cc +++ b/src/api/operator/numpy/np_cross.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -28,23 +28,22 @@ namespace mxnet { -MXNET_REGISTER_API("_npi.cross") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { +MXNET_REGISTER_API("_npi.cross").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { using namespace runtime; nnvm::NodeAttrs attrs; const nnvm::Op* op = Op::Get("_npi_cross"); op::NumpyCrossParam param; - param.axisa = args[2].operator int(); - param.axisb = args[3].operator int(); - param.axisc = args[4].operator int(); - attrs.op = op; + param.axisa = args[2].operator int(); + param.axisb = args[3].operator int(); + param.axisc = args[4].operator int(); + attrs.op = op; attrs.parsed = param; SetAttrDict(&attrs); - int num_inputs = 2; + int num_inputs = 2; NDArray* inputs[] = {args[0].operator mxnet::NDArray*(), args[1].operator mxnet::NDArray*()}; - int num_outputs = 0; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); - *ret = reinterpret_cast(ndoutputs[0]); + int num_outputs = 0; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = reinterpret_cast(ndoutputs[0]); }); } // namespace mxnet diff --git a/src/api/operator/numpy/np_cumsum.cc b/src/api/operator/numpy/np_cumsum.cc index a0f68cca1b6b..227ac0531e0d 100644 --- a/src/api/operator/numpy/np_cumsum.cc +++ b/src/api/operator/numpy/np_cumsum.cc @@ -29,39 +29,39 @@ namespace mxnet { MXNET_REGISTER_API("_npi.cumsum") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - nnvm::NodeAttrs attrs; - const nnvm::Op* op = Op::Get("_npi_cumsum"); - op::CumsumParam param; - // axis - if (args[1].type_code() == kNull) { - param.axis = dmlc::nullopt; - } else { - param.axis = args[1].operator int(); - } - // dtype - if (args[2].type_code() == kNull) { - param.dtype = dmlc::nullopt; - } else { - param.dtype = String2MXNetTypeWithBool(args[2].operator std::string()); - } - attrs.parsed = param; - attrs.op = op; - SetAttrDict(&attrs); - // inputs - NDArray* inputs[] = {args[0].operator NDArray*()}; - int num_inputs = 1; - // outputs - NDArray* outputs[] = {args[3].operator NDArray*()}; - NDArray** out = outputs[0] == nullptr ? nullptr : outputs; - int num_outputs = outputs[0] != nullptr; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, out); - if (out) { - *ret = PythonArg(3); - } else { - *ret = reinterpret_cast(ndoutputs[0]); - } -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + nnvm::NodeAttrs attrs; + const nnvm::Op* op = Op::Get("_npi_cumsum"); + op::CumsumParam param; + // axis + if (args[1].type_code() == kNull) { + param.axis = dmlc::nullopt; + } else { + param.axis = args[1].operator int(); + } + // dtype + if (args[2].type_code() == kNull) { + param.dtype = dmlc::nullopt; + } else { + param.dtype = String2MXNetTypeWithBool(args[2].operator std::string()); + } + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + // inputs + NDArray* inputs[] = {args[0].operator NDArray*()}; + int num_inputs = 1; + // outputs + NDArray* outputs[] = {args[3].operator NDArray*()}; + NDArray** out = outputs[0] == nullptr ? nullptr : outputs; + int num_outputs = outputs[0] != nullptr; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, out); + if (out) { + *ret = PythonArg(3); + } else { + *ret = reinterpret_cast(ndoutputs[0]); + } + }); } // namespace mxnet diff --git a/src/api/operator/numpy/np_delete_op.cc b/src/api/operator/numpy/np_delete_op.cc index 76b2a709a270..dd5746994a29 100644 --- a/src/api/operator/numpy/np_delete_op.cc +++ b/src/api/operator/numpy/np_delete_op.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -30,73 +30,72 @@ namespace mxnet { MXNET_REGISTER_API("_npi.delete") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - static const nnvm::Op* op = Op::Get("_npi_delete"); - nnvm::NodeAttrs attrs; - op::NumpyDeleteParam param; - int num_inputs = 0; - param.start = dmlc::nullopt; - param.step = dmlc::nullopt; - param.stop = dmlc::nullopt; - param.int_ind = dmlc::nullopt; - param.axis = dmlc::nullopt; - if (args.num_args == 3) { - if (args[1].type_code() == kDLInt || - args[1].type_code() == kDLFloat) { - if (args[1].type_code() == kDLInt) { - param.int_ind = args[1].operator int64_t(); - } else if (args[1].type_code() == kDLFloat) { - param.int_ind = static_cast(args[1].operator double()); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + static const nnvm::Op* op = Op::Get("_npi_delete"); + nnvm::NodeAttrs attrs; + op::NumpyDeleteParam param; + int num_inputs = 0; + param.start = dmlc::nullopt; + param.step = dmlc::nullopt; + param.stop = dmlc::nullopt; + param.int_ind = dmlc::nullopt; + param.axis = dmlc::nullopt; + if (args.num_args == 3) { + if (args[1].type_code() == kDLInt || args[1].type_code() == kDLFloat) { + if (args[1].type_code() == kDLInt) { + param.int_ind = args[1].operator int64_t(); + } else if (args[1].type_code() == kDLFloat) { + param.int_ind = static_cast(args[1].operator double()); + } + if (args[2].type_code() == kDLInt) { + param.axis = args[2].operator int(); + } else if (args[2].type_code() == kDLFloat) { + param.axis = static_cast(args[2].operator double()); + } + num_inputs = 1; + } else { + if (args[2].type_code() == kDLInt) { + param.axis = args[2].operator int(); + } else if (args[2].type_code() == kDLFloat) { + param.axis = static_cast(args[2].operator double()); + } + num_inputs = 2; + } + } else { + num_inputs = 1; + if (args[1].type_code() == kDLInt) { + param.start = args[1].operator int64_t(); + } else if (args[1].type_code() == kDLFloat) { + param.start = static_cast(args[1].operator double()); + } + if (args[2].type_code() == kDLInt) { + param.stop = args[2].operator int64_t(); + } else if (args[2].type_code() == kDLFloat) { + param.stop = static_cast(args[2].operator double()); + } + if (args[3].type_code() == kDLInt) { + param.step = args[3].operator int64_t(); + } else if (args[3].type_code() == kDLFloat) { + param.step = static_cast(args[3].operator double()); + } + if (args[4].type_code() == kDLInt) { + param.axis = args[4].operator int(); + } else if (args[4].type_code() == kDLFloat) { + param.axis = static_cast(args[4].operator double()); + } } - if (args[2].type_code() == kDLInt) { - param.axis = args[2].operator int(); - } else if (args[2].type_code() == kDLFloat) { - param.axis = static_cast(args[2].operator double()); + std::vector inputs; + inputs.reserve(num_inputs); + for (int i = 0; i < num_inputs; ++i) { + inputs.push_back(args[i].operator mxnet::NDArray*()); } - num_inputs = 1; - } else { - if (args[2].type_code() == kDLInt) { - param.axis = args[2].operator int(); - } else if (args[2].type_code() == kDLFloat) { - param.axis = static_cast(args[2].operator double()); - } - num_inputs = 2; - } - } else { - num_inputs = 1; - if (args[1].type_code() == kDLInt) { - param.start = args[1].operator int64_t(); - } else if (args[1].type_code() == kDLFloat) { - param.start = static_cast(args[1].operator double()); - } - if (args[2].type_code() == kDLInt) { - param.stop = args[2].operator int64_t(); - } else if (args[2].type_code() == kDLFloat) { - param.stop = static_cast(args[2].operator double()); - } - if (args[3].type_code() == kDLInt) { - param.step = args[3].operator int64_t(); - } else if (args[3].type_code() == kDLFloat) { - param.step = static_cast(args[3].operator double()); - } - if (args[4].type_code() == kDLInt) { - param.axis = args[4].operator int(); - } else if (args[4].type_code() == kDLFloat) { - param.axis = static_cast(args[4].operator double()); - } - } - std::vector inputs; - inputs.reserve(num_inputs); - for (int i = 0; i < num_inputs; ++i) { - inputs.push_back(args[i].operator mxnet::NDArray*()); - } - attrs.parsed = param; - attrs.op = op; - SetAttrDict(&attrs); - int num_outputs = 0; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, nullptr); - *ret = ndoutputs[0]; -}); + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + int num_outputs = 0; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, nullptr); + *ret = ndoutputs[0]; + }); } // namespace mxnet diff --git a/src/api/operator/numpy/np_diff_op.cc b/src/api/operator/numpy/np_diff_op.cc index 7be5b804eade..a89063b93eb2 100644 --- a/src/api/operator/numpy/np_diff_op.cc +++ b/src/api/operator/numpy/np_diff_op.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -27,25 +27,24 @@ namespace mxnet { -MXNET_REGISTER_API("_npi.diff") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { +MXNET_REGISTER_API("_npi.diff").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { using namespace runtime; const nnvm::Op* op = Op::Get("_npi_diff"); nnvm::NodeAttrs attrs; op::DiffParam param; - param.n = args[1].operator int(); + param.n = args[1].operator int(); param.axis = args[2].operator int(); // we directly copy DiffParam, which is trivially-copyable attrs.parsed = param; - attrs.op = op; + attrs.op = op; SetAttrDict(&attrs); - int num_outputs = 0; + int num_outputs = 0; NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; - int num_inputs = 1; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); - *ret = ndoutputs[0]; + int num_inputs = 1; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = ndoutputs[0]; }); } // namespace mxnet diff --git a/src/api/operator/numpy/np_dot_op.cc b/src/api/operator/numpy/np_dot_op.cc index 66a0f1116052..1ce67c40c5d5 100644 --- a/src/api/operator/numpy/np_dot_op.cc +++ b/src/api/operator/numpy/np_dot_op.cc @@ -27,19 +27,17 @@ namespace mxnet { -MXNET_REGISTER_API("_npi.dot") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { +MXNET_REGISTER_API("_npi.dot").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { using namespace runtime; const nnvm::Op* op = Op::Get("_npi_dot"); nnvm::NodeAttrs attrs; - attrs.op = op; - NDArray* inputs[] = {args[0].operator mxnet::NDArray*(), - args[1].operator mxnet::NDArray*()}; - NDArray* out = args[2].operator NDArray*(); + attrs.op = op; + NDArray* inputs[] = {args[0].operator mxnet::NDArray*(), args[1].operator mxnet::NDArray*()}; + NDArray* out = args[2].operator NDArray*(); NDArray** outputs = out == nullptr ? nullptr : &out; - int num_inputs = 2; - int num_outputs = out != nullptr; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs); + int num_inputs = 2; + int num_outputs = out != nullptr; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs); if (outputs) { *ret = PythonArg(2); } else { diff --git a/src/api/operator/numpy/np_ediff1d_op.cc b/src/api/operator/numpy/np_ediff1d_op.cc index 64e15064889a..ee88eac54908 100644 --- a/src/api/operator/numpy/np_ediff1d_op.cc +++ b/src/api/operator/numpy/np_ediff1d_op.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -28,48 +28,48 @@ namespace mxnet { MXNET_REGISTER_API("_npi.ediff1d") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_ediff1d"); - nnvm::NodeAttrs attrs; - op::EDiff1DParam param; - int num_inputs = 1; - NDArray* inputs[3]; - inputs[0] = args[0].operator mxnet::NDArray*(); - // the order of `to_end` and `to_begin` array in the backend is different from the front-end - if (args[2].type_code() == kDLFloat || args[2].type_code() == kDLInt) { - param.to_begin_scalar = args[2].operator double(); - param.to_begin_arr_given = false; - } else if (args[2].type_code() == kNull) { - param.to_begin_scalar = dmlc::nullopt; - param.to_begin_arr_given = false; - } else { - param.to_begin_scalar = dmlc::nullopt; - param.to_begin_arr_given = true; - inputs[num_inputs] = args[2].operator mxnet::NDArray*(); - num_inputs++; - } + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_ediff1d"); + nnvm::NodeAttrs attrs; + op::EDiff1DParam param; + int num_inputs = 1; + NDArray* inputs[3]; + inputs[0] = args[0].operator mxnet::NDArray*(); + // the order of `to_end` and `to_begin` array in the backend is different from the front-end + if (args[2].type_code() == kDLFloat || args[2].type_code() == kDLInt) { + param.to_begin_scalar = args[2].operator double(); + param.to_begin_arr_given = false; + } else if (args[2].type_code() == kNull) { + param.to_begin_scalar = dmlc::nullopt; + param.to_begin_arr_given = false; + } else { + param.to_begin_scalar = dmlc::nullopt; + param.to_begin_arr_given = true; + inputs[num_inputs] = args[2].operator mxnet::NDArray*(); + num_inputs++; + } - if (args[1].type_code() == kDLFloat || args[1].type_code() == kDLInt) { - param.to_end_scalar = args[1].operator double(); - param.to_end_arr_given = false; - } else if (args[1].type_code() == kNull) { - param.to_end_scalar = dmlc::nullopt; - param.to_end_arr_given = false; - } else { - param.to_end_scalar = dmlc::nullopt; - param.to_end_arr_given = true; - inputs[num_inputs] = args[1].operator mxnet::NDArray*(); - num_inputs++; - } + if (args[1].type_code() == kDLFloat || args[1].type_code() == kDLInt) { + param.to_end_scalar = args[1].operator double(); + param.to_end_arr_given = false; + } else if (args[1].type_code() == kNull) { + param.to_end_scalar = dmlc::nullopt; + param.to_end_arr_given = false; + } else { + param.to_end_scalar = dmlc::nullopt; + param.to_end_arr_given = true; + inputs[num_inputs] = args[1].operator mxnet::NDArray*(); + num_inputs++; + } - attrs.parsed = param; - attrs.op = op; - SetAttrDict(&attrs); + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); - int num_outputs = 0; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); - *ret = ndoutputs[0]; -}); + int num_outputs = 0; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = ndoutputs[0]; + }); } // namespace mxnet diff --git a/src/api/operator/numpy/np_einsum_op.cc b/src/api/operator/numpy/np_einsum_op.cc index 900739ac10ab..8c96297a4433 100644 --- a/src/api/operator/numpy/np_einsum_op.cc +++ b/src/api/operator/numpy/np_einsum_op.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -30,42 +30,42 @@ namespace mxnet { MXNET_REGISTER_API("_npi.einsum") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_einsum"); - nnvm::NodeAttrs attrs; - op::NumpyEinsumParam param; - int args_size = args.size(); - // param.num_args - param.num_args = args_size - 3; - // param.subscripts - param.subscripts = args[args_size - 3].operator std::string(); - // param.optimize - param.optimize = args[args_size - 1].operator int(); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_einsum"); + nnvm::NodeAttrs attrs; + op::NumpyEinsumParam param; + int args_size = args.size(); + // param.num_args + param.num_args = args_size - 3; + // param.subscripts + param.subscripts = args[args_size - 3].operator std::string(); + // param.optimize + param.optimize = args[args_size - 1].operator int(); - attrs.parsed = param; - attrs.op = op; - SetAttrDict(&attrs); + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); - // inputs - int num_inputs = param.num_args; - std::vector inputs_vec(num_inputs, nullptr); - for (int i = 0; i < num_inputs; ++i) { - inputs_vec[i] = args[i].operator mxnet::NDArray*(); - } - NDArray** inputs = inputs_vec.data(); + // inputs + int num_inputs = param.num_args; + std::vector inputs_vec(num_inputs, nullptr); + for (int i = 0; i < num_inputs; ++i) { + inputs_vec[i] = args[i].operator mxnet::NDArray*(); + } + NDArray** inputs = inputs_vec.data(); - // outputs - NDArray* out = args[args_size - 2].operator mxnet::NDArray*(); - NDArray** outputs = out == nullptr ? nullptr : &out; - int num_outputs = out != nullptr; + // outputs + NDArray* out = args[args_size - 2].operator mxnet::NDArray*(); + NDArray** outputs = out == nullptr ? nullptr : &out; + int num_outputs = out != nullptr; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs); - if (out) { - *ret = PythonArg(args_size - 2); - } else { - *ret = reinterpret_cast(ndoutputs[0]); - } -}); + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs); + if (out) { + *ret = PythonArg(args_size - 2); + } else { + *ret = reinterpret_cast(ndoutputs[0]); + } + }); } // namespace mxnet diff --git a/src/api/operator/numpy/np_elemwise_broadcast_logic_op.cc b/src/api/operator/numpy/np_elemwise_broadcast_logic_op.cc index 224843358526..20e0b83b750f 100644 --- a/src/api/operator/numpy/np_elemwise_broadcast_logic_op.cc +++ b/src/api/operator/numpy/np_elemwise_broadcast_logic_op.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -19,7 +19,8 @@ /*! * \file np_elemwise_broadcast_logic_op.cc - * \brief Implementation of the API of functions in src/operator/numpy/np_elemwise_broadcast_logic_op.cc + * \brief Implementation of the API of functions in + * src/operator/numpy/np_elemwise_broadcast_logic_op.cc */ #include #include @@ -28,27 +29,27 @@ namespace mxnet { -MXNET_REGISTER_API("_npi.equal") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { +MXNET_REGISTER_API("_npi.equal").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_equal"); + const nnvm::Op* op = Op::Get("_npi_equal"); const nnvm::Op* op_scalar = Op::Get("_npi_equal_scalar"); UFuncHelper(args, ret, op, op_scalar, nullptr); }); MXNET_REGISTER_API("_npi.not_equal") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_not_equal"); - const nnvm::Op* op_scalar = Op::Get("_npi_not_equal_scalar"); - UFuncHelper(args, ret, op, op_scalar, nullptr); -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_not_equal"); + const nnvm::Op* op_scalar = Op::Get("_npi_not_equal_scalar"); + UFuncHelper(args, ret, op, op_scalar, nullptr); + }); -void SetUFuncHelper(runtime::MXNetArgs args, runtime::MXNetRetValue* ret, - const nnvm::Op* op, const nnvm::Op* op_scalar, - const nnvm::Op* op_rscalar) { - if (args[0].type_code() == kNDArrayHandle && - args[1].type_code() == kNDArrayHandle) { +void SetUFuncHelper(runtime::MXNetArgs args, + runtime::MXNetRetValue* ret, + const nnvm::Op* op, + const nnvm::Op* op_scalar, + const nnvm::Op* op_rscalar) { + if (args[0].type_code() == kNDArrayHandle && args[1].type_code() == kNDArrayHandle) { UFuncHelper(args, ret, op, nullptr, nullptr); } else if (args[0].type_code() == kNDArrayHandle) { UFuncHelper(args, ret, nullptr, op_scalar, nullptr); @@ -58,39 +59,38 @@ void SetUFuncHelper(runtime::MXNetArgs args, runtime::MXNetRetValue* ret, } MXNET_REGISTER_API("_npi.greater") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_greater"); - const nnvm::Op* op_scalar = Op::Get("_npi_greater_scalar"); - const nnvm::Op* op_rscalar = Op::Get("_npi_less_scalar"); - SetUFuncHelper(args, ret, op, op_scalar, op_rscalar); -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_greater"); + const nnvm::Op* op_scalar = Op::Get("_npi_greater_scalar"); + const nnvm::Op* op_rscalar = Op::Get("_npi_less_scalar"); + SetUFuncHelper(args, ret, op, op_scalar, op_rscalar); + }); -MXNET_REGISTER_API("_npi.less") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { +MXNET_REGISTER_API("_npi.less").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_less"); - const nnvm::Op* op_scalar = Op::Get("_npi_less_scalar"); + const nnvm::Op* op = Op::Get("_npi_less"); + const nnvm::Op* op_scalar = Op::Get("_npi_less_scalar"); const nnvm::Op* op_rscalar = Op::Get("_npi_greater_scalar"); SetUFuncHelper(args, ret, op, op_scalar, op_rscalar); }); MXNET_REGISTER_API("_npi.greater_equal") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_greater_equal"); - const nnvm::Op* op_scalar = Op::Get("_npi_greater_equal_scalar"); - const nnvm::Op* op_rscalar = Op::Get("_npi_less_equal_scalar"); - SetUFuncHelper(args, ret, op, op_scalar, op_rscalar); -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_greater_equal"); + const nnvm::Op* op_scalar = Op::Get("_npi_greater_equal_scalar"); + const nnvm::Op* op_rscalar = Op::Get("_npi_less_equal_scalar"); + SetUFuncHelper(args, ret, op, op_scalar, op_rscalar); + }); MXNET_REGISTER_API("_npi.less_equal") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_less_equal"); - const nnvm::Op* op_scalar = Op::Get("_npi_less_equal_scalar"); - const nnvm::Op* op_rscalar = Op::Get("_npi_greater_equal_scalar"); - SetUFuncHelper(args, ret, op, op_scalar, op_rscalar); -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_less_equal"); + const nnvm::Op* op_scalar = Op::Get("_npi_less_equal_scalar"); + const nnvm::Op* op_rscalar = Op::Get("_npi_greater_equal_scalar"); + SetUFuncHelper(args, ret, op, op_scalar, op_rscalar); + }); } // namespace mxnet diff --git a/src/api/operator/numpy/np_elemwise_broadcast_op.cc b/src/api/operator/numpy/np_elemwise_broadcast_op.cc index a411b067f1c0..184a4e241eff 100644 --- a/src/api/operator/numpy/np_elemwise_broadcast_op.cc +++ b/src/api/operator/numpy/np_elemwise_broadcast_op.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -28,153 +28,146 @@ namespace mxnet { -MXNET_REGISTER_API("_npi.add") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { +MXNET_REGISTER_API("_npi.add").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_add"); + const nnvm::Op* op = Op::Get("_npi_add"); const nnvm::Op* op_scalar = Op::Get("_npi_add_scalar"); UFuncHelper(args, ret, op, op_scalar, nullptr); }); MXNET_REGISTER_API("_npi.subtract") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_subtract"); - const nnvm::Op* op_scalar = Op::Get("_npi_subtract_scalar"); - const nnvm::Op* op_rscalar = Op::Get("_npi_rsubtract_scalar"); - UFuncHelper(args, ret, op, op_scalar, op_rscalar); -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_subtract"); + const nnvm::Op* op_scalar = Op::Get("_npi_subtract_scalar"); + const nnvm::Op* op_rscalar = Op::Get("_npi_rsubtract_scalar"); + UFuncHelper(args, ret, op, op_scalar, op_rscalar); + }); MXNET_REGISTER_API("_npi.multiply") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_multiply"); - const nnvm::Op* op_scalar = Op::Get("_npi_multiply_scalar"); - UFuncHelper(args, ret, op, op_scalar, nullptr); -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_multiply"); + const nnvm::Op* op_scalar = Op::Get("_npi_multiply_scalar"); + UFuncHelper(args, ret, op, op_scalar, nullptr); + }); MXNET_REGISTER_API("_npi.true_divide") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_true_divide"); - const nnvm::Op* op_scalar = Op::Get("_npi_true_divide_scalar"); - const nnvm::Op* op_rscalar = Op::Get("_npi_rtrue_divide_scalar"); - UFuncHelper(args, ret, op, op_scalar, op_rscalar); -}); - -MXNET_REGISTER_API("_npi.mod") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_mod"); - const nnvm::Op* op_scalar = Op::Get("_npi_mod_scalar"); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_true_divide"); + const nnvm::Op* op_scalar = Op::Get("_npi_true_divide_scalar"); + const nnvm::Op* op_rscalar = Op::Get("_npi_rtrue_divide_scalar"); + UFuncHelper(args, ret, op, op_scalar, op_rscalar); + }); + +MXNET_REGISTER_API("_npi.mod").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_mod"); + const nnvm::Op* op_scalar = Op::Get("_npi_mod_scalar"); const nnvm::Op* op_rscalar = Op::Get("_npi_rmod_scalar"); UFuncHelper(args, ret, op, op_scalar, op_rscalar); }); -MXNET_REGISTER_API("_npi.power") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { +MXNET_REGISTER_API("_npi.power").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_power"); - const nnvm::Op* op_scalar = Op::Get("_npi_power_scalar"); + const nnvm::Op* op = Op::Get("_npi_power"); + const nnvm::Op* op_scalar = Op::Get("_npi_power_scalar"); const nnvm::Op* op_rscalar = Op::Get("_npi_rpower_scalar"); UFuncHelper(args, ret, op, op_scalar, op_rscalar); }); -MXNET_REGISTER_API("_npi.lcm") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { +MXNET_REGISTER_API("_npi.lcm").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_lcm"); + const nnvm::Op* op = Op::Get("_npi_lcm"); const nnvm::Op* op_scalar = Op::Get("_npi_lcm_scalar"); UFuncHelper(args, ret, op, op_scalar, nullptr); }); -MXNET_REGISTER_API("_npi.gcd") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { +MXNET_REGISTER_API("_npi.gcd").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_gcd"); + const nnvm::Op* op = Op::Get("_npi_gcd"); const nnvm::Op* op_scalar = Op::Get("_npi_gcd_scalar"); UFuncHelper(args, ret, op, op_scalar, nullptr); }); MXNET_REGISTER_API("_npi.logical_and") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_logical_and"); - const nnvm::Op* op_scalar = Op::Get("_npi_logical_and_scalar"); - UFuncHelper(args, ret, op, op_scalar, nullptr); -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_logical_and"); + const nnvm::Op* op_scalar = Op::Get("_npi_logical_and_scalar"); + UFuncHelper(args, ret, op, op_scalar, nullptr); + }); MXNET_REGISTER_API("_npi.logical_or") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_logical_or"); - const nnvm::Op* op_scalar = Op::Get("_npi_logical_or_scalar"); - UFuncHelper(args, ret, op, op_scalar, nullptr); -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_logical_or"); + const nnvm::Op* op_scalar = Op::Get("_npi_logical_or_scalar"); + UFuncHelper(args, ret, op, op_scalar, nullptr); + }); MXNET_REGISTER_API("_npi.logical_xor") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_logical_xor"); - const nnvm::Op* op_scalar = Op::Get("_npi_logical_xor_scalar"); - UFuncHelper(args, ret, op, op_scalar, nullptr); -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_logical_xor"); + const nnvm::Op* op_scalar = Op::Get("_npi_logical_xor_scalar"); + UFuncHelper(args, ret, op, op_scalar, nullptr); + }); MXNET_REGISTER_API("_npi.bitwise_or") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_bitwise_or"); - const nnvm::Op* op_scalar = Op::Get("_npi_bitwise_or_scalar"); - UFuncHelper(args, ret, op, op_scalar, nullptr); -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_bitwise_or"); + const nnvm::Op* op_scalar = Op::Get("_npi_bitwise_or_scalar"); + UFuncHelper(args, ret, op, op_scalar, nullptr); + }); MXNET_REGISTER_API("_npi.bitwise_xor") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_bitwise_xor"); - const nnvm::Op* op_scalar = Op::Get("_npi_bitwise_xor_scalar"); - UFuncHelper(args, ret, op, op_scalar, nullptr); -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_bitwise_xor"); + const nnvm::Op* op_scalar = Op::Get("_npi_bitwise_xor_scalar"); + UFuncHelper(args, ret, op, op_scalar, nullptr); + }); MXNET_REGISTER_API("_npi.bitwise_and") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_bitwise_and"); - const nnvm::Op* op_scalar = Op::Get("_npi_bitwise_and_scalar"); - UFuncHelper(args, ret, op, op_scalar, nullptr); -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_bitwise_and"); + const nnvm::Op* op_scalar = Op::Get("_npi_bitwise_and_scalar"); + UFuncHelper(args, ret, op, op_scalar, nullptr); + }); MXNET_REGISTER_API("_npi.copysign") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_copysign"); - const nnvm::Op* op_scalar = Op::Get("_npi_copysign_scalar"); - const nnvm::Op* op_rscalar = Op::Get("_npi_rcopysign_scalar"); - UFuncHelper(args, ret, op, op_scalar, op_rscalar); -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_copysign"); + const nnvm::Op* op_scalar = Op::Get("_npi_copysign_scalar"); + const nnvm::Op* op_rscalar = Op::Get("_npi_rcopysign_scalar"); + UFuncHelper(args, ret, op, op_scalar, op_rscalar); + }); MXNET_REGISTER_API("_npi.arctan2") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_arctan2"); - const nnvm::Op* op_scalar = Op::Get("_npi_arctan2_scalar"); - const nnvm::Op* op_rscalar = Op::Get("_npi_rarctan2_scalar"); - UFuncHelper(args, ret, op, op_scalar, op_rscalar); -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_arctan2"); + const nnvm::Op* op_scalar = Op::Get("_npi_arctan2_scalar"); + const nnvm::Op* op_rscalar = Op::Get("_npi_rarctan2_scalar"); + UFuncHelper(args, ret, op, op_scalar, op_rscalar); + }); -MXNET_REGISTER_API("_npi.hypot") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { +MXNET_REGISTER_API("_npi.hypot").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_hypot"); + const nnvm::Op* op = Op::Get("_npi_hypot"); const nnvm::Op* op_scalar = Op::Get("_npi_hypot_scalar"); UFuncHelper(args, ret, op, op_scalar, nullptr); }); -MXNET_REGISTER_API("_npi.ldexp") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { +MXNET_REGISTER_API("_npi.ldexp").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_ldexp"); - const nnvm::Op* op_scalar = Op::Get("_npi_ldexp_scalar"); + const nnvm::Op* op = Op::Get("_npi_ldexp"); + const nnvm::Op* op_scalar = Op::Get("_npi_ldexp_scalar"); const nnvm::Op* op_rscalar = Op::Get("_npi_rldexp_scalar"); UFuncHelper(args, ret, op, op_scalar, op_rscalar); }); diff --git a/src/api/operator/numpy/np_elemwise_broadcast_op_extended_sec.cc b/src/api/operator/numpy/np_elemwise_broadcast_op_extended_sec.cc index 248af4dd6e3e..1367e0d28830 100644 --- a/src/api/operator/numpy/np_elemwise_broadcast_op_extended_sec.cc +++ b/src/api/operator/numpy/np_elemwise_broadcast_op_extended_sec.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -19,7 +19,8 @@ /*! * \file np_elemwise_broadcast_op_extended_sec.cc - * \brief Implementation of the API of functions in src/operator/numpy/np_elemwise_broadcast_op_extended_sec.cc + * \brief Implementation of the API of functions in + * src/operator/numpy/np_elemwise_broadcast_op_extended_sec.cc */ #include #include @@ -28,27 +29,24 @@ namespace mxnet { -MXNET_REGISTER_API("_npi.fmax") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { +MXNET_REGISTER_API("_npi.fmax").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_fmax"); + const nnvm::Op* op = Op::Get("_npi_fmax"); const nnvm::Op* op_scalar = Op::Get("_npi_fmax_scalar"); UFuncHelper(args, ret, op, op_scalar, nullptr); }); -MXNET_REGISTER_API("_npi.fmin") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { +MXNET_REGISTER_API("_npi.fmin").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_fmin"); + const nnvm::Op* op = Op::Get("_npi_fmin"); const nnvm::Op* op_scalar = Op::Get("_npi_fmin_scalar"); UFuncHelper(args, ret, op, op_scalar, nullptr); }); -MXNET_REGISTER_API("_npi.fmod") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { +MXNET_REGISTER_API("_npi.fmod").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_fmod"); - const nnvm::Op* op_scalar = Op::Get("_npi_fmod_scalar"); + const nnvm::Op* op = Op::Get("_npi_fmod"); + const nnvm::Op* op_scalar = Op::Get("_npi_fmod_scalar"); const nnvm::Op* op_rscalar = Op::Get("_npi_rfmod_scalar"); UFuncHelper(args, ret, op, op_scalar, op_rscalar); }); diff --git a/src/api/operator/numpy/np_elemwise_unary_op_basic.cc b/src/api/operator/numpy/np_elemwise_unary_op_basic.cc index 7d2b4bb66eb0..be5afcfed2c0 100644 --- a/src/api/operator/numpy/np_elemwise_unary_op_basic.cc +++ b/src/api/operator/numpy/np_elemwise_unary_op_basic.cc @@ -29,12 +29,12 @@ namespace mxnet { -#define MXNET_REGISTER_UNARY_API(op_name) \ -MXNET_REGISTER_API("_npi." #op_name) \ -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { \ - const nnvm::Op* op = Op::Get("_npi_" #op_name); \ - UFuncHelper(args, ret, op); \ -}) +#define MXNET_REGISTER_UNARY_API(op_name) \ + MXNET_REGISTER_API("_npi." #op_name) \ + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { \ + const nnvm::Op* op = Op::Get("_npi_" #op_name); \ + UFuncHelper(args, ret, op); \ + }) MXNET_REGISTER_UNARY_API(negative); MXNET_REGISTER_UNARY_API(reciprocal); @@ -72,18 +72,18 @@ MXNET_REGISTER_UNARY_API(radians); #if MXNET_USE_TVM_OP MXNET_REGISTER_UNARY_API(rad2deg); // from src/operator/contrib/tvmop/ufunc.cc MXNET_REGISTER_UNARY_API(deg2rad); // from src/operator/contrib/tvmop/ufunc.cc -#else // MXNET_USE_TVM_OP +#else // MXNET_USE_TVM_OP MXNET_REGISTER_API("_npi.rad2deg") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - const nnvm::Op* op = Op::Get("_npi_degrees"); - UFuncHelper(args, ret, op); -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + const nnvm::Op* op = Op::Get("_npi_degrees"); + UFuncHelper(args, ret, op); + }); MXNET_REGISTER_API("_npi.deg2rad") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - const nnvm::Op* op = Op::Get("_npi_radians"); - UFuncHelper(args, ret, op); -}); -#endif // MXNET_USE_TVM_OP + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + const nnvm::Op* op = Op::Get("_npi_radians"); + UFuncHelper(args, ret, op); + }); +#endif // MXNET_USE_TVM_OP MXNET_REGISTER_UNARY_API(sinh); MXNET_REGISTER_UNARY_API(cosh); MXNET_REGISTER_UNARY_API(tanh); @@ -92,39 +92,38 @@ MXNET_REGISTER_UNARY_API(arccosh); MXNET_REGISTER_UNARY_API(arctanh); MXNET_REGISTER_API("_npi.around") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_around"); - nnvm::NodeAttrs attrs; - op::AroundParam param; - param.decimals = args[1].operator int64_t(); - attrs.parsed = param; - attrs.op = op; - SetAttrDict(&attrs); - int num_inputs = 1; - NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; - NDArray* out = args[2].operator mxnet::NDArray*(); - NDArray** outputs = out == nullptr ? nullptr : &out; - int num_outputs = out != nullptr; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs); - if (out) { - *ret = PythonArg(2); - } else { - *ret = ndoutputs[0]; - } -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_around"); + nnvm::NodeAttrs attrs; + op::AroundParam param; + param.decimals = args[1].operator int64_t(); + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + int num_inputs = 1; + NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; + NDArray* out = args[2].operator mxnet::NDArray*(); + NDArray** outputs = out == nullptr ? nullptr : &out; + int num_outputs = out != nullptr; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs); + if (out) { + *ret = PythonArg(2); + } else { + *ret = ndoutputs[0]; + } + }); -MXNET_REGISTER_API("_npi.copy") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { +MXNET_REGISTER_API("_npi.copy").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { using namespace runtime; const nnvm::Op* op = Op::Get("_npi_copy"); nnvm::NodeAttrs attrs; - attrs.op = op; + attrs.op = op; NDArray* inputs[] = {args[0].operator NDArray*()}; - int num_inputs = 1; - int num_outputs = 0; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); - *ret = ndoutputs[0]; + int num_inputs = 1; + int num_outputs = 0; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = ndoutputs[0]; }); } // namespace mxnet diff --git a/src/api/operator/numpy/np_fill_diagonal_op.cc b/src/api/operator/numpy/np_fill_diagonal_op.cc index f087e7d2e608..089d7cd95903 100644 --- a/src/api/operator/numpy/np_fill_diagonal_op.cc +++ b/src/api/operator/numpy/np_fill_diagonal_op.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -27,37 +27,37 @@ namespace mxnet { MXNET_REGISTER_API("_npi.fill_diagonal") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_fill_diagonal"); - nnvm::NodeAttrs attrs; - - op::NumpyFillDiagonalParam param; - int num_inputs = 1; - NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; - - if (args[1].type_code() == kDLInt || args[1].type_code() == kDLUInt - || args[1].type_code() == kDLFloat || args[1].type_code() == kDLBfloat) { - param.val = Tuple(1, args[1].operator double()); - } else { - param.val = Obj2Tuple(args[1].operator ObjectRef()); - } - param.wrap = args[2].operator bool(); - - attrs.parsed = param; - attrs.op = op; - SetAttrDict(&attrs); - - NDArray* out = args[3].operator mxnet::NDArray*(); - NDArray** outputs = out == nullptr ? nullptr : &out; - // set the number of outputs provided by the `out` arugment - int num_outputs = out != nullptr; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs); - if (out) { - *ret = PythonArg(3); - } else { - *ret = ndoutputs[0]; - } -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_fill_diagonal"); + nnvm::NodeAttrs attrs; + + op::NumpyFillDiagonalParam param; + int num_inputs = 1; + NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; + + if (args[1].type_code() == kDLInt || args[1].type_code() == kDLUInt || + args[1].type_code() == kDLFloat || args[1].type_code() == kDLBfloat) { + param.val = Tuple(1, args[1].operator double()); + } else { + param.val = Obj2Tuple(args[1].operator ObjectRef()); + } + param.wrap = args[2].operator bool(); + + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + + NDArray* out = args[3].operator mxnet::NDArray*(); + NDArray** outputs = out == nullptr ? nullptr : &out; + // set the number of outputs provided by the `out` arugment + int num_outputs = out != nullptr; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs); + if (out) { + *ret = PythonArg(3); + } else { + *ret = ndoutputs[0]; + } + }); } // namespace mxnet diff --git a/src/api/operator/numpy/np_histogram_op.cc b/src/api/operator/numpy/np_histogram_op.cc index fa911268e39b..daeb3c730ca6 100644 --- a/src/api/operator/numpy/np_histogram_op.cc +++ b/src/api/operator/numpy/np_histogram_op.cc @@ -30,52 +30,50 @@ namespace mxnet { MXNET_REGISTER_API("_npi.histogram") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - nnvm::NodeAttrs attrs; - const nnvm::Op* op = Op::Get("_npi_histogram"); - op::HistogramParam param; - // parse bin_cnt - if (args[2].type_code() == kNull) { - param.bin_cnt = dmlc::nullopt; - } else { - param.bin_cnt = args[2].operator int(); - } + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + nnvm::NodeAttrs attrs; + const nnvm::Op* op = Op::Get("_npi_histogram"); + op::HistogramParam param; + // parse bin_cnt + if (args[2].type_code() == kNull) { + param.bin_cnt = dmlc::nullopt; + } else { + param.bin_cnt = args[2].operator int(); + } - // parse range - if (args[3].type_code() == kNull) { - param.range = dmlc::nullopt; - } else { - param.range = Obj2Tuple(args[3].operator ObjectRef()); - } + // parse range + if (args[3].type_code() == kNull) { + param.range = dmlc::nullopt; + } else { + param.range = Obj2Tuple(args[3].operator ObjectRef()); + } - attrs.parsed = param; - attrs.op = op; - SetAttrDict(&attrs); + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); - std::vector inputs_vec; - int num_inputs = 0; + std::vector inputs_vec; + int num_inputs = 0; - if (args[2].type_code() != kNull) { - CHECK_EQ(args[1].type_code(), kNull) - << "bins should be None when bin_cnt is provided"; - inputs_vec.push_back((args[0].operator NDArray*())); - num_inputs = 1; - } else { - CHECK_NE(args[1].type_code(), kNull) - << "bins should not be None when bin_cnt is not provided"; - // inputs - inputs_vec.push_back((args[0].operator NDArray*())); - inputs_vec.push_back((args[1].operator NDArray*())); - num_inputs = 2; - } + if (args[2].type_code() != kNull) { + CHECK_EQ(args[1].type_code(), kNull) << "bins should be None when bin_cnt is provided"; + inputs_vec.push_back((args[0].operator NDArray*())); + num_inputs = 1; + } else { + CHECK_NE(args[1].type_code(), kNull) + << "bins should not be None when bin_cnt is not provided"; + // inputs + inputs_vec.push_back((args[0].operator NDArray*())); + inputs_vec.push_back((args[1].operator NDArray*())); + num_inputs = 2; + } - // outputs - NDArray** out = nullptr; - int num_outputs = 0; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs_vec.data(), &num_outputs, out); - *ret = ADT(0, {NDArrayHandle(ndoutputs[0]), - NDArrayHandle(ndoutputs[1])}); -}); + // outputs + NDArray** out = nullptr; + int num_outputs = 0; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs_vec.data(), &num_outputs, out); + *ret = ADT(0, {NDArrayHandle(ndoutputs[0]), NDArrayHandle(ndoutputs[1])}); + }); } // namespace mxnet diff --git a/src/api/operator/numpy/np_init_op.cc b/src/api/operator/numpy/np_init_op.cc index 0c617725fa47..46c41a142d53 100644 --- a/src/api/operator/numpy/np_init_op.cc +++ b/src/api/operator/numpy/np_init_op.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -31,8 +31,7 @@ namespace mxnet { -MXNET_REGISTER_API("_npi.zeros") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { +MXNET_REGISTER_API("_npi.zeros").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { using namespace runtime; const nnvm::Op* op = Op::Get("_npi_zeros"); nnvm::NodeAttrs attrs; @@ -48,194 +47,192 @@ MXNET_REGISTER_API("_npi.zeros") param.dtype = String2MXNetTypeWithBool(args[1].operator std::string()); } attrs.parsed = param; - attrs.op = op; + attrs.op = op; SetAttrDict(&attrs); if (args[2].type_code() != kNull) { attrs.dict["ctx"] = args[2].operator std::string(); } int num_outputs = 0; - auto ndoutputs = Invoke(op, &attrs, 0, nullptr, &num_outputs, nullptr); - *ret = ndoutputs[0]; + auto ndoutputs = Invoke(op, &attrs, 0, nullptr, &num_outputs, nullptr); + *ret = ndoutputs[0]; }); MXNET_REGISTER_API("_npi.full_like") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_full_like"); - nnvm::NodeAttrs attrs; - op::FullLikeOpParam param; - param.fill_value = args[1].operator double(); - if (args[2].type_code() == kNull) { - param.dtype = dmlc::nullopt; - } else { - param.dtype = String2MXNetTypeWithBool(args[2].operator std::string()); - } - attrs.parsed = param; - attrs.op = op; - if (args[3].type_code() != kNull) { - attrs.dict["ctx"] = args[3].operator std::string(); - } - SetAttrDict(&attrs); - NDArray* out = args[4].operator mxnet::NDArray*(); - NDArray** outputs = out == nullptr ? nullptr : &out; - int num_outputs = out != nullptr; - NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; - int num_inputs = 1; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs); - if (out) { - *ret = PythonArg(4); - } else { - *ret = ndoutputs[0]; - } -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_full_like"); + nnvm::NodeAttrs attrs; + op::FullLikeOpParam param; + param.fill_value = args[1].operator double(); + if (args[2].type_code() == kNull) { + param.dtype = dmlc::nullopt; + } else { + param.dtype = String2MXNetTypeWithBool(args[2].operator std::string()); + } + attrs.parsed = param; + attrs.op = op; + if (args[3].type_code() != kNull) { + attrs.dict["ctx"] = args[3].operator std::string(); + } + SetAttrDict(&attrs); + NDArray* out = args[4].operator mxnet::NDArray*(); + NDArray** outputs = out == nullptr ? nullptr : &out; + int num_outputs = out != nullptr; + NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; + int num_inputs = 1; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs); + if (out) { + *ret = PythonArg(4); + } else { + *ret = ndoutputs[0]; + } + }); MXNET_REGISTER_API("_npi.indices") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_indices"); - nnvm::NodeAttrs attrs; - op::IndicesOpParam param; - // param.dimensions - if (args[0].type_code() == kDLInt) { - param.dimensions = TShape(1, args[0].operator int64_t()); - } else { - param.dimensions = TShape(args[0].operator ObjectRef()); - } - // param.dtype - if (args[1].type_code() == kNull) { - param.dtype = -1; - } else { - param.dtype = String2MXNetTypeWithBool(args[1].operator std::string()); - } - attrs.parsed = param; - attrs.op = op; - SetAttrDict(&attrs); - // param.ctx - if (args[2].type_code() != kNull) { - attrs.dict["ctx"] = args[2].operator std::string(); - } - int num_inputs = 0; - int num_outputs = 0; - auto ndoutputs = Invoke(op, &attrs, num_inputs, nullptr, &num_outputs, nullptr); - *ret = ndoutputs[0]; -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_indices"); + nnvm::NodeAttrs attrs; + op::IndicesOpParam param; + // param.dimensions + if (args[0].type_code() == kDLInt) { + param.dimensions = TShape(1, args[0].operator int64_t()); + } else { + param.dimensions = TShape(args[0].operator ObjectRef()); + } + // param.dtype + if (args[1].type_code() == kNull) { + param.dtype = -1; + } else { + param.dtype = String2MXNetTypeWithBool(args[1].operator std::string()); + } + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + // param.ctx + if (args[2].type_code() != kNull) { + attrs.dict["ctx"] = args[2].operator std::string(); + } + int num_inputs = 0; + int num_outputs = 0; + auto ndoutputs = Invoke(op, &attrs, num_inputs, nullptr, &num_outputs, nullptr); + *ret = ndoutputs[0]; + }); MXNET_REGISTER_API("_npi.atleast_1d") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_atleast_1d"); - nnvm::NodeAttrs attrs; - op::AtleastNDParam param; - int args_size = args.size(); - param.num_args = args_size; - attrs.parsed = param; - attrs.op = op; - SetAttrDict(&attrs); - int num_inputs = args_size; - std::vector inputs_vec(args_size, nullptr); - for (int i = 0; i < args_size; ++i) { - inputs_vec[i] = args[i].operator mxnet::NDArray*(); - } - NDArray** inputs = inputs_vec.data(); - int num_outputs = 0; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); - std::vector ndarray_handles; - ndarray_handles.reserve(num_outputs); - for (int i = 0; i < num_outputs; ++i) { - ndarray_handles.emplace_back(ndoutputs[i]); - } - *ret = ADT(0, ndarray_handles.begin(), ndarray_handles.end()); -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_atleast_1d"); + nnvm::NodeAttrs attrs; + op::AtleastNDParam param; + int args_size = args.size(); + param.num_args = args_size; + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + int num_inputs = args_size; + std::vector inputs_vec(args_size, nullptr); + for (int i = 0; i < args_size; ++i) { + inputs_vec[i] = args[i].operator mxnet::NDArray*(); + } + NDArray** inputs = inputs_vec.data(); + int num_outputs = 0; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + std::vector ndarray_handles; + ndarray_handles.reserve(num_outputs); + for (int i = 0; i < num_outputs; ++i) { + ndarray_handles.emplace_back(ndoutputs[i]); + } + *ret = ADT(0, ndarray_handles.begin(), ndarray_handles.end()); + }); MXNET_REGISTER_API("_npi.atleast_2d") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_atleast_2d"); - nnvm::NodeAttrs attrs; - op::AtleastNDParam param; - int args_size = args.size(); - param.num_args = args_size; - attrs.parsed = param; - attrs.op = op; - SetAttrDict(&attrs); - int num_inputs = args_size; - std::vector inputs_vec(args_size, nullptr); - for (int i = 0; i < args_size; ++i) { - inputs_vec[i] = args[i].operator mxnet::NDArray*(); - } - NDArray** inputs = inputs_vec.data(); - int num_outputs = 0; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); - std::vector ndarray_handles; - ndarray_handles.reserve(num_outputs); - for (int i = 0; i < num_outputs; ++i) { - ndarray_handles.emplace_back(ndoutputs[i]); - } - *ret = ADT(0, ndarray_handles.begin(), ndarray_handles.end()); -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_atleast_2d"); + nnvm::NodeAttrs attrs; + op::AtleastNDParam param; + int args_size = args.size(); + param.num_args = args_size; + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + int num_inputs = args_size; + std::vector inputs_vec(args_size, nullptr); + for (int i = 0; i < args_size; ++i) { + inputs_vec[i] = args[i].operator mxnet::NDArray*(); + } + NDArray** inputs = inputs_vec.data(); + int num_outputs = 0; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + std::vector ndarray_handles; + ndarray_handles.reserve(num_outputs); + for (int i = 0; i < num_outputs; ++i) { + ndarray_handles.emplace_back(ndoutputs[i]); + } + *ret = ADT(0, ndarray_handles.begin(), ndarray_handles.end()); + }); MXNET_REGISTER_API("_npi.atleast_3d") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_atleast_3d"); - nnvm::NodeAttrs attrs; - op::AtleastNDParam param; - int args_size = args.size(); - param.num_args = args_size; - attrs.parsed = param; - attrs.op = op; - SetAttrDict(&attrs); - int num_inputs = args_size; - std::vector inputs_vec(args_size, nullptr); - for (int i = 0; i < args_size; ++i) { - inputs_vec[i] = args[i].operator mxnet::NDArray*(); - } - NDArray** inputs = inputs_vec.data(); - int num_outputs = 0; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); - std::vector ndarray_handles; - ndarray_handles.reserve(num_outputs); - for (int i = 0; i < num_outputs; ++i) { - ndarray_handles.emplace_back(ndoutputs[i]); - } - *ret = ADT(0, ndarray_handles.begin(), ndarray_handles.end()); -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_atleast_3d"); + nnvm::NodeAttrs attrs; + op::AtleastNDParam param; + int args_size = args.size(); + param.num_args = args_size; + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + int num_inputs = args_size; + std::vector inputs_vec(args_size, nullptr); + for (int i = 0; i < args_size; ++i) { + inputs_vec[i] = args[i].operator mxnet::NDArray*(); + } + NDArray** inputs = inputs_vec.data(); + int num_outputs = 0; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + std::vector ndarray_handles; + ndarray_handles.reserve(num_outputs); + for (int i = 0; i < num_outputs; ++i) { + ndarray_handles.emplace_back(ndoutputs[i]); + } + *ret = ADT(0, ndarray_handles.begin(), ndarray_handles.end()); + }); MXNET_REGISTER_API("_npi.arange") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_arange"); - nnvm::NodeAttrs attrs; - op::RangeParam param; - param.start = args[0].operator double(); - if (args[1].type_code() == kNull) { - param.stop = dmlc::nullopt; - } else { - param.stop = args[1].operator double(); - } - param.step = args[2].operator double(); - param.repeat = 1; - param.infer_range = false; - if (args[3].type_code() == kNull) { - param.dtype = Imperative::Get()->is_np_default_dtype() ? - mshadow::kInt64 : - mshadow::kFloat32; - } else { - param.dtype = String2MXNetTypeWithBool(args[3].operator std::string()); - } - attrs.parsed = param; - attrs.op = op; - SetAttrDict(&attrs); - if (args[4].type_code() != kNull) { - attrs.dict["ctx"] = args[4].operator std::string(); - } - int num_outputs = 0; - auto ndoutputs = Invoke(op, &attrs, 0, nullptr, &num_outputs, nullptr); - *ret = ndoutputs[0]; -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_arange"); + nnvm::NodeAttrs attrs; + op::RangeParam param; + param.start = args[0].operator double(); + if (args[1].type_code() == kNull) { + param.stop = dmlc::nullopt; + } else { + param.stop = args[1].operator double(); + } + param.step = args[2].operator double(); + param.repeat = 1; + param.infer_range = false; + if (args[3].type_code() == kNull) { + param.dtype = + Imperative::Get()->is_np_default_dtype() ? mshadow::kInt64 : mshadow::kFloat32; + } else { + param.dtype = String2MXNetTypeWithBool(args[3].operator std::string()); + } + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + if (args[4].type_code() != kNull) { + attrs.dict["ctx"] = args[4].operator std::string(); + } + int num_outputs = 0; + auto ndoutputs = Invoke(op, &attrs, 0, nullptr, &num_outputs, nullptr); + *ret = ndoutputs[0]; + }); -MXNET_REGISTER_API("_npi.eye") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { +MXNET_REGISTER_API("_npi.eye").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { using namespace runtime; const nnvm::Op* op = Op::Get("_npi_eye"); nnvm::NodeAttrs attrs; @@ -253,89 +250,88 @@ MXNET_REGISTER_API("_npi.eye") param.dtype = String2MXNetTypeWithBool(args[4].operator std::string()); } attrs.parsed = param; - attrs.op = op; + attrs.op = op; SetAttrDict(&attrs); if (args[3].type_code() != kNull) { attrs.dict["ctx"] = args[3].operator std::string(); } int num_outputs = 0; - auto ndoutputs = Invoke(op, &attrs, 0, nullptr, &num_outputs, nullptr); - *ret = ndoutputs[0]; + auto ndoutputs = Invoke(op, &attrs, 0, nullptr, &num_outputs, nullptr); + *ret = ndoutputs[0]; }); MXNET_REGISTER_API("_npi.linspace") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_linspace"); - nnvm::NodeAttrs attrs; - op::LinspaceParam param; - param.start = args[0].operator double(); - param.stop = args[1].operator double(); - if (features::is_enabled(features::INT64_TENSOR_SIZE)) - param.num = args[2].operator int64_t(); - else - param.num = args[2].operator int(); - if (args[3].type_code() == kNull) { - param.endpoint = true; - } else { - param.endpoint = args[3].operator bool(); - } - if (args[5].type_code() == kNull) { - param.dtype = mxnet::common::GetDefaultDtype(); - } else { - param.dtype = String2MXNetTypeWithBool(args[5].operator std::string()); - } - attrs.parsed = param; - attrs.op = op; - SetAttrDict(&attrs); - if (args[4].type_code() != kNull) { - attrs.dict["ctx"] = args[4].operator std::string(); - } - int num_outputs = 0; - auto ndoutputs = Invoke(op, &attrs, 0, nullptr, &num_outputs, nullptr); - *ret = ndoutputs[0]; -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_linspace"); + nnvm::NodeAttrs attrs; + op::LinspaceParam param; + param.start = args[0].operator double(); + param.stop = args[1].operator double(); + if (features::is_enabled(features::INT64_TENSOR_SIZE)) + param.num = args[2].operator int64_t(); + else + param.num = args[2].operator int(); + if (args[3].type_code() == kNull) { + param.endpoint = true; + } else { + param.endpoint = args[3].operator bool(); + } + if (args[5].type_code() == kNull) { + param.dtype = mxnet::common::GetDefaultDtype(); + } else { + param.dtype = String2MXNetTypeWithBool(args[5].operator std::string()); + } + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + if (args[4].type_code() != kNull) { + attrs.dict["ctx"] = args[4].operator std::string(); + } + int num_outputs = 0; + auto ndoutputs = Invoke(op, &attrs, 0, nullptr, &num_outputs, nullptr); + *ret = ndoutputs[0]; + }); MXNET_REGISTER_API("_npi.logspace") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_logspace"); - nnvm::NodeAttrs attrs; - op::LogspaceParam param; - param.start = args[0].operator double(); - param.stop = args[1].operator double(); - if (features::is_enabled(features::INT64_TENSOR_SIZE)) - param.num = args[2].operator int64_t(); - else - param.num = args[2].operator int(); - if (args[3].type_code() == kNull) { - param.endpoint = true; - } else { - param.endpoint = args[3].operator bool(); - } - if (args[4].type_code() == kNull) { - param.base = 10.0; - } else { - param.base = args[4].operator double(); - } - if (args[6].type_code() == kNull) { - param.dtype = mxnet::common::GetDefaultDtype(); - } else { - param.dtype = String2MXNetTypeWithBool(args[6].operator std::string()); - } - attrs.parsed = param; - attrs.op = op; - SetAttrDict(&attrs); - if (args[5].type_code() != kNull) { - attrs.dict["ctx"] = args[5].operator std::string(); - } - int num_outputs = 0; - auto ndoutputs = Invoke(op, &attrs, 0, nullptr, &num_outputs, nullptr); - *ret = ndoutputs[0]; -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_logspace"); + nnvm::NodeAttrs attrs; + op::LogspaceParam param; + param.start = args[0].operator double(); + param.stop = args[1].operator double(); + if (features::is_enabled(features::INT64_TENSOR_SIZE)) + param.num = args[2].operator int64_t(); + else + param.num = args[2].operator int(); + if (args[3].type_code() == kNull) { + param.endpoint = true; + } else { + param.endpoint = args[3].operator bool(); + } + if (args[4].type_code() == kNull) { + param.base = 10.0; + } else { + param.base = args[4].operator double(); + } + if (args[6].type_code() == kNull) { + param.dtype = mxnet::common::GetDefaultDtype(); + } else { + param.dtype = String2MXNetTypeWithBool(args[6].operator std::string()); + } + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + if (args[5].type_code() != kNull) { + attrs.dict["ctx"] = args[5].operator std::string(); + } + int num_outputs = 0; + auto ndoutputs = Invoke(op, &attrs, 0, nullptr, &num_outputs, nullptr); + *ret = ndoutputs[0]; + }); -MXNET_REGISTER_API("_npi.ones") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { +MXNET_REGISTER_API("_npi.ones").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { using namespace runtime; const nnvm::Op* op = Op::Get("_npi_ones"); nnvm::NodeAttrs attrs; @@ -351,18 +347,17 @@ MXNET_REGISTER_API("_npi.ones") param.dtype = String2MXNetTypeWithBool(args[1].operator std::string()); } attrs.parsed = param; - attrs.op = op; + attrs.op = op; if (args[2].type_code() != kNull) { attrs.dict["ctx"] = args[2].operator std::string(); } int num_outputs = 0; SetAttrDict(&attrs); auto ndoutputs = Invoke(op, &attrs, 0, nullptr, &num_outputs, nullptr); - *ret = ndoutputs[0]; + *ret = ndoutputs[0]; }); -MXNET_REGISTER_API("_npi.full") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { +MXNET_REGISTER_API("_npi.full").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { using namespace runtime; const nnvm::Op* op = Op::Get("_npi_full"); nnvm::NodeAttrs attrs; @@ -377,17 +372,17 @@ MXNET_REGISTER_API("_npi.full") } else { param.dtype = String2MXNetTypeWithBool(args[1].operator std::string()); } - param.value = args[2].operator double(); + param.value = args[2].operator double(); attrs.parsed = param; - attrs.op = op; + attrs.op = op; if (args[3].type_code() != kNull) { attrs.dict["ctx"] = args[3].operator std::string(); } SetAttrDict(&attrs); - NDArray* out = args[4].operator mxnet::NDArray*(); + NDArray* out = args[4].operator mxnet::NDArray*(); NDArray** outputs = out == nullptr ? nullptr : &out; - int num_outputs = out != nullptr; - auto ndoutputs = Invoke(op, &attrs, 0, nullptr, &num_outputs, outputs); + int num_outputs = out != nullptr; + auto ndoutputs = Invoke(op, &attrs, 0, nullptr, &num_outputs, outputs); if (out) { *ret = PythonArg(4); } else { @@ -396,26 +391,26 @@ MXNET_REGISTER_API("_npi.full") }); MXNET_REGISTER_API("_npi.identity") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_identity"); - nnvm::NodeAttrs attrs; - op::InitOpParam param; - param.shape = TShape(args[0].operator ObjectRef()); - if (args[1].type_code() == kNull) { - param.dtype = mxnet::common::GetDefaultDtype(); - } else { - param.dtype = String2MXNetTypeWithBool(args[1].operator std::string()); - } - attrs.parsed = param; - attrs.op = op; - if (args[2].type_code() != kNull) { - attrs.dict["ctx"] = args[2].operator std::string(); - } - int num_outputs = 0; - SetAttrDict(&attrs); - auto ndoutputs = Invoke(op, &attrs, 0, nullptr, &num_outputs, nullptr); - *ret = ndoutputs[0]; -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_identity"); + nnvm::NodeAttrs attrs; + op::InitOpParam param; + param.shape = TShape(args[0].operator ObjectRef()); + if (args[1].type_code() == kNull) { + param.dtype = mxnet::common::GetDefaultDtype(); + } else { + param.dtype = String2MXNetTypeWithBool(args[1].operator std::string()); + } + attrs.parsed = param; + attrs.op = op; + if (args[2].type_code() != kNull) { + attrs.dict["ctx"] = args[2].operator std::string(); + } + int num_outputs = 0; + SetAttrDict(&attrs); + auto ndoutputs = Invoke(op, &attrs, 0, nullptr, &num_outputs, nullptr); + *ret = ndoutputs[0]; + }); } // namespace mxnet diff --git a/src/api/operator/numpy/np_insert_op.cc b/src/api/operator/numpy/np_insert_op.cc index 4a5610fe46dd..2d6b7574ecb9 100644 --- a/src/api/operator/numpy/np_insert_op.cc +++ b/src/api/operator/numpy/np_insert_op.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -33,131 +33,128 @@ namespace mxnet { MXNET_REGISTER_API("_npi.insert_scalar") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - static const nnvm::Op* op = Op::Get("_npi_insert_scalar"); - nnvm::NodeAttrs attrs; - op::NumpyInsertParam param; - int num_inputs = 0; - param.start = dmlc::nullopt; - param.step = dmlc::nullopt; - param.stop = dmlc::nullopt; - if (args[1].type_code() == kDLInt || - args[1].type_code() == kDLUInt || - args[1].type_code() == kDLFloat) { - param.val = args[1].operator double(); - num_inputs = 1; - } else { - param.val = dmlc::nullopt; - num_inputs = 2; - } - if (features::is_enabled(features::INT64_TENSOR_SIZE)) { - param.int_ind = args[2].operator int64_t(); - } else { - param.int_ind = args[2].operator int(); - } - if (args[3].type_code() == kNull) { - param.axis = dmlc::nullopt; - } else { - param.axis = args[3].operator int(); - } - attrs.parsed = param; - attrs.op = op; - SetAttrDict(&attrs); - std::vector inputs; - inputs.reserve(num_inputs); - for (int i = 0; i < num_inputs; ++i) { - inputs.push_back(args[i].operator mxnet::NDArray*()); - } - int num_outputs = 0; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, nullptr); - *ret = ndoutputs[0]; -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + static const nnvm::Op* op = Op::Get("_npi_insert_scalar"); + nnvm::NodeAttrs attrs; + op::NumpyInsertParam param; + int num_inputs = 0; + param.start = dmlc::nullopt; + param.step = dmlc::nullopt; + param.stop = dmlc::nullopt; + if (args[1].type_code() == kDLInt || args[1].type_code() == kDLUInt || + args[1].type_code() == kDLFloat) { + param.val = args[1].operator double(); + num_inputs = 1; + } else { + param.val = dmlc::nullopt; + num_inputs = 2; + } + if (features::is_enabled(features::INT64_TENSOR_SIZE)) { + param.int_ind = args[2].operator int64_t(); + } else { + param.int_ind = args[2].operator int(); + } + if (args[3].type_code() == kNull) { + param.axis = dmlc::nullopt; + } else { + param.axis = args[3].operator int(); + } + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + std::vector inputs; + inputs.reserve(num_inputs); + for (int i = 0; i < num_inputs; ++i) { + inputs.push_back(args[i].operator mxnet::NDArray*()); + } + int num_outputs = 0; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, nullptr); + *ret = ndoutputs[0]; + }); MXNET_REGISTER_API("_npi.insert_slice") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - static const nnvm::Op* op = Op::Get("_npi_insert_slice"); - nnvm::NodeAttrs attrs; - op::NumpyInsertParam param; - int num_inputs = 0; - if (args[1].type_code() == kDLInt || - args[1].type_code() == kDLUInt || - args[1].type_code() == kDLFloat) { - param.val = args[1].operator double(); - num_inputs = 1; - } else { - param.val = dmlc::nullopt; - num_inputs = 2; - } - if (args[2].type_code() == kNull) { - param.start = dmlc::nullopt; - } else { - param.start = args[2].operator int64_t(); - } - if (args[3].type_code() == kNull) { - param.stop = dmlc::nullopt; - } else { - param.stop = args[3].operator int64_t(); - } - if (args[4].type_code() == kNull) { - param.step = dmlc::nullopt; - } else { - param.step = args[4].operator int64_t(); - } - if (args[5].type_code() == kNull) { - param.axis = dmlc::nullopt; - } else { - param.axis = args[5].operator int(); - } - attrs.parsed = param; - attrs.op = op; - SetAttrDict(&attrs); - std::vector inputs; - inputs.reserve(num_inputs); - for (int i = 0; i < num_inputs; ++i) { - inputs.push_back(args[i].operator mxnet::NDArray*()); - } - int num_outputs = 0; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, nullptr); - *ret = ndoutputs[0]; -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + static const nnvm::Op* op = Op::Get("_npi_insert_slice"); + nnvm::NodeAttrs attrs; + op::NumpyInsertParam param; + int num_inputs = 0; + if (args[1].type_code() == kDLInt || args[1].type_code() == kDLUInt || + args[1].type_code() == kDLFloat) { + param.val = args[1].operator double(); + num_inputs = 1; + } else { + param.val = dmlc::nullopt; + num_inputs = 2; + } + if (args[2].type_code() == kNull) { + param.start = dmlc::nullopt; + } else { + param.start = args[2].operator int64_t(); + } + if (args[3].type_code() == kNull) { + param.stop = dmlc::nullopt; + } else { + param.stop = args[3].operator int64_t(); + } + if (args[4].type_code() == kNull) { + param.step = dmlc::nullopt; + } else { + param.step = args[4].operator int64_t(); + } + if (args[5].type_code() == kNull) { + param.axis = dmlc::nullopt; + } else { + param.axis = args[5].operator int(); + } + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + std::vector inputs; + inputs.reserve(num_inputs); + for (int i = 0; i < num_inputs; ++i) { + inputs.push_back(args[i].operator mxnet::NDArray*()); + } + int num_outputs = 0; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, nullptr); + *ret = ndoutputs[0]; + }); MXNET_REGISTER_API("_npi.insert_tensor") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - static const nnvm::Op* op = Op::Get("_npi_insert_tensor"); - nnvm::NodeAttrs attrs; - op::NumpyInsertParam param; - param.start = dmlc::nullopt; - param.step = dmlc::nullopt; - param.stop = dmlc::nullopt; - int num_inputs = 0; - if (args[2].type_code() == kDLInt || - args[2].type_code() == kDLUInt || - args[2].type_code() == kDLFloat) { - param.val = args[2].operator double(); - num_inputs = 2; - } else { - param.val = dmlc::nullopt; - num_inputs = 3; - } - if (args[3].type_code() == kNull) { - param.axis = dmlc::nullopt; - } else { - param.axis = args[3].operator int(); - } - attrs.parsed = param; - attrs.op = op; - SetAttrDict(&attrs); - std::vector inputs; - inputs.reserve(num_inputs); - for (int i = 0; i < num_inputs; ++i) { - inputs.push_back(args[i].operator mxnet::NDArray*()); - } - int num_outputs = 0; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, nullptr); - *ret = ndoutputs[0]; -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + static const nnvm::Op* op = Op::Get("_npi_insert_tensor"); + nnvm::NodeAttrs attrs; + op::NumpyInsertParam param; + param.start = dmlc::nullopt; + param.step = dmlc::nullopt; + param.stop = dmlc::nullopt; + int num_inputs = 0; + if (args[2].type_code() == kDLInt || args[2].type_code() == kDLUInt || + args[2].type_code() == kDLFloat) { + param.val = args[2].operator double(); + num_inputs = 2; + } else { + param.val = dmlc::nullopt; + num_inputs = 3; + } + if (args[3].type_code() == kNull) { + param.axis = dmlc::nullopt; + } else { + param.axis = args[3].operator int(); + } + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + std::vector inputs; + inputs.reserve(num_inputs); + for (int i = 0; i < num_inputs; ++i) { + inputs.push_back(args[i].operator mxnet::NDArray*()); + } + int num_outputs = 0; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, nullptr); + *ret = ndoutputs[0]; + }); } // namespace mxnet diff --git a/src/api/operator/numpy/np_interp_op.cc b/src/api/operator/numpy/np_interp_op.cc index 0b89373b1a88..c3682ded7314 100644 --- a/src/api/operator/numpy/np_interp_op.cc +++ b/src/api/operator/numpy/np_interp_op.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -29,50 +29,52 @@ namespace mxnet { MXNET_REGISTER_API("_npi.interp") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - static const nnvm::Op* op = Op::Get("_npi_interp"); - nnvm::NodeAttrs attrs; - op::NumpyInterpParam param; - if (args[3].type_code() == kNull) { - param.left = dmlc::nullopt; - } else { - param.left = args[3].operator double(); - } - if (args[4].type_code() == kNull) { - param.right = dmlc::nullopt; - } else { - param.right = args[4].operator double(); - } - if (args[5].type_code() == kNull) { - param.period = dmlc::nullopt; - } else { - param.period = args[5].operator double(); - } - if (args[2].type_code() == kDLInt || args[2].type_code() == kDLFloat) { - param.x_scalar = args[2].operator double(); - param.x_is_scalar = true; - attrs.op = op; - attrs.parsed = param; - SetAttrDict(&attrs); - NDArray* inputs[] = {args[0].operator mxnet::NDArray*(), args[1].operator mxnet::NDArray*()}; - int num_inputs = 2; - int num_outputs = 0; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); - *ret = reinterpret_cast(ndoutputs[0]); - } else { - param.x_scalar = 0.0; - param.x_is_scalar = false; - attrs.op = op; - attrs.parsed = param; - SetAttrDict(&attrs); - NDArray* inputs[] = {args[0].operator mxnet::NDArray*(), args[1].operator mxnet::NDArray*(), - args[2].operator mxnet::NDArray*()}; - int num_inputs = 3; - int num_outputs = 0; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); - *ret = reinterpret_cast(ndoutputs[0]); - } -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + static const nnvm::Op* op = Op::Get("_npi_interp"); + nnvm::NodeAttrs attrs; + op::NumpyInterpParam param; + if (args[3].type_code() == kNull) { + param.left = dmlc::nullopt; + } else { + param.left = args[3].operator double(); + } + if (args[4].type_code() == kNull) { + param.right = dmlc::nullopt; + } else { + param.right = args[4].operator double(); + } + if (args[5].type_code() == kNull) { + param.period = dmlc::nullopt; + } else { + param.period = args[5].operator double(); + } + if (args[2].type_code() == kDLInt || args[2].type_code() == kDLFloat) { + param.x_scalar = args[2].operator double(); + param.x_is_scalar = true; + attrs.op = op; + attrs.parsed = param; + SetAttrDict(&attrs); + NDArray* inputs[] = {args[0].operator mxnet::NDArray*(), + args[1].operator mxnet::NDArray*()}; + int num_inputs = 2; + int num_outputs = 0; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = reinterpret_cast(ndoutputs[0]); + } else { + param.x_scalar = 0.0; + param.x_is_scalar = false; + attrs.op = op; + attrs.parsed = param; + SetAttrDict(&attrs); + NDArray* inputs[] = {args[0].operator mxnet::NDArray*(), + args[1].operator mxnet::NDArray*(), + args[2].operator mxnet::NDArray*()}; + int num_inputs = 3; + int num_outputs = 0; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = reinterpret_cast(ndoutputs[0]); + } + }); } // namespace mxnet diff --git a/src/api/operator/numpy/np_kron.cc b/src/api/operator/numpy/np_kron.cc index 753798208b4f..0bc59eecebc3 100644 --- a/src/api/operator/numpy/np_kron.cc +++ b/src/api/operator/numpy/np_kron.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -28,17 +28,16 @@ namespace mxnet { -MXNET_REGISTER_API("_npi.kron") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { +MXNET_REGISTER_API("_npi.kron").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { using namespace runtime; nnvm::NodeAttrs attrs; const nnvm::Op* op = Op::Get("_npi_kron"); - attrs.op = op; - NDArray* inputs[] = {args[0].operator NDArray*(), args[1].operator NDArray*()}; - int num_inputs = 2; - int num_outputs = 0; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); - *ret = reinterpret_cast(ndoutputs[0]); + attrs.op = op; + NDArray* inputs[] = {args[0].operator NDArray*(), args[1].operator NDArray*()}; + int num_inputs = 2; + int num_outputs = 0; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = reinterpret_cast(ndoutputs[0]); }); } // namespace mxnet diff --git a/src/api/operator/numpy/np_matmul_op.cc b/src/api/operator/numpy/np_matmul_op.cc index 48f4ec06fe83..d8b5250ed61a 100644 --- a/src/api/operator/numpy/np_matmul_op.cc +++ b/src/api/operator/numpy/np_matmul_op.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -29,23 +29,22 @@ namespace mxnet { MXNET_REGISTER_API("_npi.matmul") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - static const nnvm::Op* op = Op::Get("_npi_matmul"); - nnvm::NodeAttrs attrs; - int num_inputs = 2; - NDArray* inputs[2] = {args[0].operator mxnet::NDArray*(), - args[1].operator mxnet::NDArray*()}; - attrs.op = op; - NDArray* out = args[2].operator mxnet::NDArray*(); - NDArray** outputs = out == nullptr ? nullptr : &out; - int num_outputs = out != nullptr; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs); - if (out) { - *ret = PythonArg(2); - } else { - *ret = ndoutputs[0]; - } -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + static const nnvm::Op* op = Op::Get("_npi_matmul"); + nnvm::NodeAttrs attrs; + int num_inputs = 2; + NDArray* inputs[2] = {args[0].operator mxnet::NDArray*(), args[1].operator mxnet::NDArray*()}; + attrs.op = op; + NDArray* out = args[2].operator mxnet::NDArray*(); + NDArray** outputs = out == nullptr ? nullptr : &out; + int num_outputs = out != nullptr; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs); + if (out) { + *ret = PythonArg(2); + } else { + *ret = ndoutputs[0]; + } + }); } // namespace mxnet diff --git a/src/api/operator/numpy/np_matrix_op.cc b/src/api/operator/numpy/np_matrix_op.cc index 96f481db56ae..921dd5fbbc3d 100644 --- a/src/api/operator/numpy/np_matrix_op.cc +++ b/src/api/operator/numpy/np_matrix_op.cc @@ -32,56 +32,55 @@ namespace mxnet { MXNET_REGISTER_API("_npi.transpose") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - static const nnvm::Op* op = Op::Get("_npi_transpose"); - nnvm::NodeAttrs attrs; - op::NumpyTransposeParam param; - if (args[1].type_code() == kNull) { - param.axes = TShape(-1, 0); - } else if (args[1].type_code() == kDLInt) { - param.axes = TShape(1, args[1].operator int64_t()); - } else { - param.axes = TShape(args[1].operator ObjectRef()); - } - attrs.parsed = param; - attrs.op = op; - SetAttrDict(&attrs); - NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; - int num_inputs = 1; - int num_outputs = 0; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); - *ret = ndoutputs[0]; -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + static const nnvm::Op* op = Op::Get("_npi_transpose"); + nnvm::NodeAttrs attrs; + op::NumpyTransposeParam param; + if (args[1].type_code() == kNull) { + param.axes = TShape(-1, 0); + } else if (args[1].type_code() == kDLInt) { + param.axes = TShape(1, args[1].operator int64_t()); + } else { + param.axes = TShape(args[1].operator ObjectRef()); + } + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; + int num_inputs = 1; + int num_outputs = 0; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = ndoutputs[0]; + }); MXNET_REGISTER_API("_npi.expand_dims") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_expand_dims"); - nnvm::NodeAttrs attrs; - op::ExpandDimParam param; - param.axis = args[1].operator int(); - - // we directly copy ExpandDimParam, which is trivially-copyable - attrs.parsed = param; - attrs.op = op; - SetAttrDict(&attrs); - - int num_outputs = 0; - NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; - int num_inputs = 1; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); - *ret = ndoutputs[0]; -}); - -MXNET_REGISTER_API("_npi.stack") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_expand_dims"); + nnvm::NodeAttrs attrs; + op::ExpandDimParam param; + param.axis = args[1].operator int(); + + // we directly copy ExpandDimParam, which is trivially-copyable + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + + int num_outputs = 0; + NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; + int num_inputs = 1; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = ndoutputs[0]; + }); + +MXNET_REGISTER_API("_npi.stack").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { using namespace runtime; const nnvm::Op* op = Op::Get("_npi_stack"); nnvm::NodeAttrs attrs; op::StackParam param; - int i = 0; + int i = 0; int num_inputs = 0; std::vector inputs; while (args[i].type_code() != kDLInt) { @@ -91,32 +90,30 @@ MXNET_REGISTER_API("_npi.stack") } param.num_args = i; - param.axis = args[i].operator int64_t(); - attrs.parsed = param; - attrs.op = op; + param.axis = args[i].operator int64_t(); + attrs.parsed = param; + attrs.op = op; SetAttrDict(&attrs); - NDArray* out = args[i+1].operator mxnet::NDArray*(); + NDArray* out = args[i + 1].operator mxnet::NDArray*(); NDArray** outputs = out == nullptr ? nullptr : &out; - int num_outputs = out != nullptr; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), - &num_outputs, outputs); + int num_outputs = out != nullptr; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, outputs); if (out) { - *ret = PythonArg(i+1); + *ret = PythonArg(i + 1); } else { *ret = ndoutputs[0]; } }); -MXNET_REGISTER_API("_npi.flip") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { +MXNET_REGISTER_API("_npi.flip").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { using namespace runtime; const nnvm::Op* op = Op::Get("_npi_flip"); nnvm::NodeAttrs attrs; op::FlipParam param; - NDArray* out = args[2].operator mxnet::NDArray*(); + NDArray* out = args[2].operator mxnet::NDArray*(); NDArray** outputs = out == nullptr ? nullptr : &out; - int num_outputs = out != nullptr; + int num_outputs = out != nullptr; if (args[1].type_code() == kNull) { param.axis = mxnet::Tuple(-1, dim_t(0)); } else if (args[1].type_code() == kDLInt) { @@ -125,9 +122,9 @@ MXNET_REGISTER_API("_npi.flip") param.axis = Tuple(args[1].operator ObjectRef()); } NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; - int num_inputs = 1; - attrs.parsed = param; - attrs.op = op; + int num_inputs = 1; + attrs.parsed = param; + attrs.op = op; SetAttrDict(&attrs); auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs); if (out) { @@ -138,85 +135,82 @@ MXNET_REGISTER_API("_npi.flip") }); MXNET_REGISTER_API("_npi.concatenate") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_concatenate"); - nnvm::NodeAttrs attrs; - op::NumpyConcatenateParam param; - int arg_size = args.num_args; - param.num_args = arg_size - 2; - if (args[arg_size - 2].type_code() == kNull) { - param.axis = dmlc::nullopt; - } else { - param.axis = args[arg_size - 2].operator int(); - } - attrs.parsed = param; - attrs.op = op; - SetAttrDict(&attrs); - int num_inputs = arg_size - 2; - std::vector inputs; - inputs.reserve(num_inputs); - for (int i = 0; i < num_inputs; ++i) { - inputs.push_back(args[i].operator mxnet::NDArray*()); - } - NDArray* out = args[arg_size - 1].operator mxnet::NDArray*(); - NDArray** outputs = out == nullptr ? nullptr : &out; - int num_outputs = out != nullptr; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, outputs); - if (out) { - *ret = PythonArg(arg_size - 1); - } else { - *ret = ndoutputs[0]; - } -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_concatenate"); + nnvm::NodeAttrs attrs; + op::NumpyConcatenateParam param; + int arg_size = args.num_args; + param.num_args = arg_size - 2; + if (args[arg_size - 2].type_code() == kNull) { + param.axis = dmlc::nullopt; + } else { + param.axis = args[arg_size - 2].operator int(); + } + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + int num_inputs = arg_size - 2; + std::vector inputs; + inputs.reserve(num_inputs); + for (int i = 0; i < num_inputs; ++i) { + inputs.push_back(args[i].operator mxnet::NDArray*()); + } + NDArray* out = args[arg_size - 1].operator mxnet::NDArray*(); + NDArray** outputs = out == nullptr ? nullptr : &out; + int num_outputs = out != nullptr; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, outputs); + if (out) { + *ret = PythonArg(arg_size - 1); + } else { + *ret = ndoutputs[0]; + } + }); MXNET_REGISTER_API("_npi.dstack") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_dstack"); - nnvm::NodeAttrs attrs; - op::ConcatParam param; - int args_size = args.size(); - // param.num_args - param.num_args = args_size; - attrs.parsed = param; - attrs.op = op; - SetAttrDict(&attrs); - // inputs - int num_inputs = args_size; - std::vector inputs_vec(args_size, nullptr); - for (int i = 0; i < args_size; ++i) { - inputs_vec[i] = args[i].operator mxnet::NDArray*(); - } - NDArray** inputs = inputs_vec.data(); - // outputs - int num_outputs = 0; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); - *ret = ndoutputs[0]; -}); - -MXNET_REGISTER_API("_npi.split") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_dstack"); + nnvm::NodeAttrs attrs; + op::ConcatParam param; + int args_size = args.size(); + // param.num_args + param.num_args = args_size; + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + // inputs + int num_inputs = args_size; + std::vector inputs_vec(args_size, nullptr); + for (int i = 0; i < args_size; ++i) { + inputs_vec[i] = args[i].operator mxnet::NDArray*(); + } + NDArray** inputs = inputs_vec.data(); + // outputs + int num_outputs = 0; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = ndoutputs[0]; + }); + +MXNET_REGISTER_API("_npi.split").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { using namespace runtime; const nnvm::Op* op = Op::Get("_npi_split"); - int num_inputs = 1; - NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; + int num_inputs = 1; + NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; nnvm::NodeAttrs attrs; op::SplitParam param; - param.axis = args[2].operator int(); + param.axis = args[2].operator int(); param.squeeze_axis = false; if (args[1].type_code() == kDLInt) { - param.indices = TShape(0, 0); + param.indices = TShape(0, 0); param.sections = args[1].operator int(); - int index = param.axis >= 0 ? param.axis : - param.axis + inputs[0]->shape().ndim(); + int index = param.axis >= 0 ? param.axis : param.axis + inputs[0]->shape().ndim(); CHECK_GE(index, 0) << "IndexError: tuple index out of range"; - CHECK_GT(param.sections, 0) - << "ValueError: number sections must be larger than 0"; + CHECK_GT(param.sections, 0) << "ValueError: number sections must be larger than 0"; CHECK_EQ(inputs[0]->shape()[index] % param.sections, 0) - << "ValueError: array split does not result in an equal division"; + << "ValueError: array split does not result in an equal division"; } else { - TShape t = TShape(args[1].operator ObjectRef()); + TShape t = TShape(args[1].operator ObjectRef()); param.indices = TShape(t.ndim() + 1, 0); for (int i = 0; i < t.ndim(); ++i) { param.indices[i + 1] = t[i]; @@ -224,11 +218,11 @@ MXNET_REGISTER_API("_npi.split") param.sections = 0; } attrs.parsed = param; - attrs.op = op; + attrs.op = op; SetAttrDict(&attrs); int num_outputs = 0; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); std::vector ndarray_handles; ndarray_handles.reserve(num_outputs); for (int i = 0; i < num_outputs; ++i) { @@ -237,8 +231,7 @@ MXNET_REGISTER_API("_npi.split") *ret = ADT(0, ndarray_handles.begin(), ndarray_handles.end()); }); -MXNET_REGISTER_API("_npi.roll") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { +MXNET_REGISTER_API("_npi.roll").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { using namespace runtime; static const nnvm::Op* op = Op::Get("_npi_roll"); nnvm::NodeAttrs attrs; @@ -258,17 +251,16 @@ MXNET_REGISTER_API("_npi.roll") param.axis = TShape(args[2].operator ObjectRef()); } attrs.parsed = param; - attrs.op = op; + attrs.op = op; SetAttrDict(&attrs); NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; - int num_inputs = 1; - int num_outputs = 0; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); - *ret = ndoutputs[0]; + int num_inputs = 1; + int num_outputs = 0; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = ndoutputs[0]; }); -MXNET_REGISTER_API("_npi.rot90") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { +MXNET_REGISTER_API("_npi.rot90").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { using namespace runtime; static const nnvm::Op* op = Op::Get("_npi_rot90"); nnvm::NodeAttrs attrs; @@ -282,213 +274,208 @@ MXNET_REGISTER_API("_npi.rot90") param.axes = TShape(args[2].operator ObjectRef()); } attrs.parsed = param; - attrs.op = op; + attrs.op = op; SetAttrDict(&attrs); NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; - int num_inputs = 1; - int num_outputs = 0; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); - *ret = ndoutputs[0]; + int num_inputs = 1; + int num_outputs = 0; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = ndoutputs[0]; }); MXNET_REGISTER_API("_npi.column_stack") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_column_stack"); - nnvm::NodeAttrs attrs; - op::NumpyColumnStackParam param; - param.num_args = args.size(); - - attrs.parsed = param; - attrs.op = op; - SetAttrDict(&attrs); - int num_outputs = 0; - std::vector inputs; - inputs.reserve(param.num_args); - for (int i = 0; i < param.num_args; ++i) { - inputs.push_back(args[i].operator mxnet::NDArray*()); - } - auto ndoutputs = Invoke(op, &attrs, param.num_args, &inputs[0], &num_outputs, nullptr); - *ret = ndoutputs[0]; -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_column_stack"); + nnvm::NodeAttrs attrs; + op::NumpyColumnStackParam param; + param.num_args = args.size(); + + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + int num_outputs = 0; + std::vector inputs; + inputs.reserve(param.num_args); + for (int i = 0; i < param.num_args; ++i) { + inputs.push_back(args[i].operator mxnet::NDArray*()); + } + auto ndoutputs = Invoke(op, &attrs, param.num_args, &inputs[0], &num_outputs, nullptr); + *ret = ndoutputs[0]; + }); MXNET_REGISTER_API("_npi.hstack") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_hstack"); - nnvm::NodeAttrs attrs; - op::ConcatParam param; - param.num_args = args.size(); - - attrs.parsed = param; - attrs.op = op; - SetAttrDict(&attrs); - int num_outputs = 0; - std::vector inputs; - inputs.reserve(param.num_args); - for (int i = 0; i < param.num_args; ++i) { - inputs.push_back(args[i].operator mxnet::NDArray*()); - } - auto ndoutputs = Invoke(op, &attrs, param.num_args, &inputs[0], &num_outputs, nullptr); - *ret = ndoutputs[0]; -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_hstack"); + nnvm::NodeAttrs attrs; + op::ConcatParam param; + param.num_args = args.size(); + + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + int num_outputs = 0; + std::vector inputs; + inputs.reserve(param.num_args); + for (int i = 0; i < param.num_args; ++i) { + inputs.push_back(args[i].operator mxnet::NDArray*()); + } + auto ndoutputs = Invoke(op, &attrs, param.num_args, &inputs[0], &num_outputs, nullptr); + *ret = ndoutputs[0]; + }); MXNET_REGISTER_API("_npi.array_split") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - static const nnvm::Op* op = Op::Get("_npi_array_split"); - nnvm::NodeAttrs attrs; - op::SplitParam param; - param.axis = args[2].operator int(); - param.squeeze_axis = false; - if (args[1].type_code() == kDLInt) { - param.indices = TShape(0, 0); - param.sections = args[1].operator int(); - CHECK_GT(param.sections, 0) - << "ValueError: number sections must be larger than 0"; - } else { - TShape t = TShape(args[1].operator ObjectRef()); - param.indices = TShape(t.ndim() + 1, 0); - for (int i = 0; i < t.ndim(); ++i) { - param.indices[i + 1] = t[i]; - } - param.sections = 0; - } - attrs.parsed = param; - attrs.op = op; - SetAttrDict(&attrs); - NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; - int num_inputs = 1; - int num_outputs = 0; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); - std::vector ndarray_handles; - ndarray_handles.reserve(num_outputs); - for (int i = 0; i < num_outputs; ++i) { - ndarray_handles.emplace_back(ndoutputs[i]); - } - *ret = ADT(0, ndarray_handles.begin(), ndarray_handles.end()); -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + static const nnvm::Op* op = Op::Get("_npi_array_split"); + nnvm::NodeAttrs attrs; + op::SplitParam param; + param.axis = args[2].operator int(); + param.squeeze_axis = false; + if (args[1].type_code() == kDLInt) { + param.indices = TShape(0, 0); + param.sections = args[1].operator int(); + CHECK_GT(param.sections, 0) << "ValueError: number sections must be larger than 0"; + } else { + TShape t = TShape(args[1].operator ObjectRef()); + param.indices = TShape(t.ndim() + 1, 0); + for (int i = 0; i < t.ndim(); ++i) { + param.indices[i + 1] = t[i]; + } + param.sections = 0; + } + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; + int num_inputs = 1; + int num_outputs = 0; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + std::vector ndarray_handles; + ndarray_handles.reserve(num_outputs); + for (int i = 0; i < num_outputs; ++i) { + ndarray_handles.emplace_back(ndoutputs[i]); + } + *ret = ADT(0, ndarray_handles.begin(), ndarray_handles.end()); + }); MXNET_REGISTER_API("_npi.dsplit") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - static const nnvm::Op* op = Op::Get("_npi_split"); - int num_inputs = 1; - NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; - CHECK_GE(inputs[0]->shape().ndim(), 3) - << "ValueError: dsplit only works on arrays of 3 or more dimensions"; - nnvm::NodeAttrs attrs; - op::SplitParam param; - param.axis = 2; - param.squeeze_axis = false; - if (args[1].type_code() == kDLInt) { - param.indices = TShape(0, 0); - param.sections = args[1].operator int(); - CHECK_EQ(inputs[0]->shape()[2] % param.sections, 0) - << "ValueError: array split does not result in an equal division"; - CHECK_GT(param.sections, 0) - << "ValueError: number sections must be larger than 0"; - } else { - TShape t = TShape(args[1].operator ObjectRef()); - param.indices = TShape(t.ndim() + 1, 0); - for (int i = 0; i < t.ndim(); ++i) { - param.indices[i + 1] = t[i]; - } - param.sections = 0; - } - attrs.parsed = param; - attrs.op = op; - SetAttrDict(&attrs); - int num_outputs = 0; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); - std::vector ndarray_handles; - ndarray_handles.reserve(num_outputs); - for (int i = 0; i < num_outputs; ++i) { - ndarray_handles.emplace_back(ndoutputs[i]); - } - *ret = ADT(0, ndarray_handles.begin(), ndarray_handles.end()); -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + static const nnvm::Op* op = Op::Get("_npi_split"); + int num_inputs = 1; + NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; + CHECK_GE(inputs[0]->shape().ndim(), 3) + << "ValueError: dsplit only works on arrays of 3 or more dimensions"; + nnvm::NodeAttrs attrs; + op::SplitParam param; + param.axis = 2; + param.squeeze_axis = false; + if (args[1].type_code() == kDLInt) { + param.indices = TShape(0, 0); + param.sections = args[1].operator int(); + CHECK_EQ(inputs[0]->shape()[2] % param.sections, 0) + << "ValueError: array split does not result in an equal division"; + CHECK_GT(param.sections, 0) << "ValueError: number sections must be larger than 0"; + } else { + TShape t = TShape(args[1].operator ObjectRef()); + param.indices = TShape(t.ndim() + 1, 0); + for (int i = 0; i < t.ndim(); ++i) { + param.indices[i + 1] = t[i]; + } + param.sections = 0; + } + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + int num_outputs = 0; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + std::vector ndarray_handles; + ndarray_handles.reserve(num_outputs); + for (int i = 0; i < num_outputs; ++i) { + ndarray_handles.emplace_back(ndoutputs[i]); + } + *ret = ADT(0, ndarray_handles.begin(), ndarray_handles.end()); + }); MXNET_REGISTER_API("_npi.hsplit") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - static const nnvm::Op* op = Op::Get("_npi_hsplit"); - int num_inputs = 1; - NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; - CHECK_GE(inputs[0]->shape().ndim(), 1) - << "ValueError: hsplit only works on arrays of 1 or more dimensions"; - nnvm::NodeAttrs attrs; - op::SplitParam param; - param.axis = 0; - param.squeeze_axis = false; - if (args[1].type_code() == kDLInt) { - param.indices = TShape(0, 0); - param.sections = args[1].operator int(); - CHECK_GT(param.sections, 0) - << "ValueError: number sections must be larger than 0"; - } else { - TShape t = TShape(args[1].operator ObjectRef()); - param.indices = TShape(t.ndim() + 1, 0); - for (int i = 0; i < t.ndim(); ++i) { - param.indices[i + 1] = t[i]; - } - param.sections = 0; - } - attrs.parsed = param; - attrs.op = op; - SetAttrDict(&attrs); - int num_outputs = 0; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); - std::vector ndarray_handles; - ndarray_handles.reserve(num_outputs); - for (int i = 0; i < num_outputs; ++i) { - ndarray_handles.emplace_back(ndoutputs[i]); - } - *ret = ADT(0, ndarray_handles.begin(), ndarray_handles.end()); -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + static const nnvm::Op* op = Op::Get("_npi_hsplit"); + int num_inputs = 1; + NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; + CHECK_GE(inputs[0]->shape().ndim(), 1) + << "ValueError: hsplit only works on arrays of 1 or more dimensions"; + nnvm::NodeAttrs attrs; + op::SplitParam param; + param.axis = 0; + param.squeeze_axis = false; + if (args[1].type_code() == kDLInt) { + param.indices = TShape(0, 0); + param.sections = args[1].operator int(); + CHECK_GT(param.sections, 0) << "ValueError: number sections must be larger than 0"; + } else { + TShape t = TShape(args[1].operator ObjectRef()); + param.indices = TShape(t.ndim() + 1, 0); + for (int i = 0; i < t.ndim(); ++i) { + param.indices[i + 1] = t[i]; + } + param.sections = 0; + } + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + int num_outputs = 0; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + std::vector ndarray_handles; + ndarray_handles.reserve(num_outputs); + for (int i = 0; i < num_outputs; ++i) { + ndarray_handles.emplace_back(ndoutputs[i]); + } + *ret = ADT(0, ndarray_handles.begin(), ndarray_handles.end()); + }); MXNET_REGISTER_API("_npi.vsplit") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - static const nnvm::Op* op = Op::Get("_npi_split"); - int num_inputs = 1; - NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; - CHECK_GE(inputs[0]->shape().ndim(), 2) - << "ValueError: vsplit only works on arrays of 2 or more dimensions"; - nnvm::NodeAttrs attrs; - op::SplitParam param; - param.axis = 0; - param.squeeze_axis = false; - if (args[1].type_code() == kDLInt) { - param.indices = TShape(0, 0); - param.sections = args[1].operator int(); - CHECK_EQ(inputs[0]->shape()[0] % param.sections, 0) - << "ValueError: array split does not result in an equal division"; - CHECK_GT(param.sections, 0) - << "ValueError: number sections must be larger than 0"; - } else { - TShape t = TShape(args[1].operator ObjectRef()); - param.indices = TShape(t.ndim() + 1, 0); - for (int i = 0; i < t.ndim(); ++i) { - param.indices[i + 1] = t[i]; - } - param.sections = 0; - } - attrs.parsed = param; - attrs.op = op; - SetAttrDict(&attrs); - int num_outputs = 0; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); - std::vector ndarray_handles; - ndarray_handles.reserve(num_outputs); - for (int i = 0; i < num_outputs; ++i) { - ndarray_handles.emplace_back(ndoutputs[i]); - } - *ret = ADT(0, ndarray_handles.begin(), ndarray_handles.end()); -}); - -MXNET_REGISTER_API("_npi.diag") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + static const nnvm::Op* op = Op::Get("_npi_split"); + int num_inputs = 1; + NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; + CHECK_GE(inputs[0]->shape().ndim(), 2) + << "ValueError: vsplit only works on arrays of 2 or more dimensions"; + nnvm::NodeAttrs attrs; + op::SplitParam param; + param.axis = 0; + param.squeeze_axis = false; + if (args[1].type_code() == kDLInt) { + param.indices = TShape(0, 0); + param.sections = args[1].operator int(); + CHECK_EQ(inputs[0]->shape()[0] % param.sections, 0) + << "ValueError: array split does not result in an equal division"; + CHECK_GT(param.sections, 0) << "ValueError: number sections must be larger than 0"; + } else { + TShape t = TShape(args[1].operator ObjectRef()); + param.indices = TShape(t.ndim() + 1, 0); + for (int i = 0; i < t.ndim(); ++i) { + param.indices[i + 1] = t[i]; + } + param.sections = 0; + } + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + int num_outputs = 0; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + std::vector ndarray_handles; + ndarray_handles.reserve(num_outputs); + for (int i = 0; i < num_outputs; ++i) { + ndarray_handles.emplace_back(ndoutputs[i]); + } + *ret = ADT(0, ndarray_handles.begin(), ndarray_handles.end()); + }); + +MXNET_REGISTER_API("_npi.diag").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { using namespace runtime; const nnvm::Op* op = Op::Get("_npi_diag"); nnvm::NodeAttrs attrs; @@ -498,211 +485,211 @@ MXNET_REGISTER_API("_npi.diag") else param.k = args[1].operator int(); attrs.parsed = param; - attrs.op = op; + attrs.op = op; SetAttrDict(&attrs); NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; - int num_inputs = 1; - int num_outputs = 0; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); - *ret = ndoutputs[0]; + int num_inputs = 1; + int num_outputs = 0; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = ndoutputs[0]; }); MXNET_REGISTER_API("_npi.rollaxis") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_rollaxis"); - nnvm::NodeAttrs attrs; - op::NumpyRollaxisParam param; - param.axis = args[1].operator int(); - param.start = args[2].operator int(); - attrs.parsed = param; - attrs.op = op; - SetAttrDict(&attrs); - NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; - int num_inputs = 1; - int num_outputs = 0; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); - *ret = ndoutputs[0]; -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_rollaxis"); + nnvm::NodeAttrs attrs; + op::NumpyRollaxisParam param; + param.axis = args[1].operator int(); + param.start = args[2].operator int(); + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; + int num_inputs = 1; + int num_outputs = 0; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = ndoutputs[0]; + }); MXNET_REGISTER_API("_npi.reshape") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_reshape"); - nnvm::NodeAttrs attrs; - op::NumpyXReshapeParam param; - if (args[1].type_code() == kNull) { - param.newshape = TShape(-1, 0); - } else if (args[1].type_code() == kDLInt) { - param.newshape = TShape(1, args[1].operator int64_t()); - } else { - param.newshape = TShape(args[1].operator ObjectRef()); - } - param.reverse = args[2].operator bool(); - param.order = args[3].operator std::string(); - attrs.parsed = param; - attrs.op = op; - SetAttrDict(&attrs); - NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; - int num_inputs = 1; - int num_outputs = 0; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); - *ret = ndoutputs[0]; -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_reshape"); + nnvm::NodeAttrs attrs; + op::NumpyXReshapeParam param; + if (args[1].type_code() == kNull) { + param.newshape = TShape(-1, 0); + } else if (args[1].type_code() == kDLInt) { + param.newshape = TShape(1, args[1].operator int64_t()); + } else { + param.newshape = TShape(args[1].operator ObjectRef()); + } + param.reverse = args[2].operator bool(); + param.order = args[3].operator std::string(); + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; + int num_inputs = 1; + int num_outputs = 0; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = ndoutputs[0]; + }); MXNET_REGISTER_API("_npi.moveaxis") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_moveaxis"); - nnvm::NodeAttrs attrs; - op::NumpyMoveaxisParam param; - if (args[1].type_code() == kNull) { - param.source = TShape(-1, 0); - } else if (args[1].type_code() == kDLInt) { - param.source = TShape(1, args[1].operator int64_t()); - } else { - param.source = TShape(args[1].operator ObjectRef()); - } - if (args[2].type_code() == kNull) { - param.destination = TShape(-1, 0); - } else if (args[2].type_code() == kDLInt) { - param.destination = TShape(1, args[2].operator int64_t()); - } else { - param.destination = TShape(args[2].operator ObjectRef()); - } - attrs.parsed = param; - attrs.op = op; - SetAttrDict(&attrs); - NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; - int num_inputs = 1; - int num_outputs = 0; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); - *ret = ndoutputs[0]; -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_moveaxis"); + nnvm::NodeAttrs attrs; + op::NumpyMoveaxisParam param; + if (args[1].type_code() == kNull) { + param.source = TShape(-1, 0); + } else if (args[1].type_code() == kDLInt) { + param.source = TShape(1, args[1].operator int64_t()); + } else { + param.source = TShape(args[1].operator ObjectRef()); + } + if (args[2].type_code() == kNull) { + param.destination = TShape(-1, 0); + } else if (args[2].type_code() == kDLInt) { + param.destination = TShape(1, args[2].operator int64_t()); + } else { + param.destination = TShape(args[2].operator ObjectRef()); + } + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; + int num_inputs = 1; + int num_outputs = 0; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = ndoutputs[0]; + }); MXNET_REGISTER_API("_npi.diagonal") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_diagonal"); - nnvm::NodeAttrs attrs; - op::NumpyDiagonalParam param; - if (features::is_enabled(features::INT64_TENSOR_SIZE)) - param.offset = args[1].operator int64_t(); - else - param.offset = args[1].operator int(); - param.axis1 = args[2].operator int(); - param.axis2 = args[3].operator int(); - attrs.parsed = param; - attrs.op = op; - SetAttrDict(&attrs); - NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; - int num_inputs = 1; - int num_outputs = 0; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); - *ret = ndoutputs[0]; -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_diagonal"); + nnvm::NodeAttrs attrs; + op::NumpyDiagonalParam param; + if (features::is_enabled(features::INT64_TENSOR_SIZE)) + param.offset = args[1].operator int64_t(); + else + param.offset = args[1].operator int(); + param.axis1 = args[2].operator int(); + param.axis2 = args[3].operator int(); + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; + int num_inputs = 1; + int num_outputs = 0; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = ndoutputs[0]; + }); MXNET_REGISTER_API("_npi.diag_indices_from") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_diag_indices_from"); - nnvm::NodeAttrs attrs; - attrs.op = op; - NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; - int num_inputs = 1; - int num_outputs = 0; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); - *ret = ndoutputs[0]; -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_diag_indices_from"); + nnvm::NodeAttrs attrs; + attrs.op = op; + NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; + int num_inputs = 1; + int num_outputs = 0; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = ndoutputs[0]; + }); MXNET_REGISTER_API("_npi.diagflat") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_diagflat"); - nnvm::NodeAttrs attrs; - op::NumpyDiagflatParam param; - param.k = args[1].operator int(); - int num_inputs = 1; - int num_outputs = 0; - attrs.parsed = param; - attrs.op = op; - SetAttrDict(&attrs); - NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); - *ret = ndoutputs[0]; -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_diagflat"); + nnvm::NodeAttrs attrs; + op::NumpyDiagflatParam param; + param.k = args[1].operator int(); + int num_inputs = 1; + int num_outputs = 0; + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = ndoutputs[0]; + }); MXNET_REGISTER_API("_npi.squeeze") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_squeeze"); - nnvm::NodeAttrs attrs; - op::SqueezeParam param; - if (args[1].type_code() == kNull) { - param.axis = dmlc::optional>(); - } else if (args[1].type_code() == kDLInt) { - param.axis = Tuple(1, args[1].operator int64_t()); - } else { - param.axis = Tuple(args[1].operator ObjectRef()); - } - int num_inputs = 1; - int num_outputs = 0; - attrs.parsed = param; - attrs.op = op; - SetAttrDict(&attrs); - NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); - *ret = ndoutputs[0]; -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_squeeze"); + nnvm::NodeAttrs attrs; + op::SqueezeParam param; + if (args[1].type_code() == kNull) { + param.axis = dmlc::optional>(); + } else if (args[1].type_code() == kDLInt) { + param.axis = Tuple(1, args[1].operator int64_t()); + } else { + param.axis = Tuple(args[1].operator ObjectRef()); + } + int num_inputs = 1; + int num_outputs = 0; + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = ndoutputs[0]; + }); MXNET_REGISTER_API("_npi.tril_indices") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_tril_indices"); - nnvm::NodeAttrs attrs; - op::NumpyTrilindicesParam param; - if (features::is_enabled(features::INT64_TENSOR_SIZE)) { - param.n = args[0].operator int64_t(); - param.k = args[1].operator int64_t(); - param.m = args[2].operator int64_t(); - } else { - param.n = args[0].operator int(); - param.k = args[1].operator int(); - param.m = args[2].operator int(); - } - - attrs.parsed = param; - attrs.op = op; - SetAttrDict(&attrs); - - int num_outputs = 0; - auto ndoutputs = Invoke(op, &attrs, 0, nullptr, &num_outputs, nullptr); - std::vector ndarray_handles; - ndarray_handles.reserve(num_outputs); - for (int i = 0; i < num_outputs; ++i) { - ndarray_handles.emplace_back(ndoutputs[i]); - } - *ret = ADT(0, ndarray_handles.begin(), ndarray_handles.end()); -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_tril_indices"); + nnvm::NodeAttrs attrs; + op::NumpyTrilindicesParam param; + if (features::is_enabled(features::INT64_TENSOR_SIZE)) { + param.n = args[0].operator int64_t(); + param.k = args[1].operator int64_t(); + param.m = args[2].operator int64_t(); + } else { + param.n = args[0].operator int(); + param.k = args[1].operator int(); + param.m = args[2].operator int(); + } + + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + + int num_outputs = 0; + auto ndoutputs = Invoke(op, &attrs, 0, nullptr, &num_outputs, nullptr); + std::vector ndarray_handles; + ndarray_handles.reserve(num_outputs); + for (int i = 0; i < num_outputs; ++i) { + ndarray_handles.emplace_back(ndoutputs[i]); + } + *ret = ADT(0, ndarray_handles.begin(), ndarray_handles.end()); + }); MXNET_REGISTER_API("_npi.vstack") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_vstack"); - nnvm::NodeAttrs attrs; - op::NumpyVstackParam param; - param.num_args = args.size(); - - attrs.parsed = param; - attrs.op = op; - SetAttrDict(&attrs); - int num_outputs = 0; - std::vector inputs_vec(args.size(), nullptr); - for (int i = 0; i < args.size(); ++i) { - inputs_vec[i] = args[i].operator mxnet::NDArray*(); - } - NDArray** inputs = inputs_vec.data(); - auto ndoutputs = Invoke(op, &attrs, param.num_args, inputs, &num_outputs, nullptr); - *ret = ndoutputs[0]; -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_vstack"); + nnvm::NodeAttrs attrs; + op::NumpyVstackParam param; + param.num_args = args.size(); + + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + int num_outputs = 0; + std::vector inputs_vec(args.size(), nullptr); + for (int i = 0; i < args.size(); ++i) { + inputs_vec[i] = args[i].operator mxnet::NDArray*(); + } + NDArray** inputs = inputs_vec.data(); + auto ndoutputs = Invoke(op, &attrs, param.num_args, inputs, &num_outputs, nullptr); + *ret = ndoutputs[0]; + }); } // namespace mxnet diff --git a/src/api/operator/numpy/np_memory_op.cc b/src/api/operator/numpy/np_memory_op.cc index 33e5d4cfb7d8..d13cb3d38980 100644 --- a/src/api/operator/numpy/np_memory_op.cc +++ b/src/api/operator/numpy/np_memory_op.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -28,16 +28,16 @@ namespace mxnet { MXNET_REGISTER_API("_npi.share_memory") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_share_memory"); - nnvm::NodeAttrs attrs; - attrs.op = op; - int num_inputs = 2; - int num_outputs = 0; - NDArray* inputs[] = {args[0].operator mxnet::NDArray*(), args[1].operator mxnet::NDArray*()}; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); - *ret = reinterpret_cast(ndoutputs[0]); -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_share_memory"); + nnvm::NodeAttrs attrs; + attrs.op = op; + int num_inputs = 2; + int num_outputs = 0; + NDArray* inputs[] = {args[0].operator mxnet::NDArray*(), args[1].operator mxnet::NDArray*()}; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = reinterpret_cast(ndoutputs[0]); + }); } // namespace mxnet diff --git a/src/api/operator/numpy/np_moments_op.cc b/src/api/operator/numpy/np_moments_op.cc index 45dd45e8f4c9..5cb0cfaf6531 100644 --- a/src/api/operator/numpy/np_moments_op.cc +++ b/src/api/operator/numpy/np_moments_op.cc @@ -29,8 +29,7 @@ namespace mxnet { -MXNET_REGISTER_API("_npi.std") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { +MXNET_REGISTER_API("_npi.std").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { using namespace runtime; const nnvm::Op* op = Op::Get("_npi_std"); op::NumpyMomentsParam param; @@ -70,12 +69,12 @@ MXNET_REGISTER_API("_npi.std") SetAttrDict(&attrs); NDArray* inputs[] = {args[0].operator NDArray*()}; - int num_inputs = 1; + int num_inputs = 1; NDArray* outputs[] = {args[5].operator NDArray*()}; - NDArray** out = (outputs[0] == nullptr) ? nullptr : outputs; - int num_outputs = (outputs[0] != nullptr); - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, out); + NDArray** out = (outputs[0] == nullptr) ? nullptr : outputs; + int num_outputs = (outputs[0] != nullptr); + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, out); if (out) { *ret = PythonArg(5); @@ -84,8 +83,7 @@ MXNET_REGISTER_API("_npi.std") } }); -MXNET_REGISTER_API("_npi.var") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { +MXNET_REGISTER_API("_npi.var").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { using namespace runtime; const nnvm::Op* op = Op::Get("_npi_var"); op::NumpyMomentsParam param; @@ -125,12 +123,12 @@ MXNET_REGISTER_API("_npi.var") SetAttrDict(&attrs); NDArray* inputs[] = {args[0].operator NDArray*()}; - int num_inputs = 1; + int num_inputs = 1; NDArray* outputs[] = {args[5].operator NDArray*()}; - NDArray** out = (outputs[0] == nullptr) ? nullptr : outputs; - int num_outputs = (outputs[0] != nullptr); - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, out); + NDArray** out = (outputs[0] == nullptr) ? nullptr : outputs; + int num_outputs = (outputs[0] != nullptr); + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, out); if (out) { *ret = PythonArg(5); @@ -140,70 +138,66 @@ MXNET_REGISTER_API("_npi.var") }); MXNET_REGISTER_API("_npi.average") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_average"); - op::NumpyWeightedAverageParam param; - nnvm::NodeAttrs attrs; - attrs.op = op; - - // parse axis - if (args[2].type_code() == kNull) { - param.axis = dmlc::nullopt; - } else { - if (args[2].type_code() == kDLInt) { - param.axis = Tuple(1, args[2].operator int64_t()); - } else { - param.axis = Tuple(args[2].operator ObjectRef()); - } - } - - // parse returned - CHECK_NE(args[3].type_code(), kNull) - << "returned cannot be None"; - param.returned = args[3].operator bool(); - - // parse weighted - CHECK_NE(args[4].type_code(), kNull) - << "weighted cannot be None"; - param.weighted = args[4].operator bool(); - - attrs.parsed = param; - - SetAttrDict(&attrs); - - int num_inputs = param.weighted ? 2 : 1; - NDArray* outputs[] = {args[5].operator NDArray*()}; - NDArray** out = (outputs[0] == nullptr) ? nullptr : outputs; - int num_outputs = (outputs[0] != nullptr); - - if (param.weighted) { - NDArray* inputs[] = {args[0].operator NDArray*(), args[1].operator NDArray*()}; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, out); - if (out) { - *ret = PythonArg(5); - } else { - if (param.returned) { - *ret = ADT(0, {NDArrayHandle(ndoutputs[0]), - NDArrayHandle(ndoutputs[1])}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_average"); + op::NumpyWeightedAverageParam param; + nnvm::NodeAttrs attrs; + attrs.op = op; + + // parse axis + if (args[2].type_code() == kNull) { + param.axis = dmlc::nullopt; } else { - *ret = reinterpret_cast(ndoutputs[0]); + if (args[2].type_code() == kDLInt) { + param.axis = Tuple(1, args[2].operator int64_t()); + } else { + param.axis = Tuple(args[2].operator ObjectRef()); + } } - } - } else { - NDArray* inputs[] = {args[0].operator NDArray*()}; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, out); - if (out) { - *ret = PythonArg(5); - } else { - if (param.returned) { - *ret = ADT(0, {NDArrayHandle(ndoutputs[0]), - NDArrayHandle(ndoutputs[1])}); + + // parse returned + CHECK_NE(args[3].type_code(), kNull) << "returned cannot be None"; + param.returned = args[3].operator bool(); + + // parse weighted + CHECK_NE(args[4].type_code(), kNull) << "weighted cannot be None"; + param.weighted = args[4].operator bool(); + + attrs.parsed = param; + + SetAttrDict(&attrs); + + int num_inputs = param.weighted ? 2 : 1; + NDArray* outputs[] = {args[5].operator NDArray*()}; + NDArray** out = (outputs[0] == nullptr) ? nullptr : outputs; + int num_outputs = (outputs[0] != nullptr); + + if (param.weighted) { + NDArray* inputs[] = {args[0].operator NDArray*(), args[1].operator NDArray*()}; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, out); + if (out) { + *ret = PythonArg(5); + } else { + if (param.returned) { + *ret = ADT(0, {NDArrayHandle(ndoutputs[0]), NDArrayHandle(ndoutputs[1])}); + } else { + *ret = reinterpret_cast(ndoutputs[0]); + } + } } else { - *ret = reinterpret_cast(ndoutputs[0]); + NDArray* inputs[] = {args[0].operator NDArray*()}; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, out); + if (out) { + *ret = PythonArg(5); + } else { + if (param.returned) { + *ret = ADT(0, {NDArrayHandle(ndoutputs[0]), NDArrayHandle(ndoutputs[1])}); + } else { + *ret = reinterpret_cast(ndoutputs[0]); + } + } } - } - } -}); + }); }; // namespace mxnet diff --git a/src/api/operator/numpy/np_nan_to_num_op.cc b/src/api/operator/numpy/np_nan_to_num_op.cc index 65fd26e5432e..804d757a035b 100644 --- a/src/api/operator/numpy/np_nan_to_num_op.cc +++ b/src/api/operator/numpy/np_nan_to_num_op.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -29,44 +29,44 @@ namespace mxnet { MXNET_REGISTER_API("_npi.nan_to_num") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_nan_to_num"); - nnvm::NodeAttrs attrs; + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_nan_to_num"); + nnvm::NodeAttrs attrs; - op::NumpyNanToNumParam param; - int num_inputs = 1; - NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; + op::NumpyNanToNumParam param; + int num_inputs = 1; + NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; - param.copy = args[1].operator bool(); - param.nan = args[2].operator double(); + param.copy = args[1].operator bool(); + param.nan = args[2].operator double(); - if (args[3].type_code() == kNull) { - param.posinf = dmlc::nullopt; - } else { - param.posinf = args[3].operator double(); - } + if (args[3].type_code() == kNull) { + param.posinf = dmlc::nullopt; + } else { + param.posinf = args[3].operator double(); + } - if (args[4].type_code() == kNull) { - param.neginf = dmlc::nullopt; - } else { - param.neginf = args[4].operator double(); - } + if (args[4].type_code() == kNull) { + param.neginf = dmlc::nullopt; + } else { + param.neginf = args[4].operator double(); + } - attrs.parsed = param; - attrs.op = op; - SetAttrDict(&attrs); + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); - NDArray* out = args[5].operator mxnet::NDArray*(); - NDArray** outputs = out == nullptr ? nullptr : &out; - // set the number of outputs provided by the `out` arugment - int num_outputs = out != nullptr; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs); - if (out) { - *ret = PythonArg(5); - } else { - *ret = ndoutputs[0]; - } -}); + NDArray* out = args[5].operator mxnet::NDArray*(); + NDArray** outputs = out == nullptr ? nullptr : &out; + // set the number of outputs provided by the `out` arugment + int num_outputs = out != nullptr; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs); + if (out) { + *ret = PythonArg(5); + } else { + *ret = ndoutputs[0]; + } + }); } // namespace mxnet diff --git a/src/api/operator/numpy/np_nonzero_op.cc b/src/api/operator/numpy/np_nonzero_op.cc index 85510633c054..7558b4fadad4 100644 --- a/src/api/operator/numpy/np_nonzero_op.cc +++ b/src/api/operator/numpy/np_nonzero_op.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -28,18 +28,18 @@ namespace mxnet { MXNET_REGISTER_API("_npi.nonzero") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_nonzero"); - nnvm::NodeAttrs attrs; + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_nonzero"); + nnvm::NodeAttrs attrs; - attrs.op = op; + attrs.op = op; - int num_inputs = 1; - int num_outputs = 0; - NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); - *ret = ndoutputs[0]; -}); + int num_inputs = 1; + int num_outputs = 0; + NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = ndoutputs[0]; + }); } // namespace mxnet diff --git a/src/api/operator/numpy/np_ordering_op.cc b/src/api/operator/numpy/np_ordering_op.cc index ec0db28b4f9a..11c00fbfb71e 100644 --- a/src/api/operator/numpy/np_ordering_op.cc +++ b/src/api/operator/numpy/np_ordering_op.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -28,8 +28,7 @@ namespace mxnet { -MXNET_REGISTER_API("_npi.sort") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { +MXNET_REGISTER_API("_npi.sort").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { using namespace runtime; const nnvm::Op* op = Op::Get("_npi_sort"); nnvm::NodeAttrs attrs; @@ -43,46 +42,46 @@ MXNET_REGISTER_API("_npi.sort") param.is_ascend = true; attrs.parsed = std::move(param); - attrs.op = op; + attrs.op = op; - int num_inputs = 1; + int num_inputs = 1; NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; int num_outputs = 0; SetAttrDict(&attrs); auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); - *ret = reinterpret_cast(ndoutputs[0]); + *ret = reinterpret_cast(ndoutputs[0]); }); MXNET_REGISTER_API("_npi.argsort") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_argsort"); - nnvm::NodeAttrs attrs; - op::ArgSortParam param; + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_argsort"); + nnvm::NodeAttrs attrs; + op::ArgSortParam param; - if (args[1].type_code() == kNull) { - param.axis = dmlc::nullopt; - } else { - param.axis = args[1].operator int(); - } - param.is_ascend = true; - if (args[3].type_code() == kNull) { - param.dtype = mshadow::kFloat32; - } else { - param.dtype = String2MXNetTypeWithBool(args[3].operator std::string()); - } + if (args[1].type_code() == kNull) { + param.axis = dmlc::nullopt; + } else { + param.axis = args[1].operator int(); + } + param.is_ascend = true; + if (args[3].type_code() == kNull) { + param.dtype = mshadow::kFloat32; + } else { + param.dtype = String2MXNetTypeWithBool(args[3].operator std::string()); + } - attrs.parsed = std::move(param); - attrs.op = op; + attrs.parsed = std::move(param); + attrs.op = op; - int num_inputs = 1; - NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; + int num_inputs = 1; + NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; - int num_outputs = 0; - SetAttrDict(&attrs); - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); - *ret = reinterpret_cast(ndoutputs[0]); -}); + int num_outputs = 0; + SetAttrDict(&attrs); + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = reinterpret_cast(ndoutputs[0]); + }); } // namespace mxnet diff --git a/src/api/operator/numpy/np_pad_op.cc b/src/api/operator/numpy/np_pad_op.cc index 1351c26dd5ba..4f3b46cf0a28 100644 --- a/src/api/operator/numpy/np_pad_op.cc +++ b/src/api/operator/numpy/np_pad_op.cc @@ -7,9 +7,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -68,7 +68,7 @@ inline Tuple> BroadcastPadWidth(int ndim, runtime::ADT adt) { } else { CHECK_EQ(adt_size, 2) << "Invalid Input pad_width"; int pad_before = static_cast(pad->value); - int pad_after = static_cast(Downcast(adt[1])->value); + int pad_after = static_cast(Downcast(adt[1])->value); if (ndim == 1) { temp.emplace_back(mxnet::Tuple({pad_before})); temp.emplace_back(mxnet::Tuple({pad_after})); @@ -82,10 +82,8 @@ inline Tuple> BroadcastPadWidth(int ndim, runtime::ADT adt) { if (adt_size == 1) { if (ndim == 1) { runtime::ADT pad_adt = Downcast(adt[0]); - int pad_before = - static_cast(Downcast(pad_adt[0])->value); - int pad_after = - static_cast(Downcast(pad_adt[1])->value); + int pad_before = static_cast(Downcast(pad_adt[0])->value); + int pad_after = static_cast(Downcast(pad_adt[1])->value); temp.emplace_back(mxnet::Tuple({pad_before})); temp.emplace_back(mxnet::Tuple({pad_after})); } else { @@ -103,32 +101,31 @@ inline Tuple> BroadcastPadWidth(int ndim, runtime::ADT adt) { return Tuple>(temp.begin(), temp.end()); } -MXNET_REGISTER_API("_npi.pad") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { +MXNET_REGISTER_API("_npi.pad").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { using namespace runtime; const nnvm::Op* op = Op::Get("_npi_pad"); nnvm::NodeAttrs attrs; op::NumpyPadParam param; - NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; + NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; mxnet::TShape ashape = inputs[0]->shape(); - int ndim = ashape.ndim(); - ADT adt = Downcast(args[1].operator ObjectRef()); + int ndim = ashape.ndim(); + ADT adt = Downcast(args[1].operator ObjectRef()); // broadcast pad_width to (ndim, 2) param.pad_width = BroadcastPadWidth(ndim, adt); - param.mode = String2MXNetPadType(args[2].operator std::string()); + param.mode = String2MXNetPadType(args[2].operator std::string()); if (args[3].type_code() != kNull) { param.constant_values = args[3].operator double(); } if (args[4].type_code() != kNull) { param.reflect_type = args[4].operator std::string(); } - attrs.op = op; + attrs.op = op; attrs.parsed = param; SetAttrDict(&attrs); - int num_inputs = 1; + int num_inputs = 1; int num_outputs = 0; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); - *ret = reinterpret_cast(ndoutputs[0]); + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = reinterpret_cast(ndoutputs[0]); }); } // namespace mxnet diff --git a/src/api/operator/numpy/np_percentile_op.cc b/src/api/operator/numpy/np_percentile_op.cc index 196cca9baaf9..fd311c73aeb3 100644 --- a/src/api/operator/numpy/np_percentile_op.cc +++ b/src/api/operator/numpy/np_percentile_op.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -38,7 +38,7 @@ inline int String2MXNetPercentileType(const std::string& s) { return percentile_enum::kHigher; } else if (s == "midpoint") { return percentile_enum::kMidpoint; - } else if (s== "nearest") { + } else if (s == "nearest") { return percentile_enum::kNearest; } else { LOG(FATAL) << "unknown type " << s; @@ -48,51 +48,52 @@ inline int String2MXNetPercentileType(const std::string& s) { } MXNET_REGISTER_API("_npi.percentile") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_percentile"); - nnvm::NodeAttrs attrs; - op::NumpyPercentileParam param; + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_percentile"); + nnvm::NodeAttrs attrs; + op::NumpyPercentileParam param; - NDArray* out = args[5].operator mxnet::NDArray*(); - NDArray** outputs = out == nullptr ? nullptr : &out; - int num_outputs = out != nullptr; - if (args[2].type_code() == kNull) { - param.axis = dmlc::nullopt; - } else if (args[2].type_code() == kDLInt) { - param.axis = Tuple(1, args[2].operator int64_t()); - } else { - param.axis = Tuple(args[2].operator ObjectRef()); - } - param.interpolation = String2MXNetPercentileType(args[3].operator std::string()); - param.keepdims = args[4].operator bool(); - if (args[1].type_code() == kDLInt || args[1].type_code() == kDLFloat) { - param.q_scalar = args[1].operator double(); - NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; - int num_inputs = 1; - attrs.parsed = param; - attrs.op = op; - SetAttrDict(&attrs); - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs); - if (out) { - *ret = PythonArg(5); - } else { - *ret = reinterpret_cast(ndoutputs[0]); - } - } else { - param.q_scalar = dmlc::nullopt; - NDArray* inputs[] = {args[0].operator mxnet::NDArray*(), args[1].operator mxnet::NDArray*()}; - int num_inputs = 2; - attrs.parsed = param; - attrs.op = op; - SetAttrDict(&attrs); - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs); - if (out) { - *ret = PythonArg(5); - } else { - *ret = reinterpret_cast(ndoutputs[0]); - } - } -}); + NDArray* out = args[5].operator mxnet::NDArray*(); + NDArray** outputs = out == nullptr ? nullptr : &out; + int num_outputs = out != nullptr; + if (args[2].type_code() == kNull) { + param.axis = dmlc::nullopt; + } else if (args[2].type_code() == kDLInt) { + param.axis = Tuple(1, args[2].operator int64_t()); + } else { + param.axis = Tuple(args[2].operator ObjectRef()); + } + param.interpolation = String2MXNetPercentileType(args[3].operator std::string()); + param.keepdims = args[4].operator bool(); + if (args[1].type_code() == kDLInt || args[1].type_code() == kDLFloat) { + param.q_scalar = args[1].operator double(); + NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; + int num_inputs = 1; + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs); + if (out) { + *ret = PythonArg(5); + } else { + *ret = reinterpret_cast(ndoutputs[0]); + } + } else { + param.q_scalar = dmlc::nullopt; + NDArray* inputs[] = {args[0].operator mxnet::NDArray*(), + args[1].operator mxnet::NDArray*()}; + int num_inputs = 2; + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs); + if (out) { + *ret = PythonArg(5); + } else { + *ret = reinterpret_cast(ndoutputs[0]); + } + } + }); } // namespace mxnet diff --git a/src/api/operator/numpy/np_polynomial_op.cc b/src/api/operator/numpy/np_polynomial_op.cc index 87081d2952ca..749cf6b859d2 100644 --- a/src/api/operator/numpy/np_polynomial_op.cc +++ b/src/api/operator/numpy/np_polynomial_op.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -28,17 +28,17 @@ namespace mxnet { MXNET_REGISTER_API("_npi.polyval") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_polyval"); - nnvm::NodeAttrs attrs; - attrs.op = op; + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_polyval"); + nnvm::NodeAttrs attrs; + attrs.op = op; - NDArray* inputs[] = {args[0].operator mxnet::NDArray*(), args[1].operator mxnet::NDArray*()}; - int num_inputs = 2; - int num_outputs = 0; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); - *ret = ndoutputs[0]; -}); + NDArray* inputs[] = {args[0].operator mxnet::NDArray*(), args[1].operator mxnet::NDArray*()}; + int num_inputs = 2; + int num_outputs = 0; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = ndoutputs[0]; + }); } // namespace mxnet diff --git a/src/api/operator/numpy/np_repeat_op.cc b/src/api/operator/numpy/np_repeat_op.cc index c98a1711050a..c7bed2b3ec69 100644 --- a/src/api/operator/numpy/np_repeat_op.cc +++ b/src/api/operator/numpy/np_repeat_op.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -28,25 +28,25 @@ namespace mxnet { MXNET_REGISTER_API("_npi.repeats") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_repeats"); - nnvm::NodeAttrs attrs; - op::RepeatsParam param; - param.repeats = Tuple(args[1].operator ObjectRef());; - if (args[2].type_code() == kNull) { - param.axis = dmlc::optional(); - } else { - param.axis = args[2].operator int64_t(); - } - int num_inputs = 1; - int num_outputs = 0; - attrs.parsed = param; - attrs.op = op; - SetAttrDict(&attrs); - NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); - *ret = ndoutputs[0]; -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_repeats"); + nnvm::NodeAttrs attrs; + op::RepeatsParam param; + param.repeats = Tuple(args[1].operator ObjectRef()); + if (args[2].type_code() == kNull) { + param.axis = dmlc::optional(); + } else { + param.axis = args[2].operator int64_t(); + } + int num_inputs = 1; + int num_outputs = 0; + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = ndoutputs[0]; + }); } // namespace mxnet diff --git a/src/api/operator/numpy/np_tensordot_op.cc b/src/api/operator/numpy/np_tensordot_op.cc index 0cc74d9355e1..cf1c0fc0fefb 100644 --- a/src/api/operator/numpy/np_tensordot_op.cc +++ b/src/api/operator/numpy/np_tensordot_op.cc @@ -27,26 +27,24 @@ namespace mxnet { -inline static void _npi_tensordot_int_axes(runtime::MXNetArgs args, - runtime::MXNetRetValue* ret) { +inline static void _npi_tensordot_int_axes(runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { using namespace runtime; const nnvm::Op* op = Op::Get("_npi_tensordot_int_axes"); op::TensordotIntAxesParam param; nnvm::NodeAttrs attrs; param.axes = args[2].operator int(); - attrs.op = op; + attrs.op = op; // we directly copy TensordotIntAxesParam, which is trivially-copyable attrs.parsed = param; SetAttrDict(&attrs); - int num_outputs = 0; - int num_inputs = 2; + int num_outputs = 0; + int num_inputs = 2; NDArray* inputs[] = {args[0].operator mxnet::NDArray*(), args[1].operator mxnet::NDArray*()}; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); - *ret = reinterpret_cast(ndoutputs[0]); + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = reinterpret_cast(ndoutputs[0]); } -inline static void _npi_tensordot(runtime::MXNetArgs args, - runtime::MXNetRetValue* ret) { +inline static void _npi_tensordot(runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { using namespace runtime; const nnvm::Op* op = Op::Get("_npi_tensordot"); op::TensordotParam param; @@ -61,23 +59,23 @@ inline static void _npi_tensordot(runtime::MXNetArgs args, param.a_axes_summed = Tuple(adt[0]); param.b_axes_summed = Tuple(adt[1]); } - attrs.op = op; + attrs.op = op; attrs.parsed = param; SetAttrDict(&attrs); - int num_outputs = 0; - int num_inputs = 2; + int num_outputs = 0; + int num_inputs = 2; NDArray* inputs[] = {args[0].operator mxnet::NDArray*(), args[1].operator mxnet::NDArray*()}; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); - *ret = reinterpret_cast(ndoutputs[0]); + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = reinterpret_cast(ndoutputs[0]); } MXNET_REGISTER_API("_npi.tensordot") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - if (args[2].type_code() == kDLInt) { - _npi_tensordot_int_axes(args, ret); - } else { - _npi_tensordot(args, ret); - } -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + if (args[2].type_code() == kDLInt) { + _npi_tensordot_int_axes(args, ret); + } else { + _npi_tensordot(args, ret); + } + }); } // namespace mxnet diff --git a/src/api/operator/numpy/np_trace_op.cc b/src/api/operator/numpy/np_trace_op.cc index 2979d21dbdc9..125f96d2d01e 100644 --- a/src/api/operator/numpy/np_trace_op.cc +++ b/src/api/operator/numpy/np_trace_op.cc @@ -28,24 +28,23 @@ namespace mxnet { -MXNET_REGISTER_API("_npi.trace") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { +MXNET_REGISTER_API("_npi.trace").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { using namespace runtime; const nnvm::Op* op = Op::Get("_npi_trace"); nnvm::NodeAttrs attrs; op::NumpyTraceParam param; param.offset = args[1].operator int64_t(); - param.axis1 = args[2].operator int64_t(); - param.axis2 = args[3].operator int64_t(); + param.axis1 = args[2].operator int64_t(); + param.axis2 = args[3].operator int64_t(); attrs.parsed = param; - attrs.op = op; + attrs.op = op; SetAttrDict(&attrs); NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; - int num_inputs = 1; - NDArray* out = args[4].operator mxnet::NDArray*(); + int num_inputs = 1; + NDArray* out = args[4].operator mxnet::NDArray*(); NDArray** outputs = out == nullptr ? nullptr : &out; - int num_outputs = out != nullptr; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs); + int num_outputs = out != nullptr; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs); if (out) { *ret = PythonArg(4); } else { diff --git a/src/api/operator/numpy/np_tri_op.cc b/src/api/operator/numpy/np_tri_op.cc index 972d7864493a..915c68ca4eb0 100644 --- a/src/api/operator/numpy/np_tri_op.cc +++ b/src/api/operator/numpy/np_tri_op.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -27,8 +27,7 @@ namespace mxnet { -MXNET_REGISTER_API("_npi.tri") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { +MXNET_REGISTER_API("_npi.tri").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { using namespace runtime; const nnvm::Op* op = Op::Get("_npi_tri"); nnvm::NodeAttrs attrs; @@ -39,20 +38,21 @@ MXNET_REGISTER_API("_npi.tri") } else { param.M = args[1].operator nnvm::dim_t(); } - param.k = args[2].operator int(); - param.dtype = args[3].type_code() == kNull ? mshadow::kFloat32 : - String2MXNetTypeWithBool(args[3].operator std::string()); + param.k = args[2].operator int(); + param.dtype = args[3].type_code() == kNull + ? mshadow::kFloat32 + : String2MXNetTypeWithBool(args[3].operator std::string()); if (args[4].type_code() != kNull) { attrs.dict["ctx"] = args[4].operator std::string(); } attrs.parsed = param; - attrs.op = op; + attrs.op = op; SetAttrDict(&attrs); int num_outputs = 0; - auto ndoutputs = Invoke(op, &attrs, 0, nullptr, &num_outputs, nullptr); - *ret = ndoutputs[0]; + auto ndoutputs = Invoke(op, &attrs, 0, nullptr, &num_outputs, nullptr); + *ret = ndoutputs[0]; }); } // namespace mxnet diff --git a/src/api/operator/numpy/np_tril_op.cc b/src/api/operator/numpy/np_tril_op.cc index 1acb1b8e4b10..8388797ad24a 100644 --- a/src/api/operator/numpy/np_tril_op.cc +++ b/src/api/operator/numpy/np_tril_op.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -27,8 +27,7 @@ namespace mxnet { -MXNET_REGISTER_API("_npi.tril") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { +MXNET_REGISTER_API("_npi.tril").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { using namespace runtime; const nnvm::Op* op = Op::Get("_npi_tril"); nnvm::NodeAttrs attrs; @@ -37,14 +36,14 @@ MXNET_REGISTER_API("_npi.tril") // we directly copy TrilParam, which is trivially-copyable attrs.parsed = param; - attrs.op = op; + attrs.op = op; SetAttrDict(&attrs); - int num_outputs = 0; + int num_outputs = 0; NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; - int num_inputs = 1; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); - *ret = ndoutputs[0]; + int num_inputs = 1; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = ndoutputs[0]; }); } // namespace mxnet diff --git a/src/api/operator/numpy/np_triu_op.cc b/src/api/operator/numpy/np_triu_op.cc index e42169aca43b..8bad12e018a9 100644 --- a/src/api/operator/numpy/np_triu_op.cc +++ b/src/api/operator/numpy/np_triu_op.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -28,23 +28,22 @@ namespace mxnet { -MXNET_REGISTER_API("_npi.triu") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { +MXNET_REGISTER_API("_npi.triu").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { using namespace runtime; op::TriuParam param; nnvm::NodeAttrs attrs; const nnvm::Op* op = Op::Get("_npi_triu"); // inputs - param.k = args[1].operator int(); + param.k = args[1].operator int(); NDArray* inputs[] = {args[0].operator NDArray*()}; - attrs.op = op; + attrs.op = op; attrs.parsed = param; SetAttrDict(&attrs); int num_outputs = 0; - auto ndoutputs = Invoke(op, &attrs, 1, inputs, &num_outputs, nullptr); - *ret = reinterpret_cast(ndoutputs[0]); + auto ndoutputs = Invoke(op, &attrs, 1, inputs, &num_outputs, nullptr); + *ret = reinterpret_cast(ndoutputs[0]); }); } // namespace mxnet diff --git a/src/api/operator/numpy/np_unique_op.cc b/src/api/operator/numpy/np_unique_op.cc index a669025e108f..19f64d714b97 100644 --- a/src/api/operator/numpy/np_unique_op.cc +++ b/src/api/operator/numpy/np_unique_op.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -30,35 +30,35 @@ namespace mxnet { MXNET_REGISTER_API("_npi.unique") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_unique"); - nnvm::NodeAttrs attrs; - op::NumpyUniqueParam param; - // param - param.return_index = args[1].operator bool(); - param.return_inverse = args[2].operator bool(); - param.return_counts = args[3].operator bool(); - if (args[4].type_code() == kNull) { - param.axis = dmlc::nullopt; - } else { - param.axis = args[4].operator int(); - } - attrs.parsed = param; - attrs.op = op; - SetAttrDict(&attrs); - // inputs - int num_inputs = 1; - NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; - // outputs - int num_outputs = 0; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); - std::vector ndarray_handles; - ndarray_handles.reserve(num_outputs); - for (int i = 0; i < num_outputs; ++i) { - ndarray_handles.emplace_back(ndoutputs[i]); - } - *ret = ADT(0, ndarray_handles.begin(), ndarray_handles.end()); -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_unique"); + nnvm::NodeAttrs attrs; + op::NumpyUniqueParam param; + // param + param.return_index = args[1].operator bool(); + param.return_inverse = args[2].operator bool(); + param.return_counts = args[3].operator bool(); + if (args[4].type_code() == kNull) { + param.axis = dmlc::nullopt; + } else { + param.axis = args[4].operator int(); + } + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + // inputs + int num_inputs = 1; + NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; + // outputs + int num_outputs = 0; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + std::vector ndarray_handles; + ndarray_handles.reserve(num_outputs); + for (int i = 0; i < num_outputs; ++i) { + ndarray_handles.emplace_back(ndoutputs[i]); + } + *ret = ADT(0, ndarray_handles.begin(), ndarray_handles.end()); + }); } // namespace mxnet diff --git a/src/api/operator/numpy/np_where_op.cc b/src/api/operator/numpy/np_where_op.cc index aca4e075d62d..8b458a274f6d 100644 --- a/src/api/operator/numpy/np_where_op.cc +++ b/src/api/operator/numpy/np_where_op.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -29,24 +29,21 @@ namespace mxnet { inline static bool isScalar(const runtime::MXNetArgValue& arg) { - return arg.type_code() == kDLInt || - arg.type_code() == kDLUInt || - arg.type_code() == kDLFloat; + return arg.type_code() == kDLInt || arg.type_code() == kDLUInt || arg.type_code() == kDLFloat; } -inline static void _npi_where(runtime::MXNetArgs args, - runtime::MXNetRetValue* ret) { +inline static void _npi_where(runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { using namespace runtime; const nnvm::Op* op = Op::Get("_npi_where"); nnvm::NodeAttrs attrs; - attrs.op = op; - int num_inputs = 3; - int num_outputs = 0; + attrs.op = op; + int num_inputs = 3; + int num_outputs = 0; NDArray* inputs[] = {args[0].operator mxnet::NDArray*(), args[1].operator mxnet::NDArray*(), args[2].operator mxnet::NDArray*()}; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); - *ret = reinterpret_cast(ndoutputs[0]); + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = reinterpret_cast(ndoutputs[0]); } inline static void _npi_where_scalar1(runtime::MXNetArgs args, @@ -57,38 +54,36 @@ inline static void _npi_where_scalar1(runtime::MXNetArgs args, const nnvm::Op* op = isl ? Op::Get("_npi_where_lscalar") : Op::Get("_npi_where_rscalar"); op::NumpyWhereScalarParam param; param.scalar = isl ? args[1].operator double() : args[2].operator double(); - attrs.op = op; + attrs.op = op; attrs.parsed = param; SetAttrDict(&attrs); - int num_inputs = 2; - int num_outputs = 0; - NDArray* inputs[] = - {args[0].operator mxnet::NDArray*(), - isl ? args[2].operator mxnet::NDArray*() : args[1].operator mxnet::NDArray*()}; + int num_inputs = 2; + int num_outputs = 0; + NDArray* inputs[] = { + args[0].operator mxnet::NDArray*(), + isl ? args[2].operator mxnet::NDArray*() : args[1].operator mxnet::NDArray*()}; auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); - *ret = reinterpret_cast(ndoutputs[0]); + *ret = reinterpret_cast(ndoutputs[0]); } -inline static void _npi_where_scalar2(runtime::MXNetArgs args, - runtime::MXNetRetValue* ret) { +inline static void _npi_where_scalar2(runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { using namespace runtime; const nnvm::Op* op = Op::Get("_npi_where_scalar2"); op::NumpyWhereScalar2Param param; nnvm::NodeAttrs attrs; - param.x = args[1].operator double(); - param.y = args[2].operator double(); - attrs.op = op; + param.x = args[1].operator double(); + param.y = args[2].operator double(); + attrs.op = op; attrs.parsed = param; SetAttrDict(&attrs); - int num_inputs = 1; - int num_outputs = 0; + int num_inputs = 1; + int num_outputs = 0; NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); - *ret = reinterpret_cast(ndoutputs[0]); + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = reinterpret_cast(ndoutputs[0]); } -MXNET_REGISTER_API("_npi.where") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { +MXNET_REGISTER_API("_npi.where").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { if (isScalar(args[1]) && isScalar(args[2])) { _npi_where_scalar2(args, ret); } else if (!isScalar(args[1]) && !isScalar(args[2])) { diff --git a/src/api/operator/numpy/np_window_op.cc b/src/api/operator/numpy/np_window_op.cc index 41c78cb16b6d..848f5c64cbe5 100644 --- a/src/api/operator/numpy/np_window_op.cc +++ b/src/api/operator/numpy/np_window_op.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -46,35 +46,35 @@ inline static void SetNumpyWindowsParam(runtime::MXNetArgs args, param.dtype = String2MXNetTypeWithBool(args[1].operator std::string()); } attrs.parsed = param; - attrs.op = op; + attrs.op = op; SetAttrDict(&attrs); if (args[2].type_code() != kNull) { attrs.dict["ctx"] = args[2].operator std::string(); } int num_outputs = 0; - auto ndoutputs = Invoke(op, &attrs, 0, nullptr, &num_outputs, nullptr); - *ret = ndoutputs[0]; + auto ndoutputs = Invoke(op, &attrs, 0, nullptr, &num_outputs, nullptr); + *ret = ndoutputs[0]; } MXNET_REGISTER_API("_npi.blackman") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_blackman"); - SetNumpyWindowsParam(args, ret, op); -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_blackman"); + SetNumpyWindowsParam(args, ret, op); + }); MXNET_REGISTER_API("_npi.hamming") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_hamming"); - SetNumpyWindowsParam(args, ret, op); -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_hamming"); + SetNumpyWindowsParam(args, ret, op); + }); MXNET_REGISTER_API("_npi.hanning") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_hanning"); - SetNumpyWindowsParam(args, ret, op); -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_hanning"); + SetNumpyWindowsParam(args, ret, op); + }); } // namespace mxnet diff --git a/src/api/operator/numpy/random/np_choice_op.cc b/src/api/operator/numpy/random/np_choice_op.cc index bc5ebbcffa58..7f64a697ecaf 100644 --- a/src/api/operator/numpy/random/np_choice_op.cc +++ b/src/api/operator/numpy/random/np_choice_op.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -29,61 +29,61 @@ namespace mxnet { MXNET_REGISTER_API("_npi.choice") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_choice"); - nnvm::NodeAttrs attrs; - op::NumpyChoiceParam param; + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_choice"); + nnvm::NodeAttrs attrs; + op::NumpyChoiceParam param; - NDArray* inputs[2]; - int num_inputs = 0; + NDArray* inputs[2]; + int num_inputs = 0; - if (args[0].type_code() == kDLInt) { - param.a = args[0].operator int(); - } else if (args[0].type_code() == kNDArrayHandle) { - param.a = dmlc::nullopt; - inputs[num_inputs] = args[0].operator mxnet::NDArray*(); - num_inputs++; - } + if (args[0].type_code() == kDLInt) { + param.a = args[0].operator int(); + } else if (args[0].type_code() == kNDArrayHandle) { + param.a = dmlc::nullopt; + inputs[num_inputs] = args[0].operator mxnet::NDArray*(); + num_inputs++; + } - if (args[1].type_code() == kNull) { - param.size = dmlc::nullopt; - } else { - if (args[1].type_code() == kDLInt) { - param.size = mxnet::Tuple(1, args[1].operator int64_t()); - } else { - param.size = mxnet::Tuple(args[1].operator ObjectRef()); - } - } + if (args[1].type_code() == kNull) { + param.size = dmlc::nullopt; + } else { + if (args[1].type_code() == kDLInt) { + param.size = mxnet::Tuple(1, args[1].operator int64_t()); + } else { + param.size = mxnet::Tuple(args[1].operator ObjectRef()); + } + } - if (args[2].type_code() == kNull) { - param.replace = true; - } else { - param.replace = args[2].operator bool(); - } + if (args[2].type_code() == kNull) { + param.replace = true; + } else { + param.replace = args[2].operator bool(); + } - if (args[3].type_code() == kNull) { - param.weighted = false; - } else if (args[0].type_code() == kNDArrayHandle) { - param.weighted = true; - inputs[num_inputs] = args[3].operator mxnet::NDArray*(); - num_inputs++; - } + if (args[3].type_code() == kNull) { + param.weighted = false; + } else if (args[0].type_code() == kNDArrayHandle) { + param.weighted = true; + inputs[num_inputs] = args[3].operator mxnet::NDArray*(); + num_inputs++; + } - attrs.parsed = param; - attrs.op = op; - if (args[4].type_code() != kNull) { - attrs.dict["ctx"] = args[4].operator std::string(); - } - NDArray* out = args[5].operator mxnet::NDArray*(); - NDArray** outputs = out == nullptr ? nullptr : &out; - int num_outputs = out != nullptr; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs); - if (out) { - *ret = PythonArg(5); - } else { - *ret = ndoutputs[0]; - } -}); + attrs.parsed = param; + attrs.op = op; + if (args[4].type_code() != kNull) { + attrs.dict["ctx"] = args[4].operator std::string(); + } + NDArray* out = args[5].operator mxnet::NDArray*(); + NDArray** outputs = out == nullptr ? nullptr : &out; + int num_outputs = out != nullptr; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs); + if (out) { + *ret = PythonArg(5); + } else { + *ret = ndoutputs[0]; + } + }); } // namespace mxnet diff --git a/src/api/operator/numpy/random/np_exponential_op.cc b/src/api/operator/numpy/random/np_exponential_op.cc index e95ebc8ed136..15347a0893d2 100644 --- a/src/api/operator/numpy/random/np_exponential_op.cc +++ b/src/api/operator/numpy/random/np_exponential_op.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -29,43 +29,42 @@ namespace mxnet { MXNET_REGISTER_API("_npi.exponential") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_exponential"); - op::NumpyExponentialParam param; - nnvm::NodeAttrs attrs; - attrs.op = op; - if (args[1].type_code() == kDLInt) { - param.size = Tuple(1, args[1].operator int64_t()); - } else if (args[1].type_code() == kNull) { - param.size = dmlc::nullopt; - } else { - param.size = Tuple(args[1].operator ObjectRef()); - } - if (args[2].type_code() != kNull) { - attrs.dict["ctx"] = args[2].operator std::string(); - } - NDArray* out = args[3].operator mxnet::NDArray*(); - NDArray** outputs = out == nullptr ? nullptr : &out; - int num_outputs = out != nullptr; - NDArray* inputs[1]; - int num_inputs = 0; - if (args[0].type_code() == kDLFloat || args[0].type_code() == kDLInt) { - param.scale = args[0].operator double(); - num_inputs = 0; - } else { - param.scale = dmlc::nullopt; - inputs[0] = args[0].operator mxnet::NDArray*(); - num_inputs = 1; - } - attrs.parsed = param; - SetAttrDict(&attrs); - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, - &num_outputs, outputs); - if (out) { - *ret = PythonArg(3); - } else { - *ret = ndoutputs[0]; - } -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_exponential"); + op::NumpyExponentialParam param; + nnvm::NodeAttrs attrs; + attrs.op = op; + if (args[1].type_code() == kDLInt) { + param.size = Tuple(1, args[1].operator int64_t()); + } else if (args[1].type_code() == kNull) { + param.size = dmlc::nullopt; + } else { + param.size = Tuple(args[1].operator ObjectRef()); + } + if (args[2].type_code() != kNull) { + attrs.dict["ctx"] = args[2].operator std::string(); + } + NDArray* out = args[3].operator mxnet::NDArray*(); + NDArray** outputs = out == nullptr ? nullptr : &out; + int num_outputs = out != nullptr; + NDArray* inputs[1]; + int num_inputs = 0; + if (args[0].type_code() == kDLFloat || args[0].type_code() == kDLInt) { + param.scale = args[0].operator double(); + num_inputs = 0; + } else { + param.scale = dmlc::nullopt; + inputs[0] = args[0].operator mxnet::NDArray*(); + num_inputs = 1; + } + attrs.parsed = param; + SetAttrDict(&attrs); + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs); + if (out) { + *ret = PythonArg(3); + } else { + *ret = ndoutputs[0]; + } + }); } // namespace mxnet diff --git a/src/api/operator/numpy/random/np_laplace_op.cc b/src/api/operator/numpy/random/np_laplace_op.cc index 6d8158384aa1..594b4b79413b 100644 --- a/src/api/operator/numpy/random/np_laplace_op.cc +++ b/src/api/operator/numpy/random/np_laplace_op.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -29,68 +29,68 @@ namespace mxnet { MXNET_REGISTER_API("_npi.laplace") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_laplace"); - nnvm::NodeAttrs attrs; - op::NumpyLaplaceParam param; + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_laplace"); + nnvm::NodeAttrs attrs; + op::NumpyLaplaceParam param; - NDArray** inputs = new NDArray*[2](); - int num_inputs = 0; + NDArray** inputs = new NDArray*[2](); + int num_inputs = 0; - if (args[0].type_code() == kNull) { - param.loc = dmlc::nullopt; - } else if (args[0].type_code() == kNDArrayHandle) { - param.loc = dmlc::nullopt; - inputs[num_inputs] = args[0].operator mxnet::NDArray *(); - num_inputs++; - } else { - param.loc = args[0].operator double(); // convert arg to T - } + if (args[0].type_code() == kNull) { + param.loc = dmlc::nullopt; + } else if (args[0].type_code() == kNDArrayHandle) { + param.loc = dmlc::nullopt; + inputs[num_inputs] = args[0].operator mxnet::NDArray*(); + num_inputs++; + } else { + param.loc = args[0].operator double(); // convert arg to T + } - if (args[1].type_code() == kNull) { - param.scale = dmlc::nullopt; - } else if (args[1].type_code() == kNDArrayHandle) { - param.scale = dmlc::nullopt; - inputs[num_inputs] = args[1].operator mxnet::NDArray *(); - num_inputs++; - } else { - param.scale = args[1].operator double(); // convert arg to T - } + if (args[1].type_code() == kNull) { + param.scale = dmlc::nullopt; + } else if (args[1].type_code() == kNDArrayHandle) { + param.scale = dmlc::nullopt; + inputs[num_inputs] = args[1].operator mxnet::NDArray*(); + num_inputs++; + } else { + param.scale = args[1].operator double(); // convert arg to T + } - if (args[2].type_code() == kNull) { - param.size = dmlc::nullopt; - } else { - if (args[2].type_code() == kDLInt) { - param.size = mxnet::Tuple(1, args[2].operator int64_t()); - } else { - param.size = mxnet::Tuple(args[2].operator ObjectRef()); - } - } + if (args[2].type_code() == kNull) { + param.size = dmlc::nullopt; + } else { + if (args[2].type_code() == kDLInt) { + param.size = mxnet::Tuple(1, args[2].operator int64_t()); + } else { + param.size = mxnet::Tuple(args[2].operator ObjectRef()); + } + } - if (args[3].type_code() == kNull) { - param.dtype = mshadow::kFloat32; - } else { - param.dtype = String2MXNetTypeWithBool(args[3].operator std::string()); - } - attrs.parsed = param; - attrs.op = op; - SetAttrDict(&attrs); - if (args[4].type_code() != kNull) { - attrs.dict["ctx"] = args[4].operator std::string(); - } + if (args[3].type_code() == kNull) { + param.dtype = mshadow::kFloat32; + } else { + param.dtype = String2MXNetTypeWithBool(args[3].operator std::string()); + } + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + if (args[4].type_code() != kNull) { + attrs.dict["ctx"] = args[4].operator std::string(); + } - inputs = inputs == nullptr ? nullptr : inputs; + inputs = inputs == nullptr ? nullptr : inputs; - NDArray* out = args[5].operator mxnet::NDArray*(); - NDArray** outputs = out == nullptr ? nullptr : &out; - int num_outputs = out != nullptr; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs); - if (out) { - *ret = PythonArg(5); - } else { - *ret = ndoutputs[0]; - } -}); + NDArray* out = args[5].operator mxnet::NDArray*(); + NDArray** outputs = out == nullptr ? nullptr : &out; + int num_outputs = out != nullptr; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs); + if (out) { + *ret = PythonArg(5); + } else { + *ret = ndoutputs[0]; + } + }); } // namespace mxnet diff --git a/src/api/operator/numpy/random/np_location_scale_op.cc b/src/api/operator/numpy/random/np_location_scale_op.cc index 3f0fbb3c8f91..30785352369c 100644 --- a/src/api/operator/numpy/random/np_location_scale_op.cc +++ b/src/api/operator/numpy/random/np_location_scale_op.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -30,122 +30,120 @@ namespace mxnet { int scalar_number(const runtime::MXNetArgs& args) { - int result = 0; - if (args[0].type_code() == kDLFloat || args[0].type_code() == kDLInt) - result++; - if (args[1].type_code() == kDLFloat || args[1].type_code() == kDLInt) - result++; - return result; + int result = 0; + if (args[0].type_code() == kDLFloat || args[0].type_code() == kDLInt) + result++; + if (args[1].type_code() == kDLFloat || args[1].type_code() == kDLInt) + result++; + return result; } MXNET_REGISTER_API("_npi.gumbel") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_gumbel"); - op::NumpyLocationScaleParam param; - nnvm::NodeAttrs attrs; - attrs.op = op; - if (args[2].type_code() == kDLInt) { - param.size = Tuple(1, args[2].operator int64_t()); - } else if (args[2].type_code() == kNull) { - param.size = dmlc::optional>(); - } else { - param.size = Tuple(args[2].operator ObjectRef()); - } - if (args[3].type_code() != kNull) { - attrs.dict["ctx"] = args[3].operator std::string(); - } - NDArray* out = args[4].operator mxnet::NDArray*(); - NDArray** outputs = out == nullptr ? nullptr : &out; - int num_outputs = out != nullptr; - int scalar = scalar_number(args); - std::vector inputs; - int num_inputs = 0; - if (scalar == 2) { - param.loc = args[0].operator double(); - param.scale = args[1].operator double(); - } else if (scalar == 0) { - param.loc = dmlc::nullopt; - param.scale = dmlc::nullopt; - inputs.push_back(args[0].operator mxnet::NDArray*()); - inputs.push_back(args[1].operator mxnet::NDArray*()); - num_inputs = 2; - } else { - if (args[0].type_code() == kDLFloat || args[0].type_code() == kDLInt) { - param.loc = args[0].operator double(); - param.scale = dmlc::nullopt; - inputs.push_back(args[1].operator mxnet::NDArray*()); - } else { - param.loc = dmlc::nullopt; - param.scale = args[1].operator double(); - inputs.push_back(args[0].operator mxnet::NDArray*()); - } - num_inputs = 1; - } - attrs.parsed = param; - SetAttrDict(&attrs); - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), - &num_outputs, outputs); - if (out) { - *ret = PythonArg(4); - } else { - *ret = ndoutputs[0]; - } -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_gumbel"); + op::NumpyLocationScaleParam param; + nnvm::NodeAttrs attrs; + attrs.op = op; + if (args[2].type_code() == kDLInt) { + param.size = Tuple(1, args[2].operator int64_t()); + } else if (args[2].type_code() == kNull) { + param.size = dmlc::optional>(); + } else { + param.size = Tuple(args[2].operator ObjectRef()); + } + if (args[3].type_code() != kNull) { + attrs.dict["ctx"] = args[3].operator std::string(); + } + NDArray* out = args[4].operator mxnet::NDArray*(); + NDArray** outputs = out == nullptr ? nullptr : &out; + int num_outputs = out != nullptr; + int scalar = scalar_number(args); + std::vector inputs; + int num_inputs = 0; + if (scalar == 2) { + param.loc = args[0].operator double(); + param.scale = args[1].operator double(); + } else if (scalar == 0) { + param.loc = dmlc::nullopt; + param.scale = dmlc::nullopt; + inputs.push_back(args[0].operator mxnet::NDArray*()); + inputs.push_back(args[1].operator mxnet::NDArray*()); + num_inputs = 2; + } else { + if (args[0].type_code() == kDLFloat || args[0].type_code() == kDLInt) { + param.loc = args[0].operator double(); + param.scale = dmlc::nullopt; + inputs.push_back(args[1].operator mxnet::NDArray*()); + } else { + param.loc = dmlc::nullopt; + param.scale = args[1].operator double(); + inputs.push_back(args[0].operator mxnet::NDArray*()); + } + num_inputs = 1; + } + attrs.parsed = param; + SetAttrDict(&attrs); + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, outputs); + if (out) { + *ret = PythonArg(4); + } else { + *ret = ndoutputs[0]; + } + }); MXNET_REGISTER_API("_npi.logistic") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_logistic"); - op::NumpyLocationScaleParam param; - nnvm::NodeAttrs attrs; - attrs.op = op; - if (args[2].type_code() == kDLInt) { - param.size = Tuple(1, args[2].operator int64_t()); - } else if (args[2].type_code() == kNull) { - param.size = dmlc::nullopt; - } else { - param.size = Tuple(args[2].operator ObjectRef()); - } - if (args[3].type_code() != kNull) { - attrs.dict["ctx"] = args[3].operator std::string(); - } - NDArray* out = args[4].operator mxnet::NDArray*(); - NDArray** outputs = out == nullptr ? nullptr : &out; - int num_outputs = out != nullptr; - int scalar = scalar_number(args); - std::vector inputs; - int num_inputs = 0; - if (scalar == 2) { - param.loc = args[0].operator double(); - param.scale = args[1].operator double(); - } else if (scalar == 0) { - param.loc = dmlc::nullopt; - param.scale = dmlc::nullopt; - inputs.push_back(args[0].operator mxnet::NDArray*()); - inputs.push_back(args[1].operator mxnet::NDArray*()); - num_inputs = 2; - } else { - if (args[0].type_code() == kDLFloat || args[0].type_code() == kDLInt) { - param.loc = args[0].operator double(); - param.scale = dmlc::nullopt; - inputs.push_back(args[1].operator mxnet::NDArray*()); - } else { - param.loc = dmlc::nullopt; - param.scale = args[1].operator double(); - inputs.push_back(args[0].operator mxnet::NDArray*()); - } - num_inputs = 1; - } - attrs.parsed = param; - SetAttrDict(&attrs); - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), - &num_outputs, outputs); - if (out) { - *ret = PythonArg(4); - } else { - *ret = ndoutputs[0]; - } -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_logistic"); + op::NumpyLocationScaleParam param; + nnvm::NodeAttrs attrs; + attrs.op = op; + if (args[2].type_code() == kDLInt) { + param.size = Tuple(1, args[2].operator int64_t()); + } else if (args[2].type_code() == kNull) { + param.size = dmlc::nullopt; + } else { + param.size = Tuple(args[2].operator ObjectRef()); + } + if (args[3].type_code() != kNull) { + attrs.dict["ctx"] = args[3].operator std::string(); + } + NDArray* out = args[4].operator mxnet::NDArray*(); + NDArray** outputs = out == nullptr ? nullptr : &out; + int num_outputs = out != nullptr; + int scalar = scalar_number(args); + std::vector inputs; + int num_inputs = 0; + if (scalar == 2) { + param.loc = args[0].operator double(); + param.scale = args[1].operator double(); + } else if (scalar == 0) { + param.loc = dmlc::nullopt; + param.scale = dmlc::nullopt; + inputs.push_back(args[0].operator mxnet::NDArray*()); + inputs.push_back(args[1].operator mxnet::NDArray*()); + num_inputs = 2; + } else { + if (args[0].type_code() == kDLFloat || args[0].type_code() == kDLInt) { + param.loc = args[0].operator double(); + param.scale = dmlc::nullopt; + inputs.push_back(args[1].operator mxnet::NDArray*()); + } else { + param.loc = dmlc::nullopt; + param.scale = args[1].operator double(); + inputs.push_back(args[0].operator mxnet::NDArray*()); + } + num_inputs = 1; + } + attrs.parsed = param; + SetAttrDict(&attrs); + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, outputs); + if (out) { + *ret = PythonArg(4); + } else { + *ret = ndoutputs[0]; + } + }); } // namespace mxnet diff --git a/src/api/operator/numpy/random/np_multinomial_op.cc b/src/api/operator/numpy/random/np_multinomial_op.cc index 13f18bea23ff..ad4d80838b45 100644 --- a/src/api/operator/numpy/random/np_multinomial_op.cc +++ b/src/api/operator/numpy/random/np_multinomial_op.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -30,46 +30,46 @@ namespace mxnet { MXNET_REGISTER_API("_npi.multinomial") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_multinomial"); - nnvm::NodeAttrs attrs; - op::NumpyMultinomialParam param; - NDArray** inputs = new NDArray*[1](); - int num_inputs = 0; + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_multinomial"); + nnvm::NodeAttrs attrs; + op::NumpyMultinomialParam param; + NDArray** inputs = new NDArray*[1](); + int num_inputs = 0; - // parse int - param.n = args[0].operator int(); + // parse int + param.n = args[0].operator int(); - // parse pvals - if (args[1].type_code() == kNull) { - param.pvals = dmlc::nullopt; - } else if (args[1].type_code() == kNDArrayHandle) { - param.pvals = dmlc::nullopt; - inputs[0] = args[1].operator mxnet::NDArray*(); - num_inputs = 1; - } else { - param.pvals = Obj2Tuple(args[1].operator ObjectRef()); - } + // parse pvals + if (args[1].type_code() == kNull) { + param.pvals = dmlc::nullopt; + } else if (args[1].type_code() == kNDArrayHandle) { + param.pvals = dmlc::nullopt; + inputs[0] = args[1].operator mxnet::NDArray*(); + num_inputs = 1; + } else { + param.pvals = Obj2Tuple(args[1].operator ObjectRef()); + } - // parse size - if (args[2].type_code() == kNull) { - param.size = dmlc::nullopt; - } else { - if (args[2].type_code() == kDLInt) { - param.size = mxnet::Tuple(1, args[2].operator int64_t()); - } else { - param.size = mxnet::Tuple(args[2].operator ObjectRef()); - } - } + // parse size + if (args[2].type_code() == kNull) { + param.size = dmlc::nullopt; + } else { + if (args[2].type_code() == kDLInt) { + param.size = mxnet::Tuple(1, args[2].operator int64_t()); + } else { + param.size = mxnet::Tuple(args[2].operator ObjectRef()); + } + } - attrs.parsed = std::move(param); - attrs.op = op; - SetAttrDict(&attrs); - inputs = num_inputs == 0 ? nullptr : inputs; - int num_outputs = 0; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); - *ret = ndoutputs[0]; -}); + attrs.parsed = std::move(param); + attrs.op = op; + SetAttrDict(&attrs); + inputs = num_inputs == 0 ? nullptr : inputs; + int num_outputs = 0; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = ndoutputs[0]; + }); } // namespace mxnet diff --git a/src/api/operator/numpy/random/np_pareto_op.cc b/src/api/operator/numpy/random/np_pareto_op.cc index 6d85a2810adf..079b4810adbf 100644 --- a/src/api/operator/numpy/random/np_pareto_op.cc +++ b/src/api/operator/numpy/random/np_pareto_op.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -29,44 +29,43 @@ namespace mxnet { MXNET_REGISTER_API("_npi.pareto") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_pareto"); - op::NumpyParetoParam param; - nnvm::NodeAttrs attrs; - attrs.op = op; - if (args[1].type_code() == kDLInt) { - param.size = Tuple(1, args[1].operator int64_t()); - } else if (args[1].type_code() == kNull) { - param.size = dmlc::nullopt; - } else { - param.size = Tuple(args[1].operator ObjectRef()); - } - if (args[2].type_code() != kNull) { - attrs.dict["ctx"] = args[2].operator std::string(); - } - NDArray* out = args[3].operator mxnet::NDArray*(); - NDArray** outputs = out == nullptr ? nullptr : &out; - int num_outputs = out != nullptr; - NDArray* inputs[1]; - int num_inputs = 0; - if (args[0].type_code() == kDLFloat || args[0].type_code() == kDLInt) { - param.a = args[0].operator double(); - num_inputs = 0; - } else { - param.a = dmlc::nullopt; - inputs[0] = args[0].operator mxnet::NDArray*(); - num_inputs = 1; - } - attrs.parsed = param; - SetAttrDict(&attrs); - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, - &num_outputs, outputs); - if (out) { - *ret = PythonArg(3); - } else { - *ret = ndoutputs[0]; - } -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_pareto"); + op::NumpyParetoParam param; + nnvm::NodeAttrs attrs; + attrs.op = op; + if (args[1].type_code() == kDLInt) { + param.size = Tuple(1, args[1].operator int64_t()); + } else if (args[1].type_code() == kNull) { + param.size = dmlc::nullopt; + } else { + param.size = Tuple(args[1].operator ObjectRef()); + } + if (args[2].type_code() != kNull) { + attrs.dict["ctx"] = args[2].operator std::string(); + } + NDArray* out = args[3].operator mxnet::NDArray*(); + NDArray** outputs = out == nullptr ? nullptr : &out; + int num_outputs = out != nullptr; + NDArray* inputs[1]; + int num_inputs = 0; + if (args[0].type_code() == kDLFloat || args[0].type_code() == kDLInt) { + param.a = args[0].operator double(); + num_inputs = 0; + } else { + param.a = dmlc::nullopt; + inputs[0] = args[0].operator mxnet::NDArray*(); + num_inputs = 1; + } + attrs.parsed = param; + SetAttrDict(&attrs); + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs); + if (out) { + *ret = PythonArg(3); + } else { + *ret = ndoutputs[0]; + } + }); } // namespace mxnet diff --git a/src/api/operator/numpy/random/np_power_op.cc b/src/api/operator/numpy/random/np_power_op.cc index d532c6d4703b..8543c613e46d 100644 --- a/src/api/operator/numpy/random/np_power_op.cc +++ b/src/api/operator/numpy/random/np_power_op.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -29,44 +29,43 @@ namespace mxnet { MXNET_REGISTER_API("_npi.powerd") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_powerd"); - op::NumpyPowerParam param; - nnvm::NodeAttrs attrs; - attrs.op = op; - if (args[1].type_code() == kDLInt) { - param.size = Tuple(1, args[1].operator int64_t()); - } else if (args[1].type_code() == kNull) { - param.size = dmlc::nullopt; - } else { - param.size = Tuple(args[1].operator ObjectRef()); - } - if (args[2].type_code() != kNull) { - attrs.dict["ctx"] = args[2].operator std::string(); - } - NDArray* out = args[3].operator mxnet::NDArray*(); - NDArray** outputs = out == nullptr ? nullptr : &out; - int num_outputs = out != nullptr; - NDArray* inputs[1]; - int num_inputs = 0; - if (args[0].type_code() == kDLFloat || args[0].type_code() == kDLInt) { - param.a = args[0].operator double(); - num_inputs = 0; - } else { - param.a = dmlc::nullopt; - inputs[0] = args[0].operator mxnet::NDArray*(); - num_inputs = 1; - } - attrs.parsed = param; - SetAttrDict(&attrs); - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, - &num_outputs, outputs); - if (out) { - *ret = PythonArg(3); - } else { - *ret = ndoutputs[0]; - } -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_powerd"); + op::NumpyPowerParam param; + nnvm::NodeAttrs attrs; + attrs.op = op; + if (args[1].type_code() == kDLInt) { + param.size = Tuple(1, args[1].operator int64_t()); + } else if (args[1].type_code() == kNull) { + param.size = dmlc::nullopt; + } else { + param.size = Tuple(args[1].operator ObjectRef()); + } + if (args[2].type_code() != kNull) { + attrs.dict["ctx"] = args[2].operator std::string(); + } + NDArray* out = args[3].operator mxnet::NDArray*(); + NDArray** outputs = out == nullptr ? nullptr : &out; + int num_outputs = out != nullptr; + NDArray* inputs[1]; + int num_inputs = 0; + if (args[0].type_code() == kDLFloat || args[0].type_code() == kDLInt) { + param.a = args[0].operator double(); + num_inputs = 0; + } else { + param.a = dmlc::nullopt; + inputs[0] = args[0].operator mxnet::NDArray*(); + num_inputs = 1; + } + attrs.parsed = param; + SetAttrDict(&attrs); + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs); + if (out) { + *ret = PythonArg(3); + } else { + *ret = ndoutputs[0]; + } + }); } // namespace mxnet diff --git a/src/api/operator/numpy/random/np_rayleigh_op.cc b/src/api/operator/numpy/random/np_rayleigh_op.cc index 387e2cf61d20..df1d61c40dba 100644 --- a/src/api/operator/numpy/random/np_rayleigh_op.cc +++ b/src/api/operator/numpy/random/np_rayleigh_op.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -29,44 +29,43 @@ namespace mxnet { MXNET_REGISTER_API("_npi.rayleigh") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_rayleigh"); - op::NumpyRayleighParam param; - nnvm::NodeAttrs attrs; - attrs.op = op; - if (args[1].type_code() == kDLInt) { - param.size = Tuple(1, args[1].operator int64_t()); - } else if (args[1].type_code() == kNull) { - param.size = dmlc::nullopt; - } else { - param.size = Tuple(args[1].operator ObjectRef()); - } - if (args[2].type_code() != kNull) { - attrs.dict["ctx"] = args[2].operator std::string(); - } - NDArray* out = args[3].operator mxnet::NDArray*(); - NDArray** outputs = out == nullptr ? nullptr : &out; - int num_outputs = out != nullptr; - NDArray* inputs[1]; - int num_inputs = 0; - if (args[0].type_code() == kDLFloat || args[0].type_code() == kDLInt) { - param.scale = args[0].operator double(); - num_inputs = 0; - } else { - param.scale = dmlc::nullopt; - inputs[0] = args[0].operator mxnet::NDArray*(); - num_inputs = 1; - } - attrs.parsed = param; - SetAttrDict(&attrs); - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, - &num_outputs, outputs); - if (out) { - *ret = PythonArg(3); - } else { - *ret = ndoutputs[0]; - } -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_rayleigh"); + op::NumpyRayleighParam param; + nnvm::NodeAttrs attrs; + attrs.op = op; + if (args[1].type_code() == kDLInt) { + param.size = Tuple(1, args[1].operator int64_t()); + } else if (args[1].type_code() == kNull) { + param.size = dmlc::nullopt; + } else { + param.size = Tuple(args[1].operator ObjectRef()); + } + if (args[2].type_code() != kNull) { + attrs.dict["ctx"] = args[2].operator std::string(); + } + NDArray* out = args[3].operator mxnet::NDArray*(); + NDArray** outputs = out == nullptr ? nullptr : &out; + int num_outputs = out != nullptr; + NDArray* inputs[1]; + int num_inputs = 0; + if (args[0].type_code() == kDLFloat || args[0].type_code() == kDLInt) { + param.scale = args[0].operator double(); + num_inputs = 0; + } else { + param.scale = dmlc::nullopt; + inputs[0] = args[0].operator mxnet::NDArray*(); + num_inputs = 1; + } + attrs.parsed = param; + SetAttrDict(&attrs); + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs); + if (out) { + *ret = PythonArg(3); + } else { + *ret = ndoutputs[0]; + } + }); } // namespace mxnet diff --git a/src/api/operator/numpy/random/np_weibull_op.cc b/src/api/operator/numpy/random/np_weibull_op.cc index a91b72f8cf74..3504f569f92f 100644 --- a/src/api/operator/numpy/random/np_weibull_op.cc +++ b/src/api/operator/numpy/random/np_weibull_op.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -29,44 +29,43 @@ namespace mxnet { MXNET_REGISTER_API("_npi.weibull") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_weibull"); - op::NumpyWeibullParam param; - nnvm::NodeAttrs attrs; - attrs.op = op; - if (args[1].type_code() == kDLInt) { - param.size = Tuple(1, args[1].operator int64_t()); - } else if (args[1].type_code() == kNull) { - param.size = dmlc::nullopt; - } else { - param.size = Tuple(args[1].operator ObjectRef()); - } - if (args[2].type_code() != kNull) { - attrs.dict["ctx"] = args[2].operator std::string(); - } - NDArray* out = args[3].operator mxnet::NDArray*(); - NDArray** outputs = out == nullptr ? nullptr : &out; - int num_outputs = out != nullptr; - NDArray* inputs[1]; - int num_inputs = 0; - if (args[0].type_code() == kDLFloat || args[0].type_code() == kDLInt) { - param.a = args[0].operator double(); - num_inputs = 0; - } else { - param.a = dmlc::nullopt; - inputs[0] = args[0].operator mxnet::NDArray*(); - num_inputs = 1; - } - attrs.parsed = param; - SetAttrDict(&attrs); - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, - &num_outputs, outputs); - if (out) { - *ret = PythonArg(3); - } else { - *ret = ndoutputs[0]; - } -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_weibull"); + op::NumpyWeibullParam param; + nnvm::NodeAttrs attrs; + attrs.op = op; + if (args[1].type_code() == kDLInt) { + param.size = Tuple(1, args[1].operator int64_t()); + } else if (args[1].type_code() == kNull) { + param.size = dmlc::nullopt; + } else { + param.size = Tuple(args[1].operator ObjectRef()); + } + if (args[2].type_code() != kNull) { + attrs.dict["ctx"] = args[2].operator std::string(); + } + NDArray* out = args[3].operator mxnet::NDArray*(); + NDArray** outputs = out == nullptr ? nullptr : &out; + int num_outputs = out != nullptr; + NDArray* inputs[1]; + int num_inputs = 0; + if (args[0].type_code() == kDLFloat || args[0].type_code() == kDLInt) { + param.a = args[0].operator double(); + num_inputs = 0; + } else { + param.a = dmlc::nullopt; + inputs[0] = args[0].operator mxnet::NDArray*(); + num_inputs = 1; + } + attrs.parsed = param; + SetAttrDict(&attrs); + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs); + if (out) { + *ret = PythonArg(3); + } else { + *ret = ndoutputs[0]; + } + }); } // namespace mxnet diff --git a/src/api/operator/numpy_extension/npx_activation_op.cc b/src/api/operator/numpy_extension/npx_activation_op.cc index 27810fbc8ca6..32a0d6661d28 100644 --- a/src/api/operator/numpy_extension/npx_activation_op.cc +++ b/src/api/operator/numpy_extension/npx_activation_op.cc @@ -19,7 +19,8 @@ /*! * \file npx_activation_op.cc - * \brief Implementation of the API of functions in src/operator/numpy_extension/npx_activation_op.cc + * \brief Implementation of the API of functions in + * src/operator/numpy_extension/npx_activation_op.cc */ #include #include @@ -52,22 +53,22 @@ inline int String2MXNetActType(const std::string& s) { } MXNET_REGISTER_API("_npx.activation") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - nnvm::NodeAttrs attrs; - const nnvm::Op* op = Op::Get("_npx_activation"); - op::ActivationParam param; - // act_type - param.act_type = String2MXNetActType(args[1].operator std::string()); - attrs.parsed = param; - attrs.op = op; - SetAttrDict(&attrs); - // inputs - NDArray* inputs[] = {args[0].operator NDArray*()}; - int num_inputs = 1; - int num_outputs = 0; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); - *ret = ndoutputs[0]; -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + nnvm::NodeAttrs attrs; + const nnvm::Op* op = Op::Get("_npx_activation"); + op::ActivationParam param; + // act_type + param.act_type = String2MXNetActType(args[1].operator std::string()); + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + // inputs + NDArray* inputs[] = {args[0].operator NDArray*()}; + int num_inputs = 1; + int num_outputs = 0; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = ndoutputs[0]; + }); } // namespace mxnet diff --git a/src/api/operator/numpy_extension/npx_arange_like_op.cc b/src/api/operator/numpy_extension/npx_arange_like_op.cc index 859e96373c86..07e37efe8145 100644 --- a/src/api/operator/numpy_extension/npx_arange_like_op.cc +++ b/src/api/operator/numpy_extension/npx_arange_like_op.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -19,7 +19,8 @@ /*! * \file npx_arange_like_op.cc - * \brief Implementation of the API of functions in src/operator/numpy_extension/npx_arange_like_op.cc + * \brief Implementation of the API of functions in + * src/operator/numpy_extension/npx_arange_like_op.cc */ #include #include @@ -29,50 +30,50 @@ namespace mxnet { MXNET_REGISTER_API("_npx.arange_like") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - nnvm::NodeAttrs attrs; - const nnvm::Op* op = Op::Get("_npx_arange_like"); - op::RangeLikeParam param; - // inputs - int num_inputs = 1; - NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; - // start - if (args[1].type_code() == kNull) { - param.start = 0.0; - } else { - param.start = args[1].operator double(); - } - // step - if (args[2].type_code() == kNull) { - param.step = 1.0; - } else { - param.step = args[2].operator double(); - } - // repeat - if (args[3].type_code() == kNull) { - param.repeat = 1; - } else { - param.repeat = args[3].operator int(); - } - // ctx - if (args[4].type_code() != kNull) { - attrs.dict["ctx"] = args[4].operator std::string(); - } - // axis - if (args[5].type_code() == kNull) { - param.axis = dmlc::nullopt; - } else { - param.axis = args[5].operator int(); - } - attrs.op = op; - attrs.parsed = param; - SetAttrDict(&attrs); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + nnvm::NodeAttrs attrs; + const nnvm::Op* op = Op::Get("_npx_arange_like"); + op::RangeLikeParam param; + // inputs + int num_inputs = 1; + NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; + // start + if (args[1].type_code() == kNull) { + param.start = 0.0; + } else { + param.start = args[1].operator double(); + } + // step + if (args[2].type_code() == kNull) { + param.step = 1.0; + } else { + param.step = args[2].operator double(); + } + // repeat + if (args[3].type_code() == kNull) { + param.repeat = 1; + } else { + param.repeat = args[3].operator int(); + } + // ctx + if (args[4].type_code() != kNull) { + attrs.dict["ctx"] = args[4].operator std::string(); + } + // axis + if (args[5].type_code() == kNull) { + param.axis = dmlc::nullopt; + } else { + param.axis = args[5].operator int(); + } + attrs.op = op; + attrs.parsed = param; + SetAttrDict(&attrs); - // outputs - int num_outputs = 0; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); - *ret = reinterpret_cast(ndoutputs[0]); -}); + // outputs + int num_outputs = 0; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = reinterpret_cast(ndoutputs[0]); + }); } // namespace mxnet diff --git a/src/api/operator/numpy_extension/npx_batch_dot_op.cc b/src/api/operator/numpy_extension/npx_batch_dot_op.cc index 314e8e528908..d764801859c5 100644 --- a/src/api/operator/numpy_extension/npx_batch_dot_op.cc +++ b/src/api/operator/numpy_extension/npx_batch_dot_op.cc @@ -44,42 +44,42 @@ inline int String2ForwardStype(const std::string& s) { } MXNET_REGISTER_API("_npx.batch_dot") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - nnvm::NodeAttrs attrs; - const nnvm::Op* op = Op::Get("_npx_batch_dot"); - op::DotParam param; - // inputs - int num_inputs = 2; - std::vector inputs; - inputs.reserve(num_inputs); - for (int i = 0; i < num_inputs; ++i) { - inputs.push_back(args[i].operator mxnet::NDArray*()); - } - // transpose_a - if (args[2].type_code() == kNull) { - param.transpose_a = false; - } else { - param.transpose_a = args[2].operator bool(); - } - // transpose_b - if (args[3].type_code() == kNull) { - param.transpose_b = false; - } else { - param.transpose_b = args[3].operator bool(); - } - // forward_stype - if (args[4].type_code() == kNull) { - param.forward_stype = dmlc::nullopt; - } else { - param.forward_stype = String2ForwardStype(args[4].operator std::string()); - } - attrs.parsed = param; - attrs.op = op; - SetAttrDict(&attrs); - int num_outputs = 1; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, nullptr); - *ret = ndoutputs[0]; -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + nnvm::NodeAttrs attrs; + const nnvm::Op* op = Op::Get("_npx_batch_dot"); + op::DotParam param; + // inputs + int num_inputs = 2; + std::vector inputs; + inputs.reserve(num_inputs); + for (int i = 0; i < num_inputs; ++i) { + inputs.push_back(args[i].operator mxnet::NDArray*()); + } + // transpose_a + if (args[2].type_code() == kNull) { + param.transpose_a = false; + } else { + param.transpose_a = args[2].operator bool(); + } + // transpose_b + if (args[3].type_code() == kNull) { + param.transpose_b = false; + } else { + param.transpose_b = args[3].operator bool(); + } + // forward_stype + if (args[4].type_code() == kNull) { + param.forward_stype = dmlc::nullopt; + } else { + param.forward_stype = String2ForwardStype(args[4].operator std::string()); + } + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + int num_outputs = 1; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, nullptr); + *ret = ndoutputs[0]; + }); } // namespace mxnet diff --git a/src/api/operator/numpy_extension/npx_batch_norm_op.cc b/src/api/operator/numpy_extension/npx_batch_norm_op.cc index dcf3ac4f0df7..a82703d9212e 100644 --- a/src/api/operator/numpy_extension/npx_batch_norm_op.cc +++ b/src/api/operator/numpy_extension/npx_batch_norm_op.cc @@ -19,7 +19,8 @@ /*! * \file npx_batch_norm_op.cc - * \brief Implementation of the API of functions in src/operator/numpy_extension/npx_batch_norm_op.cc + * \brief Implementation of the API of functions in + * src/operator/numpy_extension/npx_batch_norm_op.cc */ #include #include @@ -29,59 +30,59 @@ namespace mxnet { MXNET_REGISTER_API("_npx.batch_norm") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - nnvm::NodeAttrs attrs; - const nnvm::Op* op = Op::Get("_npx_batch_norm"); - op::BatchNormParam param; - // eps - param.eps = args[5].operator double(); - // momentum - param.momentum = args[6].operator double(); - // fix_gamma - param.fix_gamma = args[7].operator bool(); - // use_global_stats - param.use_global_stats = args[8].operator bool(); - // output_mean_var - param.output_mean_var = args[9].operator bool(); - // axis - param.axis = args[10].operator int(); - // cudnn_off - param.cudnn_off = args[11].operator bool(); - // min_calib_range - if (args[12].type_code() == kDLFloat || args[12].type_code() == kDLInt) { - param.min_calib_range = args[12].operator double(); - } else { - param.min_calib_range = dmlc::nullopt; - } - // max_calib_range - if (args[13].type_code() == kDLFloat || args[13].type_code() == kDLInt) { - param.max_calib_range = args[13].operator double(); - } else { - param.max_calib_range = dmlc::nullopt; - } - attrs.parsed = param; - attrs.op = op; - SetAttrDict(&attrs); - // inputs - int num_inputs = 5; - std::vector inputs; - inputs.reserve(num_inputs); - for (int i = 0; i < num_inputs; ++i) { - inputs.push_back(args[i].operator mxnet::NDArray*()); - } - int num_outputs = 0; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, nullptr); - if (num_outputs == 1) { - *ret = ndoutputs[0]; - } else { - std::vector ndarray_handles; - ndarray_handles.reserve(num_outputs); - for (int i = 0; i < num_outputs; ++i) { - ndarray_handles.emplace_back(ndoutputs[i]); - } - *ret = ADT(0, ndarray_handles.begin(), ndarray_handles.end()); - } -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + nnvm::NodeAttrs attrs; + const nnvm::Op* op = Op::Get("_npx_batch_norm"); + op::BatchNormParam param; + // eps + param.eps = args[5].operator double(); + // momentum + param.momentum = args[6].operator double(); + // fix_gamma + param.fix_gamma = args[7].operator bool(); + // use_global_stats + param.use_global_stats = args[8].operator bool(); + // output_mean_var + param.output_mean_var = args[9].operator bool(); + // axis + param.axis = args[10].operator int(); + // cudnn_off + param.cudnn_off = args[11].operator bool(); + // min_calib_range + if (args[12].type_code() == kDLFloat || args[12].type_code() == kDLInt) { + param.min_calib_range = args[12].operator double(); + } else { + param.min_calib_range = dmlc::nullopt; + } + // max_calib_range + if (args[13].type_code() == kDLFloat || args[13].type_code() == kDLInt) { + param.max_calib_range = args[13].operator double(); + } else { + param.max_calib_range = dmlc::nullopt; + } + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + // inputs + int num_inputs = 5; + std::vector inputs; + inputs.reserve(num_inputs); + for (int i = 0; i < num_inputs; ++i) { + inputs.push_back(args[i].operator mxnet::NDArray*()); + } + int num_outputs = 0; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, nullptr); + if (num_outputs == 1) { + *ret = ndoutputs[0]; + } else { + std::vector ndarray_handles; + ndarray_handles.reserve(num_outputs); + for (int i = 0; i < num_outputs; ++i) { + ndarray_handles.emplace_back(ndoutputs[i]); + } + *ret = ADT(0, ndarray_handles.begin(), ndarray_handles.end()); + } + }); } // namespace mxnet diff --git a/src/api/operator/numpy_extension/npx_broadcast_like_op.cc b/src/api/operator/numpy_extension/npx_broadcast_like_op.cc index bd882665208e..3929a516f116 100644 --- a/src/api/operator/numpy_extension/npx_broadcast_like_op.cc +++ b/src/api/operator/numpy_extension/npx_broadcast_like_op.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -19,7 +19,8 @@ /*! * \file npx_broadcast_like_op.cc - * \brief Implementation of the API of functions in src/operator/numpy_extension/npx_broadcast_like_op.cc + * \brief Implementation of the API of functions in + * src/operator/numpy_extension/npx_broadcast_like_op.cc */ #include #include @@ -29,43 +30,43 @@ namespace mxnet { MXNET_REGISTER_API("_npx.broadcast_like") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - nnvm::NodeAttrs attrs; - const nnvm::Op* op = Op::Get("_npx_broadcast_like"); - op::BroadcastLikeParam param; - // inputs - int num_inputs = 2; - std::vector inputs; - inputs.reserve(num_inputs); - for (int i = 0; i < num_inputs; ++i) { - inputs.push_back(args[i].operator mxnet::NDArray*()); - } - // lhs_axes - if (args[2].type_code() == kNull) { - param.lhs_axes = dmlc::optional(); - } else if (args[2].type_code() == kDLInt) { - param.lhs_axes = TShape(1, args[2].operator int64_t()); - } else { - param.lhs_axes = mxnet::TShape(args[2].operator ObjectRef()); - } - // rhs_axes - if (args[3].type_code() == kNull) { - param.rhs_axes = dmlc::optional(); - } else if (args[3].type_code() == kDLInt) { - param.rhs_axes = TShape(1, args[3].operator int64_t()); - } else { - param.rhs_axes = mxnet::TShape(args[3].operator ObjectRef()); - } + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + nnvm::NodeAttrs attrs; + const nnvm::Op* op = Op::Get("_npx_broadcast_like"); + op::BroadcastLikeParam param; + // inputs + int num_inputs = 2; + std::vector inputs; + inputs.reserve(num_inputs); + for (int i = 0; i < num_inputs; ++i) { + inputs.push_back(args[i].operator mxnet::NDArray*()); + } + // lhs_axes + if (args[2].type_code() == kNull) { + param.lhs_axes = dmlc::optional(); + } else if (args[2].type_code() == kDLInt) { + param.lhs_axes = TShape(1, args[2].operator int64_t()); + } else { + param.lhs_axes = mxnet::TShape(args[2].operator ObjectRef()); + } + // rhs_axes + if (args[3].type_code() == kNull) { + param.rhs_axes = dmlc::optional(); + } else if (args[3].type_code() == kDLInt) { + param.rhs_axes = TShape(1, args[3].operator int64_t()); + } else { + param.rhs_axes = mxnet::TShape(args[3].operator ObjectRef()); + } - attrs.op = op; - attrs.parsed = param; - SetAttrDict(&attrs); + attrs.op = op; + attrs.parsed = param; + SetAttrDict(&attrs); - // outputs - int num_outputs = 0; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, nullptr); - *ret = reinterpret_cast(ndoutputs[0]); -}); + // outputs + int num_outputs = 0; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, nullptr); + *ret = reinterpret_cast(ndoutputs[0]); + }); } // namespace mxnet diff --git a/src/api/operator/numpy_extension/npx_control_flow_op.cc b/src/api/operator/numpy_extension/npx_control_flow_op.cc index 52001d8f7bd1..5e422381e1e1 100644 --- a/src/api/operator/numpy_extension/npx_control_flow_op.cc +++ b/src/api/operator/numpy_extension/npx_control_flow_op.cc @@ -19,7 +19,8 @@ /*! * \file npx_control_flow_op.cc - * \brief Implementation of the API of functions in src/operator/numpy_extension/npx_control_flow_op.cc + * \brief Implementation of the API of functions in + * src/operator/numpy_extension/npx_control_flow_op.cc */ #include #include @@ -30,130 +31,128 @@ namespace mxnet { MXNET_REGISTER_API("_npx.foreach") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - nnvm::NodeAttrs attrs; - const nnvm::Op* op = Op::Get("_npx_foreach"); - op::NPXForeachParam param; - int args_size = args.size(); - int num_inputs = args_size - 7; - // inputs - nnvm::Symbol* sym = static_cast(args[0].value().v_handle); - std::vector > subgraphs; - subgraphs.push_back(std::make_shared(*sym)); - std::vector inputs; - inputs.reserve(num_inputs); - for (int i = 1; i < num_inputs + 1; ++i) { - inputs.push_back(static_cast(args[i])); - } - - param.num_args = num_inputs; - param.num_outputs = args[1+num_inputs].operator int(); - param.num_out_data = args[2+num_inputs].operator int(); - if (args[3+num_inputs].type_code() == kDLInt) { - param.in_state_locs = mxnet::Tuple(1, args[3+num_inputs].operator int64_t()); - } else { - param.in_state_locs = mxnet::Tuple(args[3+num_inputs].operator ObjectRef()); - } - if (args[4+num_inputs].type_code() == kDLInt) { - param.in_data_locs = mxnet::Tuple(1, args[4+num_inputs].operator int64_t()); - } else { - param.in_data_locs = mxnet::Tuple(args[4+num_inputs].operator ObjectRef()); - } - if (args[5+num_inputs].type_code() == kDLInt) { - param.remain_locs = mxnet::Tuple(1, args[5+num_inputs].operator int64_t()); - } else { - param.remain_locs = mxnet::Tuple(args[5+num_inputs].operator ObjectRef()); - } - if (args[6+num_inputs].type_code() == kDLInt) { - param.in_state_index = mxnet::Tuple(1, args[6+num_inputs].operator int64_t()); - } else { - param.in_state_index = mxnet::Tuple(args[6+num_inputs].operator ObjectRef()); - } - attrs.parsed = param; - attrs.op = op; - attrs.subgraphs = subgraphs; - SetAttrDict(&attrs); - int num_outputs = 0; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, nullptr); - if (num_outputs == 1) { - *ret = ndoutputs[0]; - } else { - std::vector ndarray_handles; - ndarray_handles.reserve(num_outputs); - for (int i = 0; i < num_outputs; ++i) { - ndarray_handles.emplace_back(ndoutputs[i]); - } - *ret = ADT(0, ndarray_handles.begin(), ndarray_handles.end()); - } -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + nnvm::NodeAttrs attrs; + const nnvm::Op* op = Op::Get("_npx_foreach"); + op::NPXForeachParam param; + int args_size = args.size(); + int num_inputs = args_size - 7; + // inputs + nnvm::Symbol* sym = static_cast(args[0].value().v_handle); + std::vector > subgraphs; + subgraphs.push_back(std::make_shared(*sym)); + std::vector inputs; + inputs.reserve(num_inputs); + for (int i = 1; i < num_inputs + 1; ++i) { + inputs.push_back(static_cast(args[i])); + } + param.num_args = num_inputs; + param.num_outputs = args[1 + num_inputs].operator int(); + param.num_out_data = args[2 + num_inputs].operator int(); + if (args[3 + num_inputs].type_code() == kDLInt) { + param.in_state_locs = mxnet::Tuple(1, args[3 + num_inputs].operator int64_t()); + } else { + param.in_state_locs = mxnet::Tuple(args[3 + num_inputs].operator ObjectRef()); + } + if (args[4 + num_inputs].type_code() == kDLInt) { + param.in_data_locs = mxnet::Tuple(1, args[4 + num_inputs].operator int64_t()); + } else { + param.in_data_locs = mxnet::Tuple(args[4 + num_inputs].operator ObjectRef()); + } + if (args[5 + num_inputs].type_code() == kDLInt) { + param.remain_locs = mxnet::Tuple(1, args[5 + num_inputs].operator int64_t()); + } else { + param.remain_locs = mxnet::Tuple(args[5 + num_inputs].operator ObjectRef()); + } + if (args[6 + num_inputs].type_code() == kDLInt) { + param.in_state_index = mxnet::Tuple(1, args[6 + num_inputs].operator int64_t()); + } else { + param.in_state_index = mxnet::Tuple(args[6 + num_inputs].operator ObjectRef()); + } + attrs.parsed = param; + attrs.op = op; + attrs.subgraphs = subgraphs; + SetAttrDict(&attrs); + int num_outputs = 0; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, nullptr); + if (num_outputs == 1) { + *ret = ndoutputs[0]; + } else { + std::vector ndarray_handles; + ndarray_handles.reserve(num_outputs); + for (int i = 0; i < num_outputs; ++i) { + ndarray_handles.emplace_back(ndoutputs[i]); + } + *ret = ADT(0, ndarray_handles.begin(), ndarray_handles.end()); + } + }); MXNET_REGISTER_API("_npx.while_loop") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - nnvm::NodeAttrs attrs; - const nnvm::Op* op = Op::Get("_npx_while_loop"); - op::NPXWhileLoopParam param; - int args_size = args.size(); - int num_inputs = args_size - 8; - // inputs - std::vector > subgraphs; - subgraphs.reserve(2); - for (int i = 0; i < 2; i++) { - nnvm::Symbol* sym = static_cast(args[i].value().v_handle); - subgraphs.push_back(std::make_shared(*sym)); - } - std::vector inputs; - inputs.reserve(num_inputs); - for (int i = 2; i < num_inputs + 2; ++i) { - inputs.push_back(static_cast(args[i])); - } + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + nnvm::NodeAttrs attrs; + const nnvm::Op* op = Op::Get("_npx_while_loop"); + op::NPXWhileLoopParam param; + int args_size = args.size(); + int num_inputs = args_size - 8; + // inputs + std::vector > subgraphs; + subgraphs.reserve(2); + for (int i = 0; i < 2; i++) { + nnvm::Symbol* sym = static_cast(args[i].value().v_handle); + subgraphs.push_back(std::make_shared(*sym)); + } + std::vector inputs; + inputs.reserve(num_inputs); + for (int i = 2; i < num_inputs + 2; ++i) { + inputs.push_back(static_cast(args[i])); + } - param.num_args = num_inputs; - param.max_iterations = args[2+num_inputs].operator int(); - if (args[3+num_inputs].type_code() == kDLInt) { - param.cond_input_locs = mxnet::Tuple(1, args[3+num_inputs].operator int64_t()); - } else { - param.cond_input_locs = mxnet::Tuple(args[3+num_inputs].operator ObjectRef()); - } - if (args[4+num_inputs].type_code() == kDLInt) { - param.func_input_locs = mxnet::Tuple(1, args[4+num_inputs].operator int64_t()); - } else { - param.func_input_locs = mxnet::Tuple(args[4+num_inputs].operator ObjectRef()); - } - if (args[5+num_inputs].type_code() == kDLInt) { - param.func_var_locs = mxnet::Tuple(1, args[5+num_inputs].operator int64_t()); - } else { - param.func_var_locs = mxnet::Tuple(args[5+num_inputs].operator ObjectRef()); - } - param.num_out_data = args[6+num_inputs].operator int(); - param.num_outputs = args[7+num_inputs].operator int(); - attrs.parsed = param; - attrs.op = op; - attrs.subgraphs = subgraphs; - SetAttrDict(&attrs); - int num_outputs = 0; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, nullptr); - if (num_outputs == 1) { - *ret = ndoutputs[0]; - } else { - std::vector ndarray_handles; - ndarray_handles.reserve(num_outputs); - for (int i = 0; i < num_outputs; ++i) { - ndarray_handles.emplace_back(ndoutputs[i]); - } - *ret = ADT(0, ndarray_handles.begin(), ndarray_handles.end()); - } -}); + param.num_args = num_inputs; + param.max_iterations = args[2 + num_inputs].operator int(); + if (args[3 + num_inputs].type_code() == kDLInt) { + param.cond_input_locs = mxnet::Tuple(1, args[3 + num_inputs].operator int64_t()); + } else { + param.cond_input_locs = mxnet::Tuple(args[3 + num_inputs].operator ObjectRef()); + } + if (args[4 + num_inputs].type_code() == kDLInt) { + param.func_input_locs = mxnet::Tuple(1, args[4 + num_inputs].operator int64_t()); + } else { + param.func_input_locs = mxnet::Tuple(args[4 + num_inputs].operator ObjectRef()); + } + if (args[5 + num_inputs].type_code() == kDLInt) { + param.func_var_locs = mxnet::Tuple(1, args[5 + num_inputs].operator int64_t()); + } else { + param.func_var_locs = mxnet::Tuple(args[5 + num_inputs].operator ObjectRef()); + } + param.num_out_data = args[6 + num_inputs].operator int(); + param.num_outputs = args[7 + num_inputs].operator int(); + attrs.parsed = param; + attrs.op = op; + attrs.subgraphs = subgraphs; + SetAttrDict(&attrs); + int num_outputs = 0; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, nullptr); + if (num_outputs == 1) { + *ret = ndoutputs[0]; + } else { + std::vector ndarray_handles; + ndarray_handles.reserve(num_outputs); + for (int i = 0; i < num_outputs; ++i) { + ndarray_handles.emplace_back(ndoutputs[i]); + } + *ret = ADT(0, ndarray_handles.begin(), ndarray_handles.end()); + } + }); -MXNET_REGISTER_API("_npx.cond") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { +MXNET_REGISTER_API("_npx.cond").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { using namespace runtime; nnvm::NodeAttrs attrs; const nnvm::Op* op = Op::Get("_npx_cond"); op::NPXCondParam param; - int args_size = args.size(); + int args_size = args.size(); int num_inputs = args_size - 7; // inputs std::vector > subgraphs; @@ -169,28 +168,28 @@ MXNET_REGISTER_API("_npx.cond") } param.num_args = num_inputs; - if (args[3+num_inputs].type_code() == kDLInt) { - param.cond_input_locs = mxnet::Tuple(1, args[3+num_inputs].operator int64_t()); + if (args[3 + num_inputs].type_code() == kDLInt) { + param.cond_input_locs = mxnet::Tuple(1, args[3 + num_inputs].operator int64_t()); } else { - param.cond_input_locs = mxnet::Tuple(args[3+num_inputs].operator ObjectRef()); + param.cond_input_locs = mxnet::Tuple(args[3 + num_inputs].operator ObjectRef()); } - if (args[4+num_inputs].type_code() == kDLInt) { - param.then_input_locs = mxnet::Tuple(1, args[4+num_inputs].operator int64_t()); + if (args[4 + num_inputs].type_code() == kDLInt) { + param.then_input_locs = mxnet::Tuple(1, args[4 + num_inputs].operator int64_t()); } else { - param.then_input_locs = mxnet::Tuple(args[4+num_inputs].operator ObjectRef()); + param.then_input_locs = mxnet::Tuple(args[4 + num_inputs].operator ObjectRef()); } - if (args[5+num_inputs].type_code() == kDLInt) { - param.else_input_locs = mxnet::Tuple(1, args[5+num_inputs].operator int64_t()); + if (args[5 + num_inputs].type_code() == kDLInt) { + param.else_input_locs = mxnet::Tuple(1, args[5 + num_inputs].operator int64_t()); } else { - param.else_input_locs = mxnet::Tuple(args[5+num_inputs].operator ObjectRef()); + param.else_input_locs = mxnet::Tuple(args[5 + num_inputs].operator ObjectRef()); } - param.num_outputs = args[6+num_inputs].operator int(); - attrs.parsed = param; - attrs.op = op; - attrs.subgraphs = subgraphs; + param.num_outputs = args[6 + num_inputs].operator int(); + attrs.parsed = param; + attrs.op = op; + attrs.subgraphs = subgraphs; SetAttrDict(&attrs); int num_outputs = 0; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, nullptr); + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, nullptr); if (num_outputs == 1) { *ret = ndoutputs[0]; } else { diff --git a/src/api/operator/numpy_extension/npx_convolution_op.cc b/src/api/operator/numpy_extension/npx_convolution_op.cc index adb1ec379283..9174d9190032 100644 --- a/src/api/operator/numpy_extension/npx_convolution_op.cc +++ b/src/api/operator/numpy_extension/npx_convolution_op.cc @@ -19,7 +19,8 @@ /*! * \file npx_convolution_op.cc - * \brief Implementation of the API of functions in src/operator/numpy_extension/npx_convolution_op.cc + * \brief Implementation of the API of functions in + * src/operator/numpy_extension/npx_convolution_op.cc */ #include #include @@ -63,126 +64,123 @@ inline int String2CudnnTune(const std::string& s) { } MXNET_REGISTER_API("_npx.convolution") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - nnvm::NodeAttrs attrs; - const nnvm::Op* op = Op::Get("_npx_convolution"); - op::ConvolutionParam param; - int args_size = args.size(); - // no_bias - if (args[args_size - 4].type_code() == kNull) { - param.no_bias = false; - } else { - param.no_bias = args[args_size - 4].operator bool(); - } - // inputs - int num_inputs = param.no_bias ? 2 : 3; - std::vector inputs; - inputs.reserve(num_inputs); - for (int i = 0; i < num_inputs; ++i) { - inputs.push_back(args[i].operator mxnet::NDArray*()); - } - // kernel - if (args[num_inputs].type_code() == kDLInt) { - param.kernel = TShape(1, args[num_inputs].operator int64_t()); - } else { - param.kernel = TShape(args[num_inputs].operator ObjectRef()); - } - // layout - if (args[num_inputs + 10].type_code() == kNull) { - param.layout = dmlc::nullopt; - } else { - param.layout = String2Layout(args[num_inputs + 10]); - } - // Check - if (param.kernel.ndim() == 1) { - param.layout = param.layout? param.layout.value() : mshadow::kNCW; - } else if (param.kernel.ndim() == 2) { - param.layout = param.layout ? param.layout.value() : mshadow::kNCHW; - } else { - CHECK_EQ(param.kernel.ndim(), 3U) << param.kernel.ndim() << "D convolution not supported"; - param.layout = param.layout ? param.layout.value(): mshadow::kNCDHW; - } - // stride - if (args[num_inputs + 1].type_code() == kNull) { - if (param.kernel.ndim() == 1) { - param.stride = Shape1(1); - } else if (param.kernel.ndim() == 2) { - param.stride = Shape2(1, 1); - } else { - param.stride = Shape3(1, 1, 1); - } - } else if (args[num_inputs + 1].type_code() == kDLInt) { - param.stride = TShape(1, args[num_inputs + 1].operator int64_t()); - } else { - param.stride = TShape(args[num_inputs + 1].operator ObjectRef()); - } - // dilate - if (args[num_inputs + 2].type_code() == kNull) { - if (param.kernel.ndim() == 1) { - param.dilate = Shape1(1); - } else if (param.kernel.ndim() == 2) { - param.dilate = Shape2(1, 1); - } else { - param.dilate = Shape3(1, 1, 1); - } - } else if (args[num_inputs + 2].type_code() == kDLInt) { - param.dilate = TShape(1, args[num_inputs + 2].operator int64_t()); - } else { - param.dilate = TShape(args[num_inputs + 2].operator ObjectRef()); - } - // pad - if (args[num_inputs + 3].type_code() == kNull) { - if (param.kernel.ndim() == 1) { - param.pad = Shape1(0); - } else if (param.kernel.ndim() == 2) { - param.pad = Shape2(0, 0); - } else { - param.pad = Shape3(0, 0, 0); - } - } else if (args[num_inputs + 3].type_code() == kDLInt) { - param.pad = TShape(1, args[num_inputs + 3].operator int64_t()); - } else { - param.pad = TShape(args[num_inputs + 3].operator ObjectRef()); - } - // num_filter - param.num_filter = (uint32_t) (args[num_inputs + 4].operator int()); - // num_group - param.num_group = (uint32_t) (args[num_inputs + 5].operator int()); - // workspace - param.workspace = args[num_inputs + 6].operator uint64_t(); - // cudnn_tune - if (args[num_inputs + 8].type_code() == kNull) { - param.cudnn_tune = dmlc::nullopt; - } else { - param.cudnn_tune = String2CudnnTune(args[num_inputs + 8]); - } - // cudnn_off - if (args[num_inputs + 9].type_code() == kNull) { - param.cudnn_off = false; - } else { - param.cudnn_off = args[num_inputs + 9].operator bool(); - } + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + nnvm::NodeAttrs attrs; + const nnvm::Op* op = Op::Get("_npx_convolution"); + op::ConvolutionParam param; + int args_size = args.size(); + // no_bias + if (args[args_size - 4].type_code() == kNull) { + param.no_bias = false; + } else { + param.no_bias = args[args_size - 4].operator bool(); + } + // inputs + int num_inputs = param.no_bias ? 2 : 3; + std::vector inputs; + inputs.reserve(num_inputs); + for (int i = 0; i < num_inputs; ++i) { + inputs.push_back(args[i].operator mxnet::NDArray*()); + } + // kernel + if (args[num_inputs].type_code() == kDLInt) { + param.kernel = TShape(1, args[num_inputs].operator int64_t()); + } else { + param.kernel = TShape(args[num_inputs].operator ObjectRef()); + } + // layout + if (args[num_inputs + 10].type_code() == kNull) { + param.layout = dmlc::nullopt; + } else { + param.layout = String2Layout(args[num_inputs + 10]); + } + // Check + if (param.kernel.ndim() == 1) { + param.layout = param.layout ? param.layout.value() : mshadow::kNCW; + } else if (param.kernel.ndim() == 2) { + param.layout = param.layout ? param.layout.value() : mshadow::kNCHW; + } else { + CHECK_EQ(param.kernel.ndim(), 3U) << param.kernel.ndim() << "D convolution not supported"; + param.layout = param.layout ? param.layout.value() : mshadow::kNCDHW; + } + // stride + if (args[num_inputs + 1].type_code() == kNull) { + if (param.kernel.ndim() == 1) { + param.stride = Shape1(1); + } else if (param.kernel.ndim() == 2) { + param.stride = Shape2(1, 1); + } else { + param.stride = Shape3(1, 1, 1); + } + } else if (args[num_inputs + 1].type_code() == kDLInt) { + param.stride = TShape(1, args[num_inputs + 1].operator int64_t()); + } else { + param.stride = TShape(args[num_inputs + 1].operator ObjectRef()); + } + // dilate + if (args[num_inputs + 2].type_code() == kNull) { + if (param.kernel.ndim() == 1) { + param.dilate = Shape1(1); + } else if (param.kernel.ndim() == 2) { + param.dilate = Shape2(1, 1); + } else { + param.dilate = Shape3(1, 1, 1); + } + } else if (args[num_inputs + 2].type_code() == kDLInt) { + param.dilate = TShape(1, args[num_inputs + 2].operator int64_t()); + } else { + param.dilate = TShape(args[num_inputs + 2].operator ObjectRef()); + } + // pad + if (args[num_inputs + 3].type_code() == kNull) { + if (param.kernel.ndim() == 1) { + param.pad = Shape1(0); + } else if (param.kernel.ndim() == 2) { + param.pad = Shape2(0, 0); + } else { + param.pad = Shape3(0, 0, 0); + } + } else if (args[num_inputs + 3].type_code() == kDLInt) { + param.pad = TShape(1, args[num_inputs + 3].operator int64_t()); + } else { + param.pad = TShape(args[num_inputs + 3].operator ObjectRef()); + } + // num_filter + param.num_filter = (uint32_t)(args[num_inputs + 4].operator int()); + // num_group + param.num_group = (uint32_t)(args[num_inputs + 5].operator int()); + // workspace + param.workspace = args[num_inputs + 6].operator uint64_t(); + // cudnn_tune + if (args[num_inputs + 8].type_code() == kNull) { + param.cudnn_tune = dmlc::nullopt; + } else { + param.cudnn_tune = String2CudnnTune(args[num_inputs + 8]); + } + // cudnn_off + if (args[num_inputs + 9].type_code() == kNull) { + param.cudnn_off = false; + } else { + param.cudnn_off = args[num_inputs + 9].operator bool(); + } - CHECK_EQ(param.kernel.ndim(), param.stride.ndim()) - << "Stride must have the same number of dimensions with kernel_size," - << "but kernel_size is set to " << param.kernel << " while stride is " - << param.stride; - CHECK_EQ(param.kernel.ndim(), param.dilate.ndim()) - << "Dilate must have the same number of dimensions with kernel_size," - << "but kernel_size is set to " << param.kernel << " while dilate is " - << param.dilate; - CHECK_EQ(param.kernel.ndim(), param.pad.ndim()) - << "Padding must have the same number of dimensions with kernel_size," - << "but kernel_size is set to " << param.kernel << " while padding is " - << param.pad; + CHECK_EQ(param.kernel.ndim(), param.stride.ndim()) + << "Stride must have the same number of dimensions with kernel_size," + << "but kernel_size is set to " << param.kernel << " while stride is " << param.stride; + CHECK_EQ(param.kernel.ndim(), param.dilate.ndim()) + << "Dilate must have the same number of dimensions with kernel_size," + << "but kernel_size is set to " << param.kernel << " while dilate is " << param.dilate; + CHECK_EQ(param.kernel.ndim(), param.pad.ndim()) + << "Padding must have the same number of dimensions with kernel_size," + << "but kernel_size is set to " << param.kernel << " while padding is " << param.pad; - attrs.parsed = param; - attrs.op = op; - SetAttrDict(&attrs); - int num_outputs = 0; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, nullptr); - *ret = ndoutputs[0]; -}); + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + int num_outputs = 0; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, nullptr); + *ret = ndoutputs[0]; + }); } // namespace mxnet diff --git a/src/api/operator/numpy_extension/npx_deconvolution_op.cc b/src/api/operator/numpy_extension/npx_deconvolution_op.cc index 838f4408bfa1..763d7402cfa4 100644 --- a/src/api/operator/numpy_extension/npx_deconvolution_op.cc +++ b/src/api/operator/numpy_extension/npx_deconvolution_op.cc @@ -19,7 +19,8 @@ /*! * \file npx_deconvolution_op.cc - * \brief Implementation of the API of functions in src/operator/numpy_extension/npx_deconvolution_op.cc + * \brief Implementation of the API of functions in + * src/operator/numpy_extension/npx_deconvolution_op.cc */ #include #include @@ -63,152 +64,148 @@ inline int String2CudnnTune(const std::string& s) { } MXNET_REGISTER_API("_npx.deconvolution") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - nnvm::NodeAttrs attrs; - const nnvm::Op* op = Op::Get("_npx_deconvolution"); - op::DeconvolutionParam param; - int args_size = args.size(); - // no_bias - if (args[args_size - 4].type_code() == kNull) { - param.no_bias = false; - } else { - param.no_bias = args[args_size - 4].operator bool(); - } - // inputs - int num_inputs = param.no_bias ? 2 : 3; - std::vector inputs; - inputs.reserve(num_inputs); - for (int i = 0; i < num_inputs; ++i) { - inputs.push_back(args[i].operator mxnet::NDArray*()); - } - // kernel - if (args[num_inputs].type_code() == kDLInt) { - param.kernel = TShape(1, args[num_inputs].operator int64_t()); - } else { - param.kernel = TShape(args[num_inputs].operator ObjectRef()); - } - // layout - if (args[num_inputs + 12].type_code() == kNull) { - param.layout = dmlc::nullopt; - } else { - param.layout = String2Layout(args[num_inputs + 12]); - } - // Check - if (param.kernel.ndim() == 1) { - param.layout = param.layout? param.layout.value() : mshadow::kNCW; - } else if (param.kernel.ndim() == 2) { - param.layout = param.layout ? param.layout.value() : mshadow::kNCHW; - } else { - CHECK_EQ(param.kernel.ndim(), 3U) << param.kernel.ndim() << "D convolution not supported"; - param.layout = param.layout ? param.layout.value(): mshadow::kNCDHW; - } - // stride - if (args[num_inputs + 1].type_code() == kNull) { - if (param.kernel.ndim() == 1) { - param.stride = Shape1(1); - } else if (param.kernel.ndim() == 2) { - param.stride = Shape2(1, 1); - } else { - param.stride = Shape3(1, 1, 1); - } - } else if (args[num_inputs + 1].type_code() == kDLInt) { - param.stride = TShape(1, args[num_inputs + 1].operator int64_t()); - } else { - param.stride = TShape(args[num_inputs + 1].operator ObjectRef()); - } - // dilate - if (args[num_inputs + 2].type_code() == kNull) { - if (param.kernel.ndim() == 1) { - param.dilate = Shape1(1); - } else if (param.kernel.ndim() == 2) { - param.dilate = Shape2(1, 1); - } else { - param.dilate = Shape3(1, 1, 1); - } - } else if (args[num_inputs + 2].type_code() == kDLInt) { - param.dilate = TShape(1, args[num_inputs + 2].operator int64_t()); - } else { - param.dilate = TShape(args[num_inputs + 2].operator ObjectRef()); - } - // pad - if (args[num_inputs + 3].type_code() == kNull) { - if (param.kernel.ndim() == 1) { - param.pad = Shape1(0); - } else if (param.kernel.ndim() == 2) { - param.pad = Shape2(0, 0); - } else { - param.pad = Shape3(0, 0, 0); - } - } else if (args[num_inputs + 3].type_code() == kDLInt) { - param.pad = TShape(1, args[num_inputs + 3].operator int64_t()); - } else { - param.pad = TShape(args[num_inputs + 3].operator ObjectRef()); - } - // adj - if (args[num_inputs + 4].type_code() == kNull) { - if (param.kernel.ndim() == 1) { - param.adj = Shape1(0); - } else if (param.kernel.ndim() == 2) { - param.adj = Shape2(0, 0); - } else { - param.adj = Shape3(0, 0, 0); - } - } else if (args[num_inputs + 4].type_code() == kDLInt) { - param.adj = TShape(1, args[num_inputs + 4].operator int64_t()); - } else { - param.adj = TShape(args[num_inputs + 4].operator ObjectRef()); - } - // target_shape - if (args[num_inputs + 5].type_code() != kNull) { - if (args[num_inputs + 5].type_code() == kDLInt) { - param.target_shape = TShape(1, args[num_inputs + 5].operator int64_t()); - } else { - param.target_shape = TShape(args[num_inputs + 5].operator ObjectRef()); - } - } - // num_filter - param.num_filter = (uint32_t) (args[num_inputs + 6].operator int()); - // num_group - param.num_group = (uint32_t) (args[num_inputs + 7].operator int()); - // workspace - param.workspace = args[num_inputs + 8].operator uint64_t(); - // cudnn_tune - if (args[num_inputs + 10].type_code() == kNull) { - param.cudnn_tune = dmlc::nullopt; - } else { - param.cudnn_tune = String2CudnnTune(args[num_inputs + 10]); - } - // cudnn_off - if (args[num_inputs + 11].type_code() == kNull) { - param.cudnn_off = false; - } else { - param.cudnn_off = args[num_inputs + 11].operator bool(); - } + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + nnvm::NodeAttrs attrs; + const nnvm::Op* op = Op::Get("_npx_deconvolution"); + op::DeconvolutionParam param; + int args_size = args.size(); + // no_bias + if (args[args_size - 4].type_code() == kNull) { + param.no_bias = false; + } else { + param.no_bias = args[args_size - 4].operator bool(); + } + // inputs + int num_inputs = param.no_bias ? 2 : 3; + std::vector inputs; + inputs.reserve(num_inputs); + for (int i = 0; i < num_inputs; ++i) { + inputs.push_back(args[i].operator mxnet::NDArray*()); + } + // kernel + if (args[num_inputs].type_code() == kDLInt) { + param.kernel = TShape(1, args[num_inputs].operator int64_t()); + } else { + param.kernel = TShape(args[num_inputs].operator ObjectRef()); + } + // layout + if (args[num_inputs + 12].type_code() == kNull) { + param.layout = dmlc::nullopt; + } else { + param.layout = String2Layout(args[num_inputs + 12]); + } + // Check + if (param.kernel.ndim() == 1) { + param.layout = param.layout ? param.layout.value() : mshadow::kNCW; + } else if (param.kernel.ndim() == 2) { + param.layout = param.layout ? param.layout.value() : mshadow::kNCHW; + } else { + CHECK_EQ(param.kernel.ndim(), 3U) << param.kernel.ndim() << "D convolution not supported"; + param.layout = param.layout ? param.layout.value() : mshadow::kNCDHW; + } + // stride + if (args[num_inputs + 1].type_code() == kNull) { + if (param.kernel.ndim() == 1) { + param.stride = Shape1(1); + } else if (param.kernel.ndim() == 2) { + param.stride = Shape2(1, 1); + } else { + param.stride = Shape3(1, 1, 1); + } + } else if (args[num_inputs + 1].type_code() == kDLInt) { + param.stride = TShape(1, args[num_inputs + 1].operator int64_t()); + } else { + param.stride = TShape(args[num_inputs + 1].operator ObjectRef()); + } + // dilate + if (args[num_inputs + 2].type_code() == kNull) { + if (param.kernel.ndim() == 1) { + param.dilate = Shape1(1); + } else if (param.kernel.ndim() == 2) { + param.dilate = Shape2(1, 1); + } else { + param.dilate = Shape3(1, 1, 1); + } + } else if (args[num_inputs + 2].type_code() == kDLInt) { + param.dilate = TShape(1, args[num_inputs + 2].operator int64_t()); + } else { + param.dilate = TShape(args[num_inputs + 2].operator ObjectRef()); + } + // pad + if (args[num_inputs + 3].type_code() == kNull) { + if (param.kernel.ndim() == 1) { + param.pad = Shape1(0); + } else if (param.kernel.ndim() == 2) { + param.pad = Shape2(0, 0); + } else { + param.pad = Shape3(0, 0, 0); + } + } else if (args[num_inputs + 3].type_code() == kDLInt) { + param.pad = TShape(1, args[num_inputs + 3].operator int64_t()); + } else { + param.pad = TShape(args[num_inputs + 3].operator ObjectRef()); + } + // adj + if (args[num_inputs + 4].type_code() == kNull) { + if (param.kernel.ndim() == 1) { + param.adj = Shape1(0); + } else if (param.kernel.ndim() == 2) { + param.adj = Shape2(0, 0); + } else { + param.adj = Shape3(0, 0, 0); + } + } else if (args[num_inputs + 4].type_code() == kDLInt) { + param.adj = TShape(1, args[num_inputs + 4].operator int64_t()); + } else { + param.adj = TShape(args[num_inputs + 4].operator ObjectRef()); + } + // target_shape + if (args[num_inputs + 5].type_code() != kNull) { + if (args[num_inputs + 5].type_code() == kDLInt) { + param.target_shape = TShape(1, args[num_inputs + 5].operator int64_t()); + } else { + param.target_shape = TShape(args[num_inputs + 5].operator ObjectRef()); + } + } + // num_filter + param.num_filter = (uint32_t)(args[num_inputs + 6].operator int()); + // num_group + param.num_group = (uint32_t)(args[num_inputs + 7].operator int()); + // workspace + param.workspace = args[num_inputs + 8].operator uint64_t(); + // cudnn_tune + if (args[num_inputs + 10].type_code() == kNull) { + param.cudnn_tune = dmlc::nullopt; + } else { + param.cudnn_tune = String2CudnnTune(args[num_inputs + 10]); + } + // cudnn_off + if (args[num_inputs + 11].type_code() == kNull) { + param.cudnn_off = false; + } else { + param.cudnn_off = args[num_inputs + 11].operator bool(); + } - CHECK_EQ(param.kernel.ndim(), param.stride.ndim()) - << "Stride must have the same number of dimensions with kernel_size," - << "but kernel_size is set to " << param.kernel << " while stride is " - << param.stride; - CHECK_EQ(param.kernel.ndim(), param.dilate.ndim()) - << "Dilate must have the same number of dimensions with kernel_size," - << "but kernel_size is set to " << param.kernel << " while dilate is " - << param.dilate; - CHECK_EQ(param.kernel.ndim(), param.pad.ndim()) - << "Padding must have the same number of dimensions with kernel_size," - << "but kernel_size is set to " << param.kernel << " while padding is " - << param.pad; - CHECK_EQ(param.kernel.ndim(), param.adj.ndim()) - << "Adjustment must have the same number of dimensions with kernel_size," - << "but kernel_size is set to " << param.kernel << " while adjustment is " - << param.adj; + CHECK_EQ(param.kernel.ndim(), param.stride.ndim()) + << "Stride must have the same number of dimensions with kernel_size," + << "but kernel_size is set to " << param.kernel << " while stride is " << param.stride; + CHECK_EQ(param.kernel.ndim(), param.dilate.ndim()) + << "Dilate must have the same number of dimensions with kernel_size," + << "but kernel_size is set to " << param.kernel << " while dilate is " << param.dilate; + CHECK_EQ(param.kernel.ndim(), param.pad.ndim()) + << "Padding must have the same number of dimensions with kernel_size," + << "but kernel_size is set to " << param.kernel << " while padding is " << param.pad; + CHECK_EQ(param.kernel.ndim(), param.adj.ndim()) + << "Adjustment must have the same number of dimensions with kernel_size," + << "but kernel_size is set to " << param.kernel << " while adjustment is " << param.adj; - attrs.parsed = param; - attrs.op = op; - SetAttrDict(&attrs); - int num_outputs = 0; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, nullptr); - *ret = ndoutputs[0]; -}); + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + int num_outputs = 0; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, nullptr); + *ret = ndoutputs[0]; + }); } // namespace mxnet diff --git a/src/api/operator/numpy_extension/npx_dropout_op.cc b/src/api/operator/numpy_extension/npx_dropout_op.cc index e17320f30a2e..3ccc7f62fe9b 100644 --- a/src/api/operator/numpy_extension/npx_dropout_op.cc +++ b/src/api/operator/numpy_extension/npx_dropout_op.cc @@ -42,38 +42,38 @@ inline int String2Mode(const std::string& s) { } MXNET_REGISTER_API("_npx.dropout") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - nnvm::NodeAttrs attrs; - const nnvm::Op* op = Op::Get("_npx_dropout"); - op::DropoutParam param; - // inputs - int num_inputs = 1; - NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; - // p - param.p = args[1].operator double(); - // mode - param.mode = String2Mode(args[2].operator std::string()); - // axes - if (args[3].type_code() == kNull) { - param.axes = TShape(0, 0); - } else if (args[3].type_code() == kDLInt) { - param.axes = TShape(1, args[3].operator int64_t()); - } else { - param.axes = TShape(args[3].operator ObjectRef()); - } - // cudnn_off - if (args[4].type_code() == kNull) { - param.cudnn_off = false; - } else { - param.cudnn_off = args[4].operator bool(); - } - attrs.parsed = param; - attrs.op = op; - SetAttrDict(&attrs); - int num_outputs = 1; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); - *ret = ndoutputs[0]; -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + nnvm::NodeAttrs attrs; + const nnvm::Op* op = Op::Get("_npx_dropout"); + op::DropoutParam param; + // inputs + int num_inputs = 1; + NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; + // p + param.p = args[1].operator double(); + // mode + param.mode = String2Mode(args[2].operator std::string()); + // axes + if (args[3].type_code() == kNull) { + param.axes = TShape(0, 0); + } else if (args[3].type_code() == kDLInt) { + param.axes = TShape(1, args[3].operator int64_t()); + } else { + param.axes = TShape(args[3].operator ObjectRef()); + } + // cudnn_off + if (args[4].type_code() == kNull) { + param.cudnn_off = false; + } else { + param.cudnn_off = args[4].operator bool(); + } + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + int num_outputs = 1; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = ndoutputs[0]; + }); } // namespace mxnet diff --git a/src/api/operator/numpy_extension/npx_embedding_op.cc b/src/api/operator/numpy_extension/npx_embedding_op.cc index 58b5e3ff740f..73d47c83c441 100644 --- a/src/api/operator/numpy_extension/npx_embedding_op.cc +++ b/src/api/operator/numpy_extension/npx_embedding_op.cc @@ -29,36 +29,36 @@ namespace mxnet { MXNET_REGISTER_API("_npx.embedding") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - nnvm::NodeAttrs attrs; - const nnvm::Op* op = Op::Get("_npx_embedding"); - op::EmbeddingParam param; - // inputs - int num_inputs = 2; - std::vector inputs; - inputs.reserve(num_inputs); - for (int i = 0; i < num_inputs; ++i) { - inputs.push_back(args[i].operator mxnet::NDArray*()); - } - // input_dim - param.input_dim = args[2].operator int64_t(); - // output_dim - param.output_dim = args[3].operator int64_t(); - // dtype - param.dtype = String2MXNetTypeWithBool(args[4].operator std::string()); - // sparse_grad; - if (args[5].type_code() == kNull) { - param.sparse_grad = false; - } else { - param.sparse_grad = args[5].operator bool(); - } - attrs.parsed = param; - attrs.op = op; - SetAttrDict(&attrs); - int num_outputs = 1; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, nullptr); - *ret = ndoutputs[0]; -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + nnvm::NodeAttrs attrs; + const nnvm::Op* op = Op::Get("_npx_embedding"); + op::EmbeddingParam param; + // inputs + int num_inputs = 2; + std::vector inputs; + inputs.reserve(num_inputs); + for (int i = 0; i < num_inputs; ++i) { + inputs.push_back(args[i].operator mxnet::NDArray*()); + } + // input_dim + param.input_dim = args[2].operator int64_t(); + // output_dim + param.output_dim = args[3].operator int64_t(); + // dtype + param.dtype = String2MXNetTypeWithBool(args[4].operator std::string()); + // sparse_grad; + if (args[5].type_code() == kNull) { + param.sparse_grad = false; + } else { + param.sparse_grad = args[5].operator bool(); + } + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + int num_outputs = 1; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, nullptr); + *ret = ndoutputs[0]; + }); } // namespace mxnet diff --git a/src/api/operator/numpy_extension/npx_fully_connected_op.cc b/src/api/operator/numpy_extension/npx_fully_connected_op.cc index d9ab3c02c61b..892c3e0037c9 100644 --- a/src/api/operator/numpy_extension/npx_fully_connected_op.cc +++ b/src/api/operator/numpy_extension/npx_fully_connected_op.cc @@ -19,7 +19,8 @@ /*! * \file npx_fully_connected_op.cc - * \brief Implementation of the API of functions in src/operator/numpy_extension/npx_fully_connected_op.cc + * \brief Implementation of the API of functions in + * src/operator/numpy_extension/npx_fully_connected_op.cc */ #include #include @@ -29,38 +30,38 @@ namespace mxnet { MXNET_REGISTER_API("_npx.fully_connected") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - int args_size = args.size(); - nnvm::NodeAttrs attrs; - const nnvm::Op* op = Op::Get("_npx_fully_connected"); - op::FullyConnectedParam param; - // no_bias - param.no_bias = args[args_size - 2].operator bool(); - // inputs - int num_inputs = 2; - if (param.no_bias) { - num_inputs = 2; - } else { - num_inputs = 3; - } - std::vector inputs; - inputs.reserve(num_inputs); - for (int i = 0; i < num_inputs; ++i) { - inputs.push_back(args[i].operator mxnet::NDArray*()); - } - // num_hidden - param.num_hidden = args[args_size - 3].operator int(); - // flatten - param.flatten = args[args_size - 1].operator bool(); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + int args_size = args.size(); + nnvm::NodeAttrs attrs; + const nnvm::Op* op = Op::Get("_npx_fully_connected"); + op::FullyConnectedParam param; + // no_bias + param.no_bias = args[args_size - 2].operator bool(); + // inputs + int num_inputs = 2; + if (param.no_bias) { + num_inputs = 2; + } else { + num_inputs = 3; + } + std::vector inputs; + inputs.reserve(num_inputs); + for (int i = 0; i < num_inputs; ++i) { + inputs.push_back(args[i].operator mxnet::NDArray*()); + } + // num_hidden + param.num_hidden = args[args_size - 3].operator int(); + // flatten + param.flatten = args[args_size - 1].operator bool(); - attrs.parsed = param; - attrs.op = op; - SetAttrDict(&attrs); + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); - int num_outputs = 0; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, nullptr); - *ret = ndoutputs[0]; -}); + int num_outputs = 0; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, nullptr); + *ret = ndoutputs[0]; + }); } // namespace mxnet diff --git a/src/api/operator/numpy_extension/npx_group_norm_op.cc b/src/api/operator/numpy_extension/npx_group_norm_op.cc index aff66c999b72..473e43e20616 100644 --- a/src/api/operator/numpy_extension/npx_group_norm_op.cc +++ b/src/api/operator/numpy_extension/npx_group_norm_op.cc @@ -19,7 +19,8 @@ /*! * \file npx_group_norm_op.cc - * \brief Implementation of the API of functions in src/operator/numpy_extension/npx_group_norm_op.cc + * \brief Implementation of the API of functions in + * src/operator/numpy_extension/npx_group_norm_op.cc */ #include #include @@ -29,39 +30,39 @@ namespace mxnet { MXNET_REGISTER_API("_npx.group_norm") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - nnvm::NodeAttrs attrs; - const nnvm::Op* op = Op::Get("_npx_group_norm"); - op::GroupNormParam param; - // num_groups - param.num_groups = args[3]; - // eps - param.eps = args[4].operator double(); - // output_mean_var - param.output_mean_var = args[5].operator bool(); - attrs.parsed = param; - attrs.op = op; - SetAttrDict(&attrs); - // inputs - int num_inputs = 3; - std::vector inputs; - inputs.reserve(num_inputs); - for (int i = 0; i < num_inputs; ++i) { - inputs.push_back(args[i].operator mxnet::NDArray*()); - } - int num_outputs = 0; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, nullptr); - if (num_outputs == 1) { - *ret = ndoutputs[0]; - } else { - std::vector ndarray_handles; - ndarray_handles.reserve(num_outputs); - for (int i = 0; i < num_outputs; ++i) { - ndarray_handles.emplace_back(ndoutputs[i]); - } - *ret = ADT(0, ndarray_handles.begin(), ndarray_handles.end()); - } -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + nnvm::NodeAttrs attrs; + const nnvm::Op* op = Op::Get("_npx_group_norm"); + op::GroupNormParam param; + // num_groups + param.num_groups = args[3]; + // eps + param.eps = args[4].operator double(); + // output_mean_var + param.output_mean_var = args[5].operator bool(); + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + // inputs + int num_inputs = 3; + std::vector inputs; + inputs.reserve(num_inputs); + for (int i = 0; i < num_inputs; ++i) { + inputs.push_back(args[i].operator mxnet::NDArray*()); + } + int num_outputs = 0; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, nullptr); + if (num_outputs == 1) { + *ret = ndoutputs[0]; + } else { + std::vector ndarray_handles; + ndarray_handles.reserve(num_outputs); + for (int i = 0; i < num_outputs; ++i) { + ndarray_handles.emplace_back(ndoutputs[i]); + } + *ret = ADT(0, ndarray_handles.begin(), ndarray_handles.end()); + } + }); } // namespace mxnet diff --git a/src/api/operator/numpy_extension/npx_layer_norm_op.cc b/src/api/operator/numpy_extension/npx_layer_norm_op.cc index b638088d328d..6b79a95f7237 100644 --- a/src/api/operator/numpy_extension/npx_layer_norm_op.cc +++ b/src/api/operator/numpy_extension/npx_layer_norm_op.cc @@ -19,7 +19,8 @@ /*! * \file npx_layer_norm_op.cc - * \brief Implementation of the API of functions in src/operator/numpy_extension/npx_layer_norm_op.cc + * \brief Implementation of the API of functions in + * src/operator/numpy_extension/npx_layer_norm_op.cc */ #include #include @@ -29,51 +30,51 @@ namespace mxnet { MXNET_REGISTER_API("_npx.layer_norm") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - nnvm::NodeAttrs attrs; - const nnvm::Op* op = Op::Get("_npx_layer_norm"); - op::LayerNormParam param; - // inputs - int num_inputs = 3; - std::vector inputs; - inputs.reserve(num_inputs); - for (int i = 0; i < num_inputs; ++i) { - inputs.push_back(args[i].operator mxnet::NDArray*()); - } - // axis - if (args[3].type_code() == kNull) { - param.axis = -1; - } else { - param.axis = args[3].operator int(); - } - // eps - if (args[4].type_code() == kNull) { - param.eps = 1e-5f; - } else { - param.eps = args[4].operator double(); - } - // output_mean_var - if (args[5].type_code() == kNull) { - param.output_mean_var = false; - } else { - param.output_mean_var = args[5].operator bool(); - } - attrs.parsed = param; - attrs.op = op; - SetAttrDict(&attrs); - int num_outputs = 3; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, nullptr); - if (num_outputs == 1) { - *ret = ndoutputs[0]; - } else { - std::vector ndarray_handles; - ndarray_handles.reserve(num_outputs); - for (int i = 0; i < num_outputs; ++i) { - ndarray_handles.emplace_back(ndoutputs[i]); - } - *ret = ADT(0, ndarray_handles.begin(), ndarray_handles.end()); - } -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + nnvm::NodeAttrs attrs; + const nnvm::Op* op = Op::Get("_npx_layer_norm"); + op::LayerNormParam param; + // inputs + int num_inputs = 3; + std::vector inputs; + inputs.reserve(num_inputs); + for (int i = 0; i < num_inputs; ++i) { + inputs.push_back(args[i].operator mxnet::NDArray*()); + } + // axis + if (args[3].type_code() == kNull) { + param.axis = -1; + } else { + param.axis = args[3].operator int(); + } + // eps + if (args[4].type_code() == kNull) { + param.eps = 1e-5f; + } else { + param.eps = args[4].operator double(); + } + // output_mean_var + if (args[5].type_code() == kNull) { + param.output_mean_var = false; + } else { + param.output_mean_var = args[5].operator bool(); + } + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + int num_outputs = 3; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, nullptr); + if (num_outputs == 1) { + *ret = ndoutputs[0]; + } else { + std::vector ndarray_handles; + ndarray_handles.reserve(num_outputs); + for (int i = 0; i < num_outputs; ++i) { + ndarray_handles.emplace_back(ndoutputs[i]); + } + *ret = ADT(0, ndarray_handles.begin(), ndarray_handles.end()); + } + }); } // namespace mxnet diff --git a/src/api/operator/numpy_extension/npx_leaky_relu_op.cc b/src/api/operator/numpy_extension/npx_leaky_relu_op.cc index 7717cf79c8ab..d4723bf46852 100644 --- a/src/api/operator/numpy_extension/npx_leaky_relu_op.cc +++ b/src/api/operator/numpy_extension/npx_leaky_relu_op.cc @@ -19,7 +19,8 @@ /*! * \file npx_leaky_relu_op.cc - * \brief Implementation of the API of functions in src/operator/numpy_extension/npx_leaky_relu_op.cc + * \brief Implementation of the API of functions in + * src/operator/numpy_extension/npx_leaky_relu_op.cc */ #include #include @@ -50,55 +51,55 @@ inline int String2ActType(const std::string& s) { } MXNET_REGISTER_API("_npx.leaky_relu") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - nnvm::NodeAttrs attrs; - const nnvm::Op* op = Op::Get("_npx_leaky_relu"); - op::LeakyReLUParam param; - int args_size = args.size(); - // act_type - param.act_type = String2ActType(args[args_size - 4].operator std::string()); - // inputs - int num_inputs = param.act_type == op::leakyrelu::kPReLU ? 2 : 1; - int num_outputs = param.act_type == op::leakyrelu::kPReLU ? 2 : 1; - std::vector inputs; - inputs.reserve(num_inputs); - for (int i = 0; i < num_inputs; ++i) { - inputs.push_back(args[i].operator mxnet::NDArray*()); - } - // slope - if (args[args_size - 3].type_code() == kNull) { - param.slope = 0.25f; - } else { - param.slope = args[args_size - 3].operator double(); - } - // lower_bound - if (args[args_size - 2].type_code() == kNull) { - param.lower_bound = 0.125f; - } else { - param.lower_bound = args[args_size - 2].operator double(); - } - // upper_bound - if (args[args_size - 1].type_code() == kNull) { - param.upper_bound = 0.334f; - } else { - param.upper_bound = args[args_size - 1].operator double(); - } - attrs.parsed = param; - attrs.op = op; - SetAttrDict(&attrs); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + nnvm::NodeAttrs attrs; + const nnvm::Op* op = Op::Get("_npx_leaky_relu"); + op::LeakyReLUParam param; + int args_size = args.size(); + // act_type + param.act_type = String2ActType(args[args_size - 4].operator std::string()); + // inputs + int num_inputs = param.act_type == op::leakyrelu::kPReLU ? 2 : 1; + int num_outputs = param.act_type == op::leakyrelu::kPReLU ? 2 : 1; + std::vector inputs; + inputs.reserve(num_inputs); + for (int i = 0; i < num_inputs; ++i) { + inputs.push_back(args[i].operator mxnet::NDArray*()); + } + // slope + if (args[args_size - 3].type_code() == kNull) { + param.slope = 0.25f; + } else { + param.slope = args[args_size - 3].operator double(); + } + // lower_bound + if (args[args_size - 2].type_code() == kNull) { + param.lower_bound = 0.125f; + } else { + param.lower_bound = args[args_size - 2].operator double(); + } + // upper_bound + if (args[args_size - 1].type_code() == kNull) { + param.upper_bound = 0.334f; + } else { + param.upper_bound = args[args_size - 1].operator double(); + } + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, nullptr); - if (num_outputs == 1) { - *ret = ndoutputs[0]; - } else { - std::vector ndarray_handles; - ndarray_handles.reserve(num_outputs); - for (int i = 0; i < num_outputs; ++i) { - ndarray_handles.emplace_back(ndoutputs[i]); - } - *ret = ADT(0, ndarray_handles.begin(), ndarray_handles.end()); - } -}); + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, nullptr); + if (num_outputs == 1) { + *ret = ndoutputs[0]; + } else { + std::vector ndarray_handles; + ndarray_handles.reserve(num_outputs); + for (int i = 0; i < num_outputs; ++i) { + ndarray_handles.emplace_back(ndoutputs[i]); + } + *ret = ADT(0, ndarray_handles.begin(), ndarray_handles.end()); + } + }); } // namespace mxnet diff --git a/src/api/operator/numpy_extension/npx_one_hot_op.cc b/src/api/operator/numpy_extension/npx_one_hot_op.cc index 090d56e3b22e..e8d66af0d4de 100644 --- a/src/api/operator/numpy_extension/npx_one_hot_op.cc +++ b/src/api/operator/numpy_extension/npx_one_hot_op.cc @@ -29,38 +29,38 @@ namespace mxnet { MXNET_REGISTER_API("_npx.one_hot") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - nnvm::NodeAttrs attrs; - const nnvm::Op* op = Op::Get("_npx_one_hot"); - op::OneHotParam param; - // inputs - int num_inputs = 1; - NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; - // depth - param.depth = args[1].operator int64_t(); - // on_value - if (args[2].type_code() == kNull) { - param.on_value = 1.0; - } else { - param.on_value = args[2].operator double(); - } - // off_value - if (args[3].type_code() == kNull) { - param.off_value = 0.0; - } else { - param.off_value = args[3].operator double(); - } - // dtype - if (args[4].type_code() != kNull) { - param.dtype = String2MXNetTypeWithBool(args[4].operator std::string()); - } - attrs.parsed = param; - attrs.op = op; - SetAttrDict(&attrs); - int num_outputs = 1; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); - *ret = ndoutputs[0]; -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + nnvm::NodeAttrs attrs; + const nnvm::Op* op = Op::Get("_npx_one_hot"); + op::OneHotParam param; + // inputs + int num_inputs = 1; + NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; + // depth + param.depth = args[1].operator int64_t(); + // on_value + if (args[2].type_code() == kNull) { + param.on_value = 1.0; + } else { + param.on_value = args[2].operator double(); + } + // off_value + if (args[3].type_code() == kNull) { + param.off_value = 0.0; + } else { + param.off_value = args[3].operator double(); + } + // dtype + if (args[4].type_code() != kNull) { + param.dtype = String2MXNetTypeWithBool(args[4].operator std::string()); + } + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + int num_outputs = 1; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = ndoutputs[0]; + }); } // namespace mxnet diff --git a/src/api/operator/numpy_extension/npx_pick_op.cc b/src/api/operator/numpy_extension/npx_pick_op.cc index 423a91f41cfe..22cbc84ec44a 100644 --- a/src/api/operator/numpy_extension/npx_pick_op.cc +++ b/src/api/operator/numpy_extension/npx_pick_op.cc @@ -41,8 +41,7 @@ inline int String2PickMode(const std::string& s) { return 0; } -MXNET_REGISTER_API("_npx.pick") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { +MXNET_REGISTER_API("_npx.pick").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { using namespace runtime; nnvm::NodeAttrs attrs; const nnvm::Op* op = Op::Get("_npx_pick"); @@ -62,7 +61,7 @@ MXNET_REGISTER_API("_npx.pick") param.keepdims = args[4].operator bool(); } attrs.parsed = param; - attrs.op = op; + attrs.op = op; SetAttrDict(&attrs); // inputs int num_inputs = 2; @@ -72,8 +71,8 @@ MXNET_REGISTER_API("_npx.pick") inputs.push_back(args[i].operator mxnet::NDArray*()); } int num_outputs = 0; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, nullptr); - *ret = ndoutputs[0]; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, nullptr); + *ret = ndoutputs[0]; }); } // namespace mxnet diff --git a/src/api/operator/numpy_extension/npx_pooling_op.cc b/src/api/operator/numpy_extension/npx_pooling_op.cc index 923e116f2a0f..0b743bda9909 100644 --- a/src/api/operator/numpy_extension/npx_pooling_op.cc +++ b/src/api/operator/numpy_extension/npx_pooling_op.cc @@ -82,99 +82,100 @@ inline int String2Convention(const std::string& s) { } MXNET_REGISTER_API("_npx.pooling") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - nnvm::NodeAttrs attrs; - const nnvm::Op* op = Op::Get("_npx_pooling"); - op::PoolingParam param; - // inputs - int num_inputs = 1; - NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + nnvm::NodeAttrs attrs; + const nnvm::Op* op = Op::Get("_npx_pooling"); + op::PoolingParam param; + // inputs + int num_inputs = 1; + NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; - // kernel - if (args[1].type_code() == kDLInt) { - param.kernel = TShape(1, args[1].operator int64_t()); - } else { - param.kernel = TShape(args[1].operator ObjectRef()); - } - // global pool - param.global_pool = args[6].operator bool(); - // stride - if (args[2].type_code() == kNull) { - if (param.kernel.ndim() == 1) { - param.stride = mshadow::Shape1(1); - } else if (param.kernel.ndim() == 2) { - param.stride = mshadow::Shape2(1, 1); - } else { - if (param.global_pool == false) { - CHECK_EQ(param.kernel.ndim(), 3U) << param.kernel.ndim() - << "D pooling not supported. Only 1D, 2D, and 3D pooling are supported."; + // kernel + if (args[1].type_code() == kDLInt) { + param.kernel = TShape(1, args[1].operator int64_t()); + } else { + param.kernel = TShape(args[1].operator ObjectRef()); + } + // global pool + param.global_pool = args[6].operator bool(); + // stride + if (args[2].type_code() == kNull) { + if (param.kernel.ndim() == 1) { + param.stride = mshadow::Shape1(1); + } else if (param.kernel.ndim() == 2) { + param.stride = mshadow::Shape2(1, 1); + } else { + if (param.global_pool == false) { + CHECK_EQ(param.kernel.ndim(), 3U) + << param.kernel.ndim() + << "D pooling not supported. Only 1D, 2D, and 3D pooling are supported."; + } + param.stride = mshadow::Shape3(1, 1, 1); + } + } else if (args[2].type_code() == kDLInt) { + param.stride = TShape(1, args[2].operator int64_t()); + } else { + param.stride = TShape(args[2].operator ObjectRef()); + } + // pad + if (args[3].type_code() == kNull) { + if (param.kernel.ndim() == 1) { + param.pad = mshadow::Shape1(0); + } else if (param.kernel.ndim() == 2) { + param.pad = mshadow::Shape2(0, 0); + } else { + param.pad = mshadow::Shape3(0, 0, 0); + } + } else if (args[3].type_code() == kDLInt) { + param.pad = TShape(1, args[3].operator int64_t()); + } else { + param.pad = TShape(args[3].operator ObjectRef()); + } + // pool type + param.pool_type = String2PoolType(args[4].operator std::string()); + // pooling convention + param.pooling_convention = String2Convention(args[5].operator std::string()); + // cudnn_off + if (args[7].type_code() == kNull) { + param.cudnn_off = false; + } else { + param.cudnn_off = args[7].operator bool(); + } + // p_value + if (args[8].type_code() == kNull) { + param.p_value = dmlc::nullopt; + } else { + param.p_value = args[8].operator int(); + } + // count_include_pad + if (args[9].type_code() == kNull) { + param.count_include_pad = dmlc::nullopt; + } else { + param.count_include_pad = args[9].operator bool(); + } + // layout + if (args[10].type_code() == kNull) { + param.layout = dmlc::nullopt; + } else { + param.layout = String2PoolingLayout(args[10]); } - param.stride = mshadow::Shape3(1, 1, 1); - } - } else if (args[2].type_code() == kDLInt) { - param.stride = TShape(1, args[2].operator int64_t()); - } else { - param.stride = TShape(args[2].operator ObjectRef()); - } - // pad - if (args[3].type_code() == kNull) { - if (param.kernel.ndim() == 1) { - param.pad = mshadow::Shape1(0); - } else if (param.kernel.ndim() == 2) { - param.pad = mshadow::Shape2(0, 0); - } else { - param.pad = mshadow::Shape3(0, 0, 0); - } - } else if (args[3].type_code() == kDLInt) { - param.pad = TShape(1, args[3].operator int64_t()); - } else { - param.pad = TShape(args[3].operator ObjectRef()); - } - // pool type - param.pool_type = String2PoolType(args[4].operator std::string()); - // pooling convention - param.pooling_convention = String2Convention(args[5].operator std::string()); - // cudnn_off - if (args[7].type_code() == kNull) { - param.cudnn_off = false; - } else { - param.cudnn_off = args[7].operator bool(); - } - // p_value - if (args[8].type_code() == kNull) { - param.p_value = dmlc::nullopt; - } else { - param.p_value = args[8].operator int(); - } - // count_include_pad - if (args[9].type_code() == kNull) { - param.count_include_pad = dmlc::nullopt; - } else { - param.count_include_pad = args[9].operator bool(); - } - // layout - if (args[10].type_code() == kNull) { - param.layout = dmlc::nullopt; - } else { - param.layout = String2PoolingLayout(args[10]); - } - attrs.parsed = param; - attrs.op = op; - SetAttrDict(&attrs); - int num_outputs = 0; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); - if (num_outputs == 1) { - *ret = ndoutputs[0]; - } else { - std::vector ndarray_handles; - ndarray_handles.reserve(num_outputs); - for (int i = 0; i < num_outputs; ++i) { - ndarray_handles.emplace_back(ndoutputs[i]); - } - *ret = ADT(0, ndarray_handles.begin(), ndarray_handles.end()); - } -}); + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + int num_outputs = 0; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + if (num_outputs == 1) { + *ret = ndoutputs[0]; + } else { + std::vector ndarray_handles; + ndarray_handles.reserve(num_outputs); + for (int i = 0; i < num_outputs; ++i) { + ndarray_handles.emplace_back(ndoutputs[i]); + } + *ret = ADT(0, ndarray_handles.begin(), ndarray_handles.end()); + } + }); } // namespace mxnet diff --git a/src/api/operator/numpy_extension/npx_rnn_op.cc b/src/api/operator/numpy_extension/npx_rnn_op.cc index 6d94b390c4d2..7d75e13dfb5e 100644 --- a/src/api/operator/numpy_extension/npx_rnn_op.cc +++ b/src/api/operator/numpy_extension/npx_rnn_op.cc @@ -45,13 +45,12 @@ inline int String2ComputeMode(const std::string& s) { return 0; } -MXNET_REGISTER_API("_npx.rnn") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { +MXNET_REGISTER_API("_npx.rnn").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { using namespace runtime; nnvm::NodeAttrs attrs; const nnvm::Op* op = Op::Get("_npx_rnn"); op::RNNParam param; - int args_size = args.size(); + int args_size = args.size(); int num_inputs = 0; // mode @@ -63,7 +62,8 @@ MXNET_REGISTER_API("_npx.rnn") } else { param.use_sequence_length = args[args_size - 5].operator bool(); } - if (param.use_sequence_length) num_inputs += 1; + if (param.use_sequence_length) + num_inputs += 1; // inputs std::vector inputs; inputs.reserve(num_inputs); @@ -71,9 +71,9 @@ MXNET_REGISTER_API("_npx.rnn") inputs.push_back(args[i].operator mxnet::NDArray*()); } // state_size - param.state_size = (uint32_t) (args[args_size - 11].operator int()); + param.state_size = (uint32_t)(args[args_size - 11].operator int()); // num_layers - param.num_layers = (uint32_t) (args[args_size - 10].operator int()); + param.num_layers = (uint32_t)(args[args_size - 10].operator int()); // bidirectional if (args[args_size - 9].type_code() == kNull) { param.bidirectional = false; @@ -120,11 +120,11 @@ MXNET_REGISTER_API("_npx.rnn") param.seq_length_ = 0; param.batch_size_ = 0; param.input_size_ = 0; - attrs.parsed = param; - attrs.op = op; + attrs.parsed = param; + attrs.op = op; SetAttrDict(&attrs); int num_outputs = 0; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, nullptr); + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, nullptr); if (num_outputs == 1) { *ret = ndoutputs[0]; } else { diff --git a/src/api/operator/numpy_extension/npx_softmax_op.cc b/src/api/operator/numpy_extension/npx_softmax_op.cc index 6e934ed4a64f..6c8f9f438499 100644 --- a/src/api/operator/numpy_extension/npx_softmax_op.cc +++ b/src/api/operator/numpy_extension/npx_softmax_op.cc @@ -29,200 +29,200 @@ namespace mxnet { MXNET_REGISTER_API("_npx.softmax") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - nnvm::NodeAttrs attrs; - static const nnvm::Op* op = Op::Get("_npx_softmax"); - op::SoftmaxParam param; - int args_size = args.size(); - // inputs - int num_inputs = args_size - 4; - std::vector inputs; - inputs.reserve(num_inputs); - for (int i = 0; i < num_inputs; ++i) { - inputs.push_back(args[i].operator mxnet::NDArray*()); - } - - // parse use_length - if (args[args_size - 2].type_code() == kNull) { - param.use_length = false; - } else { - param.use_length = args[args_size - 2].operator bool(); - } - - // parse axis - if (args[args_size - 4].type_code() == kDLInt) { - param.axis = args[args_size - 4].operator int(); - } else if (args[args_size - 4].type_code() == kDLFloat) { - param.axis = static_cast(args[args_size - 4].operator double()); - } else { - param.axis = -1; - } - - // parse temperature - if (args[args_size - 3].type_code() == kNull) { - param.temperature = dmlc::nullopt; - } else { - param.temperature = args[args_size - 3].operator double(); - } - - // parse dtype - if (args[args_size - 1].type_code() == kNull) { - param.dtype = dmlc::nullopt; - } else { - param.dtype = String2MXNetTypeWithBool(args[args_size - 1].operator std::string()); - } - - attrs.parsed = param; - attrs.op = op; - SetAttrDict(&attrs); - - int num_outputs = 0; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, nullptr); - *ret = ndoutputs[0]; -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + nnvm::NodeAttrs attrs; + static const nnvm::Op* op = Op::Get("_npx_softmax"); + op::SoftmaxParam param; + int args_size = args.size(); + // inputs + int num_inputs = args_size - 4; + std::vector inputs; + inputs.reserve(num_inputs); + for (int i = 0; i < num_inputs; ++i) { + inputs.push_back(args[i].operator mxnet::NDArray*()); + } + + // parse use_length + if (args[args_size - 2].type_code() == kNull) { + param.use_length = false; + } else { + param.use_length = args[args_size - 2].operator bool(); + } + + // parse axis + if (args[args_size - 4].type_code() == kDLInt) { + param.axis = args[args_size - 4].operator int(); + } else if (args[args_size - 4].type_code() == kDLFloat) { + param.axis = static_cast(args[args_size - 4].operator double()); + } else { + param.axis = -1; + } + + // parse temperature + if (args[args_size - 3].type_code() == kNull) { + param.temperature = dmlc::nullopt; + } else { + param.temperature = args[args_size - 3].operator double(); + } + + // parse dtype + if (args[args_size - 1].type_code() == kNull) { + param.dtype = dmlc::nullopt; + } else { + param.dtype = String2MXNetTypeWithBool(args[args_size - 1].operator std::string()); + } + + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + + int num_outputs = 0; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, nullptr); + *ret = ndoutputs[0]; + }); MXNET_REGISTER_API("_npx.log_softmax") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - nnvm::NodeAttrs attrs; - static const nnvm::Op* op = Op::Get("_npx_log_softmax"); - op::SoftmaxParam param; - - int args_size = args.size(); - // inputs - int num_inputs = args_size - 4; - std::vector inputs; - inputs.reserve(num_inputs); - for (int i = 0; i < num_inputs; ++i) { - inputs.push_back(args[i].operator mxnet::NDArray*()); - } - - // parse use_length - if (args[args_size - 2].type_code() == kNull) { - param.use_length = false; - } else { - param.use_length = args[args_size - 2].operator bool(); - } - - // parse axis - if (args[args_size - 4].type_code() == kDLInt) { - param.axis = args[args_size - 4].operator int(); - } else if (args[args_size - 4].type_code() == kDLFloat) { - param.axis = static_cast(args[args_size - 4].operator double()); - } else { - param.axis = -1; - } - - // parse temperature - if (args[args_size - 3].type_code() == kNull) { - param.temperature = dmlc::nullopt; - } else { - param.temperature = args[args_size - 3].operator double(); - } - - // parse dtype - if (args[args_size - 1].type_code() == kNull) { - param.dtype = dmlc::nullopt; - } else { - param.dtype = String2MXNetTypeWithBool(args[args_size - 1].operator std::string()); - } - - attrs.parsed = param; - attrs.op = op; - SetAttrDict(&attrs); - - int num_outputs = 0; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, nullptr); - *ret = ndoutputs[0]; -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + nnvm::NodeAttrs attrs; + static const nnvm::Op* op = Op::Get("_npx_log_softmax"); + op::SoftmaxParam param; + + int args_size = args.size(); + // inputs + int num_inputs = args_size - 4; + std::vector inputs; + inputs.reserve(num_inputs); + for (int i = 0; i < num_inputs; ++i) { + inputs.push_back(args[i].operator mxnet::NDArray*()); + } + + // parse use_length + if (args[args_size - 2].type_code() == kNull) { + param.use_length = false; + } else { + param.use_length = args[args_size - 2].operator bool(); + } + + // parse axis + if (args[args_size - 4].type_code() == kDLInt) { + param.axis = args[args_size - 4].operator int(); + } else if (args[args_size - 4].type_code() == kDLFloat) { + param.axis = static_cast(args[args_size - 4].operator double()); + } else { + param.axis = -1; + } + + // parse temperature + if (args[args_size - 3].type_code() == kNull) { + param.temperature = dmlc::nullopt; + } else { + param.temperature = args[args_size - 3].operator double(); + } + + // parse dtype + if (args[args_size - 1].type_code() == kNull) { + param.dtype = dmlc::nullopt; + } else { + param.dtype = String2MXNetTypeWithBool(args[args_size - 1].operator std::string()); + } + + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + + int num_outputs = 0; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, nullptr); + *ret = ndoutputs[0]; + }); MXNET_REGISTER_API("_npx.masked_softmax") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - nnvm::NodeAttrs attrs; - static const nnvm::Op* op = Op::Get("_npx_masked_softmax"); - op::MaskedSoftmaxParam param; - - // inputs - int num_inputs = 2; - std::vector inputs; - inputs.reserve(num_inputs); - for (int i = 0; i < num_inputs; ++i) { - inputs.push_back(args[i].operator mxnet::NDArray*()); - } - // parse axis - if (args[2].type_code() == kDLInt) { - param.axis = args[2].operator int(); - } else if (args[2].type_code() == kDLFloat) { - param.axis = static_cast(args[2].operator double()); - } else { - param.axis = -1; - } - // parse temperature - if (args[3].type_code() == kNull) { - param.temperature = dmlc::nullopt; - } else { - param.temperature = args[3].operator double(); - } - // parse normalize - if (args[4].type_code() == kNull) { - param.normalize = true; - } else { - param.normalize = args[4].operator bool(); - } - - attrs.parsed = param; - attrs.op = op; - SetAttrDict(&attrs); - - int num_outputs = 0; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, nullptr); - *ret = ndoutputs[0]; -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + nnvm::NodeAttrs attrs; + static const nnvm::Op* op = Op::Get("_npx_masked_softmax"); + op::MaskedSoftmaxParam param; + + // inputs + int num_inputs = 2; + std::vector inputs; + inputs.reserve(num_inputs); + for (int i = 0; i < num_inputs; ++i) { + inputs.push_back(args[i].operator mxnet::NDArray*()); + } + // parse axis + if (args[2].type_code() == kDLInt) { + param.axis = args[2].operator int(); + } else if (args[2].type_code() == kDLFloat) { + param.axis = static_cast(args[2].operator double()); + } else { + param.axis = -1; + } + // parse temperature + if (args[3].type_code() == kNull) { + param.temperature = dmlc::nullopt; + } else { + param.temperature = args[3].operator double(); + } + // parse normalize + if (args[4].type_code() == kNull) { + param.normalize = true; + } else { + param.normalize = args[4].operator bool(); + } + + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + + int num_outputs = 0; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, nullptr); + *ret = ndoutputs[0]; + }); MXNET_REGISTER_API("_npx.masked_log_softmax") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - nnvm::NodeAttrs attrs; - static const nnvm::Op* op = Op::Get("_npx_masked_log_softmax"); - op::MaskedSoftmaxParam param; - - // inputs - int num_inputs = 2; - std::vector inputs; - inputs.reserve(num_inputs); - for (int i = 0; i < num_inputs; ++i) { - inputs.push_back(args[i].operator mxnet::NDArray*()); - } - // parse axis - if (args[2].type_code() == kDLInt) { - param.axis = args[2].operator int(); - } else if (args[2].type_code() == kDLFloat) { - param.axis = static_cast(args[2].operator double()); - } else { - param.axis = -1; - } - // parse temperature - if (args[3].type_code() == kNull) { - param.temperature = dmlc::nullopt; - } else { - param.temperature = args[3].operator double(); - } - // parse normalize - if (args[4].type_code() == kNull) { - param.normalize = true; - } else { - param.normalize = args[4].operator bool(); - } - - attrs.parsed = param; - attrs.op = op; - SetAttrDict(&attrs); - - int num_outputs = 0; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, nullptr); - *ret = ndoutputs[0]; -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + nnvm::NodeAttrs attrs; + static const nnvm::Op* op = Op::Get("_npx_masked_log_softmax"); + op::MaskedSoftmaxParam param; + + // inputs + int num_inputs = 2; + std::vector inputs; + inputs.reserve(num_inputs); + for (int i = 0; i < num_inputs; ++i) { + inputs.push_back(args[i].operator mxnet::NDArray*()); + } + // parse axis + if (args[2].type_code() == kDLInt) { + param.axis = args[2].operator int(); + } else if (args[2].type_code() == kDLFloat) { + param.axis = static_cast(args[2].operator double()); + } else { + param.axis = -1; + } + // parse temperature + if (args[3].type_code() == kNull) { + param.temperature = dmlc::nullopt; + } else { + param.temperature = args[3].operator double(); + } + // parse normalize + if (args[4].type_code() == kNull) { + param.normalize = true; + } else { + param.normalize = args[4].operator bool(); + } + + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + + int num_outputs = 0; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, nullptr); + *ret = ndoutputs[0]; + }); } // namespace mxnet diff --git a/src/api/operator/numpy_extension/npx_topk_op.cc b/src/api/operator/numpy_extension/npx_topk_op.cc index 6fcea5ae5591..af200f59e5f8 100644 --- a/src/api/operator/numpy_extension/npx_topk_op.cc +++ b/src/api/operator/numpy_extension/npx_topk_op.cc @@ -45,15 +45,14 @@ inline int String2ReturnType(const std::string& s) { return 0; } -MXNET_REGISTER_API("_npx.topk") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { +MXNET_REGISTER_API("_npx.topk").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { using namespace runtime; nnvm::NodeAttrs attrs; const nnvm::Op* op = Op::Get("_npx_topk"); op::TopKParam param; // inputs - int num_inputs = 1; - NDArray* inputs[] = {args[0].operator mxnet::NDArray *()}; + int num_inputs = 1; + NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; // axis if (args[1].type_code() == kNull) { param.axis = dmlc::nullopt; @@ -75,12 +74,12 @@ MXNET_REGISTER_API("_npx.topk") param.is_ascend = args[4].operator bool(); } // dtype - param.dtype = String2MXNetTypeWithBool(args[5].operator std::string()); + param.dtype = String2MXNetTypeWithBool(args[5].operator std::string()); attrs.parsed = param; - attrs.op = op; + attrs.op = op; SetAttrDict(&attrs); int num_outputs = 1; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); if (num_outputs == 1) { *ret = ndoutputs[0]; } else { diff --git a/src/api/operator/op_utils.cc b/src/api/operator/op_utils.cc index 1cf813eb8688..2424ef67d7ab 100644 --- a/src/api/operator/op_utils.cc +++ b/src/api/operator/op_utils.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY diff --git a/src/api/operator/op_utils.h b/src/api/operator/op_utils.h index 285919cd14c4..b50a4f6a6fb8 100644 --- a/src/api/operator/op_utils.h +++ b/src/api/operator/op_utils.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY diff --git a/src/api/operator/random/np_gamma_op.cc b/src/api/operator/random/np_gamma_op.cc index 2778ff6450e6..a543e2b6c4d3 100644 --- a/src/api/operator/random/np_gamma_op.cc +++ b/src/api/operator/random/np_gamma_op.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -29,8 +29,7 @@ namespace mxnet { -MXNET_REGISTER_API("_npi.gamma") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { +MXNET_REGISTER_API("_npi.gamma").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { using namespace runtime; const nnvm::Op* op = Op::Get("_npi_gamma"); nnvm::NodeAttrs attrs; @@ -53,7 +52,7 @@ MXNET_REGISTER_API("_npi.gamma") } } else { // 'shape' is numeric types but 'scale' is not - num_inputs = 1; + num_inputs = 1; param.scale = dmlc::nullopt; inputs.push_back(args[1].operator mxnet::NDArray*()); } @@ -70,15 +69,14 @@ MXNET_REGISTER_API("_npi.gamma") } } else { // nither 'shape' or 'scale' is numeric types - num_inputs = 2; + num_inputs = 2; param.scale = dmlc::nullopt; inputs.push_back(args[1].operator mxnet::NDArray*()); } } if (args[2].type_code() == kNull) { param.size = dmlc::optional>(); - } else if (args[2].type_code() == kDLInt || - args[2].type_code() == kDLFloat) { + } else if (args[2].type_code() == kDLInt || args[2].type_code() == kDLFloat) { param.size = Tuple(1, args[2].operator int64_t()); } else { param.size = Tuple(args[2].operator ObjectRef()); @@ -88,11 +86,11 @@ MXNET_REGISTER_API("_npi.gamma") } else { param.dtype = String2MXNetTypeWithBool(args[4].operator std::string()); } - NDArray* out = args[5].operator mxnet::NDArray*(); + NDArray* out = args[5].operator mxnet::NDArray*(); NDArray** outputs = out == nullptr ? nullptr : &out; - int num_outputs = out != nullptr; - attrs.parsed = param; - attrs.op = op; + int num_outputs = out != nullptr; + attrs.parsed = param; + attrs.op = op; if (args[3].type_code() != kNull) { attrs.dict["ctx"] = args[3].operator std::string(); } diff --git a/src/api/operator/random/np_normal_op.cc b/src/api/operator/random/np_normal_op.cc index 08cf4d0ec644..5fd22eed8048 100644 --- a/src/api/operator/random/np_normal_op.cc +++ b/src/api/operator/random/np_normal_op.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -30,69 +30,67 @@ namespace mxnet { MXNET_REGISTER_API("_npi.normal") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_normal"); - nnvm::NodeAttrs attrs; - op::NumpyNormalParam param; - int num_inputs = 0; - std::vector inputs; - if (args[0].type_code() == kDLFloat || args[0].type_code() == kDLInt) { - if (args[1].type_code() == kDLFloat || args[1].type_code() == kDLInt) { - // 'loc' and 'scale' are both numeric types - num_inputs = 0; - param.loc = args[0].operator double(); - param.scale = args[1].operator double(); - } else { - // 'loc' is numeric types but 'scale' is not numeric types - num_inputs = 1; - param.loc = args[0].operator double(); - param.scale = dmlc::nullopt; - inputs.push_back(args[1].operator mxnet::NDArray*()); - } - } else { - if (args[1].type_code() == kDLFloat || args[1].type_code() == kDLInt) { - // 'loc' is not numeric types but 'scale' is numeric types - num_inputs = 1; - param.loc = dmlc::nullopt; - param.scale = args[1].operator double(); - inputs.push_back(args[0].operator mxnet::NDArray*()); - } else { - // nither 'loc' or 'scale' is numeric types - num_inputs = 2; - inputs.push_back(args[0].operator mxnet::NDArray*()); - inputs.push_back(args[1].operator mxnet::NDArray*()); - } - } - if (args[2].type_code() == kNull) { - param.size = dmlc::optional>(); - } else if (args[2].type_code() == kDLInt || - args[2].type_code() == kDLFloat) { - param.size = Tuple(1, args[2].operator int64_t()); - } else { - param.size = Tuple(args[2].operator ObjectRef()); - } - if (args[4].type_code() == kNull) { - param.dtype = mxnet::common::GetDefaultDtype(); - } else { - param.dtype = String2MXNetTypeWithBool(args[4].operator std::string()); - } - attrs.parsed = param; - attrs.op = op; - if (args[3].type_code() != kNull) { - attrs.dict["ctx"] = args[3].operator std::string(); - } - NDArray* out = args[5].operator mxnet::NDArray*(); - NDArray** outputs = out == nullptr ? nullptr : &out; - int num_outputs = out != nullptr; - SetAttrDict(&attrs); - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), - &num_outputs, outputs); - if (out) { - *ret = PythonArg(5); - } else { - *ret = reinterpret_cast(ndoutputs[0]); - } -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_normal"); + nnvm::NodeAttrs attrs; + op::NumpyNormalParam param; + int num_inputs = 0; + std::vector inputs; + if (args[0].type_code() == kDLFloat || args[0].type_code() == kDLInt) { + if (args[1].type_code() == kDLFloat || args[1].type_code() == kDLInt) { + // 'loc' and 'scale' are both numeric types + num_inputs = 0; + param.loc = args[0].operator double(); + param.scale = args[1].operator double(); + } else { + // 'loc' is numeric types but 'scale' is not numeric types + num_inputs = 1; + param.loc = args[0].operator double(); + param.scale = dmlc::nullopt; + inputs.push_back(args[1].operator mxnet::NDArray*()); + } + } else { + if (args[1].type_code() == kDLFloat || args[1].type_code() == kDLInt) { + // 'loc' is not numeric types but 'scale' is numeric types + num_inputs = 1; + param.loc = dmlc::nullopt; + param.scale = args[1].operator double(); + inputs.push_back(args[0].operator mxnet::NDArray*()); + } else { + // nither 'loc' or 'scale' is numeric types + num_inputs = 2; + inputs.push_back(args[0].operator mxnet::NDArray*()); + inputs.push_back(args[1].operator mxnet::NDArray*()); + } + } + if (args[2].type_code() == kNull) { + param.size = dmlc::optional>(); + } else if (args[2].type_code() == kDLInt || args[2].type_code() == kDLFloat) { + param.size = Tuple(1, args[2].operator int64_t()); + } else { + param.size = Tuple(args[2].operator ObjectRef()); + } + if (args[4].type_code() == kNull) { + param.dtype = mxnet::common::GetDefaultDtype(); + } else { + param.dtype = String2MXNetTypeWithBool(args[4].operator std::string()); + } + attrs.parsed = param; + attrs.op = op; + if (args[3].type_code() != kNull) { + attrs.dict["ctx"] = args[3].operator std::string(); + } + NDArray* out = args[5].operator mxnet::NDArray*(); + NDArray** outputs = out == nullptr ? nullptr : &out; + int num_outputs = out != nullptr; + SetAttrDict(&attrs); + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, outputs); + if (out) { + *ret = PythonArg(5); + } else { + *ret = reinterpret_cast(ndoutputs[0]); + } + }); } // namespace mxnet diff --git a/src/api/operator/random/np_randint_op.cc b/src/api/operator/random/np_randint_op.cc index 8e05822fa907..4f6128cde038 100644 --- a/src/api/operator/random/np_randint_op.cc +++ b/src/api/operator/random/np_randint_op.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -30,39 +30,39 @@ namespace mxnet { MXNET_REGISTER_API("_npi.randint") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_random_randint"); - nnvm::NodeAttrs attrs; - op::SampleRandIntParam param; - int num_inputs = 0; - param.low = args[0].operator int(); - param.high = args[1].operator int(); - if (args[2].type_code() == kDLInt) { - param.shape = TShape(1, args[2].operator int64_t()); - } else { - param.shape = TShape(args[2].operator ObjectRef()); - } - if (args[3].type_code() == kNull) { - param.dtype = mxnet::common::GetDefaultDtype(); - } else { - param.dtype = String2MXNetTypeWithBool(args[3].operator std::string()); - } - attrs.parsed = param; - attrs.op = op; - if (args[4].type_code() != kNull) { - attrs.dict["ctx"] = args[4].operator std::string(); - } - NDArray* out = args[5].operator mxnet::NDArray*(); - NDArray** outputs = out == nullptr ? nullptr : &out; - int num_outputs = out != nullptr; - SetAttrDict(&attrs); - auto ndoutputs = Invoke(op, &attrs, num_inputs, nullptr, &num_outputs, outputs); - if (out) { - *ret = PythonArg(5); - } else { - *ret = reinterpret_cast(ndoutputs[0]); - } -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_random_randint"); + nnvm::NodeAttrs attrs; + op::SampleRandIntParam param; + int num_inputs = 0; + param.low = args[0].operator int(); + param.high = args[1].operator int(); + if (args[2].type_code() == kDLInt) { + param.shape = TShape(1, args[2].operator int64_t()); + } else { + param.shape = TShape(args[2].operator ObjectRef()); + } + if (args[3].type_code() == kNull) { + param.dtype = mxnet::common::GetDefaultDtype(); + } else { + param.dtype = String2MXNetTypeWithBool(args[3].operator std::string()); + } + attrs.parsed = param; + attrs.op = op; + if (args[4].type_code() != kNull) { + attrs.dict["ctx"] = args[4].operator std::string(); + } + NDArray* out = args[5].operator mxnet::NDArray*(); + NDArray** outputs = out == nullptr ? nullptr : &out; + int num_outputs = out != nullptr; + SetAttrDict(&attrs); + auto ndoutputs = Invoke(op, &attrs, num_inputs, nullptr, &num_outputs, outputs); + if (out) { + *ret = PythonArg(5); + } else { + *ret = reinterpret_cast(ndoutputs[0]); + } + }); } // namespace mxnet diff --git a/src/api/operator/random/np_uniform_op.cc b/src/api/operator/random/np_uniform_op.cc index cbe791ca2a72..3cb2daa720ea 100644 --- a/src/api/operator/random/np_uniform_op.cc +++ b/src/api/operator/random/np_uniform_op.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -30,68 +30,67 @@ namespace mxnet { MXNET_REGISTER_API("_npi.uniform") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_uniform"); - nnvm::NodeAttrs attrs; - op::NumpyUniformParam param; - int num_inputs = 0; - std::vector inputs; - if (args[0].type_code() == kDLFloat || args[0].type_code() == kDLInt) { - if (args[1].type_code() == kDLFloat || args[1].type_code() == kDLInt) { - // 'low' and 'high' are both numeric types - num_inputs = 0; - param.low = args[0].operator double(); - param.high = args[1].operator double(); - } else { - // 'low' is numeric types but 'high' is not numeric types - num_inputs = 1; - param.low = args[0].operator double(); - param.high = dmlc::nullopt; - } - } else { - if (args[1].type_code() == kDLFloat || args[1].type_code() == kDLInt) { - // 'low' is not numeric types but 'high' is numeric types - num_inputs = 1; - param.low = dmlc::nullopt; - param.high = args[1].operator double(); - } else { - // nither 'low' or 'high' is numeric types - num_inputs = 2; - } - } - inputs.reserve(num_inputs); - for (int i = 0; i < num_inputs; ++i) { - inputs.push_back(args[i].operator mxnet::NDArray*()); - } - if (args[2].type_code() == kNull) { - param.size = dmlc::optional>(); - } else if (args[2].type_code() == kDLInt || - args[2].type_code() == kDLFloat) { - param.size = Tuple(1, args[2].operator int64_t()); - } else { - param.size = Tuple(args[2].operator ObjectRef()); - } - if (args[4].type_code() == kNull) { - param.dtype = mxnet::common::GetDefaultDtype(); - } else { - param.dtype = String2MXNetTypeWithBool(args[4].operator std::string()); - } - attrs.parsed = param; - attrs.op = op; - if (args[3].type_code() != kNull) { - attrs.dict["ctx"] = args[3].operator std::string(); - } - NDArray* out = args[5].operator mxnet::NDArray*(); - NDArray** outputs = out == nullptr ? nullptr : &out; - int num_outputs = out != nullptr; - SetAttrDict(&attrs); - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, outputs); - if (out) { - *ret = PythonArg(5); - } else { - *ret = reinterpret_cast(ndoutputs[0]); - } -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_uniform"); + nnvm::NodeAttrs attrs; + op::NumpyUniformParam param; + int num_inputs = 0; + std::vector inputs; + if (args[0].type_code() == kDLFloat || args[0].type_code() == kDLInt) { + if (args[1].type_code() == kDLFloat || args[1].type_code() == kDLInt) { + // 'low' and 'high' are both numeric types + num_inputs = 0; + param.low = args[0].operator double(); + param.high = args[1].operator double(); + } else { + // 'low' is numeric types but 'high' is not numeric types + num_inputs = 1; + param.low = args[0].operator double(); + param.high = dmlc::nullopt; + } + } else { + if (args[1].type_code() == kDLFloat || args[1].type_code() == kDLInt) { + // 'low' is not numeric types but 'high' is numeric types + num_inputs = 1; + param.low = dmlc::nullopt; + param.high = args[1].operator double(); + } else { + // nither 'low' or 'high' is numeric types + num_inputs = 2; + } + } + inputs.reserve(num_inputs); + for (int i = 0; i < num_inputs; ++i) { + inputs.push_back(args[i].operator mxnet::NDArray*()); + } + if (args[2].type_code() == kNull) { + param.size = dmlc::optional>(); + } else if (args[2].type_code() == kDLInt || args[2].type_code() == kDLFloat) { + param.size = Tuple(1, args[2].operator int64_t()); + } else { + param.size = Tuple(args[2].operator ObjectRef()); + } + if (args[4].type_code() == kNull) { + param.dtype = mxnet::common::GetDefaultDtype(); + } else { + param.dtype = String2MXNetTypeWithBool(args[4].operator std::string()); + } + attrs.parsed = param; + attrs.op = op; + if (args[3].type_code() != kNull) { + attrs.dict["ctx"] = args[3].operator std::string(); + } + NDArray* out = args[5].operator mxnet::NDArray*(); + NDArray** outputs = out == nullptr ? nullptr : &out; + int num_outputs = out != nullptr; + SetAttrDict(&attrs); + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, outputs); + if (out) { + *ret = PythonArg(5); + } else { + *ret = reinterpret_cast(ndoutputs[0]); + } + }); } // namespace mxnet diff --git a/src/api/operator/random/shuffle_op.cc b/src/api/operator/random/shuffle_op.cc index 222451cb0f3b..54e59ba5e5bd 100644 --- a/src/api/operator/random/shuffle_op.cc +++ b/src/api/operator/random/shuffle_op.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -29,29 +29,29 @@ namespace mxnet { MXNET_REGISTER_API("_npi.shuffle") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_shuffle"); - nnvm::NodeAttrs attrs; - - NDArray* inputs[1]; - int num_inputs = 1; - - if (args[0].type_code() != kNull) { - inputs[0] = args[0].operator mxnet::NDArray *(); - } - - attrs.op = op; - - NDArray* out = args[1].operator mxnet::NDArray*(); - NDArray** outputs = out == nullptr ? nullptr : &out; - int num_outputs = out != nullptr; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs); - if (out) { - *ret = PythonArg(1); - } else { - *ret = ndoutputs[0]; - } -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_shuffle"); + nnvm::NodeAttrs attrs; + + NDArray* inputs[1]; + int num_inputs = 1; + + if (args[0].type_code() != kNull) { + inputs[0] = args[0].operator mxnet::NDArray*(); + } + + attrs.op = op; + + NDArray* out = args[1].operator mxnet::NDArray*(); + NDArray** outputs = out == nullptr ? nullptr : &out; + int num_outputs = out != nullptr; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs); + if (out) { + *ret = PythonArg(1); + } else { + *ret = ndoutputs[0]; + } + }); } // namespace mxnet diff --git a/src/api/operator/tensor/elemwise_binary_broadcast_op_extended.cc b/src/api/operator/tensor/elemwise_binary_broadcast_op_extended.cc index f25e30a8b081..14c6dbb922bf 100644 --- a/src/api/operator/tensor/elemwise_binary_broadcast_op_extended.cc +++ b/src/api/operator/tensor/elemwise_binary_broadcast_op_extended.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -19,7 +19,8 @@ /*! * \file elemwise_binary_broadcast_op_extended.cc - * \brief Implementation of the API of functions in src/operator/tensor/elemwise_binary_broadcast_op_extended.cc + * \brief Implementation of the API of functions in + * src/operator/tensor/elemwise_binary_broadcast_op_extended.cc */ #include #include @@ -29,19 +30,19 @@ namespace mxnet { MXNET_REGISTER_API("_npi.maximum") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_maximum"); - const nnvm::Op* op_scalar = Op::Get("_npi_maximum_scalar"); - UFuncHelper(args, ret, op, op_scalar, nullptr); -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_maximum"); + const nnvm::Op* op_scalar = Op::Get("_npi_maximum_scalar"); + UFuncHelper(args, ret, op, op_scalar, nullptr); + }); MXNET_REGISTER_API("_npi.minimum") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_minimum"); - const nnvm::Op* op_scalar = Op::Get("_npi_minimum_scalar"); - UFuncHelper(args, ret, op, op_scalar, nullptr); -}); + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_minimum"); + const nnvm::Op* op_scalar = Op::Get("_npi_minimum_scalar"); + UFuncHelper(args, ret, op, op_scalar, nullptr); + }); } // namespace mxnet diff --git a/src/api/operator/tensor/indexing_op.cc b/src/api/operator/tensor/indexing_op.cc index df194018c712..bfd39aadfc34 100644 --- a/src/api/operator/tensor/indexing_op.cc +++ b/src/api/operator/tensor/indexing_op.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -28,8 +28,7 @@ namespace mxnet { -MXNET_REGISTER_API("_npi.take") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { +MXNET_REGISTER_API("_npi.take").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { using namespace runtime; const nnvm::Op* op = Op::Get("_npi_take"); nnvm::NodeAttrs attrs; @@ -37,11 +36,11 @@ MXNET_REGISTER_API("_npi.take") NDArray* inputs[2]; if (args[0].type_code() != kNull) { - inputs[0] = args[0].operator mxnet::NDArray *(); + inputs[0] = args[0].operator mxnet::NDArray*(); } if (args[1].type_code() != kNull) { - inputs[1] = args[1].operator mxnet::NDArray *(); + inputs[1] = args[1].operator mxnet::NDArray*(); } if (args[2].type_code() == kDLInt) { @@ -60,14 +59,14 @@ MXNET_REGISTER_API("_npi.take") } attrs.parsed = param; - attrs.op = op; + attrs.op = op; SetAttrDict(&attrs); - NDArray* out = args[4].operator mxnet::NDArray*(); + NDArray* out = args[4].operator mxnet::NDArray*(); NDArray** outputs = out == nullptr ? nullptr : &out; // set the number of outputs provided by the `out` arugment int num_outputs = out != nullptr; - auto ndoutputs = Invoke(op, &attrs, 2, inputs, &num_outputs, outputs); + auto ndoutputs = Invoke(op, &attrs, 2, inputs, &num_outputs, outputs); if (out) { *ret = PythonArg(4); } else { diff --git a/src/api/operator/tensor/matrix_op.cc b/src/api/operator/tensor/matrix_op.cc index 5b275d5c38a9..4b18ef15094f 100644 --- a/src/api/operator/tensor/matrix_op.cc +++ b/src/api/operator/tensor/matrix_op.cc @@ -28,8 +28,7 @@ namespace mxnet { -MXNET_REGISTER_API("_npi.clip") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { +MXNET_REGISTER_API("_npi.clip").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { using namespace runtime; const nnvm::Op* op = Op::Get("_npi_clip"); nnvm::NodeAttrs attrs; @@ -37,7 +36,7 @@ MXNET_REGISTER_API("_npi.clip") NDArray* inputs[1]; if (args[0].type_code() != kNull) { - inputs[0] = args[0].operator mxnet::NDArray *(); + inputs[0] = args[0].operator mxnet::NDArray*(); } if (args[1].type_code() != kNull) { @@ -53,14 +52,14 @@ MXNET_REGISTER_API("_npi.clip") } attrs.parsed = param; - attrs.op = op; + attrs.op = op; SetAttrDict(&attrs); - NDArray* out = args[3].operator mxnet::NDArray*(); + NDArray* out = args[3].operator mxnet::NDArray*(); NDArray** outputs = out == nullptr ? nullptr : &out; // set the number of outputs provided by the `out` arugment int num_outputs = out != nullptr; - auto ndoutputs = Invoke(op, &attrs, 1, inputs, &num_outputs, outputs); + auto ndoutputs = Invoke(op, &attrs, 1, inputs, &num_outputs, outputs); if (out) { *ret = PythonArg(3); } else { @@ -68,8 +67,7 @@ MXNET_REGISTER_API("_npi.clip") } }); -MXNET_REGISTER_API("_npi.tile") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { +MXNET_REGISTER_API("_npi.tile").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { using namespace runtime; const nnvm::Op* op = Op::Get("_npi_tile"); nnvm::NodeAttrs attrs; @@ -77,16 +75,16 @@ MXNET_REGISTER_API("_npi.tile") if (args[1].type_code() == kDLInt) { param.reps = Tuple(1, args[1].operator int64_t()); } else { - param.reps = Tuple(args[1].operator ObjectRef()); + param.reps = Tuple(args[1].operator ObjectRef()); } attrs.parsed = param; - attrs.op = op; + attrs.op = op; SetAttrDict(&attrs); - int num_outputs = 0; + int num_outputs = 0; NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; - int num_inputs = 1; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); - *ret = ndoutputs[0]; + int num_inputs = 1; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = ndoutputs[0]; }); } // namespace mxnet diff --git a/src/api/operator/ufunc_helper.cc b/src/api/operator/ufunc_helper.cc index b960267d4469..978e9d4840f7 100644 --- a/src/api/operator/ufunc_helper.cc +++ b/src/api/operator/ufunc_helper.cc @@ -28,23 +28,26 @@ namespace mxnet { -template<> +template <> void SetAttrDict(nnvm::NodeAttrs* attrs) { if (Imperative::Get()->is_recording()) { attrs->dict["scalar"] = std::to_string(::dmlc::get(attrs->parsed)); } } -void UFuncHelper(NDArray* lhs, NDArray* rhs, NDArray* out, - runtime::MXNetRetValue* ret, const nnvm::Op* op) { +void UFuncHelper(NDArray* lhs, + NDArray* rhs, + NDArray* out, + runtime::MXNetRetValue* ret, + const nnvm::Op* op) { using namespace runtime; nnvm::NodeAttrs attrs; - attrs.op = op; + attrs.op = op; NDArray* inputs[] = {lhs, rhs}; - int num_inputs = 2; + int num_inputs = 2; NDArray** outputs = out == nullptr ? nullptr : &out; - int num_outputs = out != nullptr; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs); + int num_outputs = out != nullptr; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs); if (outputs) { *ret = PythonArg(2); } else { @@ -52,21 +55,24 @@ void UFuncHelper(NDArray* lhs, NDArray* rhs, NDArray* out, } } -void UFuncHelper(NDArray* lhs, int64_t rhs, NDArray* out, - runtime::MXNetRetValue* ret, const nnvm::Op* op) { +void UFuncHelper(NDArray* lhs, + int64_t rhs, + NDArray* out, + runtime::MXNetRetValue* ret, + const nnvm::Op* op) { using namespace runtime; nnvm::NodeAttrs attrs; op::NumpyBinaryScalarParam param; param.scalar = rhs; param.is_int = true; - attrs.op = op; + attrs.op = op; attrs.parsed = param; SetAttrDict(&attrs); - NDArray** inputs = &lhs; - int num_inputs = 1; + NDArray** inputs = &lhs; + int num_inputs = 1; NDArray** outputs = out == nullptr ? nullptr : &out; - int num_outputs = out != nullptr; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs); + int num_outputs = out != nullptr; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs); if (outputs) { *ret = PythonArg(2); } else { @@ -74,21 +80,24 @@ void UFuncHelper(NDArray* lhs, int64_t rhs, NDArray* out, } } -void UFuncHelper(NDArray* lhs, double rhs, NDArray* out, - runtime::MXNetRetValue* ret, const nnvm::Op* op) { +void UFuncHelper(NDArray* lhs, + double rhs, + NDArray* out, + runtime::MXNetRetValue* ret, + const nnvm::Op* op) { using namespace runtime; nnvm::NodeAttrs attrs; op::NumpyBinaryScalarParam param; param.scalar = rhs; param.is_int = false; - attrs.op = op; + attrs.op = op; attrs.parsed = param; SetAttrDict(&attrs); - NDArray** inputs = &lhs; - int num_inputs = 1; + NDArray** inputs = &lhs; + int num_inputs = 1; NDArray** outputs = out == nullptr ? nullptr : &out; - int num_outputs = out != nullptr; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs); + int num_outputs = out != nullptr; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs); if (outputs) { *ret = PythonArg(2); } else { @@ -96,21 +105,24 @@ void UFuncHelper(NDArray* lhs, double rhs, NDArray* out, } } -void UFuncHelper(int64_t lhs, NDArray* rhs, NDArray* out, - runtime::MXNetRetValue* ret, const nnvm::Op* op) { +void UFuncHelper(int64_t lhs, + NDArray* rhs, + NDArray* out, + runtime::MXNetRetValue* ret, + const nnvm::Op* op) { using namespace runtime; nnvm::NodeAttrs attrs; op::NumpyBinaryScalarParam param; param.scalar = lhs; param.is_int = true; - attrs.op = op; + attrs.op = op; attrs.parsed = param; SetAttrDict(&attrs); - NDArray** inputs = &rhs; - int num_inputs = 1; + NDArray** inputs = &rhs; + int num_inputs = 1; NDArray** outputs = out == nullptr ? nullptr : &out; - int num_outputs = out != nullptr; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs); + int num_outputs = out != nullptr; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs); if (outputs) { *ret = PythonArg(2); } else { @@ -118,21 +130,24 @@ void UFuncHelper(int64_t lhs, NDArray* rhs, NDArray* out, } } -void UFuncHelper(double lhs, NDArray* rhs, NDArray* out, - runtime::MXNetRetValue* ret, const nnvm::Op* op) { +void UFuncHelper(double lhs, + NDArray* rhs, + NDArray* out, + runtime::MXNetRetValue* ret, + const nnvm::Op* op) { using namespace runtime; nnvm::NodeAttrs attrs; op::NumpyBinaryScalarParam param; param.scalar = lhs; param.is_int = false; - attrs.op = op; + attrs.op = op; attrs.parsed = param; SetAttrDict(&attrs); - NDArray** inputs = &rhs; - int num_inputs = 1; + NDArray** inputs = &rhs; + int num_inputs = 1; NDArray** outputs = out == nullptr ? nullptr : &out; - int num_outputs = out != nullptr; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs); + int num_outputs = out != nullptr; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs); if (outputs) { *ret = PythonArg(2); } else { @@ -156,26 +171,30 @@ void UFuncHelper(runtime::MXNetArgs args, UFuncHelper(args[0].operator NDArray*(), args[1].operator double(), out, ret, lfn_scalar); } } else if (args[0].type_code() == kDLInt) { - UFuncHelper(args[0].operator int64_t(), args[1].operator NDArray*(), out, ret, + UFuncHelper(args[0].operator int64_t(), + args[1].operator NDArray*(), + out, + ret, rfn_scalar ? rfn_scalar : lfn_scalar); } else { - UFuncHelper(args[0].operator double(), args[1].operator NDArray*(), out, ret, + UFuncHelper(args[0].operator double(), + args[1].operator NDArray*(), + out, + ret, rfn_scalar ? rfn_scalar : lfn_scalar); } } -void UFuncHelper(runtime::MXNetArgs args, - runtime::MXNetRetValue* ret, - const nnvm::Op* op) { +void UFuncHelper(runtime::MXNetArgs args, runtime::MXNetRetValue* ret, const nnvm::Op* op) { using namespace runtime; nnvm::NodeAttrs attrs; - attrs.op = op; + attrs.op = op; NDArray* inputs[] = {args[0].operator NDArray*()}; - NDArray* out = args[1].operator NDArray*(); + NDArray* out = args[1].operator NDArray*(); NDArray** outputs = out == nullptr ? nullptr : &out; - int num_inputs = 1; - int num_outputs = out != nullptr; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs); + int num_inputs = 1; + int num_outputs = out != nullptr; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, outputs); if (outputs) { *ret = PythonArg(1); } else { diff --git a/src/api/operator/ufunc_helper.h b/src/api/operator/ufunc_helper.h index 67303200c559..848301de5365 100644 --- a/src/api/operator/ufunc_helper.h +++ b/src/api/operator/ufunc_helper.h @@ -29,9 +29,7 @@ namespace mxnet { /* * Ufunc helper for unary operators */ -void UFuncHelper(runtime::MXNetArgs args, - runtime::MXNetRetValue* ret, - const nnvm::Op* fn_array); +void UFuncHelper(runtime::MXNetArgs args, runtime::MXNetRetValue* ret, const nnvm::Op* fn_array); /* * Ufunc helper for binary operators diff --git a/src/api/operator/utils.cc b/src/api/operator/utils.cc index 6cfbd27471f0..534b07dd2b7c 100644 --- a/src/api/operator/utils.cc +++ b/src/api/operator/utils.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -38,7 +38,7 @@ void SetInOut(std::vector* ndinputs, std::vector* ndoutputs, int num_inputs, NDArray** inputs, - int *num_outputs, + int* num_outputs, int infered_num_outputs, int num_visible_outputs, NDArray** out_array) { @@ -49,7 +49,7 @@ void SetInOut(std::vector* ndinputs, if (!features::is_enabled(features::INT64_TENSOR_SIZE)) { if (shape_is_known(inp->shape())) { // Shape may be unknown after dynamic shape operators CHECK_LT(inp->shape().Size(), (int64_t{1} << 31) - 1) - << "[SetInOut] Size of tensor you are trying to allocate is larger than " + << "[SetInOut] Size of tensor you are trying to allocate is larger than " "2^31 elements. Please build with flag USE_INT64_TENSOR_SIZE=1"; } } @@ -65,9 +65,8 @@ void SetInOut(std::vector* ndinputs, *num_outputs = num_visible_outputs; } else { CHECK(*num_outputs == infered_num_outputs || *num_outputs == num_visible_outputs) - << "Operator expects " << infered_num_outputs << " (all) or " - << num_visible_outputs << " (visible only) outputs, but got " - << *num_outputs << " instead."; + << "Operator expects " << infered_num_outputs << " (all) or " << num_visible_outputs + << " (visible only) outputs, but got " << *num_outputs << " instead."; for (int i = 0; i < *num_outputs; ++i) { ndoutputs->emplace_back(out_array[i]); } @@ -88,13 +87,19 @@ std::vector Invoke(const nnvm::Op* op, imperative::SetNumOutputs(op, *attrs, num_inputs, &infered_num_outputs, &num_visible_outputs); std::vector ndinputs, ndoutputs; - SetInOut(&ndinputs, &ndoutputs, num_inputs, inputs, - num_outputs, infered_num_outputs, num_visible_outputs, outputs); + SetInOut(&ndinputs, + &ndoutputs, + num_inputs, + inputs, + num_outputs, + infered_num_outputs, + num_visible_outputs, + outputs); if (Imperative::Get()->is_deferred_compute()) { Imperative::Get()->RecordDeferredCompute(std::move(*attrs), ndinputs, ndoutputs); } else { - for (NDArray *input : ndinputs) { + for (NDArray* input : ndinputs) { Imperative::DCInfo::Compute(*input); } auto state = Imperative::Get()->Invoke(Context::CPU(), *attrs, ndinputs, ndoutputs); @@ -102,7 +107,8 @@ std::vector Invoke(const nnvm::Op* op, Imperative::Get()->RecordOp(std::move(*attrs), ndinputs, ndoutputs, state); } } - for (int i = *num_outputs; i < infered_num_outputs; ++i) delete ndoutputs[i]; + for (int i = *num_outputs; i < infered_num_outputs; ++i) + delete ndoutputs[i]; return ndoutputs; } diff --git a/src/api/operator/utils.h b/src/api/operator/utils.h index 014ff15188b9..9b2085f76a43 100644 --- a/src/api/operator/utils.h +++ b/src/api/operator/utils.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -35,7 +35,7 @@ void SetInOut(std::vector* ndinputs, std::vector* ndoutputs, int num_inputs, NDArray** inputs, - int *num_outputs, + int* num_outputs, int infered_num_outputs, int num_visible_outputs, NDArray** out_array); @@ -50,14 +50,14 @@ std::vector Invoke(const nnvm::Op* op, bool is_recording(); bool is_deferred_compute(); -template +template void SetAttrDict(nnvm::NodeAttrs* attrs) { if (is_recording() || is_deferred_compute()) { ::dmlc::get(attrs->parsed).SetAttrDict(&(attrs->dict)); } } -template +template Tuple Obj2Tuple(const runtime::ObjectRef& src) { runtime::ADT adt = Downcast(src); Tuple ret(adt.size(), 0); diff --git a/src/base.cc b/src/base.cc index 96d66ad1d379..d67e130cc465 100644 --- a/src/base.cc +++ b/src/base.cc @@ -33,7 +33,7 @@ namespace mxnet { // Users that have rebuilt MXNet against older versions will we advised with a warning to upgrade // their systems to match the CI level. Minimally, users should rerun the CI locally. #if defined(_MSC_VER) -#define MXNET_CI_OLDEST_CUDA_VERSION 9020 +#define MXNET_CI_OLDEST_CUDA_VERSION 9020 #else #define MXNET_CI_OLDEST_CUDA_VERSION 10000 #endif @@ -82,7 +82,7 @@ void Context::CuDNNLibChecks() { << "Set MXNET_CUDNN_LIB_CHECKING=0 to quiet this warning."; if (CUDNN_VERSION < MXNET_CI_OLDEST_CUDNN_VERSION) LOG(WARNING) << "Upgrade advisory: this mxnet has been built against cuDNN lib version " - << CUDNN_VERSION << ", which is older than the oldest version tested by CI (" + << CUDNN_VERSION << ", which is older than the oldest version tested by CI (" << MXNET_CI_OLDEST_CUDNN_VERSION << "). " << "Set MXNET_CUDNN_LIB_CHECKING=0 to quiet this warning."; } diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index c54cc0e6f470..65f8efa93457 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -78,28 +78,28 @@ #include #endif - using namespace mxnet; // Internal function to get the information // from function registry // Used to implement MXSymbolGetAtomicSymbolInfo and MXFuncGetInfo -template -inline int MXAPIGetFunctionRegInfo(const FunRegType *e, - const char **name, - const char **description, - uint32_t *num_args, - const char ***arg_names, - const char ***arg_type_infos, - const char ***arg_descriptions, - const char **return_type) { - MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get(); - - API_BEGIN(); - *name = e->name.c_str(); +template +inline int MXAPIGetFunctionRegInfo(const FunRegType* e, + const char** name, + const char** description, + uint32_t* num_args, + const char*** arg_names, + const char*** arg_type_infos, + const char*** arg_descriptions, + const char** return_type) { + MXAPIThreadLocalEntry<>* ret = MXAPIThreadLocalStore<>::Get(); + + API_BEGIN(); + *name = e->name.c_str(); *description = e->description.c_str(); - *num_args = static_cast(e->arguments.size()); - if (return_type) *return_type = e->return_type.c_str(); + *num_args = static_cast(e->arguments.size()); + if (return_type) + *return_type = e->return_type.c_str(); ret->ret_vec_charp.clear(); for (size_t i = 0; i < e->arguments.size(); ++i) { ret->ret_vec_charp.push_back(e->arguments[i].name.c_str()); @@ -110,16 +110,15 @@ inline int MXAPIGetFunctionRegInfo(const FunRegType *e, for (size_t i = 0; i < e->arguments.size(); ++i) { ret->ret_vec_charp.push_back(e->arguments[i].description.c_str()); } - *arg_names = dmlc::BeginPtr(ret->ret_vec_charp); - *arg_type_infos = dmlc::BeginPtr(ret->ret_vec_charp) + e->arguments.size(); + *arg_names = dmlc::BeginPtr(ret->ret_vec_charp); + *arg_type_infos = dmlc::BeginPtr(ret->ret_vec_charp) + e->arguments.size(); *arg_descriptions = dmlc::BeginPtr(ret->ret_vec_charp) + (e->arguments.size() * 2); API_END(); } // NOTE: return value is added in API_END -std::string getExtensionMsgs(mxnet::ext::msgSize_t msgSize, - mxnet::ext::msgGet_t msgGet) { +std::string getExtensionMsgs(mxnet::ext::msgSize_t msgSize, mxnet::ext::msgGet_t msgGet) { std::string str; if (msgSize() > 0) { str = "\nExtension Traceback:\n"; @@ -127,8 +126,8 @@ std::string getExtensionMsgs(mxnet::ext::msgSize_t msgSize, const char* tmp; msgGet(i, &tmp); // format: [i] message - str += std::string("\t[") + std::to_string(i) + std::string("] ") - + std::string(tmp) + std::string("\n"); + str += std::string("\t[") + std::to_string(i) + std::string("] ") + std::string(tmp) + + std::string("\n"); } } return str; @@ -192,15 +191,15 @@ void CustomFComputeDispatcher(const std::string op_name, in_dev_id.push_back(in_nd->ctx().real_dev_id()); if (inputs[i].storage_type() == mxnet::kRowSparseStorage) { - in_stypes[i] = 1; - in_indices[i] = inputs[i].aux_data(rowsparse::kIdx).dptr_; + in_stypes[i] = 1; + in_indices[i] = inputs[i].aux_data(rowsparse::kIdx).dptr_; in_indices_shapes[i] = inputs[i].aux_shape(rowsparse::kIdx).Size(); } else if (inputs[i].storage_type() == mxnet::kCSRStorage) { - in_stypes[i] = 2; - in_indices[i] = inputs[i].aux_data(csr::kIdx).dptr_; - in_indptr[i] = inputs[i].aux_data(csr::kIndPtr).dptr_; + in_stypes[i] = 2; + in_indices[i] = inputs[i].aux_data(csr::kIdx).dptr_; + in_indptr[i] = inputs[i].aux_data(csr::kIndPtr).dptr_; in_indices_shapes[i] = inputs[i].aux_shape(csr::kIdx).Size(); - in_indptr_shapes[i] = inputs[i].aux_shape(csr::kIndPtr).Size(); + in_indptr_shapes[i] = inputs[i].aux_shape(csr::kIndPtr).Size(); } } @@ -215,53 +214,57 @@ void CustomFComputeDispatcher(const std::string op_name, out_dev_id.push_back(outputs[i].ctx().real_dev_id()); if (outputs[i].storage_type() == mxnet::kRowSparseStorage) { - out_stypes[i] = 1; - out_indices[i] = outputs[i].aux_data(rowsparse::kIdx).dptr_; + out_stypes[i] = 1; + out_indices[i] = outputs[i].aux_data(rowsparse::kIdx).dptr_; out_indices_shapes[i] = outputs[i].aux_shape(rowsparse::kIdx).Size(); } else if (outputs[i].storage_type() == mxnet::kCSRStorage) { - out_stypes[i] = 2; - out_indices[i] = outputs[i].aux_data(csr::kIdx).dptr_; - out_indptr[i] = outputs[i].aux_data(csr::kIndPtr).dptr_; + out_stypes[i] = 2; + out_indices[i] = outputs[i].aux_data(csr::kIdx).dptr_; + out_indptr[i] = outputs[i].aux_data(csr::kIndPtr).dptr_; out_indices_shapes[i] = outputs[i].aux_shape(csr::kIdx).Size(); - out_indptr_shapes[i] = outputs[i].aux_shape(csr::kIndPtr).Size(); + out_indptr_shapes[i] = outputs[i].aux_shape(csr::kIndPtr).Size(); } } // get memory resource and mxnet backend streams CHECK(ctx.requested.size() >= 2) - << "Custom operator should register at least memory resource and parallel random resource"; - const Resource &resource = ctx.requested.at(0); - mshadow::Stream *cpu_stream = ctx.get_stream(); - mshadow::Stream *gpu_stream = ctx.get_stream(); + << "Custom operator should register at least memory resource and parallel random resource"; + const Resource& resource = ctx.requested.at(0); + mshadow::Stream* cpu_stream = ctx.get_stream(); + mshadow::Stream* gpu_stream = ctx.get_stream(); // create lambda that captures stream & resource objects // this temp workspace holds memory allocated by custom library via OpResource auto cpu_alloc = [&](int size) { mshadow::Tensor workspace = - resource.get_space_typed(mshadow::Shape1(size), cpu_stream); + resource.get_space_typed(mshadow::Shape1(size), cpu_stream); return workspace.dptr_; }; auto gpu_alloc = [&](int size) { mshadow::Tensor workspace = - resource.get_space_typed(mshadow::Shape1(size), gpu_stream); + resource.get_space_typed(mshadow::Shape1(size), gpu_stream); return workspace.dptr_; }; // create lambda that allocates memory for sparse and // returns allocated arrays for data, indices and indptr. - auto sparse_alloc = [&](int index, int indices_len, int idxptr_len, - void** data, int64_t** indices, int64_t** indptr) { + auto sparse_alloc = [&](int index, + int indices_len, + int idxptr_len, + void** data, + int64_t** indices, + int64_t** indptr) { if (idxptr_len == 0) { // Row Sparse outputs[index].CheckAndAlloc({mshadow::Shape1(indices_len)}); - *data = outputs[index].data().dptr_; + *data = outputs[index].data().dptr_; *indices = reinterpret_cast(outputs[index].aux_data(rowsparse::kIdx).dptr_); } else { // CSR outputs[index].CheckAndAlloc({mshadow::Shape1(idxptr_len), mshadow::Shape1(indices_len)}); - *data = outputs[index].data().dptr_; + *data = outputs[index].data().dptr_; *indices = reinterpret_cast(outputs[index].aux_data(csr::kIdx).dptr_); - *indptr = reinterpret_cast(outputs[index].aux_data(csr::kIndPtr).dptr_); + *indptr = reinterpret_cast(outputs[index].aux_data(csr::kIndPtr).dptr_); } }; @@ -277,20 +280,25 @@ void CustomFComputeDispatcher(const std::string op_name, }; using alloc_type_gpu = decltype(gpu_alloc); - auto gpu_malloc = [](void* _gpu_alloc, int size) { + auto gpu_malloc = [](void* _gpu_alloc, int size) { alloc_type_gpu* gpualloc = static_cast(_gpu_alloc); return static_cast((*gpualloc)(size)); }; using alloc_type_sparse = decltype(sparse_alloc); - auto sparse_malloc = [](void* _sparse_alloc, int index, int indices_len, int idxptr_len, - void** data, int64_t** indices, int64_t** indptr) { + auto sparse_malloc = [](void* _sparse_alloc, + int index, + int indices_len, + int idxptr_len, + void** data, + int64_t** indices, + int64_t** indptr) { alloc_type_sparse* sparsealloc = static_cast(_sparse_alloc); (*sparsealloc)(index, indices_len, idxptr_len, data, indices, indptr); }; // get actual cudaStream_t out of mxnet gpu stream and pass to lib_api.h - void *cuda_stream = nullptr; + void* cuda_stream = nullptr; #if MXNET_USE_CUDA if ((inputs.size() > 0 && inputs[0].ctx().dev_mask() == Context::kGPU) || (outputs.size() > 0 && outputs[0].ctx().dev_mask() == Context::kGPU)) { @@ -301,90 +309,158 @@ void CustomFComputeDispatcher(const std::string op_name, // get mxnet initialized and seeded RNG states and pass to lib_api.h void *rng_cpu_states = nullptr, *rng_gpu_states = nullptr; using mxnet::common::random::RandGenerator; - RandGenerator *pgen_cpu = ctx.requested.at(1).get_parallel_random(); - rng_cpu_states = pgen_cpu->GetStates(); + RandGenerator* pgen_cpu = ctx.requested.at(1).get_parallel_random(); + rng_cpu_states = pgen_cpu->GetStates(); #if MXNET_USE_CUDA - RandGenerator *pgen_gpu = ctx.requested.at(1).get_parallel_random(); - rng_gpu_states = pgen_gpu->GetStates(); + RandGenerator* pgen_gpu = ctx.requested.at(1).get_parallel_random(); + rng_gpu_states = pgen_gpu->GetStates(); #endif - CHECK((fcomp_fp != nullptr && state_ptr == nullptr) - || (fcomp_fp == nullptr && state_ptr != nullptr)) - << "Can only register either regular op or stateful op for '" << op_name << "'"; + CHECK((fcomp_fp != nullptr && state_ptr == nullptr) || + (fcomp_fp == nullptr && state_ptr != nullptr)) + << "Can only register either regular op or stateful op for '" << op_name << "'"; if (fcomp_fp != nullptr) { // convert attributes to vector of char* std::vector attr_keys, attr_vals; - for (auto &kv : attrs->dict) { + for (auto& kv : attrs->dict) { attr_keys.push_back(kv.first.c_str()); attr_vals.push_back(kv.second.c_str()); } // call fcompute function - int retval = callFComp(fcomp_fp, attr_keys.data(), attr_vals.data(), attr_keys.size(), - in_shapes.data(), in_dims.data(), in_data.data(), in_types.data(), - in_verIDs.data(), in_dev_type.data(), in_dev_id.data(), in_data.size(), - out_shapes.data(), out_dims.data(), out_data.data(), out_types.data(), - out_verIDs.data(), out_dev_type.data(), out_dev_id.data(), + int retval = callFComp(fcomp_fp, + attr_keys.data(), + attr_vals.data(), + attr_keys.size(), + in_shapes.data(), + in_dims.data(), + in_data.data(), + in_types.data(), + in_verIDs.data(), + in_dev_type.data(), + in_dev_id.data(), + in_data.size(), + out_shapes.data(), + out_dims.data(), + out_data.data(), + out_types.data(), + out_verIDs.data(), + out_dev_type.data(), + out_dev_id.data(), out_data.size(), - cpu_malloc, &cpu_alloc, gpu_malloc, &gpu_alloc, cuda_stream, - sparse_malloc, &sparse_alloc, in_stypes.data(), out_stypes.data(), - in_indices.data(), out_indices.data(), in_indptr.data(), + cpu_malloc, + &cpu_alloc, + gpu_malloc, + &gpu_alloc, + cuda_stream, + sparse_malloc, + &sparse_alloc, + in_stypes.data(), + out_stypes.data(), + in_indices.data(), + out_indices.data(), + in_indptr.data(), out_indptr.data(), - in_indices_shapes.data(), out_indices_shapes.data(), - in_indptr_shapes.data(), out_indptr_shapes.data(), - rng_cpu_states, rng_gpu_states); + in_indices_shapes.data(), + out_indices_shapes.data(), + in_indptr_shapes.data(), + out_indptr_shapes.data(), + rng_cpu_states, + rng_gpu_states); std::string msgs = getExtensionMsgs(msgSize, msgGet); CHECK(retval) << "Error calling FCompute for custom operator '" << op_name << "'" << msgs; } if (state_ptr != nullptr) { // retrieve op state object created from CreateOpState - CustomStatefulOpWrapper& op = state_ptr->get_state(); + CustomStatefulOpWrapper& op = state_ptr->get_state(); CustomStatefulOp* state_op_inst = op.get_instance(); - std::string msgs = getExtensionMsgs(msgSize, msgGet); + std::string msgs = getExtensionMsgs(msgSize, msgGet); CHECK(state_op_inst != nullptr) - << "Error custom stateful operator is null for operator '" << op_name << "'" << msgs; + << "Error custom stateful operator is null for operator '" << op_name << "'" << msgs; // call fcompute function - int retval = callFStatefulComp(stateful_forward_flag, state_op_inst, - in_shapes.data(), in_dims.data(), in_data.data(), + int retval = callFStatefulComp(stateful_forward_flag, + state_op_inst, + in_shapes.data(), + in_dims.data(), + in_data.data(), in_types.data(), - in_verIDs.data(), in_dev_type.data(), in_dev_id.data(), + in_verIDs.data(), + in_dev_type.data(), + in_dev_id.data(), in_data.size(), - out_shapes.data(), out_dims.data(), out_data.data(), + out_shapes.data(), + out_dims.data(), + out_data.data(), out_types.data(), - out_verIDs.data(), out_dev_type.data(), out_dev_id.data(), + out_verIDs.data(), + out_dev_type.data(), + out_dev_id.data(), out_data.size(), - cpu_malloc, &cpu_alloc, gpu_malloc, &gpu_alloc, cuda_stream, - sparse_malloc, &sparse_alloc, in_stypes.data(), - out_stypes.data(), in_indices.data(), out_indices.data(), - in_indptr.data(), out_indptr.data(), - in_indices_shapes.data(), out_indices_shapes.data(), - in_indptr_shapes.data(), out_indptr_shapes.data(), - rng_cpu_states, rng_gpu_states); - msgs = getExtensionMsgs(msgSize, msgGet); + cpu_malloc, + &cpu_alloc, + gpu_malloc, + &gpu_alloc, + cuda_stream, + sparse_malloc, + &sparse_alloc, + in_stypes.data(), + out_stypes.data(), + in_indices.data(), + out_indices.data(), + in_indptr.data(), + out_indptr.data(), + in_indices_shapes.data(), + out_indices_shapes.data(), + in_indptr_shapes.data(), + out_indptr_shapes.data(), + rng_cpu_states, + rng_gpu_states); + msgs = getExtensionMsgs(msgSize, msgGet); CHECK(retval) << "Error calling FStatefulCompute for custom operator '" << op_name << "'" << msgs; } } -template -void registerOp(const char* name, const std::string& name_str, bool isSubgraphOp, - RescReq resc_req, AttrParser attr_parser, NumInputs num_inputs, - NumOutputs num_outputs, NumInOuts num_inouts, InferType infer_type, - InferShape infer_shape, InferSType infer_storage_type, - MutateInputs mutate_inputs, SubgraphNumInputs num_subgraph_inputs, - SubgraphInferType infer_subgraph_type, SubgraphInferShape infer_subgraph_shape, - SubgraphInferSType infer_subgraph_storage_type, CreateOpState create_opstate, - GradReg grad_reg, mxnet::ext::mutateInputs_t mutate_fp, - const std::unordered_map &createop_map, - const std::unordered_map &forward_ctx_map, - const std::unordered_map &backward_ctx_map, + typename InferType, + typename InferShape, + typename InferSType, + typename MutateInputs, + typename SubgraphNumInputs, + typename SubgraphInferType, + typename SubgraphInferShape, + typename SubgraphInferSType, + typename CreateOpState, + typename GradReg> +void registerOp(const char* name, + const std::string& name_str, + bool isSubgraphOp, + RescReq resc_req, + AttrParser attr_parser, + NumInputs num_inputs, + NumOutputs num_outputs, + NumInOuts num_inouts, + InferType infer_type, + InferShape infer_shape, + InferSType infer_storage_type, + MutateInputs mutate_inputs, + SubgraphNumInputs num_subgraph_inputs, + SubgraphInferType infer_subgraph_type, + SubgraphInferShape infer_subgraph_shape, + SubgraphInferSType infer_subgraph_storage_type, + CreateOpState create_opstate, + GradReg grad_reg, + mxnet::ext::mutateInputs_t mutate_fp, + const std::unordered_map& createop_map, + const std::unordered_map& forward_ctx_map, + const std::unordered_map& backward_ctx_map, mxnet::ext::opCallFComp_t callFComp, mxnet::ext::opCallFStatefulComp_t callFStatefulComp, mxnet::ext::msgSize_t msgSize, @@ -392,9 +468,9 @@ void registerOp(const char* name, const std::string& name_str, bool isSubgraphOp using namespace mxnet::ext; // check if operator is already registered - const nnvm::Op *regOpPtr = dmlc::Registry::Get()->Find(name); - nnvm::Op ®Op = dmlc::Registry::Get()->__REGISTER_OR_GET__(name); - int plevel = 10; + const nnvm::Op* regOpPtr = dmlc::Registry::Get()->Find(name); + nnvm::Op& regOp = dmlc::Registry::Get()->__REGISTER_OR_GET__(name); + int plevel = 10; if (regOpPtr != nullptr) { // overwrite registration of existing op with custom op regOp.arguments.clear(); @@ -420,10 +496,8 @@ void registerOp(const char* name, const std::string& name_str, bool isSubgraphOp regOp.set_num_outputs(DefaultSubgraphOpNumOutputs); regOp.set_attr("FInferType", infer_subgraph_type, plevel); regOp.set_attr("FInferShape", infer_subgraph_shape, plevel); - regOp.set_attr("FInferStorageType", - infer_subgraph_storage_type, plevel); - regOp.set_attr("FMutateInputs", - DefaultSubgraphOpMutableInputs, plevel); + regOp.set_attr("FInferStorageType", infer_subgraph_storage_type, plevel); + regOp.set_attr("FMutateInputs", DefaultSubgraphOpMutableInputs, plevel); } // optionally add stateful forward if (createop_map.size() != 0) { @@ -433,9 +507,19 @@ void registerOp(const char* name, const std::string& name_str, bool isSubgraphOp const std::vector& inputs, const std::vector& req, const std::vector& outputs) { - CustomFComputeDispatcher(name_str, nullptr, nullptr, nullptr, - callFStatefulComp, 1, &state_ptr, ctx, inputs, req, outputs, - msgSize, msgGet); + CustomFComputeDispatcher(name_str, + nullptr, + nullptr, + nullptr, + callFStatefulComp, + 1, + &state_ptr, + ctx, + inputs, + req, + outputs, + msgSize, + msgGet); }; if (createop_map.count("cpu") > 0) regOp.set_attr("FStatefulComputeEx", fstate_forward, plevel); @@ -450,13 +534,35 @@ void registerOp(const char* name, const std::string& name_str, bool isSubgraphOp if (ctx.run_ctx.ctx.dev_mask() == Context::kCPU) { CHECK_GT(forward_ctx_map.count("cpu"), 0); fcomp_t fcomp = forward_ctx_map.at("cpu"); - CustomFComputeDispatcher(name_str, callFComp, fcomp, &attrs, - nullptr, 0, nullptr, ctx, inputs, req, outputs, msgSize, msgGet); + CustomFComputeDispatcher(name_str, + callFComp, + fcomp, + &attrs, + nullptr, + 0, + nullptr, + ctx, + inputs, + req, + outputs, + msgSize, + msgGet); } else if (ctx.run_ctx.ctx.dev_mask() == Context::kGPU) { CHECK_GT(forward_ctx_map.count("gpu"), 0); fcomp_t fcomp = forward_ctx_map.at("gpu"); - CustomFComputeDispatcher(name_str, callFComp, fcomp, &attrs, - nullptr, 0, nullptr, ctx, inputs, req, outputs, msgSize, msgGet); + CustomFComputeDispatcher(name_str, + callFComp, + fcomp, + &attrs, + nullptr, + 0, + nullptr, + ctx, + inputs, + req, + outputs, + msgSize, + msgGet); } }; if (forward_ctx_map.count("cpu") > 0) @@ -467,7 +573,7 @@ void registerOp(const char* name, const std::string& name_str, bool isSubgraphOp // optionally add fgradient if user specified a function, or for stateful ops if (backward_ctx_map.size() != 0 || createop_map.size() != 0) { std::string grad_name = "_backward_" + name_str; - nnvm::Op &gradOp = dmlc::Registry::Get()->__REGISTER_OR_GET__(grad_name); + nnvm::Op& gradOp = dmlc::Registry::Get()->__REGISTER_OR_GET__(grad_name); regOp.set_attr("FGradient", grad_reg, plevel); gradOp.set_attr("TIsBackward", true, plevel); gradOp.set_attr("FInferStorageType", infer_storage_type, plevel); @@ -499,44 +605,78 @@ void registerOp(const char* name, const std::string& name_str, bool isSubgraphOp const std::vector& inputs, const std::vector& req, const std::vector& outputs) { - CustomFComputeDispatcher(name_str, nullptr, nullptr, nullptr, - callFStatefulComp, 0, &state_ptr, ctx, inputs, req, outputs, - msgSize, msgGet); + CustomFComputeDispatcher(name_str, + nullptr, + nullptr, + nullptr, + callFStatefulComp, + 0, + &state_ptr, + ctx, + inputs, + req, + outputs, + msgSize, + msgGet); }; gradOp.set_attr("FStatefulComputeEx", fstate_backward, plevel); gradOp.set_attr("FStatefulComputeEx", fstate_backward, plevel); } else { // for stateless operators if (backward_ctx_map.count("cpu") > 0) { - fcomp_t fcomp_back_cpu = backward_ctx_map.at("cpu"); + fcomp_t fcomp_back_cpu = backward_ctx_map.at("cpu"); auto backward_cpu_lambda = [=](const nnvm::NodeAttrs& attrs, const OpContext& ctx, const std::vector& inputs, const std::vector& req, const std::vector& outputs) { - CustomFComputeDispatcher(name_str, callFComp, fcomp_back_cpu, &attrs, - nullptr, 0, nullptr, ctx, inputs, req, outputs, msgSize, msgGet); + CustomFComputeDispatcher(name_str, + callFComp, + fcomp_back_cpu, + &attrs, + nullptr, + 0, + nullptr, + ctx, + inputs, + req, + outputs, + msgSize, + msgGet); }; gradOp.set_attr("FComputeEx", backward_cpu_lambda, plevel); } if (backward_ctx_map.count("gpu") > 0) { - fcomp_t fcomp_back_gpu = backward_ctx_map.at("gpu"); + fcomp_t fcomp_back_gpu = backward_ctx_map.at("gpu"); auto backward_gpu_lambda = [=](const nnvm::NodeAttrs& attrs, const OpContext& ctx, const std::vector& inputs, const std::vector& req, const std::vector& outputs) { - CustomFComputeDispatcher(name_str, callFComp, fcomp_back_gpu, &attrs, - nullptr, 0, nullptr, ctx, inputs, req, outputs, msgSize, msgGet); + CustomFComputeDispatcher(name_str, + callFComp, + fcomp_back_gpu, + &attrs, + nullptr, + 0, + nullptr, + ctx, + inputs, + req, + outputs, + msgSize, + msgGet); }; gradOp.set_attr("FComputeEx", backward_gpu_lambda, plevel); } } - } + } regOp.add_argument("data", "NDArray[]", "Source inputs"); } -void registerOperators(void *lib, int verbose, mxnet::ext::msgSize_t msgSize, +void registerOperators(void* lib, + int verbose, + mxnet::ext::msgSize_t msgSize, mxnet::ext::msgGet_t msgGet) { using namespace mxnet::ext; @@ -544,36 +684,36 @@ void registerOperators(void *lib, int verbose, mxnet::ext::msgSize_t msgSize, opCallFree_t callFree = get_func(lib, const_cast(MXLIB_OPCALLFREE_STR)); opCallParseAttrs_t callParseAttrs = - get_func(lib, const_cast(MXLIB_OPCALLPARSEATTRS_STR)); + get_func(lib, const_cast(MXLIB_OPCALLPARSEATTRS_STR)); opCallInferShape_t callInferShape = - get_func(lib, const_cast(MXLIB_OPCALLINFERSHAPE_STR)); + get_func(lib, const_cast(MXLIB_OPCALLINFERSHAPE_STR)); opCallInferType_t callInferType = - get_func(lib, const_cast(MXLIB_OPCALLINFERTYPE_STR)); + get_func(lib, const_cast(MXLIB_OPCALLINFERTYPE_STR)); opCallInferSType_t callInferSType = - get_func(lib, const_cast(MXLIB_OPCALLINFERSTYPE_STR)); + get_func(lib, const_cast(MXLIB_OPCALLINFERSTYPE_STR)); - opCallFComp_t callFComp = - get_func(lib, const_cast(MXLIB_OPCALLFCOMP_STR)); + opCallFComp_t callFComp = get_func(lib, const_cast(MXLIB_OPCALLFCOMP_STR)); opCallMutateInputs_t callMutateInputs = - get_func(lib, const_cast(MXLIB_OPCALLMUTATEINPUTS_STR)); + get_func(lib, const_cast(MXLIB_OPCALLMUTATEINPUTS_STR)); opCallCreateOpState_t callCreateOpState = - get_func(lib, const_cast(MXLIB_OPCALLCREATEOPSTATE_STR)); + get_func(lib, const_cast(MXLIB_OPCALLCREATEOPSTATE_STR)); opCallDestroyOpState_t callDestroyOpState = - get_func(lib, const_cast(MXLIB_OPCALLDESTROYOPSTATE_STR)); + get_func(lib, const_cast(MXLIB_OPCALLDESTROYOPSTATE_STR)); opCallFStatefulComp_t callFStatefulComp = - get_func(lib, const_cast(MXLIB_OPCALLFSTATEFULCOMP_STR)); + get_func(lib, const_cast(MXLIB_OPCALLFSTATEFULCOMP_STR)); // get number of operators registered in the library opRegSize_t opRegSize = get_func(lib, const_cast(MXLIB_OPREGSIZE_STR)); - int numOps = opRegSize(); - if (verbose) LOG(INFO) << "Found " << numOps << " operators in library"; + int numOps = opRegSize(); + if (verbose) + LOG(INFO) << "Found " << numOps << " operators in library"; /* * Get all custom operators implementation from custom library @@ -584,39 +724,51 @@ void registerOperators(void *lib, int verbose, mxnet::ext::msgSize_t msgSize, const char* name; // function pointers holding implementation from custom library parseAttrs_t parse_fp = nullptr; - inferType_t type_fp = nullptr; + inferType_t type_fp = nullptr; inferSType_t stype_fp = nullptr; inferShape_t shape_fp = nullptr; // optional attributes mutateInputs_t mutate_fp = nullptr; - bool isSubgraphOp = false; - int _isSubgraphOp = 0; + bool isSubgraphOp = false; + int _isSubgraphOp = 0; // lists of forward and backward function associated with each context const char **forward_ctx, **backward_ctx, **createop_ctx; fcomp_t *forward_fcomp, *backward_fcomp; - createOpState_t *createop_fp; + createOpState_t* createop_fp; int forward_count, backward_count, createop_count; // main function to get custom operator implemenation from the custom library - opRegGet(i, &name, &_isSubgraphOp, - &forward_ctx, &forward_fcomp, &forward_count, - &backward_ctx, &backward_fcomp, &backward_count, - &createop_ctx, &createop_fp, &createop_count, - &parse_fp, &type_fp, &stype_fp, &shape_fp, &mutate_fp); + opRegGet(i, + &name, + &_isSubgraphOp, + &forward_ctx, + &forward_fcomp, + &forward_count, + &backward_ctx, + &backward_fcomp, + &backward_count, + &createop_ctx, + &createop_fp, + &createop_count, + &parse_fp, + &type_fp, + &stype_fp, + &shape_fp, + &mutate_fp); // construct maps of context to forward/backward custom library function std::unordered_map forward_ctx_map; std::unordered_map backward_ctx_map; std::unordered_map createop_map; - for (int i=0; i < forward_count; i++) { + for (int i = 0; i < forward_count; i++) { std::string ctx_str(forward_ctx[i]); forward_ctx_map[ctx_str] = forward_fcomp[i]; } - for (int i=0; i < backward_count; i++) { + for (int i = 0; i < backward_count; i++) { std::string ctx_str(backward_ctx[i]); backward_ctx_map[ctx_str] = backward_fcomp[i]; } - for (int i=0; i < createop_count; i++) { + for (int i = 0; i < createop_count; i++) { std::string ctx_str(createop_ctx[i]); createop_map[ctx_str] = createop_fp[i]; } @@ -628,18 +780,21 @@ void registerOperators(void *lib, int verbose, mxnet::ext::msgSize_t msgSize, CHECK(parse_fp != nullptr) << "Error loading '" << name << "' custom op, ParseAttrs function was not set."; CHECK(forward_ctx_map.size() != 0 || createop_map.size() != 0) - << "Error loading '" << name - << "' custom op, Forward or CreateOpState function was not set."; + << "Error loading '" << name + << "' custom op, Forward or CreateOpState function was not set."; CHECK(type_fp != nullptr) << "Error loading '" << name - << "' custom op, InferType function was not set."; + << "' custom op, InferType function was not set."; CHECK(shape_fp != nullptr) << "Error loading '" << name - << "' custom op, InferShape function was not set."; + << "' custom op, InferShape function was not set."; } else { - CHECK(createop_map.size() != 0) << "Error loading '" << name - << "' custom subgraph op, CreateOpState function was not set."; + CHECK(createop_map.size() != 0) + << "Error loading '" << name + << "' custom subgraph op, CreateOpState function was not set."; } - if (verbose) LOG(INFO) << "\tOp[" << i << "] " << name; - if (verbose && isSubgraphOp) LOG(INFO) << "\t\tisSubgraphOp"; + if (verbose) + LOG(INFO) << "\tOp[" << i << "] " << name; + if (verbose && isSubgraphOp) + LOG(INFO) << "\t\tisSubgraphOp"; std::string name_str(name); /* @@ -652,7 +807,7 @@ void registerOperators(void *lib, int verbose, mxnet::ext::msgSize_t msgSize, auto attr_parser = [=](const NodeAttrs* attrs) { // convert attributes to vector of char std::vector attr_keys, attr_vals; - for (auto &kv : attrs->dict) { + for (auto& kv : attrs->dict) { attr_keys.push_back(kv.first.c_str()); attr_vals.push_back(kv.second.c_str()); } @@ -660,16 +815,16 @@ void registerOperators(void *lib, int verbose, mxnet::ext::msgSize_t msgSize, std::string subgraph_json; if (!attrs->subgraphs.empty()) { nnvm::Graph g; - g.outputs = attrs->subgraphs[0].get()->outputs; + g.outputs = attrs->subgraphs[0].get()->outputs; subgraph_json = nnvm::pass::SaveJSON(g); attr_keys.push_back(MX_STR_SUBGRAPH_SYM_JSON); attr_vals.push_back(subgraph_json.c_str()); } - int num_in = -1; + int num_in = -1; int num_out = -1; - int retval = callParseAttrs(parse_fp, attr_keys.data(), attr_vals.data(), attr_keys.size(), - &num_in, &num_out); + int retval = callParseAttrs( + parse_fp, attr_keys.data(), attr_vals.data(), attr_keys.size(), &num_in, &num_out); std::string msgs = getExtensionMsgs(msgSize, msgGet); CHECK(retval) << "Error calling ParseAttrs for custom operator '" << name_str << "'" << msgs; @@ -680,18 +835,18 @@ void registerOperators(void *lib, int verbose, mxnet::ext::msgSize_t msgSize, auto num_inputs = [=](const NodeAttrs& attrs) { // convert attributes to vector of char std::vector attr_keys, attr_vals; - for (auto &kv : attrs.dict) { + for (auto& kv : attrs.dict) { attr_keys.push_back(kv.first.c_str()); attr_vals.push_back(kv.second.c_str()); } - int num_in = -1; + int num_in = -1; int num_out = -1; - int retval = callParseAttrs(parse_fp, attr_keys.data(), attr_vals.data(), attr_keys.size(), - &num_in, &num_out); + int retval = callParseAttrs( + parse_fp, attr_keys.data(), attr_vals.data(), attr_keys.size(), &num_in, &num_out); std::string msgs = getExtensionMsgs(msgSize, msgGet); CHECK(retval) << "Error calling ParseAttrs::num_inputs for custom operator '" << name_str - << "'" << msgs; + << "'" << msgs; // get extra inputs, if exists int extra_inputs = 0; @@ -718,18 +873,18 @@ void registerOperators(void *lib, int verbose, mxnet::ext::msgSize_t msgSize, auto num_outputs = [=](const NodeAttrs& attrs) { // convert attributes to vector of char* std::vector attr_keys, attr_vals; - for (auto &kv : attrs.dict) { + for (auto& kv : attrs.dict) { attr_keys.push_back(kv.first.c_str()); attr_vals.push_back(kv.second.c_str()); } - int num_in = -1; + int num_in = -1; int num_out = -1; - int retval = callParseAttrs(parse_fp, attr_keys.data(), attr_vals.data(), attr_keys.size(), - &num_in, &num_out); + int retval = callParseAttrs( + parse_fp, attr_keys.data(), attr_vals.data(), attr_keys.size(), &num_in, &num_out); std::string msgs = getExtensionMsgs(msgSize, msgGet); CHECK(retval) << "Error calling ParseAttrs::num_outputs for custom operator '" << name_str - << "'" << msgs; + << "'" << msgs; return num_out; }; @@ -739,18 +894,18 @@ void registerOperators(void *lib, int verbose, mxnet::ext::msgSize_t msgSize, auto num_inouts = [=](const NodeAttrs& attrs) { // convert attributes to vector of char* std::vector attr_keys, attr_vals; - for (auto &kv : attrs.dict) { + for (auto& kv : attrs.dict) { attr_keys.push_back(kv.first.c_str()); attr_vals.push_back(kv.second.c_str()); } - int num_in = -1; + int num_in = -1; int num_out = -1; - int retval = callParseAttrs(parse_fp, attr_keys.data(), attr_vals.data(), attr_keys.size(), - &num_in, &num_out); + int retval = callParseAttrs( + parse_fp, attr_keys.data(), attr_vals.data(), attr_keys.size(), &num_in, &num_out); std::string msgs = getExtensionMsgs(msgSize, msgGet); CHECK(retval) << "Error calling ParseAttrs::num_outputs for custom operator '" << name_str - << "'" << msgs; + << "'" << msgs; // for backward passes, inputs + outputs + input gradients (one for each output) // get extra inputs, if exists @@ -762,12 +917,12 @@ void registerOperators(void *lib, int verbose, mxnet::ext::msgSize_t msgSize, }; // lambda function to call infer shape - auto infer_shape = [=] (const nnvm::NodeAttrs& attrs, - mxnet::ShapeVector *in_shape, - mxnet::ShapeVector *out_shape) { + auto infer_shape = [=](const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector* in_shape, + mxnet::ShapeVector* out_shape) { // convert attributes to vector of char* std::vector attr_keys, attr_vals; - for (auto &kv : attrs.dict) { + for (auto& kv : attrs.dict) { attr_keys.push_back(kv.first.c_str()); attr_vals.push_back(kv.second.c_str()); } @@ -788,10 +943,10 @@ void registerOperators(void *lib, int verbose, mxnet::ext::msgSize_t msgSize, // copy input shapes from ShapeVector to raw memory layout std::vector inbuff(buff_size); - uint32_t *ptr = inbuff.data(); + uint32_t* ptr = inbuff.data(); for (size_t i = 0; i < num_inputs; ++i) { inshapes[i] = ptr; - indims[i] = (*in_shape)[i].ndim(); + indims[i] = (*in_shape)[i].ndim(); for (int j = 0; j < (*in_shape)[i].ndim(); ++j, ++ptr) { *ptr = static_cast((*in_shape)[i][j]); } @@ -799,15 +954,23 @@ void registerOperators(void *lib, int verbose, mxnet::ext::msgSize_t msgSize, // modified input shapes will be allocated by infer shape function uint32_t** mod_inshapes = nullptr; - int* mod_indims = nullptr; + int* mod_indims = nullptr; // output shapes will be allocated by infer shape function uint32_t** outshapes = nullptr; - int* outdims = nullptr; - - int retval = callInferShape(shape_fp, attr_keys.data(), attr_vals.data(), attr_keys.size(), - inshapes.data(), indims.data(), num_inputs, - &mod_inshapes, &mod_indims, - &outshapes, &outdims, out_shape->size()); + int* outdims = nullptr; + + int retval = callInferShape(shape_fp, + attr_keys.data(), + attr_vals.data(), + attr_keys.size(), + inshapes.data(), + indims.data(), + num_inputs, + &mod_inshapes, + &mod_indims, + &outshapes, + &outdims, + out_shape->size()); std::string msgs = getExtensionMsgs(msgSize, msgGet); CHECK(retval) << "Error calling InferShape for custom operator '" << name_str << "'" << msgs; @@ -830,8 +993,7 @@ void registerOperators(void *lib, int verbose, mxnet::ext::msgSize_t msgSize, // assign modified input shapes to ShapeVector for (unsigned i = 0; i < num_inputs; ++i) { - SHAPE_ASSIGN_CHECK(*in_shape, i, - mxnet::TShape(in_shapes[i], in_shapes[i]+mod_indims[i])); + SHAPE_ASSIGN_CHECK(*in_shape, i, mxnet::TShape(in_shapes[i], in_shapes[i] + mod_indims[i])); } std::vector out_shapes(out_shape->size()); @@ -853,8 +1015,7 @@ void registerOperators(void *lib, int verbose, mxnet::ext::msgSize_t msgSize, // assign output shapes to ShapeVector for (unsigned i = 0; i < out_shape->size(); ++i) { - SHAPE_ASSIGN_CHECK(*out_shape, i, - mxnet::TShape(out_shapes[i], out_shapes[i]+outdims[i])); + SHAPE_ASSIGN_CHECK(*out_shape, i, mxnet::TShape(out_shapes[i], out_shapes[i] + outdims[i])); } // free memory used by custom op to allocate shapes/dims @@ -874,12 +1035,12 @@ void registerOperators(void *lib, int verbose, mxnet::ext::msgSize_t msgSize, }; // lambda function to call infer shape for subgraph ops - auto infer_subgraph_shape = [=] (const nnvm::NodeAttrs& attrs, - mxnet::ShapeVector *in_shape, - mxnet::ShapeVector *out_shape) { + auto infer_subgraph_shape = [=](const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector* in_shape, + mxnet::ShapeVector* out_shape) { // convert attributes to vector of char* std::vector attr_keys, attr_vals; - for (auto &kv : attrs.dict) { + for (auto& kv : attrs.dict) { attr_keys.push_back(kv.first.c_str()); attr_vals.push_back(kv.second.c_str()); } @@ -889,9 +1050,9 @@ void registerOperators(void *lib, int verbose, mxnet::ext::msgSize_t msgSize, if (attrs.dict.count(MX_STR_EXTRA_INPUTS) > 0) extra_inputs = std::stoi(attrs.dict.at(MX_STR_EXTRA_INPUTS)); - auto in_first = in_shape->begin(); - auto in_last = in_first + in_shape->size() - extra_inputs; - mxnet::ShapeVector *sg_in_shapes = new mxnet::ShapeVector(in_first, in_last); + auto in_first = in_shape->begin(); + auto in_last = in_first + in_shape->size() - extra_inputs; + mxnet::ShapeVector* sg_in_shapes = new mxnet::ShapeVector(in_first, in_last); bool res = mxnet::op::DefaultSubgraphOpShape(attrs, sg_in_shapes, out_shape); // assign modified input shapes to ShapeVector @@ -902,12 +1063,12 @@ void registerOperators(void *lib, int verbose, mxnet::ext::msgSize_t msgSize, }; // lambda function to call infer type - auto infer_type = [=] (const nnvm::NodeAttrs& attrs, - std::vector *in_type, - std::vector *out_type) { + auto infer_type = [=](const nnvm::NodeAttrs& attrs, + std::vector* in_type, + std::vector* out_type) { // convert attributes to vector of char* std::vector attr_keys, attr_vals; - for (auto &kv : attrs.dict) { + for (auto& kv : attrs.dict) { attr_keys.push_back(kv.first.c_str()); attr_vals.push_back(kv.second.c_str()); } @@ -924,9 +1085,14 @@ void registerOperators(void *lib, int verbose, mxnet::ext::msgSize_t msgSize, // output types will be populated by inferType function std::vector outtypes(out_type->size()); - int retval = callInferType(type_fp, attr_keys.data(), attr_vals.data(), attr_keys.size(), - intypes.data(), num_inputs, - outtypes.data(), out_type->size()); + int retval = callInferType(type_fp, + attr_keys.data(), + attr_vals.data(), + attr_keys.size(), + intypes.data(), + num_inputs, + outtypes.data(), + out_type->size()); std::string msgs = getExtensionMsgs(msgSize, msgGet); CHECK(retval) << "Error calling InferType for custom operator '" << name_str << "'" << msgs; @@ -943,55 +1109,58 @@ void registerOperators(void *lib, int verbose, mxnet::ext::msgSize_t msgSize, }; // lambda function to call infer type for subgraph ops - auto infer_subgraph_type = [=] (const nnvm::NodeAttrs& attrs, - std::vector *in_type, - std::vector *out_type) { - // convert attributes to vector of char* - std::vector attr_keys, attr_vals; - for (auto &kv : attrs.dict) { - attr_keys.push_back(kv.first.c_str()); - attr_vals.push_back(kv.second.c_str()); - } + auto infer_subgraph_type = + [=](const nnvm::NodeAttrs& attrs, std::vector* in_type, std::vector* out_type) { + // convert attributes to vector of char* + std::vector attr_keys, attr_vals; + for (auto& kv : attrs.dict) { + attr_keys.push_back(kv.first.c_str()); + attr_vals.push_back(kv.second.c_str()); + } - // get extra inputs, if exists - int extra_inputs = 0; - if (attrs.dict.count(MX_STR_EXTRA_INPUTS) > 0) - extra_inputs = std::stoi(attrs.dict.at(MX_STR_EXTRA_INPUTS)); + // get extra inputs, if exists + int extra_inputs = 0; + if (attrs.dict.count(MX_STR_EXTRA_INPUTS) > 0) + extra_inputs = std::stoi(attrs.dict.at(MX_STR_EXTRA_INPUTS)); - auto in_first = in_type->begin(); - auto in_last = in_first + in_type->size() - extra_inputs; - std::vector *sg_in_types = new std::vector(in_first, in_last); + auto in_first = in_type->begin(); + auto in_last = in_first + in_type->size() - extra_inputs; + std::vector* sg_in_types = new std::vector(in_first, in_last); - bool res = mxnet::op::DefaultSubgraphOpType(attrs, sg_in_types, out_type); - // copy and assign modified input types - for (size_t i = 0; i < sg_in_types->size(); i++) { - TYPE_ASSIGN_CHECK(*in_type, i, sg_in_types->at(i)); - } - return res; - }; + bool res = mxnet::op::DefaultSubgraphOpType(attrs, sg_in_types, out_type); + // copy and assign modified input types + for (size_t i = 0; i < sg_in_types->size(); i++) { + TYPE_ASSIGN_CHECK(*in_type, i, sg_in_types->at(i)); + } + return res; + }; // lambda function to convert from external mutate_inputs to internal MXNet types auto mutate_inputs = [=](const nnvm::NodeAttrs& attrs) { // convert attributes to vector of char* std::vector attr_keys, attr_vals; - for (auto &kv : attrs.dict) { + for (auto& kv : attrs.dict) { attr_keys.push_back(kv.first.c_str()); attr_vals.push_back(kv.second.c_str()); } // C type placeholder for mutate input indices vector int* mutate_indices = nullptr; - int indices_size = 0; + int indices_size = 0; // call mutate inputs function - int retval = callMutateInputs(mutate_fp, attr_keys.data(), attr_vals.data(), attr_keys.size(), - &mutate_indices, &indices_size); + int retval = callMutateInputs(mutate_fp, + attr_keys.data(), + attr_vals.data(), + attr_keys.size(), + &mutate_indices, + &indices_size); std::string msgs = getExtensionMsgs(msgSize, msgGet); CHECK(retval) << "Error calling MutateInputs for custom operator '" << name_str << "'" - << msgs; + << msgs; std::vector mutate_indices_list(indices_size); - for (int i=0; i < indices_size; i++) { + for (int i = 0; i < indices_size; i++) { mutate_indices_list[i] = static_cast(mutate_indices[i]); } @@ -1000,17 +1169,17 @@ void registerOperators(void *lib, int verbose, mxnet::ext::msgSize_t msgSize, // lambda function to set storage types auto infer_storage_type = [=](const nnvm::NodeAttrs& attrs, - const int dev_mask, - DispatchMode* dispatch_mode, - std::vector* in_stypes, - std::vector* out_stypes) { + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector* in_stypes, + std::vector* out_stypes) { if (stype_fp == nullptr) { // InferSType is not defined in customized lib. CHECK(mxnet::common::ContainsOnlyStorage(*in_stypes, mxnet::kDefaultStorage)) - << "Error input tensors are not dense for custom operator '" << name_str << "'"; + << "Error input tensors are not dense for custom operator '" << name_str << "'"; // set outputs as dense - return op::storage_type_assign(out_stypes, mxnet::kDefaultStorage, - dispatch_mode, DispatchMode::kFComputeEx); + return op::storage_type_assign( + out_stypes, mxnet::kDefaultStorage, dispatch_mode, DispatchMode::kFComputeEx); } else { // InferSType is defined in customized lib. // convert attributes to vector of char* @@ -1031,12 +1200,17 @@ void registerOperators(void *lib, int verbose, mxnet::ext::msgSize_t msgSize, // output types will be populated by inferType function std::vector outstypes(out_stypes->size()); - int retval = callInferSType(stype_fp, attr_keys.data(), attr_vals.data(), attr_keys.size(), - instypes.data(), num_inputs, - outstypes.data(), out_stypes->size()); + int retval = callInferSType(stype_fp, + attr_keys.data(), + attr_vals.data(), + attr_keys.size(), + instypes.data(), + num_inputs, + outstypes.data(), + out_stypes->size()); std::string msgs = getExtensionMsgs(msgSize, msgGet); CHECK(retval) << "Error calling InferSType for custom operator '" << name_str << "'" - << msgs; + << msgs; // copy and assign modified input storage types from custom op to MXNet memory. for (size_t i = 0; i < num_inputs; i++) { @@ -1058,31 +1232,31 @@ void registerOperators(void *lib, int verbose, mxnet::ext::msgSize_t msgSize, DispatchMode* dispatch_mode, std::vector* in_stypes, std::vector* out_stypes) { - // get extra inputs, if exists - int extra_inputs = 0; - if (attrs.dict.count(MX_STR_EXTRA_INPUTS) > 0) - extra_inputs = std::stoi(attrs.dict.at(MX_STR_EXTRA_INPUTS)); + // get extra inputs, if exists + int extra_inputs = 0; + if (attrs.dict.count(MX_STR_EXTRA_INPUTS) > 0) + extra_inputs = std::stoi(attrs.dict.at(MX_STR_EXTRA_INPUTS)); - auto in_first = in_stypes->begin(); - auto in_last = in_first + in_stypes->size() - extra_inputs; - std::vector *sg_in_stypes = new std::vector(in_first, in_last); + auto in_first = in_stypes->begin(); + auto in_last = in_first + in_stypes->size() - extra_inputs; + std::vector* sg_in_stypes = new std::vector(in_first, in_last); - bool res = mxnet::op::DefaultSubgraphOpStorageType(attrs, dev_mask, dispatch_mode, - sg_in_stypes, out_stypes); - // copy and assign modified input storage types - for (size_t i = 0; i < sg_in_stypes->size(); i++) { - STORAGE_TYPE_ASSIGN_CHECK(*in_stypes, i, sg_in_stypes->at(i)); - } - return res; + bool res = mxnet::op::DefaultSubgraphOpStorageType( + attrs, dev_mask, dispatch_mode, sg_in_stypes, out_stypes); + // copy and assign modified input storage types + for (size_t i = 0; i < sg_in_stypes->size(); i++) { + STORAGE_TYPE_ASSIGN_CHECK(*in_stypes, i, sg_in_stypes->at(i)); + } + return res; }; // FGradient register lambda auto grad_reg = [=](const nnvm::ObjectPtr& n, const std::vector& ograds) { // create node for gradient - auto p = nnvm::Node::Create(); + auto p = nnvm::Node::Create(); std::string grad_name = "_backward_" + name_str; - p->attrs.op = nnvm::Op::Get(grad_name.c_str()); - p->attrs.name = n->attrs.name + "_backward"; + p->attrs.op = nnvm::Op::Get(grad_name.c_str()); + p->attrs.name = n->attrs.name + "_backward"; // copy attributes and subgraphs p->attrs.dict = n->attrs.dict; for (const auto& s : n->attrs.subgraphs) @@ -1106,9 +1280,9 @@ void registerOperators(void *lib, int verbose, mxnet::ext::msgSize_t msgSize, // set inputs to gradient node p->inputs = heads; CHECK_EQ(p->num_inputs(), p->inputs.size()) - << "Number of inputs to operator " << grad_name << " (" << p->num_inputs() - << ") does not match the actual number of inputs provided to operator " - << p->attrs.name << " (" << p->inputs.size() << ")."; + << "Number of inputs to operator " << grad_name << " (" << p->num_inputs() + << ") does not match the actual number of inputs provided to operator " << p->attrs.name + << " (" << p->inputs.size() << ")."; // create output node entries return mxnet::op::CreateNodeEntries(p); }; @@ -1120,13 +1294,13 @@ void registerOperators(void *lib, int verbose, mxnet::ext::msgSize_t msgSize, // library author should implement and return a 'state' which points to an instance // in lambda we create OpStatePtr using the returned 'state' - auto create_opstate = [=] (const NodeAttrs& attrs, - Context ctx, - const std::vector& in_shapes, - const std::vector& in_types) { + auto create_opstate = [=](const NodeAttrs& attrs, + Context ctx, + const std::vector& in_shapes, + const std::vector& in_types) { // convert attributes to vector of char* std::vector attr_keys, attr_vals; - for (auto &kv : attrs.dict) { + for (auto& kv : attrs.dict) { attr_keys.push_back(kv.first.c_str()); attr_vals.push_back(kv.second.c_str()); } @@ -1139,15 +1313,15 @@ void registerOperators(void *lib, int verbose, mxnet::ext::msgSize_t msgSize, // determine amount of memory needed to store all the input shapes size_t buff_size = 0; - for (const auto & in_shape : in_shapes) + for (const auto& in_shape : in_shapes) buff_size += in_shape.ndim(); // copy input shapes to raw memory layout std::vector inbuff(buff_size); - uint32_t *ptr = inbuff.data(); + uint32_t* ptr = inbuff.data(); for (size_t i = 0; i < in_shapes.size(); ++i) { inshapes[i] = ptr; - indims[i] = in_shapes[i].ndim(); + indims[i] = in_shapes[i].ndim(); for (int j = 0; j < in_shapes[i].ndim(); ++j, ++ptr) { *ptr = static_cast(in_shapes[i][j]); } @@ -1157,7 +1331,7 @@ void registerOperators(void *lib, int verbose, mxnet::ext::msgSize_t msgSize, std::string subgraph_json; if (!attrs.subgraphs.empty()) { nnvm::Graph g; - g.outputs = attrs.subgraphs[0].get()->outputs; + g.outputs = attrs.subgraphs[0].get()->outputs; subgraph_json = nnvm::pass::SaveJSON(g); attr_keys.push_back(MX_STR_SUBGRAPH_SYM_JSON); attr_vals.push_back(subgraph_json.c_str()); @@ -1169,29 +1343,43 @@ void registerOperators(void *lib, int verbose, mxnet::ext::msgSize_t msgSize, void* state_op_inst = nullptr; if (ctx.dev_mask() == Context::kCPU) { CHECK(createop_map.count("cpu") > 0) - << "CPU CreateOpState not implemented for '" << name_str << "'"; - int retval = callCreateOpState(createop_map.at("cpu"), attr_keys.data(), attr_vals.data(), - attr_keys.size(), ctx_str, ctx.real_dev_id(), - inshapes.data(), indims.data(), - in_shapes.size(), in_types.data(), &state_op_inst); + << "CPU CreateOpState not implemented for '" << name_str << "'"; + int retval = callCreateOpState(createop_map.at("cpu"), + attr_keys.data(), + attr_vals.data(), + attr_keys.size(), + ctx_str, + ctx.real_dev_id(), + inshapes.data(), + indims.data(), + in_shapes.size(), + in_types.data(), + &state_op_inst); std::string msgs = getExtensionMsgs(msgSize, msgGet); CHECK(retval) << "Error calling CreateOpState CPU for custom operator '" << name_str << "'" << msgs; } else if (ctx.dev_mask() == Context::kGPU) { CHECK(createop_map.count("gpu") > 0) - << "GPU CreateOpState not implemented for '" << name_str << "'"; - int retval = callCreateOpState(createop_map.at("gpu"), attr_keys.data(), attr_vals.data(), - attr_keys.size(), ctx_str, ctx.real_dev_id(), - inshapes.data(), indims.data(), - in_shapes.size(), in_types.data(), &state_op_inst); + << "GPU CreateOpState not implemented for '" << name_str << "'"; + int retval = callCreateOpState(createop_map.at("gpu"), + attr_keys.data(), + attr_vals.data(), + attr_keys.size(), + ctx_str, + ctx.real_dev_id(), + inshapes.data(), + indims.data(), + in_shapes.size(), + in_types.data(), + &state_op_inst); std::string msgs = getExtensionMsgs(msgSize, msgGet); CHECK(retval) << "Error calling CreateOpState GPU for custom operator '" << name_str << "'" - << msgs; + << msgs; } std::string msgs = getExtensionMsgs(msgSize, msgGet); CHECK(state_op_inst != nullptr) - << "Error custom library failed to create stateful operator '" << name_str << "'" << msgs; + << "Error custom library failed to create stateful operator '" << name_str << "'" << msgs; CustomStatefulOp* state_op = reinterpret_cast(state_op_inst); if (!state_op->wasCreated() && !state_op->ignore_warn) @@ -1205,115 +1393,152 @@ void registerOperators(void *lib, int verbose, mxnet::ext::msgSize_t msgSize, /* -------------- BELOW IS THE REGISTRATION FOR CUSTOM OPERATORS --------------- */ - registerOp(name, name_str, isSubgraphOp, resc_req, attr_parser, num_inputs, num_outputs, - num_inouts, infer_type, infer_shape, infer_storage_type, mutate_inputs, - num_subgraph_inputs, infer_subgraph_type, infer_subgraph_shape, - infer_subgraph_storage_type, create_opstate, grad_reg, mutate_fp, - createop_map, forward_ctx_map, backward_ctx_map, callFComp, callFStatefulComp, - msgSize, msgGet); - } -} - -void registerPartitioners(void *lib, int verbose, mxnet::ext::msgSize_t msgSize, - mxnet::ext::msgGet_t msgGet) { + registerOp(name, + name_str, + isSubgraphOp, + resc_req, + attr_parser, + num_inputs, + num_outputs, + num_inouts, + infer_type, + infer_shape, + infer_storage_type, + mutate_inputs, + num_subgraph_inputs, + infer_subgraph_type, + infer_subgraph_shape, + infer_subgraph_storage_type, + create_opstate, + grad_reg, + mutate_fp, + createop_map, + forward_ctx_map, + backward_ctx_map, + callFComp, + callFStatefulComp, + msgSize, + msgGet); + } +} // NOLINT + +void registerPartitioners(void* lib, + int verbose, + mxnet::ext::msgSize_t msgSize, + mxnet::ext::msgGet_t msgGet) { using namespace mxnet::ext; // get C type interface functions opCallFree_t callFree = get_func(lib, const_cast(MXLIB_OPCALLFREE_STR)); partCallSupportedOps_t callSupportedOps = - get_func(lib, const_cast(MXLIB_PARTCALLSUPPORTEDOPS_STR)); + get_func(lib, const_cast(MXLIB_PARTCALLSUPPORTEDOPS_STR)); partCallCreateSelector_t callCreateSelector = - get_func(lib, const_cast(MXLIB_PARTCALLCREATESELECTOR_STR)); + get_func(lib, const_cast(MXLIB_PARTCALLCREATESELECTOR_STR)); partCallSelect_t callSelect = - get_func(lib, const_cast(MXLIB_PARTCALLSELECT_STR)); + get_func(lib, const_cast(MXLIB_PARTCALLSELECT_STR)); partCallSelectInput_t callSelectInput = - get_func(lib, const_cast(MXLIB_PARTCALLSELECTINPUT_STR)); + get_func(lib, const_cast(MXLIB_PARTCALLSELECTINPUT_STR)); partCallSelectOutput_t callSelectOutput = - get_func(lib, const_cast(MXLIB_PARTCALLSELECTOUTPUT_STR)); + get_func(lib, const_cast(MXLIB_PARTCALLSELECTOUTPUT_STR)); partCallFilter_t callFilter = - get_func(lib, const_cast(MXLIB_PARTCALLFILTER_STR)); + get_func(lib, const_cast(MXLIB_PARTCALLFILTER_STR)); partCallReset_t callReset = - get_func(lib, const_cast(MXLIB_PARTCALLRESET_STR)); + get_func(lib, const_cast(MXLIB_PARTCALLRESET_STR)); partCallReviewSubgraph_t callReviewSubgraph = - get_func(lib, const_cast(MXLIB_PARTCALLREVIEWSUBGRAPH_STR)); + get_func(lib, const_cast(MXLIB_PARTCALLREVIEWSUBGRAPH_STR)); // get number of partitioners registered in the library - partRegSize_t partRegSize = get_func(lib, - const_cast(MXLIB_PARTREGSIZE_STR)); + partRegSize_t partRegSize = + get_func(lib, const_cast(MXLIB_PARTREGSIZE_STR)); int numParts = partRegSize(); - if (verbose) LOG(INFO) << "Found " << numParts << " partitioners in library"; + if (verbose) + LOG(INFO) << "Found " << numParts << " partitioners in library"; /* * Get all custom partitioners implementation from custom library * loop and register each partitioner in the library to NNVM */ - partRegGetCount_t partRegGetCount = get_func(lib, - const_cast(MXLIB_PARTREGGETCOUNT_STR)); + partRegGetCount_t partRegGetCount = + get_func(lib, const_cast(MXLIB_PARTREGGETCOUNT_STR)); partRegGet_t partRegGet = get_func(lib, const_cast(MXLIB_PARTREGGET_STR)); for (int i = 0; i < numParts; i++) { const char* name; // get custom partitioner strategy count from the dynamic library int count = partRegGetCount(i, &name); - CHECK(count > 0) << "Error loading '" << name - << "' custom partitioner, no strategies defined"; + CHECK(count > 0) << "Error loading '" << name << "' custom partitioner, no strategies defined"; std::string name_str(name); - if (verbose) LOG(INFO) << "\tPartitioner[" << i << "] " << name; + if (verbose) + LOG(INFO) << "\tPartitioner[" << i << "] " << name; mxnet::op::SubgraphBackendRegistry::Get()->__REGISTER_BACKEND__(name); for (int j = 0; j < count; j++) { const char* strategy; // function pointers holding implementation from custom library - supportedOps_t supportedOps_fp = nullptr; + supportedOps_t supportedOps_fp = nullptr; createSelector_t createSelector_fp = nullptr; reviewSubgraph_t reviewSubgraph_fp = nullptr; // name of subgraph op const char* op_name = nullptr; // get custom partitioner strategy from the dynamic library - partRegGet(i, j, &strategy, &supportedOps_fp, &createSelector_fp, - &reviewSubgraph_fp, &op_name); + partRegGet( + i, j, &strategy, &supportedOps_fp, &createSelector_fp, &reviewSubgraph_fp, &op_name); // validate custom partitioner functions from the dynamic library if (supportedOps_fp == nullptr && createSelector_fp == nullptr) - LOG(ERROR) << "Error loading '" << name << "' custom partitioner strategy '" - << strategy << "', must implement supportedOps or createSelector"; + LOG(ERROR) << "Error loading '" << name << "' custom partitioner strategy '" << strategy + << "', must implement supportedOps or createSelector"; std::string strategy_str(strategy); std::string op_name_str(op_name); - if (verbose) LOG(INFO) << "\t\tStrategy[" << j << "] " << strategy_str - << " subgraphOp: '" << op_name_str << "'"; - mxnet::op::SubgraphBackendRegistry::Get()->__REGISTER_CUSTOM_PROPERTY__ - (name_str, std::make_shared - (strategy_str, callSupportedOps, supportedOps_fp, callCreateSelector, - createSelector_fp, callSelect, callSelectInput, callSelectOutput, - callFilter, callReset, callReviewSubgraph, reviewSubgraph_fp, callFree, - op_name_str)); + if (verbose) + LOG(INFO) << "\t\tStrategy[" << j << "] " << strategy_str << " subgraphOp: '" << op_name_str + << "'"; + mxnet::op::SubgraphBackendRegistry::Get()->__REGISTER_CUSTOM_PROPERTY__( + name_str, + std::make_shared(strategy_str, + callSupportedOps, + supportedOps_fp, + callCreateSelector, + createSelector_fp, + callSelect, + callSelectInput, + callSelectOutput, + callFilter, + callReset, + callReviewSubgraph, + reviewSubgraph_fp, + callFree, + op_name_str)); } } } -void registerPasses(void *lib, int verbose, mxnet::ext::msgSize_t msgSize, - mxnet::ext::msgGet_t msgGet) { +void registerPasses(void* lib, + int verbose, + mxnet::ext::msgSize_t msgSize, + mxnet::ext::msgGet_t msgGet) { using namespace mxnet::ext; // get C type interface functions opCallFree_t callFree = get_func(lib, const_cast(MXLIB_OPCALLFREE_STR)); passCallGraphPass_t callGraphPass = - get_func(lib, const_cast(MXLIB_PASSCALLGRAPHPASS_STR)); + get_func(lib, const_cast(MXLIB_PASSCALLGRAPHPASS_STR)); // get number of passes registered in the library - partRegSize_t passRegSize = get_func(lib, - const_cast(MXLIB_PASSREGSIZE_STR)); + partRegSize_t passRegSize = + get_func(lib, const_cast(MXLIB_PASSREGSIZE_STR)); int numPasses = passRegSize(); - if (verbose) LOG(INFO) << "Found " << numPasses << " graph passes in library"; + if (verbose) + LOG(INFO) << "Found " << numPasses << " graph passes in library"; /* * Get all custom pass implementation from custom library @@ -1328,14 +1553,15 @@ void registerPasses(void *lib, int verbose, mxnet::ext::msgSize_t msgSize, // main function to get custom pass implemenation from the custom library passRegGet(i, &pass_fp, &name); - if (verbose) LOG(INFO) << "\tGraph Pass [" << i << "] " << name; + if (verbose) + LOG(INFO) << "\tGraph Pass [" << i << "] " << name; - auto pass_lambda = [=] (nnvm::Graph&& g) { + auto pass_lambda = [=](nnvm::Graph&& g) { // get pass name const char* pass_name = g.GetAttr("pass_name"); // get options const std::unordered_map& options_map = - g.GetAttr>("options_map"); + g.GetAttr>("options_map"); // convert options_map_ to char* to pass to backend library std::vector opt_keys, opt_vals; for (auto& kv : options_map) { @@ -1346,8 +1572,8 @@ void registerPasses(void *lib, int verbose, mxnet::ext::msgSize_t msgSize, // get input args and arg names std::vector in_arg_names = g.GetAttr>("in_arg_names"); std::vector in_aux_names = g.GetAttr>("in_aux_names"); - NDArray **in_args_ptr = g.GetAttr("in_args"); - NDArray **in_aux_ptr = g.GetAttr("in_aux"); + NDArray** in_args_ptr = g.GetAttr("in_args"); + NDArray** in_aux_ptr = g.GetAttr("in_aux"); // get shapes/types mxnet::ShapeVector shapes; @@ -1355,7 +1581,7 @@ void registerPasses(void *lib, int verbose, mxnet::ext::msgSize_t msgSize, shapes = g.GetAttr("shape"); std::vector dtypes; if (g.HasAttr("dtype")) - dtypes = g.GetAttr >("dtype"); + dtypes = g.GetAttr>("dtype"); g.attrs.clear(); const nnvm::IndexedGraph& indexed_graph = g.indexed_graph(); @@ -1368,9 +1594,10 @@ void registerPasses(void *lib, int verbose, mxnet::ext::msgSize_t msgSize, // set the output shapes for this node for (unsigned oid = 0; oid < node->num_outputs(); oid++) { const uint32_t out_entry_id = indexed_graph.entry_id(nid, oid); - mxnet::TShape& shape = shapes[out_entry_id]; + mxnet::TShape& shape = shapes[out_entry_id]; ss << shape; - if (oid < node->num_outputs()-1) ss << ","; + if (oid < node->num_outputs() - 1) + ss << ","; } ss << "]"; node->attrs.dict[MX_STR_SHAPE] = ss.str(); @@ -1385,9 +1612,10 @@ void registerPasses(void *lib, int verbose, mxnet::ext::msgSize_t msgSize, // set the output dtypes for this node for (unsigned oid = 0; oid < node->num_outputs(); oid++) { const uint32_t out_entry_id = indexed_graph.entry_id(nid, oid); - int dtype = dtypes[out_entry_id]; + int dtype = dtypes[out_entry_id]; ss << dtype; - if (oid < node->num_outputs()-1) ss << ","; + if (oid < node->num_outputs() - 1) + ss << ","; } ss << "]"; node->attrs.dict[MX_STR_DTYPE] = ss.str(); @@ -1404,10 +1632,10 @@ void registerPasses(void *lib, int verbose, mxnet::ext::msgSize_t msgSize, std::vector arg_dev_id, aux_dev_id; // convert input args - for (size_t i=0; i < in_arg_names.size(); i++) { + for (size_t i = 0; i < in_arg_names.size(); i++) { if (in_args_ptr[i] != nullptr) { arg_names.push_back(in_arg_names[i].c_str()); - const NDArray &in_arg = *(in_args_ptr[i]); + const NDArray& in_arg = *(in_args_ptr[i]); #if MXNET_USE_ONEDNN == 1 // reorder data if in MKLDNN format @@ -1430,10 +1658,10 @@ void registerPasses(void *lib, int verbose, mxnet::ext::msgSize_t msgSize, } // convert input aux - for (size_t i=0; i < in_aux_names.size(); i++) { + for (size_t i = 0; i < in_aux_names.size(); i++) { if (in_aux_ptr[i] != nullptr) { aux_names.push_back(in_aux_names[i].c_str()); - const auto &in_aux = *(in_aux_ptr[i]); + const auto& in_aux = *(in_aux_ptr[i]); #if MXNET_USE_ONEDNN == 1 // reorder data if in MKLDNN format @@ -1463,26 +1691,32 @@ void registerPasses(void *lib, int verbose, mxnet::ext::msgSize_t msgSize, // create lambda that captures stream & resource objects // this temp workspace holds memory allocated by custom library via OpResource - auto ndarray_alloc = [&](const mxnet::TShape &shape, Context ctx, int dtype, - std::string name, bool isArg) { - NDArray* arr = new NDArray(shape, ctx, false, dtype); - if (isArg) { - new_args.push_back(arr); - new_arg_names.push_back(name); - } else { - new_aux.push_back(arr); - new_aux_names.push_back(name); - } - return arr; - }; + auto ndarray_alloc = + [&](const mxnet::TShape& shape, Context ctx, int dtype, std::string name, bool isArg) { + NDArray* arr = new NDArray(shape, ctx, false, dtype); + if (isArg) { + new_args.push_back(arr); + new_arg_names.push_back(name); + } else { + new_aux.push_back(arr); + new_aux_names.push_back(name); + } + return arr; + }; // create no-capture lambda so that we can cast it to function pointer // lambda with captures cannot be cast to function pointer and pass to lib_api.h // this needs to be a lambda function so that we can do the decltype cast using alloc_type_ndarray = decltype(ndarray_alloc); - auto ndarray_malloc = [](const void* _ndarray_alloc, const int64_t* shapes, int num_shapes, - const char* dev_str, int dev_id, int dtype, const char* name, - int isArg, void** data) { + auto ndarray_malloc = [](const void* _ndarray_alloc, + const int64_t* shapes, + int num_shapes, + const char* dev_str, + int dev_id, + int dtype, + const char* name, + int isArg, + void** data) { mxnet::TShape shape(num_shapes, 0); for (int i = 0; i < num_shapes; i++) shape[i] = shapes[i]; @@ -1497,29 +1731,46 @@ void registerPasses(void *lib, int verbose, mxnet::ext::msgSize_t msgSize, const alloc_type_ndarray* ndalloc = static_cast(_ndarray_alloc); // call cpu_alloc to actually allocate memory and return the pointer NDArray* arr = (*ndalloc)(shape, ctx, dtype, name, isArg); - *data = arr->data().dptr_; + *data = arr->data().dptr_; }; char* out_json; - int retval = callGraphPass(pass_fp, in_json.c_str(), &out_json, opt_keys.data(), - opt_vals.data(), opt_keys.size(), pass_name, - arg_names.data(), arg_names.size(), arg_data.data(), - arg_shapes.data(), arg_dims.data(), arg_types.data(), - arg_verIDs.data(), arg_dev_type.data(), - arg_dev_id.data(), aux_names.data(), aux_names.size(), - aux_data.data(), aux_shapes.data(), aux_dims.data(), - aux_types.data(), aux_verIDs.data(), - aux_dev_type.data(), aux_dev_id.data(), - ndarray_malloc, &ndarray_alloc); + int retval = callGraphPass(pass_fp, + in_json.c_str(), + &out_json, + opt_keys.data(), + opt_vals.data(), + opt_keys.size(), + pass_name, + arg_names.data(), + arg_names.size(), + arg_data.data(), + arg_shapes.data(), + arg_dims.data(), + arg_types.data(), + arg_verIDs.data(), + arg_dev_type.data(), + arg_dev_id.data(), + aux_names.data(), + aux_names.size(), + aux_data.data(), + aux_shapes.data(), + aux_dims.data(), + aux_types.data(), + aux_verIDs.data(), + aux_dev_type.data(), + aux_dev_id.data(), + ndarray_malloc, + &ndarray_alloc); std::string msgs = getExtensionMsgs(msgSize, msgGet); CHECK(retval) << "Error calling graph pass for '" << pass_name << "'" << msgs; std::string out_string(out_json); nnvm::Graph out_graph = nnvm::pass::LoadJSON(out_string); - out_graph.attrs["new_args"] = std::make_shared(new_args); + out_graph.attrs["new_args"] = std::make_shared(new_args); out_graph.attrs["new_arg_names"] = std::make_shared(new_arg_names); - out_graph.attrs["new_aux"] = std::make_shared(new_aux); + out_graph.attrs["new_aux"] = std::make_shared(new_aux); out_graph.attrs["new_aux_names"] = std::make_shared(new_aux_names); callFree(out_json); @@ -1536,7 +1787,7 @@ void registerPasses(void *lib, int verbose, mxnet::ext::msgSize_t msgSize, * \brief Loads dynamic custom library and initializes it * \param path library path */ -int MXLoadLib(const char *path, unsigned verbose, void** lib) { +int MXLoadLib(const char* path, unsigned verbose, void** lib) { API_BEGIN(); *lib = LibraryInitializer::Get()->lib_load(path); if (!*lib) @@ -1544,21 +1795,21 @@ int MXLoadLib(const char *path, unsigned verbose, void** lib) { // check that library and MXNet use same version of library API mxnet::ext::opVersion_t opVersion = - get_func(*lib, const_cast(MXLIB_OPVERSION_STR)); - int libVersion = opVersion(); + get_func(*lib, const_cast(MXLIB_OPVERSION_STR)); + int libVersion = opVersion(); if (MX_LIBRARY_VERSION != libVersion) LOG(FATAL) << "Library version (" << libVersion << ") does not match MXNet version (" << MX_LIBRARY_VERSION << ")"; // get error messaging APIs mxnet::ext::msgSize_t msgSize = - get_func(*lib, const_cast(MXLIB_MSGSIZE_STR)); + get_func(*lib, const_cast(MXLIB_MSGSIZE_STR)); mxnet::ext::msgGet_t msgGet = - get_func(*lib, const_cast(MXLIB_MSGGET_STR)); + get_func(*lib, const_cast(MXLIB_MSGGET_STR)); // initialize library by passing MXNet version mxnet::ext::initialize_t initialize = - get_func(*lib, const_cast(MXLIB_INITIALIZE_STR)); + get_func(*lib, const_cast(MXLIB_INITIALIZE_STR)); if (!initialize(static_cast(MXNET_VERSION))) { std::string msgs = getExtensionMsgs(msgSize, msgGet); LOG(FATAL) << "Library failed to initialize" << msgs; @@ -1571,12 +1822,12 @@ int MXLoadLib(const char *path, unsigned verbose, void** lib) { API_END(); } -int MXLibInfoFeatures(const struct LibFeature **lib_features, size_t *size) { +int MXLibInfoFeatures(const struct LibFeature** lib_features, size_t* size) { using namespace features; API_BEGIN(); LibInfo* lib_info = LibInfo::getInstance(); - *lib_features = lib_info->getFeatures().data(); - *size = lib_info->getFeatures().size(); + *lib_features = lib_info->getFeatures().data(); + *size = lib_info->getFeatures().size(); API_END(); } @@ -1590,7 +1841,6 @@ int MXLibInfoCompiledWithCXX11ABI(int* result) { API_END(); } - int MXRandomSeed(int seed) { API_BEGIN(); mxnet::RandomSeed(seed); @@ -1608,31 +1858,31 @@ int MXSetFlushDenorms(bool value, bool* prev_state) { API_BEGIN(); *prev_state = false; - #if SUPPORT_FTZ_DMZ - std::function is_dmz_flag_available = []() { - // Intel 64 and IA-32 Architectures Software Developer’s Manual: Vol. 1 - // "Checking for the DAZ Flag in the MXCSR Register" - constexpr unsigned int mxcsr_mask_offset = 28; - constexpr unsigned int dmz_flag_offset = 5; - constexpr unsigned int fxsave_req_bytes = 512; - - char* fxsave_area_ptr = reinterpret_cast(malloc(fxsave_req_bytes)); - memset(fxsave_area_ptr, 0, fxsave_req_bytes); // fill memory with 0 - _fxsave(fxsave_area_ptr); - - char* mxcsr_mask_ptr = fxsave_area_ptr + mxcsr_mask_offset; - uint32_t mxcsr_mask = *(reinterpret_cast((mxcsr_mask_ptr))); - // DMZ flag is supported if sixth bit of MXCSR_MASK is hot - bool dmz_flag = (mxcsr_mask >> dmz_flag_offset) & 0x1; - free(fxsave_area_ptr); - return dmz_flag; - }; +#if SUPPORT_FTZ_DMZ + std::function is_dmz_flag_available = []() { + // Intel 64 and IA-32 Architectures Software Developer’s Manual: Vol. 1 + // "Checking for the DAZ Flag in the MXCSR Register" + constexpr unsigned int mxcsr_mask_offset = 28; + constexpr unsigned int dmz_flag_offset = 5; + constexpr unsigned int fxsave_req_bytes = 512; + + char* fxsave_area_ptr = reinterpret_cast(malloc(fxsave_req_bytes)); + memset(fxsave_area_ptr, 0, fxsave_req_bytes); // fill memory with 0 + _fxsave(fxsave_area_ptr); + + char* mxcsr_mask_ptr = fxsave_area_ptr + mxcsr_mask_offset; + uint32_t mxcsr_mask = *(reinterpret_cast((mxcsr_mask_ptr))); + // DMZ flag is supported if sixth bit of MXCSR_MASK is hot + bool dmz_flag = (mxcsr_mask >> dmz_flag_offset) & 0x1; + free(fxsave_area_ptr); + return dmz_flag; + }; - Engine::Get()->PushSync( + Engine::Get()->PushSync( [value, prev_state, is_dmz_flag_available](RunContext rctx) { const unsigned int DMZ_STATE = value ? _MM_DENORMALS_ZERO_ON : _MM_DENORMALS_ZERO_OFF; const unsigned int FTZ_STATE = value ? _MM_FLUSH_ZERO_ON : _MM_FLUSH_ZERO_OFF; - *prev_state = _MM_GET_FLUSH_ZERO_MODE(); + *prev_state = _MM_GET_FLUSH_ZERO_MODE(); _MM_SET_FLUSH_ZERO_MODE(FTZ_STATE); // If the DAZ flag is not supported, then it is a reserved bit and attempting to write a 1 @@ -1640,12 +1890,17 @@ int MXSetFlushDenorms(bool value, bool* prev_state) { if (is_dmz_flag_available()) { _MM_SET_DENORMALS_ZERO_MODE(DMZ_STATE); } - }, Context::CPU(), {}, {}, - FnProperty::kNormal, 0, "SetFlushDenorms"); + }, + Context::CPU(), + {}, + {}, + FnProperty::kNormal, + 0, + "SetFlushDenorms"); - Engine::Get()->WaitForAll(); + Engine::Get()->WaitForAll(); - #endif +#endif API_END(); } @@ -1677,33 +1932,33 @@ int MXGetGPUCount(int* out) { } // Deprecated: use MXGetGPUMemoryInformation64() instead. -int MXGetGPUMemoryInformation(int dev, int *free_mem, int *total_mem) { +int MXGetGPUMemoryInformation(int dev, int* free_mem, int* total_mem) { API_BEGIN(); - uint64_t free_mem64 = 0UL; + uint64_t free_mem64 = 0UL; uint64_t total_mem64 = 0UL; Context::GetGPUMemoryInformation(dev, &free_mem64, &total_mem64); - *free_mem = static_cast(free_mem64); + *free_mem = static_cast(free_mem64); *total_mem = static_cast(total_mem64); API_END(); } -int MXGetGPUMemoryInformation64(int dev, uint64_t *free_mem, uint64_t *total_mem) { +int MXGetGPUMemoryInformation64(int dev, uint64_t* free_mem, uint64_t* total_mem) { API_BEGIN(); Context::GetGPUMemoryInformation(dev, free_mem, total_mem); API_END(); } -int MXGetVersion(int *out) { +int MXGetVersion(int* out) { API_BEGIN(); *out = static_cast(MXNET_VERSION); API_END(); } #if MXNET_USE_TVM_OP -int MXLoadTVMOp(const char *libpath) { +int MXLoadTVMOp(const char* libpath) { API_BEGIN(); tvm::runtime::TVMOpModule::Get()->Load(libpath); - tvm::runtime::TVMOpModule *global_module = tvm::runtime::TVMOpModule::Get(); + tvm::runtime::TVMOpModule* global_module = tvm::runtime::TVMOpModule::Get(); global_module->Load(libpath); #if MXNET_USE_CUDA std::string libpathstr(libpath); @@ -1718,8 +1973,9 @@ int MXLoadTVMOp(const char *libpath) { int MXLoadTVMConfig(ConfigSpaces config) { API_BEGIN(); for (int k = 0; k < config.spaces_size; ++k) { - tvm::runtime::TVMOpConfig& entry = ::dmlc::Registry::Get() - ->__REGISTER_OR_GET__(std::string(config.spaces_key[k])); + tvm::runtime::TVMOpConfig& entry = + ::dmlc::Registry::Get()->__REGISTER_OR_GET__( + std::string(config.spaces_key[k])); const ConfigSpace& c = config.spaces_val[k]; for (int i = 0; i < c.entity_map_size; ++i) { entry.add_entity(std::string(c.entity_map_key[i]), c.entity_map_val[i].val); @@ -1739,13 +1995,13 @@ int MXLoadTVMConfig(ConfigSpaces config) { #endif // MXNET_USE_TVM_OP -int MXNDArrayCreateNone(NDArrayHandle *out) { +int MXNDArrayCreateNone(NDArrayHandle* out) { API_BEGIN(); *out = new NDArray(); API_END(); } -template +template void CreateNDArray(const DataType* shape, int ndim, int dev_type, @@ -1755,55 +2011,56 @@ void CreateNDArray(const DataType* shape, NDArrayHandle* out) { mxnet::TShape requested_shape = mxnet::TShape(shape, shape + ndim); if (!features::is_enabled(features::INT64_TENSOR_SIZE)) { - CHECK_LT(requested_shape.Size(), (int64_t{1} << 31) - 1) << - "[CreateNDArray] Size of tensor you are trying to allocate is larger than " - "2^31 elements. Please build with flag USE_INT64_TENSOR_SIZE=1"; + CHECK_LT(requested_shape.Size(), (int64_t{1} << 31) - 1) + << "[CreateNDArray] Size of tensor you are trying to allocate is larger than " + "2^31 elements. Please build with flag USE_INT64_TENSOR_SIZE=1"; } NDArray* nd = new NDArray(requested_shape, Context::Create(static_cast(dev_type), dev_id), - delay_alloc != 0, dtype); + delay_alloc != 0, + dtype); nd->AssignStorageInfo(profiler::ProfilerScope::Get()->GetCurrentProfilerScope(), MXNET_STORAGE_DEFAULT_NAME_CSTR); *out = nd; } -int MXNDArrayCreate64(const int64_t *shape, +int MXNDArrayCreate64(const int64_t* shape, int ndim, int dev_type, int dev_id, int delay_alloc, int dtype, - NDArrayHandle *out) { + NDArrayHandle* out) { API_BEGIN(); CreateNDArray(shape, ndim, dev_type, dev_id, delay_alloc, dtype, out); API_END(); } -int MXNDArrayCreate(const uint32_t *shape, +int MXNDArrayCreate(const uint32_t* shape, uint32_t ndim, int dev_type, int dev_id, int delay_alloc, int dtype, - NDArrayHandle *out) { + NDArrayHandle* out) { API_BEGIN(); CreateNDArray(shape, static_cast(ndim), dev_type, dev_id, delay_alloc, dtype, out); API_END(); } -template +template void CreateSparseNDArray(int storage_type, - const DType *shape, + const DType* shape, int ndim, int dev_type, int dev_id, int delay_alloc, int dtype, uint32_t num_aux, - int *aux_type, - int *aux_ndims, - const DType *aux_shape, - NDArrayHandle *out) { + int* aux_type, + int* aux_ndims, + const DType* aux_shape, + NDArrayHandle* out) { std::vector aux_types; mxnet::ShapeVector aux_shapes; auto shape_start = aux_shape; @@ -1814,63 +2071,78 @@ void CreateSparseNDArray(int storage_type, aux_shapes.emplace_back(shape_start, shape_start + aux_ndims[i]); shape_start += aux_ndims[i]; } - NDArray* nd = new NDArray( - NDArrayStorageType(storage_type), - mxnet::TShape(shape, shape + ndim), - Context::Create(static_cast(dev_type), dev_id), - delay_alloc != 0, - dtype, aux_types, aux_shapes); + NDArray* nd = new NDArray(NDArrayStorageType(storage_type), + mxnet::TShape(shape, shape + ndim), + Context::Create(static_cast(dev_type), dev_id), + delay_alloc != 0, + dtype, + aux_types, + aux_shapes); nd->AssignStorageInfo(profiler::ProfilerScope::Get()->GetCurrentProfilerScope(), MXNET_STORAGE_DEFAULT_NAME_CSTR); *out = nd; } int MXNDArrayCreateSparseEx(int storage_type, - const uint32_t *shape, + const uint32_t* shape, uint32_t ndim, int dev_type, int dev_id, int delay_alloc, int dtype, uint32_t num_aux, - int *aux_type, - uint32_t *aux_ndims, - const uint32_t *aux_shape, - NDArrayHandle *out) { - API_BEGIN(); - CreateSparseNDArray(storage_type, shape, static_cast(ndim), dev_type, dev_id, - delay_alloc, dtype, num_aux, aux_type, - reinterpret_cast(aux_ndims), aux_shape, out); + int* aux_type, + uint32_t* aux_ndims, + const uint32_t* aux_shape, + NDArrayHandle* out) { + API_BEGIN(); + CreateSparseNDArray(storage_type, + shape, + static_cast(ndim), + dev_type, + dev_id, + delay_alloc, + dtype, + num_aux, + aux_type, + reinterpret_cast(aux_ndims), + aux_shape, + out); API_END(); } - int MXNDArrayCreateSparseEx64(int storage_type, - const int64_t *shape, - int ndim, - int dev_type, - int dev_id, - int delay_alloc, - int dtype, - uint32_t num_aux, - int *aux_type, - int *aux_ndims, - const int64_t *aux_shape, - NDArrayHandle *out) { - API_BEGIN(); - CreateSparseNDArray(storage_type, shape, static_cast(ndim), dev_type, dev_id, - delay_alloc, dtype, num_aux, aux_type, - reinterpret_cast(aux_ndims), aux_shape, out); - API_END(); -} - - -int MXNDArrayLoadFromRawBytes(const void *buf, - size_t size, - NDArrayHandle *out) { - NDArray *ptr = nullptr; - API_BEGIN(); - dmlc::MemoryFixedSizeStream strm((void*)buf, size); // NOLINT(*) + const int64_t* shape, + int ndim, + int dev_type, + int dev_id, + int delay_alloc, + int dtype, + uint32_t num_aux, + int* aux_type, + int* aux_ndims, + const int64_t* aux_shape, + NDArrayHandle* out) { + API_BEGIN(); + CreateSparseNDArray(storage_type, + shape, + static_cast(ndim), + dev_type, + dev_id, + delay_alloc, + dtype, + num_aux, + aux_type, + reinterpret_cast(aux_ndims), + aux_shape, + out); + API_END(); +} + +int MXNDArrayLoadFromRawBytes(const void* buf, size_t size, NDArrayHandle* out) { + NDArray* ptr = nullptr; + API_BEGIN(); + dmlc::MemoryFixedSizeStream strm((void*)buf, size); // NOLINT(*) ptr = new NDArray(); if (!ptr->Load(&strm)) { throw dmlc::Error("Invalid NDArray serialization format"); @@ -1879,30 +2151,24 @@ int MXNDArrayLoadFromRawBytes(const void *buf, API_END_HANDLE_ERROR(delete ptr); } -int MXNDArraySaveRawBytes(NDArrayHandle handle, - size_t *out_size, - const char **out_buf) { - MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get(); +int MXNDArraySaveRawBytes(NDArrayHandle handle, size_t* out_size, const char** out_buf) { + MXAPIThreadLocalEntry<>* ret = MXAPIThreadLocalStore<>::Get(); API_BEGIN(); ret->ret_str.resize(0); dmlc::MemoryStringStream strm(&ret->ret_str); static_cast(handle)->Save(&strm); *out_size = ret->ret_str.length(); - *out_buf = ret->ret_str.c_str(); + *out_buf = ret->ret_str.c_str(); API_END(); } -int MXNDArraySyncCopyFromCPU(NDArrayHandle handle, - const void *data, - size_t size) { +int MXNDArraySyncCopyFromCPU(NDArrayHandle handle, const void* data, size_t size) { API_BEGIN(); static_cast(handle)->SyncCopyFromCPU(data, size); API_END(); } -int MXNDArraySyncCopyToCPU(NDArrayHandle handle, - void *data, - size_t size) { +int MXNDArraySyncCopyToCPU(NDArrayHandle handle, void* data, size_t size) { API_BEGIN(); static_cast(handle)->SyncCopyToCPU(data, size); API_END(); @@ -1927,7 +2193,7 @@ int MXNDArraySyncCopyFromNDArray(NDArrayHandle handle_dst, int MXNDArraySyncCheckFormat(NDArrayHandle handle, const bool full_check) { API_BEGIN(); - NDArray *arr = static_cast(handle); + NDArray* arr = static_cast(handle); arr->SyncCheckFormat(full_check); API_END(); } @@ -1973,10 +2239,7 @@ int MXNDArrayLegacySave(const char* fname, API_END(); } -int MXNDArraySave(const char* fname, - uint32_t num_args, - NDArrayHandle* args, - const char** keys) { +int MXNDArraySave(const char* fname, uint32_t num_args, NDArrayHandle* args, const char** keys) { API_BEGIN(); CHECK_NOTNULL(fname); @@ -1985,142 +2248,141 @@ int MXNDArraySave(const char* fname, // and write an adapter for DMLC stream based on pZip->m_pWrite (and // pZip->m_pIO_opaque) if (num_args == 1 && keys == nullptr) { - NDArray *array = static_cast(args[0]); - if (array->storage_type() == kDefaultStorage) { - npy::save_array(fname, *array); - } else { - mz_zip_archive archive {}; - CHECK(mz_zip_writer_init_file(&archive, fname, 0)) - << "Failed to open archive " << fname << ": " - << mz_zip_get_error_string(mz_zip_get_last_error(&archive)); - npz::save_array(&archive, "", *array); - CHECK(mz_zip_writer_finalize_archive(&archive)) - << "Failed to finalize archive " << fname - << mz_zip_get_error_string(mz_zip_get_last_error(&archive)); - CHECK(mz_zip_writer_end(&archive)) - << "Failed to end archive " << fname - << mz_zip_get_error_string(mz_zip_get_last_error(&archive)); - } - } else { - mz_zip_archive archive {}; + NDArray* array = static_cast(args[0]); + if (array->storage_type() == kDefaultStorage) { + npy::save_array(fname, *array); + } else { + mz_zip_archive archive{}; CHECK(mz_zip_writer_init_file(&archive, fname, 0)) << "Failed to open archive " << fname << ": " << mz_zip_get_error_string(mz_zip_get_last_error(&archive)); - for (uint32_t i = 0; i < num_args; ++i) { - NDArray *array = static_cast(args[i]); - const std::string array_key = keys == nullptr ? "arr_" + std::to_string(i) : keys[i]; - npz::save_array(&archive, array_key, *array); - } + npz::save_array(&archive, "", *array); CHECK(mz_zip_writer_finalize_archive(&archive)) << "Failed to finalize archive " << fname << mz_zip_get_error_string(mz_zip_get_last_error(&archive)); CHECK(mz_zip_writer_end(&archive)) << "Failed to end archive " << fname << mz_zip_get_error_string(mz_zip_get_last_error(&archive)); + } + } else { + mz_zip_archive archive{}; + CHECK(mz_zip_writer_init_file(&archive, fname, 0)) + << "Failed to open archive " << fname << ": " + << mz_zip_get_error_string(mz_zip_get_last_error(&archive)); + for (uint32_t i = 0; i < num_args; ++i) { + NDArray* array = static_cast(args[i]); + const std::string array_key = keys == nullptr ? "arr_" + std::to_string(i) : keys[i]; + npz::save_array(&archive, array_key, *array); + } + CHECK(mz_zip_writer_finalize_archive(&archive)) + << "Failed to finalize archive " << fname + << mz_zip_get_error_string(mz_zip_get_last_error(&archive)); + CHECK(mz_zip_writer_end(&archive)) << "Failed to end archive " << fname + << mz_zip_get_error_string(mz_zip_get_last_error(&archive)); } API_END(); } int MXNDArrayLoad(const char* fname, - uint32_t *out_size, + uint32_t* out_size, NDArrayHandle** out_arr, - uint32_t *out_name_size, + uint32_t* out_name_size, const char*** out_names) { - MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get(); + MXAPIThreadLocalEntry<>* ret = MXAPIThreadLocalStore<>::Get(); ret->ret_vec_str.clear(); API_BEGIN(); uint32_t magic; { - std::unique_ptr strm(dmlc::Stream::Create(fname, "r")); - CHECK_EQ(strm->Read(&magic, sizeof(uint32_t)), sizeof(uint32_t)) + std::unique_ptr strm(dmlc::Stream::Create(fname, "r")); + CHECK_EQ(strm->Read(&magic, sizeof(uint32_t)), sizeof(uint32_t)) << "Failed to read 32 bits from file."; } - if (magic == 0x04034b50 || magic == 0x504b0304 || - magic == 0x06054b50 || magic == 0x504b0506) { // zip file format; assumed to be npz - auto[data, names] = npz::load_arrays(fname); - ret->ret_handles.resize(data.size()); - for (size_t i = 0; i < data.size(); ++i) { - NDArray *ptr = new NDArray(); - *ptr = data[i]; - ret->ret_handles[i] = ptr; - } - ret->ret_vec_str.resize(names.size()); - for (size_t i = 0; i < names.size(); ++i) { - ret->ret_vec_str[i] = names[i]; - } - ret->ret_vec_charp.resize(names.size()); - for (size_t i = 0; i < names.size(); ++i) { - ret->ret_vec_charp[i] = ret->ret_vec_str[i].c_str(); - } - *out_size = static_cast(data.size()); - *out_arr = dmlc::BeginPtr(ret->ret_handles); - *out_name_size = static_cast(names.size()); - *out_names = dmlc::BeginPtr(ret->ret_vec_charp); + if (magic == 0x04034b50 || magic == 0x504b0304 || magic == 0x06054b50 || + magic == 0x504b0506) { // zip file format; assumed to be npz + auto [data, names] = npz::load_arrays(fname); // NOLINT + ret->ret_handles.resize(data.size()); + for (size_t i = 0; i < data.size(); ++i) { + NDArray* ptr = new NDArray(); + *ptr = data[i]; + ret->ret_handles[i] = ptr; + } + ret->ret_vec_str.resize(names.size()); + for (size_t i = 0; i < names.size(); ++i) { + ret->ret_vec_str[i] = names[i]; + } + ret->ret_vec_charp.resize(names.size()); + for (size_t i = 0; i < names.size(); ++i) { + ret->ret_vec_charp[i] = ret->ret_vec_str[i].c_str(); + } + *out_size = static_cast(data.size()); + *out_arr = dmlc::BeginPtr(ret->ret_handles); + *out_name_size = static_cast(names.size()); + *out_names = dmlc::BeginPtr(ret->ret_vec_charp); } else if (magic == 0x4d554e93 || magic == 0x934e554d) { // first bytes of npy format - *out_size = 1; - ret->ret_handles.resize(1); - NDArray *ptr = new NDArray(); - *ptr = npy::load_array(fname); // Only supports local filesystem at this point in time - ret->ret_handles[0] = ptr; - *out_arr = dmlc::BeginPtr(ret->ret_handles); + *out_size = 1; + ret->ret_handles.resize(1); + NDArray* ptr = new NDArray(); + *ptr = npy::load_array(fname); // Only supports local filesystem at this point in time + ret->ret_handles[0] = ptr; + *out_arr = dmlc::BeginPtr(ret->ret_handles); } else { - std::vector data; - std::vector &names = ret->ret_vec_str; - { - std::unique_ptr fi(dmlc::Stream::Create(fname, "r")); - mxnet::NDArray::Load(fi.get(), &data, &names); - } - ret->ret_handles.resize(data.size()); - for (size_t i = 0; i < data.size(); ++i) { - NDArray *ptr = new NDArray(); - *ptr = data[i]; - ret->ret_handles[i] = ptr; - } - ret->ret_vec_charp.resize(names.size()); - for (size_t i = 0; i < names.size(); ++i) { - ret->ret_vec_charp[i] = names[i].c_str(); - } - *out_size = static_cast(data.size()); - *out_arr = dmlc::BeginPtr(ret->ret_handles); - *out_name_size = static_cast(names.size()); - *out_names = dmlc::BeginPtr(ret->ret_vec_charp); + std::vector data; + std::vector& names = ret->ret_vec_str; + { + std::unique_ptr fi(dmlc::Stream::Create(fname, "r")); + mxnet::NDArray::Load(fi.get(), &data, &names); + } + ret->ret_handles.resize(data.size()); + for (size_t i = 0; i < data.size(); ++i) { + NDArray* ptr = new NDArray(); + *ptr = data[i]; + ret->ret_handles[i] = ptr; + } + ret->ret_vec_charp.resize(names.size()); + for (size_t i = 0; i < names.size(); ++i) { + ret->ret_vec_charp[i] = names[i].c_str(); + } + *out_size = static_cast(data.size()); + *out_arr = dmlc::BeginPtr(ret->ret_handles); + *out_name_size = static_cast(names.size()); + *out_names = dmlc::BeginPtr(ret->ret_vec_charp); } API_END(); } -int MXNDArrayLoadFromBuffer(const void *ndarray_buffer, +int MXNDArrayLoadFromBuffer(const void* ndarray_buffer, size_t size, - uint32_t *out_size, + uint32_t* out_size, NDArrayHandle** out_arr, - uint32_t *out_name_size, + uint32_t* out_name_size, const char*** out_names) { - MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get(); + MXAPIThreadLocalEntry<>* ret = MXAPIThreadLocalStore<>::Get(); ret->ret_vec_str.clear(); API_BEGIN(); CHECK_NOTNULL(ndarray_buffer); std::vector data; - std::vector &names = ret->ret_vec_str; + std::vector& names = ret->ret_vec_str; { - std::unique_ptr fi(new dmlc::MemoryFixedSizeStream( - const_cast(ndarray_buffer), size)); + std::unique_ptr fi( + new dmlc::MemoryFixedSizeStream(const_cast(ndarray_buffer), size)); mxnet::NDArray::Load(fi.get(), &data, &names); } ret->ret_handles.resize(data.size()); for (size_t i = 0; i < data.size(); ++i) { - NDArray *ptr = new NDArray(); - *ptr = data[i]; + NDArray* ptr = new NDArray(); + *ptr = data[i]; ret->ret_handles[i] = ptr; } ret->ret_vec_charp.resize(names.size()); for (size_t i = 0; i < names.size(); ++i) { ret->ret_vec_charp[i] = names[i].c_str(); } - *out_size = static_cast(data.size()); - *out_arr = dmlc::BeginPtr(ret->ret_handles); + *out_size = static_cast(data.size()); + *out_arr = dmlc::BeginPtr(ret->ret_handles); *out_name_size = static_cast(names.size()); - *out_names = dmlc::BeginPtr(ret->ret_vec_charp); + *out_names = dmlc::BeginPtr(ret->ret_vec_charp); API_END(); } @@ -2130,8 +2392,11 @@ int MXNDArrayFree(NDArrayHandle handle) { API_END(); } -template -void SliceArray(NDArrayHandle handle, dtype slice_begin, dtype slice_end, NDArray* ptr, +template +void SliceArray(NDArrayHandle handle, + dtype slice_begin, + dtype slice_end, + NDArray* ptr, NDArrayHandle* out) { *ptr = static_cast(handle)->SliceWithRecord(slice_begin, slice_end); *out = ptr; @@ -2140,8 +2405,8 @@ void SliceArray(NDArrayHandle handle, dtype slice_begin, dtype slice_end, NDArra int MXNDArraySlice(NDArrayHandle handle, uint32_t slice_begin, uint32_t slice_end, - NDArrayHandle *out) { - NDArray *ptr = new NDArray(); + NDArrayHandle* out) { + NDArray* ptr = new NDArray(); API_BEGIN(); SliceArray(handle, slice_begin, slice_end, ptr, out); API_END_HANDLE_ERROR(delete ptr); @@ -2150,55 +2415,45 @@ int MXNDArraySlice(NDArrayHandle handle, int MXNDArraySlice64(NDArrayHandle handle, int64_t slice_begin, int64_t slice_end, - NDArrayHandle *out) { - NDArray *ptr = new NDArray(); + NDArrayHandle* out) { + NDArray* ptr = new NDArray(); API_BEGIN(); SliceArray(handle, slice_begin, slice_end, ptr, out); API_END_HANDLE_ERROR(delete ptr); } -int MXNDArrayAt(NDArrayHandle handle, - uint32_t idx, - NDArrayHandle *out) { - NDArray *ptr = new NDArray(); +int MXNDArrayAt(NDArrayHandle handle, uint32_t idx, NDArrayHandle* out) { + NDArray* ptr = new NDArray(); API_BEGIN(); *ptr = static_cast(handle)->AtWithRecord(idx); *out = ptr; API_END_HANDLE_ERROR(delete ptr); } -int MXNDArrayAt64(NDArrayHandle handle, - int64_t idx, - NDArrayHandle *out) { - NDArray *ptr = new NDArray(); +int MXNDArrayAt64(NDArrayHandle handle, int64_t idx, NDArrayHandle* out) { + NDArray* ptr = new NDArray(); API_BEGIN(); *ptr = static_cast(handle)->AtWithRecord(idx); *out = ptr; API_END_HANDLE_ERROR(delete ptr); } -int MXNDArrayReshape(NDArrayHandle handle, - int ndim, - int *dims, - NDArrayHandle *out) { - NDArray *ptr = new NDArray(); +int MXNDArrayReshape(NDArrayHandle handle, int ndim, int* dims, NDArrayHandle* out) { + NDArray* ptr = new NDArray(); API_BEGIN(); - NDArray *arr = static_cast(handle); - mxnet::TShape new_shape(dims, dims+ndim); + NDArray* arr = static_cast(handle); + mxnet::TShape new_shape(dims, dims + ndim); int size = 1; - int pos = -1; + int pos = -1; for (int i = 0; i < ndim; ++i) { int dim = dims[i]; if (dim == -1) { - CHECK_EQ(pos, -1) - << "Invalid new shape " << new_shape - << ": more than one dimensions are -1"; + CHECK_EQ(pos, -1) << "Invalid new shape " << new_shape << ": more than one dimensions are -1"; pos = i; } else { if (dim == 0) { - CHECK_LT(i, arr->shape().ndim()) - << "Invalid new shape " << new_shape - << ": 0 dimension exceeds original shape " << arr->shape(); + CHECK_LT(i, arr->shape().ndim()) << "Invalid new shape " << new_shape + << ": 0 dimension exceeds original shape " << arr->shape(); dim = arr->shape()[i]; } size *= dim; @@ -2214,24 +2469,23 @@ int MXNDArrayReshape(NDArrayHandle handle, } int MXNDArrayReshape64(NDArrayHandle handle, - int ndim, - dim_t *dims, - bool reverse, - NDArrayHandle *out) { - NDArray *ptr = new NDArray(); + int ndim, + dim_t* dims, + bool reverse, + NDArrayHandle* out) { + NDArray* ptr = new NDArray(); API_BEGIN(); - NDArray *arr = static_cast(handle); - mxnet::Tuple shape(dims, dims+ndim); + NDArray* arr = static_cast(handle); + mxnet::Tuple shape(dims, dims + ndim); mxnet::TShape new_shape = mxnet::op::InferReshapeShape(shape, arr->shape(), reverse); - *ptr = arr->ReshapeWithRecord(new_shape); - *out = ptr; + *ptr = arr->ReshapeWithRecord(new_shape); + *out = ptr; API_END_HANDLE_ERROR(delete ptr); } -int MXNDArrayGetStorageType(NDArrayHandle handle, - int *out_storage_type) { +int MXNDArrayGetStorageType(NDArrayHandle handle, int* out_storage_type) { API_BEGIN(); - NDArray *arr = static_cast(handle); + NDArray* arr = static_cast(handle); if (!arr->is_none()) { *out_storage_type = arr->storage_type(); } else { @@ -2240,8 +2494,10 @@ int MXNDArrayGetStorageType(NDArrayHandle handle, API_END(); } -template -inline void GetShape(NDArrayHandle handle, const dtype** out_pdata, int* out_dim, +template +inline void GetShape(NDArrayHandle handle, + const dtype** out_pdata, + int* out_dim, MXAPIThreadLocalEntry* ret) { NDArray* arr = static_cast(handle); if (!arr->is_none()) { @@ -2255,9 +2511,9 @@ inline void GetShape(NDArrayHandle handle, const dtype** out_pdata, int* out_dim } if (!features::is_enabled(features::INT64_TENSOR_SIZE)) { - CHECK_LT(s.Size(), (int64_t{1} << 31) - 1) << - "[Get Shape] Size of tensor you are trying to allocate is larger than " - "2^31 elements. Please build with flag USE_INT64_TENSOR_SIZE=1"; + CHECK_LT(s.Size(), (int64_t{1} << 31) - 1) + << "[Get Shape] Size of tensor you are trying to allocate is larger than " + "2^31 elements. Please build with flag USE_INT64_TENSOR_SIZE=1"; } if (!Imperative::Get()->is_np_shape()) { @@ -2265,7 +2521,7 @@ inline void GetShape(NDArrayHandle handle, const dtype** out_pdata, int* out_dim } *out_dim = s.ndim(); if (s.ndim() >= 0) { - std::vector &buffer = ret->arg_shape_buffer_ex; + std::vector& buffer = ret->arg_shape_buffer_ex; buffer.resize(s.ndim()); mxnet::ShapeTypeCast(s.begin(), s.end(), buffer.data()); *out_pdata = buffer.data(); @@ -2279,28 +2535,23 @@ inline void GetShape(NDArrayHandle handle, const dtype** out_pdata, int* out_dim } } -int MXNDArrayGetShape(NDArrayHandle handle, - int *out_dim, - const int **out_pdata) { - MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get(); +int MXNDArrayGetShape(NDArrayHandle handle, int* out_dim, const int** out_pdata) { + MXAPIThreadLocalEntry<>* ret = MXAPIThreadLocalStore<>::Get(); API_BEGIN(); GetShape(handle, out_pdata, out_dim, ret); API_END(); } -int MXNDArrayGetShape64(NDArrayHandle handle, - int *out_dim, - const int64_t **out_pdata) { - MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); +int MXNDArrayGetShape64(NDArrayHandle handle, int* out_dim, const int64_t** out_pdata) { + MXAPIThreadLocalEntry* ret = MXAPIThreadLocalStore::Get(); API_BEGIN(); GetShape(handle, out_pdata, out_dim, ret); API_END(); } -int MXNDArrayGetData(NDArrayHandle handle, - void **out_pdata) { +int MXNDArrayGetData(NDArrayHandle handle, void** out_pdata) { API_BEGIN(); - NDArray *arr = static_cast(handle); + NDArray* arr = static_cast(handle); #if MXNET_USE_ONEDNN == 1 if (arr->IsMKLDNNData()) { arr->Reorder2DefaultAsync(); @@ -2315,37 +2566,34 @@ int MXNDArrayGetData(NDArrayHandle handle, API_END(); } -int MXNDArrayToDLPack(NDArrayHandle handle, - DLManagedTensorHandle *out_dlpack) { +int MXNDArrayToDLPack(NDArrayHandle handle, DLManagedTensorHandle* out_dlpack) { API_BEGIN(); - NDArray *arr = static_cast(handle); - *out_dlpack = arr->ToDLPack(); + NDArray* arr = static_cast(handle); + *out_dlpack = arr->ToDLPack(); API_END(); } int MXNDArrayFromDLPack(DLManagedTensorHandle dlpack, const bool transient_handle, - NDArrayHandle *out_handle) { + NDArrayHandle* out_handle) { API_BEGIN(); - *out_handle = new NDArray(NDArray::FromDLPack( - static_cast(dlpack), - transient_handle)); + *out_handle = + new NDArray(NDArray::FromDLPack(static_cast(dlpack), transient_handle)); API_END(); } int MXNDArrayCallDLPackDeleter(DLManagedTensorHandle dlpack) { API_BEGIN(); if (dlpack != nullptr) { - DLManagedTensor *p_dlpack = static_cast(dlpack); + DLManagedTensor* p_dlpack = static_cast(dlpack); p_dlpack->deleter(p_dlpack); } API_END(); } -int MXNDArrayGetDType(NDArrayHandle handle, - int *out_dtype) { +int MXNDArrayGetDType(NDArrayHandle handle, int* out_dtype) { API_BEGIN(); - NDArray *arr = static_cast(handle); + NDArray* arr = static_cast(handle); if (!arr->is_none()) { *out_dtype = arr->dtype(); } else { @@ -2354,12 +2602,10 @@ int MXNDArrayGetDType(NDArrayHandle handle, API_END(); } -int MXNDArrayGetAuxType(NDArrayHandle handle, - uint32_t i, - int *out_type) { +int MXNDArrayGetAuxType(NDArrayHandle handle, uint32_t i, int* out_type) { API_BEGIN(); - NDArray *arr = static_cast(handle); - *out_type = arr->aux_type(i); + NDArray* arr = static_cast(handle); + *out_type = arr->aux_type(i); API_END(); } @@ -2368,12 +2614,10 @@ int MXNDArrayGetAuxType(NDArrayHandle handle, * in the form of an NDArray of default storage type. * This function blocks. Do not use it in performance critical code. */ -int MXNDArrayGetAuxNDArray(NDArrayHandle handle, - uint32_t i, - NDArrayHandle *out) { +int MXNDArrayGetAuxNDArray(NDArrayHandle handle, uint32_t i, NDArrayHandle* out) { API_BEGIN(); - NDArray *arr = static_cast(handle); - *out = new NDArray(arr->aux_ndarray(i)); + NDArray* arr = static_cast(handle); + *out = new NDArray(arr->aux_ndarray(i)); API_END(); } @@ -2382,35 +2626,31 @@ int MXNDArrayGetAuxNDArray(NDArrayHandle handle, * in the form of an NDArray of default storage type. * This function blocks. Do not use it in performance critical code. */ -int MXNDArrayGetDataNDArray(NDArrayHandle handle, - NDArrayHandle *out) { +int MXNDArrayGetDataNDArray(NDArrayHandle handle, NDArrayHandle* out) { API_BEGIN(); - NDArray *arr = static_cast(handle); - *out = new NDArray(arr->data_ndarray()); + NDArray* arr = static_cast(handle); + *out = new NDArray(arr->data_ndarray()); API_END(); } -int MXNDArrayGetContext(NDArrayHandle handle, - int *out_dev_type, - int *out_dev_id) { +int MXNDArrayGetContext(NDArrayHandle handle, int* out_dev_type, int* out_dev_id) { API_BEGIN(); - NDArray *arr = static_cast(handle); + NDArray* arr = static_cast(handle); if (!arr->is_none()) { - const Context &ctx = arr->ctx(); - *out_dev_type = ctx.dev_type; - *out_dev_id = ctx.dev_id; + const Context& ctx = arr->ctx(); + *out_dev_type = ctx.dev_type; + *out_dev_id = ctx.dev_id; } else { *out_dev_type = 0; - *out_dev_id = 0; + *out_dev_id = 0; } API_END(); } - -int MXNDArrayGetGrad(NDArrayHandle handle, NDArrayHandle *out) { +int MXNDArrayGetGrad(NDArrayHandle handle, NDArrayHandle* out) { API_BEGIN(); - NDArray *arr = static_cast(handle); - NDArray ret = arr->grad(); + NDArray* arr = static_cast(handle); + NDArray ret = arr->grad(); if (ret.is_none()) { *out = nullptr; } else { @@ -2419,80 +2659,82 @@ int MXNDArrayGetGrad(NDArrayHandle handle, NDArrayHandle *out) { API_END(); } -int MXNDArrayDetach(NDArrayHandle handle, NDArrayHandle *out) { +int MXNDArrayDetach(NDArrayHandle handle, NDArrayHandle* out) { API_BEGIN(); - NDArray *arr = static_cast(handle); - *out = new NDArray(arr->Detach()); + NDArray* arr = static_cast(handle); + *out = new NDArray(arr->Detach()); API_END(); } int MXNDArraySetGradState(NDArrayHandle handle, int state) { API_BEGIN(); - NDArray *arr = static_cast(handle); + NDArray* arr = static_cast(handle); arr->set_fresh_out_grad(static_cast(state)); API_END(); } -int MXNDArrayGetGradState(NDArrayHandle handle, int *out) { +int MXNDArrayGetGradState(NDArrayHandle handle, int* out) { API_BEGIN(); - NDArray *arr = static_cast(handle); - *out = arr->fresh_out_grad(); + NDArray* arr = static_cast(handle); + *out = arr->fresh_out_grad(); API_END(); } -int MXListFunctions(uint32_t *out_size, - FunctionHandle **out_array) { +int MXListFunctions(uint32_t* out_size, FunctionHandle** out_array) { API_BEGIN(); - auto &vec = dmlc::Registry::List(); - *out_size = static_cast(vec.size()); + auto& vec = dmlc::Registry::List(); + *out_size = static_cast(vec.size()); *out_array = (FunctionHandle*)(dmlc::BeginPtr(vec)); // NOLINT(*) API_END(); } -int MXGetFunction(const char *name, - FunctionHandle *out) { +int MXGetFunction(const char* name, FunctionHandle* out) { API_BEGIN(); *out = dmlc::Registry::Find(name); API_END(); } int MXFuncGetInfo(FunctionHandle fun, - const char **name, - const char **description, - uint32_t *num_args, - const char ***arg_names, - const char ***arg_type_infos, - const char ***arg_descriptions, - const char **return_type) { - return MXAPIGetFunctionRegInfo(static_cast(fun), - name, description, num_args, - arg_names, arg_type_infos, arg_descriptions, + const char** name, + const char** description, + uint32_t* num_args, + const char*** arg_names, + const char*** arg_type_infos, + const char*** arg_descriptions, + const char** return_type) { + return MXAPIGetFunctionRegInfo(static_cast(fun), + name, + description, + num_args, + arg_names, + arg_type_infos, + arg_descriptions, return_type); } int MXFuncDescribe(FunctionHandle fun, - uint32_t *num_use_vars, - uint32_t *num_scalars, - uint32_t *num_mutate_vars, - int *type_mask) { - API_BEGIN(); - auto *f = static_cast(fun); - *num_use_vars = f->num_use_vars; - *num_scalars = f->num_scalars; + uint32_t* num_use_vars, + uint32_t* num_scalars, + uint32_t* num_mutate_vars, + int* type_mask) { + API_BEGIN(); + auto* f = static_cast(fun); + *num_use_vars = f->num_use_vars; + *num_scalars = f->num_scalars; *num_mutate_vars = f->num_mutate_vars; - *type_mask = f->type_mask; + *type_mask = f->type_mask; API_END(); } int MXFuncInvoke(FunctionHandle fun, - NDArrayHandle *use_vars, - float *scalar_args, - NDArrayHandle *mutate_vars, + NDArrayHandle* use_vars, + float* scalar_args, + NDArrayHandle* mutate_vars, int num_params, - char **param_keys, - char **param_vals) { + char** param_keys, + char** param_vals) { API_BEGIN(); - auto *f = static_cast(fun); + auto* f = static_cast(fun); f->body((NDArray**)(use_vars), // NOLINT(*) scalar_args, (NDArray**)(mutate_vars), // NOLINT(*) @@ -2505,38 +2747,36 @@ int MXFuncInvoke(FunctionHandle fun, //-------------------------------------------- // Part 5: IO Interface //-------------------------------------------- -int MXListDataIters(uint32_t *out_size, - DataIterCreator **out_array) { +int MXListDataIters(uint32_t* out_size, DataIterCreator** out_array) { API_BEGIN(); - auto &vec = dmlc::Registry::List(); - *out_size = static_cast(vec.size()); + auto& vec = dmlc::Registry::List(); + *out_size = static_cast(vec.size()); *out_array = (DataIterCreator*)(dmlc::BeginPtr(vec)); // NOLINT(*) API_END(); } int MXDataIterGetIterInfo(DataIterCreator creator, - const char **name, - const char **description, - uint32_t *num_args, - const char ***arg_names, - const char ***arg_type_infos, - const char ***arg_descriptions) { - DataIteratorReg *e = static_cast(creator); - return MXAPIGetFunctionRegInfo(e, name, description, num_args, - arg_names, arg_type_infos, arg_descriptions, - nullptr); + const char** name, + const char** description, + uint32_t* num_args, + const char*** arg_names, + const char*** arg_type_infos, + const char*** arg_descriptions) { + DataIteratorReg* e = static_cast(creator); + return MXAPIGetFunctionRegInfo( + e, name, description, num_args, arg_names, arg_type_infos, arg_descriptions, nullptr); } int MXDataIterCreateIter(DataIterCreator creator, uint32_t num_param, - const char **keys, - const char **vals, - DataIterHandle *out) { - IIterator *iter = nullptr; - API_BEGIN(); - DataIteratorReg *e = static_cast(creator); - iter = e->body(); - std::vector > kwargs; + const char** keys, + const char** vals, + DataIterHandle* out) { + IIterator* iter = nullptr; + API_BEGIN(); + DataIteratorReg* e = static_cast(creator); + iter = e->body(); + std::vector> kwargs; for (uint32_t i = 0; i < num_param; ++i) { kwargs.emplace_back(std::string(keys[i]), std::string(vals[i])); } @@ -2547,36 +2787,39 @@ int MXDataIterCreateIter(DataIterCreator creator, int MXDataIterFree(DataIterHandle handle) { API_BEGIN(); - delete static_cast *>(handle); + delete static_cast*>(handle); API_END(); } int MXDataIterBeforeFirst(DataIterHandle handle) { API_BEGIN(); - static_cast* >(handle)->BeforeFirst(); + static_cast*>(handle)->BeforeFirst(); API_END(); } -int MXDataIterGetLenHint(DataIterHandle handle, int64_t *len) { +int MXDataIterGetLenHint(DataIterHandle handle, int64_t* len) { API_BEGIN(); - *len = static_cast* >(handle)->GetLenHint(); + *len = static_cast*>(handle)->GetLenHint(); API_END(); } -int MXDataIterNext(DataIterHandle handle, int *out) { +int MXDataIterNext(DataIterHandle handle, int* out) { API_BEGIN(); - *out = static_cast* >(handle)->Next(); + *out = static_cast*>(handle)->Next(); API_END(); } -int MXDataIterGetLabel(DataIterHandle handle, NDArrayHandle *out) { +int MXDataIterGetLabel(DataIterHandle handle, NDArrayHandle* out) { API_BEGIN(); - const DataBatch& db = static_cast* >(handle)->Value(); - bool no_label = db.data.size() < 2U; - NDArray* pndarray = new NDArray(); + const DataBatch& db = static_cast*>(handle)->Value(); + bool no_label = db.data.size() < 2U; + NDArray* pndarray = new NDArray(); // temp hack to make label 1D // TODO(tianjun) make label 1D when label_width=0 - mxnet::TShape shape = no_label ? TShape({1, }) : db.data[1].shape(); + mxnet::TShape shape = no_label ? TShape({ + 1, + }) + : db.data[1].shape(); if (no_label || shape.Size() < 1) { // it's possible that label is not available and not required // but we need to bypass the invalid copy @@ -2590,26 +2833,27 @@ int MXDataIterGetLabel(DataIterHandle handle, NDArrayHandle *out) { API_END(); } -int MXDataIterGetItems(DataIterHandle handle, int* num_outputs, NDArrayHandle **outputs) { - MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get(); +int MXDataIterGetItems(DataIterHandle handle, int* num_outputs, NDArrayHandle** outputs) { + MXAPIThreadLocalEntry<>* ret = MXAPIThreadLocalStore<>::Get(); API_BEGIN(); - const DataBatch& db = static_cast* >(handle)->Value(); + const DataBatch& db = static_cast*>(handle)->Value(); std::vector ndoutputs; ndoutputs.reserve(db.data.size()); if (*outputs == nullptr) { *num_outputs = db.data.size(); - for (int i = 0; i < *num_outputs; ++i) ndoutputs.push_back(new NDArray()); + for (int i = 0; i < *num_outputs; ++i) + ndoutputs.push_back(new NDArray()); } else { - CHECK_EQ(*num_outputs, db.data.size()) - << "MXDataIterGetItems expects " << db.data.size() << " outputs, but " - << *num_outputs << " was given."; + CHECK_EQ(*num_outputs, db.data.size()) << "MXDataIterGetItems expects " << db.data.size() + << " outputs, but " << *num_outputs << " was given."; for (int i = 0; i < *num_outputs; ++i) { ndoutputs.push_back(reinterpret_cast((*outputs)[i])); } } // copy outputs - for (int i = 0; i < *num_outputs; ++i) *ndoutputs[i] = db.data[i]; + for (int i = 0; i < *num_outputs; ++i) + *ndoutputs[i] = db.data[i]; if (*outputs == nullptr) { ret->ret_handles.clear(); @@ -2622,67 +2866,65 @@ int MXDataIterGetItems(DataIterHandle handle, int* num_outputs, NDArrayHandle ** API_END(); } -int MXDataIterGetIndex(DataIterHandle handle, uint64_t **out_index, uint64_t *out_size) { +int MXDataIterGetIndex(DataIterHandle handle, uint64_t** out_index, uint64_t* out_size) { API_BEGIN(); - const DataBatch& db = static_cast* >(handle)->Value(); - *out_size = db.index.size(); - *out_index = const_cast(db.index.data()); + const DataBatch& db = static_cast*>(handle)->Value(); + *out_size = db.index.size(); + *out_index = const_cast(db.index.data()); API_END(); } -int MXDataIterGetData(DataIterHandle handle, NDArrayHandle *out) { +int MXDataIterGetData(DataIterHandle handle, NDArrayHandle* out) { API_BEGIN(); - const DataBatch& db = static_cast* >(handle)->Value(); - NDArray* pndarray = new NDArray(); - *pndarray = db.data[0]; - *out = pndarray; + const DataBatch& db = static_cast*>(handle)->Value(); + NDArray* pndarray = new NDArray(); + *pndarray = db.data[0]; + *out = pndarray; API_END(); } -int MXDataIterGetPadNum(DataIterHandle handle, int *pad) { +int MXDataIterGetPadNum(DataIterHandle handle, int* pad) { API_BEGIN(); - const DataBatch& db = static_cast* >(handle)->Value(); - *pad = db.num_batch_padd; + const DataBatch& db = static_cast*>(handle)->Value(); + *pad = db.num_batch_padd; API_END(); } -int MXListDatasets(uint32_t *out_size, - DatasetCreator **out_array) { +int MXListDatasets(uint32_t* out_size, DatasetCreator** out_array) { API_BEGIN(); - auto &vec = dmlc::Registry::List(); - *out_size = static_cast(vec.size()); + auto& vec = dmlc::Registry::List(); + *out_size = static_cast(vec.size()); *out_array = (DatasetCreator*)(dmlc::BeginPtr(vec)); // NOLINT(*) API_END(); } int MXDatasetCreateDataset(DatasetCreator handle, uint32_t num_param, - const char **keys, - const char **vals, - DatasetHandle *out) { - Dataset *dataset = nullptr; + const char** keys, + const char** vals, + DatasetHandle* out) { + Dataset* dataset = nullptr; API_BEGIN(); - DatasetReg *e = static_cast(handle); - std::vector > kwargs; + DatasetReg* e = static_cast(handle); + std::vector> kwargs; for (uint32_t i = 0; i < num_param; ++i) { kwargs.emplace_back(std::string(keys[i]), std::string(vals[i])); } dataset = e->body(kwargs); - *out = new std::shared_ptr(dataset); + *out = new std::shared_ptr(dataset); API_END_HANDLE_ERROR(delete dataset); } int MXDatasetGetDatasetInfo(DatasetCreator creator, - const char **name, - const char **description, - uint32_t *num_args, - const char ***arg_names, - const char ***arg_type_infos, - const char ***arg_descriptions) { - DatasetReg *e = static_cast(creator); - return MXAPIGetFunctionRegInfo(e, name, description, num_args, - arg_names, arg_type_infos, arg_descriptions, - nullptr); + const char** name, + const char** description, + uint32_t* num_args, + const char*** arg_names, + const char*** arg_type_infos, + const char*** arg_descriptions) { + DatasetReg* e = static_cast(creator); + return MXAPIGetFunctionRegInfo( + e, name, description, num_args, arg_names, arg_type_infos, arg_descriptions, nullptr); } int MXDatasetFree(DatasetHandle handle) { @@ -2691,37 +2933,38 @@ int MXDatasetFree(DatasetHandle handle) { API_END(); } -int MXDatasetGetLen(DatasetHandle handle, uint64_t *out) { +int MXDatasetGetLen(DatasetHandle handle, uint64_t* out) { API_BEGIN(); uint64_t len = (*static_cast*>(handle))->GetLen(); - *out = len; + *out = len; API_END(); } int MXDatasetGetItems(DatasetHandle handle, uint64_t index, int* num_outputs, - NDArrayHandle **outputs) { - MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get(); + NDArrayHandle** outputs) { + MXAPIThreadLocalEntry<>* ret = MXAPIThreadLocalStore<>::Get(); API_BEGIN(); std::vector res; CHECK((*static_cast*>(handle))->GetItem(index, &res)) - << "Error getting item at index: " << index; + << "Error getting item at index: " << index; std::vector ndoutputs; ndoutputs.reserve(res.size()); if (*outputs == nullptr) { *num_outputs = res.size(); - for (int i = 0; i < *num_outputs; ++i) ndoutputs.push_back(new NDArray()); + for (int i = 0; i < *num_outputs; ++i) + ndoutputs.push_back(new NDArray()); } else { - CHECK_EQ(*num_outputs, res.size()) - << "MXDatasetGetItems expects " << res.size() << " outputs, but " - << *num_outputs << " was given."; + CHECK_EQ(*num_outputs, res.size()) << "MXDatasetGetItems expects " << res.size() + << " outputs, but " << *num_outputs << " was given."; for (int i = 0; i < *num_outputs; ++i) { ndoutputs.push_back(reinterpret_cast((*outputs)[i])); } } // copy ndarrays - for (int i = 0; i < *num_outputs; ++i) *(ndoutputs[i]) = res[i]; + for (int i = 0; i < *num_outputs; ++i) + *(ndoutputs[i]) = res[i]; if (*outputs == nullptr) { ret->ret_handles.clear(); @@ -2734,54 +2977,52 @@ int MXDatasetGetItems(DatasetHandle handle, API_END(); } -int MXListBatchifyFunctions(uint32_t *out_size, - BatchifyFunctionCreator **out_array) { +int MXListBatchifyFunctions(uint32_t* out_size, BatchifyFunctionCreator** out_array) { API_BEGIN(); - auto &vec = dmlc::Registry::List(); - *out_size = static_cast(vec.size()); + auto& vec = dmlc::Registry::List(); + *out_size = static_cast(vec.size()); *out_array = (BatchifyFunctionCreator*)(dmlc::BeginPtr(vec)); // NOLINT(*) API_END(); } int MXBatchifyFunctionCreateFunction(BatchifyFunctionCreator handle, uint32_t num_param, - const char **keys, - const char **vals, - BatchifyFunctionHandle *out) { - BatchifyFunction *bf = nullptr; + const char** keys, + const char** vals, + BatchifyFunctionHandle* out) { + BatchifyFunction* bf = nullptr; API_BEGIN(); - BatchifyFunctionReg *e = static_cast(handle); - std::vector > kwargs; + BatchifyFunctionReg* e = static_cast(handle); + std::vector> kwargs; for (uint32_t i = 0; i < num_param; ++i) { kwargs.emplace_back(std::string(keys[i]), std::string(vals[i])); } - bf = e->body(kwargs); + bf = e->body(kwargs); *out = new BatchifyFunctionPtr(bf); API_END_HANDLE_ERROR(delete bf); } int MXBatchifyFunctionGetFunctionInfo(BatchifyFunctionCreator creator, - const char **name, - const char **description, - uint32_t *num_args, - const char ***arg_names, - const char ***arg_type_infos, - const char ***arg_descriptions) { - BatchifyFunctionReg *e = static_cast(creator); - return MXAPIGetFunctionRegInfo(e, name, description, num_args, - arg_names, arg_type_infos, arg_descriptions, - nullptr); + const char** name, + const char** description, + uint32_t* num_args, + const char*** arg_names, + const char*** arg_type_infos, + const char*** arg_descriptions) { + BatchifyFunctionReg* e = static_cast(creator); + return MXAPIGetFunctionRegInfo( + e, name, description, num_args, arg_names, arg_type_infos, arg_descriptions, nullptr); } int MXBatchifyFunctionInvoke(BatchifyFunctionHandle handle, int batch_size, int num_output, - NDArrayHandle *inputs, - NDArrayHandle **outputs) { - MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get(); + NDArrayHandle* inputs, + NDArrayHandle** outputs) { + MXAPIThreadLocalEntry<>* ret = MXAPIThreadLocalStore<>::Get(); API_BEGIN(); CHECK_GT(batch_size, 0); CHECK_GT(num_output, 0); - std::vector > ndinputs; + std::vector> ndinputs; ndinputs.reserve(batch_size); int pos = 0; for (int i = 0; i < batch_size; ++i) { @@ -2795,22 +3036,23 @@ int MXBatchifyFunctionInvoke(BatchifyFunctionHandle handle, } std::vector res; CHECK((*static_cast(handle))->Batchify(ndinputs, &res)) - << "Error call batchify with " << ndinputs.size() << " inputs"; + << "Error call batchify with " << ndinputs.size() << " inputs"; std::vector ndoutputs; ndoutputs.reserve(res.size()); if (*outputs == nullptr) { - for (int i = 0; i < num_output; ++i) ndoutputs.push_back(new NDArray()); + for (int i = 0; i < num_output; ++i) + ndoutputs.push_back(new NDArray()); } else { - CHECK_EQ(num_output, res.size()) - << "MXBatchifyFunctionInvoke expects " << res.size() << " outputs, but " - << num_output << " was given."; + CHECK_EQ(num_output, res.size()) << "MXBatchifyFunctionInvoke expects " << res.size() + << " outputs, but " << num_output << " was given."; for (int i = 0; i < num_output; ++i) { ndoutputs.push_back(reinterpret_cast((*outputs)[i])); } } // copy ndarrays - for (int i = 0; i < num_output; ++i) *(ndoutputs[i]) = res[i]; + for (int i = 0; i < num_output; ++i) + *(ndoutputs[i]) = res[i]; if (*outputs == nullptr) { ret->ret_handles.clear(); @@ -2832,20 +3074,21 @@ int MXBatchifyFunctionFree(BatchifyFunctionHandle handle) { // Part 6: basic KVStore interface //-------------------------------------------- -int MXKVStoreCreate(const char *type, - KVStoreHandle *out) { +int MXKVStoreCreate(const char* type, KVStoreHandle* out) { API_BEGIN(); *out = KVStore::Create(type); API_END(); } -int MXKVStoreSetGradientCompression(KVStoreHandle handle, uint32_t num_params, - const char** keys, const char** vals) { +int MXKVStoreSetGradientCompression(KVStoreHandle handle, + uint32_t num_params, + const char** keys, + const char** vals) { API_BEGIN(); - std::vector > params; + std::vector> params; for (uint32_t i = 0; i < num_params; ++i) { std::pair p; - p.first = keys[i]; + p.first = keys[i]; p.second = vals[i]; params.push_back(p); } @@ -2859,10 +3102,7 @@ int MXKVStoreFree(KVStoreHandle handle) { API_END(); } -int MXKVStoreInit(KVStoreHandle handle, - uint32_t num, - const int* keys, - NDArrayHandle* vals) { +int MXKVStoreInit(KVStoreHandle handle, uint32_t num, const int* keys, NDArrayHandle* vals) { API_BEGIN(); std::vector v_keys(num); std::vector v_vals(num); @@ -2874,10 +3114,7 @@ int MXKVStoreInit(KVStoreHandle handle, API_END(); } -int MXKVStoreInitEx(KVStoreHandle handle, - uint32_t num, - const char** keys, - NDArrayHandle* vals) { +int MXKVStoreInitEx(KVStoreHandle handle, uint32_t num, const char** keys, NDArrayHandle* vals) { API_BEGIN(); std::vector v_keys(num); std::vector v_vals(num); @@ -2906,10 +3143,10 @@ int MXKVStorePush(KVStoreHandle handle, } int MXKVStorePushEx(KVStoreHandle handle, - uint32_t num, - const char** keys, - NDArrayHandle* vals, - int priority) { + uint32_t num, + const char** keys, + NDArrayHandle* vals, + int priority) { API_BEGIN(); std::vector v_keys(num); std::vector v_vals(num); @@ -2968,14 +3205,13 @@ int MXKVStoreBroadcast(KVStoreHandle handle, std::vector v_outs(onum); for (mx_uint i = 0; i < vnum; ++i) { v_vkeys[i] = vkeys[i]; - v_vals[i] = *static_cast(vals[i]); + v_vals[i] = *static_cast(vals[i]); } for (mx_uint i = 0; i < onum; ++i) { v_okeys[i] = okeys[i]; - v_outs[i] = static_cast(outs[i]); + v_outs[i] = static_cast(outs[i]); } - static_cast(handle)->Broadcast(v_vkeys, v_okeys, v_vals, v_outs, - priority); + static_cast(handle)->Broadcast(v_vkeys, v_okeys, v_vals, v_outs, priority); API_END(); } @@ -2994,14 +3230,13 @@ int MXKVStoreBroadcastEx(KVStoreHandle handle, std::vector v_outs(onum); for (mx_uint i = 0; i < vnum; ++i) { v_vkeys[i] = vkeys[i]; - v_vals[i] = *static_cast(vals[i]); + v_vals[i] = *static_cast(vals[i]); } for (mx_uint i = 0; i < onum; ++i) { v_okeys[i] = okeys[i]; - v_outs[i] = static_cast(outs[i]); + v_outs[i] = static_cast(outs[i]); } - static_cast(handle)->Broadcast(v_vkeys, v_okeys, v_vals, v_outs, - priority); + static_cast(handle)->Broadcast(v_vkeys, v_okeys, v_vals, v_outs, priority); API_END(); } @@ -3020,14 +3255,13 @@ int MXKVStorePushPull(KVStoreHandle handle, std::vector v_outs(onum); for (mx_uint i = 0; i < vnum; ++i) { v_vkeys[i] = vkeys[i]; - v_vals[i] = *static_cast(vals[i]); + v_vals[i] = *static_cast(vals[i]); } for (mx_uint i = 0; i < onum; ++i) { v_okeys[i] = okeys[i]; - v_outs[i] = static_cast(outs[i]); + v_outs[i] = static_cast(outs[i]); } - static_cast(handle)->PushPull(v_vkeys, v_okeys, v_vals, v_outs, - priority); + static_cast(handle)->PushPull(v_vkeys, v_okeys, v_vals, v_outs, priority); API_END(); } @@ -3046,14 +3280,13 @@ int MXKVStorePushPullEx(KVStoreHandle handle, std::vector v_outs(onum); for (mx_uint i = 0; i < vnum; ++i) { v_vkeys[i] = vkeys[i]; - v_vals[i] = *static_cast(vals[i]); + v_vals[i] = *static_cast(vals[i]); } for (mx_uint i = 0; i < onum; ++i) { v_okeys[i] = okeys[i]; - v_outs[i] = static_cast(outs[i]); + v_outs[i] = static_cast(outs[i]); } - static_cast(handle)->PushPull(v_vkeys, v_okeys, v_vals, v_outs, - priority); + static_cast(handle)->PushPull(v_vkeys, v_okeys, v_vals, v_outs, priority); API_END(); } @@ -3102,8 +3335,8 @@ int MXKVStorePullRowSparse(KVStoreHandle handle, std::vector> v_val_rowids(num); for (uint32_t i = 0; i < num; ++i) { v_keys[i] = keys[i]; - v_val_rowids[i] = std::make_pair(static_cast(vals[i]), - *static_cast(row_ids[i])); + v_val_rowids[i] = + std::make_pair(static_cast(vals[i]), *static_cast(row_ids[i])); } static_cast(handle)->PullRowSparse(v_keys, v_val_rowids, priority); API_END(); @@ -3120,32 +3353,28 @@ int MXKVStorePullRowSparseEx(KVStoreHandle handle, std::vector> v_val_rowids(num); for (uint32_t i = 0; i < num; ++i) { v_keys[i] = keys[i]; - v_val_rowids[i] = std::make_pair(static_cast(vals[i]), - *static_cast(row_ids[i])); + v_val_rowids[i] = + std::make_pair(static_cast(vals[i]), *static_cast(row_ids[i])); } static_cast(handle)->PullRowSparse(v_keys, v_val_rowids, priority); API_END(); } -void MXKVStoreSetUpdaterImpl(KVStoreHandle handle, - MXKVStoreUpdater updater, - void* updater_handle) { - MXKVStoreUpdater * updater_temp = updater; - void* updater_handle_temp = updater_handle; - std::function updt - = [updater_temp, updater_handle_temp](int key, const NDArray& recv, NDArray* local) { - NDArray* recv_copy = new NDArray(); - *recv_copy = recv; - NDArray* local_copy = new NDArray(); - *local_copy = *local; - updater_temp(key, recv_copy, local_copy, updater_handle_temp); - }; +void MXKVStoreSetUpdaterImpl(KVStoreHandle handle, MXKVStoreUpdater updater, void* updater_handle) { + MXKVStoreUpdater* updater_temp = updater; + void* updater_handle_temp = updater_handle; + std::function updt = + [updater_temp, updater_handle_temp](int key, const NDArray& recv, NDArray* local) { + NDArray* recv_copy = new NDArray(); + *recv_copy = recv; + NDArray* local_copy = new NDArray(); + *local_copy = *local; + updater_temp(key, recv_copy, local_copy, updater_handle_temp); + }; static_cast(handle)->set_updater(updt); } -int MXKVStoreSetUpdater(KVStoreHandle handle, - MXKVStoreUpdater updater, - void* updater_handle) { +int MXKVStoreSetUpdater(KVStoreHandle handle, MXKVStoreUpdater updater, void* updater_handle) { API_BEGIN(); MXKVStoreSetUpdaterImpl(handle, updater, updater_handle); API_END(); @@ -3159,28 +3388,28 @@ int MXKVStoreSetUpdaterEx(KVStoreHandle handle, // set updater with int keys MXKVStoreSetUpdaterImpl(handle, updater, updater_handle); // set updater with string keys - MXKVStoreStrUpdater * updater_temp = str_updater; - void* updater_handle_temp = updater_handle; - std::function updt - = [updater_temp, updater_handle_temp] - (const std::string& key, const NDArray& recv, NDArray* local) { - NDArray* recv_copy = new NDArray(); - *recv_copy = recv; - NDArray* local_copy = new NDArray(); - *local_copy = *local; - updater_temp(key.c_str(), recv_copy, local_copy, updater_handle_temp); - }; + MXKVStoreStrUpdater* updater_temp = str_updater; + void* updater_handle_temp = updater_handle; + std::function updt = + [updater_temp, updater_handle_temp]( + const std::string& key, const NDArray& recv, NDArray* local) { + NDArray* recv_copy = new NDArray(); + *recv_copy = recv; + NDArray* local_copy = new NDArray(); + *local_copy = *local; + updater_temp(key.c_str(), recv_copy, local_copy, updater_handle_temp); + }; static_cast(handle)->set_updater(updt); API_END(); } -int MXKVStoreGetRank(KVStoreHandle handle, int *rank) { +int MXKVStoreGetRank(KVStoreHandle handle, int* rank) { API_BEGIN(); *rank = static_cast(handle)->get_rank(); API_END(); } -int MXKVStoreGetGroupSize(KVStoreHandle handle, int *size) { +int MXKVStoreGetGroupSize(KVStoreHandle handle, int* size) { API_BEGIN(); *size = static_cast(handle)->get_group_size(); API_END(); @@ -3192,16 +3421,13 @@ int MXKVStoreBarrier(KVStoreHandle handle) { API_END(); } -int MXKVStoreSetBarrierBeforeExit(KVStoreHandle handle, - const int barrier_before_exit) { +int MXKVStoreSetBarrierBeforeExit(KVStoreHandle handle, const int barrier_before_exit) { API_BEGIN(); static_cast(handle)->set_barrier_before_exit(barrier_before_exit); API_END(); } -int MXInitPSEnv(uint32_t num_vars, - const char **keys, - const char **vals) { +int MXInitPSEnv(uint32_t num_vars, const char** keys, const char** vals) { API_BEGIN(); std::unordered_map kwargs; for (uint32_t i = 0; i < num_vars; ++i) { @@ -3211,19 +3437,19 @@ int MXInitPSEnv(uint32_t num_vars, API_END(); } -int MXKVStoreIsWorkerNode(int *ret) { +int MXKVStoreIsWorkerNode(int* ret) { API_BEGIN(); *ret = KVStore::IsWorkerNode(); API_END(); } -int MXKVStoreIsServerNode(int *ret) { +int MXKVStoreIsServerNode(int* ret) { API_BEGIN(); *ret = KVStore::IsServerNode(); API_END(); } -int MXKVStoreIsSchedulerNode(int *ret) { +int MXKVStoreIsSchedulerNode(int* ret) { API_BEGIN(); *ret = KVStore::IsSchedulerNode(); API_END(); @@ -3231,28 +3457,24 @@ int MXKVStoreIsSchedulerNode(int *ret) { int MXKVStoreRunServer(KVStoreHandle handle, MXKVStoreServerController controller, - void *controller_handle) { + void* controller_handle) { API_BEGIN(); - MXKVStoreServerController *controller_temp = controller; - void *controller_handle_temp = controller_handle; + MXKVStoreServerController* controller_temp = controller; + void* controller_handle_temp = controller_handle; auto ctrl = [controller_temp, controller_handle_temp](int head, const std::string& body) { - controller_temp(head, body.c_str(), controller_handle_temp); + controller_temp(head, body.c_str(), controller_handle_temp); }; static_cast(handle)->RunServer(ctrl); API_END(); } -int MXKVStoreSendCommmandToServers(KVStoreHandle handle, - int cmd_id, - const char* cmd_body) { +int MXKVStoreSendCommmandToServers(KVStoreHandle handle, int cmd_id, const char* cmd_body) { API_BEGIN(); - static_cast(handle)->SendCommandToServers( - cmd_id, std::string(cmd_body)); + static_cast(handle)->SendCommandToServers(cmd_id, std::string(cmd_body)); API_END(); } -int MXKVStoreGetType(KVStoreHandle handle, - const char** type) { +int MXKVStoreGetType(KVStoreHandle handle, const char** type) { API_BEGIN(); *CHECK_NOTNULL(type) = static_cast(handle)->type().c_str(); API_END(); @@ -3260,7 +3482,7 @@ int MXKVStoreGetType(KVStoreHandle handle, int MXKVStoreGetNumDeadNode(KVStoreHandle handle, const int node_id, - int *number, + int* number, const int timeout_sec) { API_BEGIN(); *number = static_cast(handle)->get_num_dead_node(node_id, timeout_sec); @@ -3268,69 +3490,62 @@ int MXKVStoreGetNumDeadNode(KVStoreHandle handle, } struct MXRecordIOContext { - dmlc::RecordIOWriter *writer; - dmlc::RecordIOReader *reader; - dmlc::Stream *stream; - std::string *read_buff; + dmlc::RecordIOWriter* writer; + dmlc::RecordIOReader* reader; + dmlc::Stream* stream; + std::string* read_buff; }; -int MXRecordIOWriterCreate(const char *uri, - RecordIOHandle *out) { +int MXRecordIOWriterCreate(const char* uri, RecordIOHandle* out) { API_BEGIN(); - dmlc::Stream *stream = dmlc::Stream::Create(uri, "w"); - MXRecordIOContext *context = new MXRecordIOContext; - context->writer = new dmlc::RecordIOWriter(stream); - context->reader = nullptr; - context->stream = stream; - context->read_buff = nullptr; - *out = reinterpret_cast(context); + dmlc::Stream* stream = dmlc::Stream::Create(uri, "w"); + MXRecordIOContext* context = new MXRecordIOContext; + context->writer = new dmlc::RecordIOWriter(stream); + context->reader = nullptr; + context->stream = stream; + context->read_buff = nullptr; + *out = reinterpret_cast(context); API_END(); } int MXRecordIOWriterFree(RecordIOHandle handle) { API_BEGIN(); - MXRecordIOContext *context = - reinterpret_cast(handle); + MXRecordIOContext* context = reinterpret_cast(handle); delete context->writer; delete context->stream; delete context; API_END(); } -int MXRecordIOWriterWriteRecord(RecordIOHandle handle, - const char *buf, size_t size) { +int MXRecordIOWriterWriteRecord(RecordIOHandle handle, const char* buf, size_t size) { API_BEGIN(); - MXRecordIOContext *context = - reinterpret_cast(handle); + MXRecordIOContext* context = reinterpret_cast(handle); context->writer->WriteRecord(reinterpret_cast(buf), size); API_END(); } -int MXRecordIOWriterTell(RecordIOHandle handle, size_t *pos) { +int MXRecordIOWriterTell(RecordIOHandle handle, size_t* pos) { API_BEGIN(); - MXRecordIOContext *context = - reinterpret_cast(handle); - *pos = context->writer->Tell(); + MXRecordIOContext* context = reinterpret_cast(handle); + *pos = context->writer->Tell(); API_END(); } -int MXRecordIOReaderCreate(const char *uri, - RecordIOHandle *out) { +int MXRecordIOReaderCreate(const char* uri, RecordIOHandle* out) { API_BEGIN(); - dmlc::Stream *stream = dmlc::Stream::Create(uri, "r"); - MXRecordIOContext *context = new MXRecordIOContext; - context->reader = new dmlc::RecordIOReader(stream); - context->writer = nullptr; - context->stream = stream; - context->read_buff = new std::string(); - *out = reinterpret_cast(context); + dmlc::Stream* stream = dmlc::Stream::Create(uri, "r"); + MXRecordIOContext* context = new MXRecordIOContext; + context->reader = new dmlc::RecordIOReader(stream); + context->writer = nullptr; + context->stream = stream; + context->read_buff = new std::string(); + *out = reinterpret_cast(context); API_END(); } int MXRecordIOReaderFree(RecordIOHandle handle) { API_BEGIN(); - MXRecordIOContext *context = - reinterpret_cast(handle); + MXRecordIOContext* context = reinterpret_cast(handle); delete context->reader; delete context->stream; delete context->read_buff; @@ -3338,16 +3553,14 @@ int MXRecordIOReaderFree(RecordIOHandle handle) { API_END(); } -int MXRecordIOReaderReadRecord(RecordIOHandle handle, - char const **buf, size_t *size) { +int MXRecordIOReaderReadRecord(RecordIOHandle handle, char const** buf, size_t* size) { API_BEGIN(); - MXRecordIOContext *context = - reinterpret_cast(handle); + MXRecordIOContext* context = reinterpret_cast(handle); if (context->reader->NextRecord(context->read_buff)) { - *buf = context->read_buff->c_str(); + *buf = context->read_buff->c_str(); *size = context->read_buff->size(); } else { - *buf = nullptr; + *buf = nullptr; *size = 0; } API_END(); @@ -3355,31 +3568,37 @@ int MXRecordIOReaderReadRecord(RecordIOHandle handle, int MXRecordIOReaderSeek(RecordIOHandle handle, size_t pos) { API_BEGIN(); - MXRecordIOContext *context = - reinterpret_cast(handle); + MXRecordIOContext* context = reinterpret_cast(handle); context->reader->Seek(pos); API_END(); } -int MXRecordIOReaderTell(RecordIOHandle handle, size_t *pos) { +int MXRecordIOReaderTell(RecordIOHandle handle, size_t* pos) { API_BEGIN(); - MXRecordIOContext *context = - reinterpret_cast(handle); - *pos = context->reader->Tell(); + MXRecordIOContext* context = reinterpret_cast(handle); + *pos = context->reader->Tell(); API_END(); } -int MXRtcCreate(char* name, uint32_t num_input, uint32_t num_output, - char** input_names, char** output_names, - NDArrayHandle* inputs, NDArrayHandle* outputs, - char* kernel, RtcHandle *out) { +int MXRtcCreate(char* name, + uint32_t num_input, + uint32_t num_output, + char** input_names, + char** output_names, + NDArrayHandle* inputs, + NDArrayHandle* outputs, + char* kernel, + RtcHandle* out) { API_BEGIN(); LOG(FATAL) << "Old rtc API is deprecated. Please use CudaModule"; API_END(); } -int MXRtcPush(RtcHandle handle, uint32_t num_input, uint32_t num_output, - NDArrayHandle* inputs, NDArrayHandle* outputs, +int MXRtcPush(RtcHandle handle, + uint32_t num_input, + uint32_t num_output, + NDArrayHandle* inputs, + NDArrayHandle* outputs, uint32_t gridDimX, uint32_t gridDimY, uint32_t gridDimZ, @@ -3403,16 +3622,20 @@ int MXCustomOpRegister(const char* op_type, CustomOpPropCreator creator) { API_END(); } - -int MXRtcCudaModuleCreate(const char* source, int num_options, - const char** options, int num_exports, - const char** exports, CudaModuleHandle *out) { +int MXRtcCudaModuleCreate(const char* source, + int num_options, + const char** options, + int num_exports, + const char** exports, + CudaModuleHandle* out) { API_BEGIN(); #if MXNET_USE_CUDA std::vector str_opts; - for (int i = 0; i < num_options; ++i) str_opts.emplace_back(options[i]); + for (int i = 0; i < num_options; ++i) + str_opts.emplace_back(options[i]); std::vector str_exports; - for (int i = 0; i < num_exports; ++i) str_exports.emplace_back(exports[i]); + for (int i = 0; i < num_exports; ++i) + str_exports.emplace_back(exports[i]); *out = new rtc::CudaModule(source, str_opts, str_exports); #else LOG(FATAL) << "Compile with USE_CUDA=1 to have CUDA runtime compilation."; @@ -3430,20 +3653,24 @@ int MXRtcCudaModuleFree(CudaModuleHandle handle) { API_END(); } -int MXRtcCudaKernelCreate(CudaModuleHandle handle, const char* name, int num_args, - int* is_ndarray, int* is_const, int* arg_types, - CudaKernelHandle *out) { +int MXRtcCudaKernelCreate(CudaModuleHandle handle, + const char* name, + int num_args, + int* is_ndarray, + int* is_const, + int* arg_types, + CudaKernelHandle* out) { API_BEGIN(); #if MXNET_USE_CUDA auto module = reinterpret_cast(handle); std::vector signature; for (int i = 0; i < num_args; ++i) { - signature.push_back(rtc::CudaModule::ArgType{ - static_cast(is_ndarray[i]), static_cast(is_const[i]), - static_cast(arg_types[i])}); + signature.push_back(rtc::CudaModule::ArgType{static_cast(is_ndarray[i]), + static_cast(is_const[i]), + static_cast(arg_types[i])}); } auto kernel = module->GetKernel(name, signature); - *out = new std::shared_ptr(kernel); + *out = new std::shared_ptr(kernel); #else LOG(FATAL) << "Compile with USE_CUDA=1 to have CUDA runtime compilation."; #endif @@ -3460,27 +3687,38 @@ int MXRtcCudaKernelFree(CudaKernelHandle handle) { API_END(); } -int MXRtcCudaKernelCall(CudaKernelHandle handle, int dev_id, void** args, - uint32_t grid_dim_x, uint32_t grid_dim_y, - uint32_t grid_dim_z, uint32_t block_dim_x, - uint32_t block_dim_y, uint32_t block_dim_z, +int MXRtcCudaKernelCall(CudaKernelHandle handle, + int dev_id, + void** args, + uint32_t grid_dim_x, + uint32_t grid_dim_y, + uint32_t grid_dim_z, + uint32_t block_dim_x, + uint32_t block_dim_y, + uint32_t block_dim_z, uint32_t shared_mem) { API_BEGIN(); #if MXNET_USE_CUDA - auto kernel = reinterpret_cast*>(handle); + auto kernel = reinterpret_cast*>(handle); const auto& signature = (*kernel)->signature(); std::vector any_args; for (size_t i = 0; i < signature.size(); ++i) { if (signature[i].is_ndarray) { any_args.emplace_back(*static_cast(args[i])); } else { - MSHADOW_TYPE_SWITCH(signature[i].dtype, DType, { - any_args.emplace_back(*static_cast(args[i])); - }); + MSHADOW_TYPE_SWITCH( + signature[i].dtype, DType, { any_args.emplace_back(*static_cast(args[i])); }); } } - (*kernel)->Launch(Context::GPU(dev_id), any_args, grid_dim_x, grid_dim_y, - grid_dim_z, block_dim_x, block_dim_y, block_dim_z, shared_mem); + (*kernel)->Launch(Context::GPU(dev_id), + any_args, + grid_dim_x, + grid_dim_y, + grid_dim_z, + block_dim_x, + block_dim_y, + block_dim_z, + shared_mem); #else LOG(FATAL) << "Compile with USE_CUDA=1 to have CUDA runtime compilation."; #endif @@ -3503,12 +3741,16 @@ int MXNDArrayGetSharedMemHandle(NDArrayHandle handle, int* shared_pid, int* shar Storage::Get()->SharedIncrementRefCount(shandle); } *shared_pid = shandle.shared_pid; - *shared_id = shandle.shared_id; + *shared_id = shandle.shared_id; API_END(); } -int MXNDArrayCreateFromSharedMem(int shared_pid, int shared_id, const int *shape, - int ndim, int dtype, NDArrayHandle *out) { +int MXNDArrayCreateFromSharedMem(int shared_pid, + int shared_id, + const int* shape, + int ndim, + int dtype, + NDArrayHandle* out) { API_BEGIN(); NDArray* nd = new NDArray(shared_pid, shared_id, mxnet::TShape(shape, shape + ndim), dtype); nd->AssignStorageInfo(profiler::ProfilerScope::Get()->GetCurrentProfilerScope(), @@ -3517,7 +3759,7 @@ int MXNDArrayCreateFromSharedMem(int shared_pid, int shared_id, const int *shape API_END(); } -using VarHandle = Engine::VarHandle; +using VarHandle = Engine::VarHandle; using CallbackOnComplete = Engine::CallbackOnComplete; void AssertValidNumberVars(int num_const_vars, int num_mutable_vars) { @@ -3525,34 +3767,38 @@ void AssertValidNumberVars(int num_const_vars, int num_mutable_vars) { CHECK_GE(num_mutable_vars, 0) << "Non-negative number of mutable vars expected."; } -int MXEnginePushAsync(EngineAsyncFunc async_func, void* func_param, - EngineFuncParamDeleter deleter, ContextHandle ctx_handle, - EngineVarHandle const_vars_handle, int num_const_vars, - EngineVarHandle mutable_vars_handle, int num_mutable_vars, - EngineFnPropertyHandle prop_handle, int priority, - const char* opr_name, bool wait) { +int MXEnginePushAsync(EngineAsyncFunc async_func, + void* func_param, + EngineFuncParamDeleter deleter, + ContextHandle ctx_handle, + EngineVarHandle const_vars_handle, + int num_const_vars, + EngineVarHandle mutable_vars_handle, + int num_mutable_vars, + EngineFnPropertyHandle prop_handle, + int priority, + const char* opr_name, + bool wait) { API_BEGIN(); - auto exec_ctx = *static_cast(ctx_handle); - auto const_vars = static_cast(const_vars_handle); + auto exec_ctx = *static_cast(ctx_handle); + auto const_vars = static_cast(const_vars_handle); auto mutable_vars = static_cast(mutable_vars_handle); - auto prop = FnProperty::kNormal; + auto prop = FnProperty::kNormal; if (prop_handle) { prop = *static_cast(prop_handle); } Engine::AsyncFn exec_fn; if (deleter == nullptr) { - exec_fn = [async_func, func_param](RunContext rctx, - CallbackOnComplete on_complete) { + exec_fn = [async_func, func_param](RunContext rctx, CallbackOnComplete on_complete) { async_func(&rctx, &on_complete, func_param); }; } else { // Wrap func_param in a shared_ptr with deleter such that deleter // will be called when the lambda goes out of scope. std::shared_ptr shared_func_param(func_param, deleter); - exec_fn = [async_func, shared_func_param](RunContext rctx, - CallbackOnComplete on_complete) { + exec_fn = [async_func, shared_func_param](RunContext rctx, CallbackOnComplete on_complete) { async_func(&rctx, &on_complete, shared_func_param.get()); }; } @@ -3560,33 +3806,36 @@ int MXEnginePushAsync(EngineAsyncFunc async_func, void* func_param, AssertValidNumberVars(num_const_vars, num_mutable_vars); std::vector const_var_vec(const_vars, const_vars + num_const_vars); std::vector mutable_var_vec(mutable_vars, mutable_vars + num_mutable_vars); - Engine::Get()->PushAsync(exec_fn, exec_ctx, const_var_vec, mutable_var_vec, - prop, priority, opr_name, wait); + Engine::Get()->PushAsync( + exec_fn, exec_ctx, const_var_vec, mutable_var_vec, prop, priority, opr_name, wait); API_END(); } -int MXEnginePushSync(EngineSyncFunc sync_func, void* func_param, - EngineFuncParamDeleter deleter, ContextHandle ctx_handle, - EngineVarHandle const_vars_handle, int num_const_vars, - EngineVarHandle mutable_vars_handle, int num_mutable_vars, - EngineFnPropertyHandle prop_handle, int priority, +int MXEnginePushSync(EngineSyncFunc sync_func, + void* func_param, + EngineFuncParamDeleter deleter, + ContextHandle ctx_handle, + EngineVarHandle const_vars_handle, + int num_const_vars, + EngineVarHandle mutable_vars_handle, + int num_mutable_vars, + EngineFnPropertyHandle prop_handle, + int priority, const char* opr_name) { API_BEGIN(); - auto exec_ctx = *static_cast(ctx_handle); - auto const_vars = static_cast(const_vars_handle); + auto exec_ctx = *static_cast(ctx_handle); + auto const_vars = static_cast(const_vars_handle); auto mutable_vars = static_cast(mutable_vars_handle); - auto prop = FnProperty::kNormal; + auto prop = FnProperty::kNormal; if (prop_handle) { prop = *static_cast(prop_handle); } Engine::SyncFn exec_fn; if (deleter == nullptr) { - exec_fn = [sync_func, func_param](RunContext rctx) { - sync_func(&rctx, func_param); - }; + exec_fn = [sync_func, func_param](RunContext rctx) { sync_func(&rctx, func_param); }; } else { // Wrap func_param in a shared_ptr with deleter such that deleter // will be called when the lambda goes out of scope. @@ -3599,49 +3848,79 @@ int MXEnginePushSync(EngineSyncFunc sync_func, void* func_param, AssertValidNumberVars(num_const_vars, num_mutable_vars); std::vector const_var_vec(const_vars, const_vars + num_const_vars); std::vector mutable_var_vec(mutable_vars, mutable_vars + num_mutable_vars); - Engine::Get()->PushSync(exec_fn, exec_ctx, const_var_vec, mutable_var_vec, - prop, priority, opr_name); + Engine::Get()->PushSync( + exec_fn, exec_ctx, const_var_vec, mutable_var_vec, prop, priority, opr_name); API_END(); } -int MXEnginePushAsyncND(EngineAsyncFunc async_func, void* func_param, - EngineFuncParamDeleter deleter, ContextHandle ctx_handle, - NDArrayHandle* const_nds_handle, int num_const_nds, - NDArrayHandle* mutable_nds_handle, int num_mutable_nds, - EngineFnPropertyHandle prop_handle, int priority, - const char* opr_name, bool wait) { +int MXEnginePushAsyncND(EngineAsyncFunc async_func, + void* func_param, + EngineFuncParamDeleter deleter, + ContextHandle ctx_handle, + NDArrayHandle* const_nds_handle, + int num_const_nds, + NDArrayHandle* mutable_nds_handle, + int num_mutable_nds, + EngineFnPropertyHandle prop_handle, + int priority, + const char* opr_name, + bool wait) { API_BEGIN(); - NDArray** const_nds = reinterpret_cast(const_nds_handle); + NDArray** const_nds = reinterpret_cast(const_nds_handle); NDArray** mutable_nds = reinterpret_cast(mutable_nds_handle); std::vector const_var_vec(num_const_nds); - for (int i = 0; i < num_const_nds; ++i) const_var_vec[i] = const_nds[i]->var(); + for (int i = 0; i < num_const_nds; ++i) + const_var_vec[i] = const_nds[i]->var(); std::vector mutable_var_vec(num_mutable_nds); - for (int i = 0; i < num_mutable_nds; ++i) mutable_var_vec[i] = mutable_nds[i]->var(); - return MXEnginePushAsync(async_func, func_param, deleter, ctx_handle, - const_var_vec.data(), num_const_nds, - mutable_var_vec.data(), num_mutable_nds, - prop_handle, priority, opr_name, wait); - API_END(); -} - -int MXEnginePushSyncND(EngineSyncFunc sync_func, void* func_param, - EngineFuncParamDeleter deleter, ContextHandle ctx_handle, - NDArrayHandle* const_nds_handle, int num_const_nds, - NDArrayHandle* mutable_nds_handle, int num_mutable_nds, - EngineFnPropertyHandle prop_handle, int priority, + for (int i = 0; i < num_mutable_nds; ++i) + mutable_var_vec[i] = mutable_nds[i]->var(); + return MXEnginePushAsync(async_func, + func_param, + deleter, + ctx_handle, + const_var_vec.data(), + num_const_nds, + mutable_var_vec.data(), + num_mutable_nds, + prop_handle, + priority, + opr_name, + wait); + API_END(); +} + +int MXEnginePushSyncND(EngineSyncFunc sync_func, + void* func_param, + EngineFuncParamDeleter deleter, + ContextHandle ctx_handle, + NDArrayHandle* const_nds_handle, + int num_const_nds, + NDArrayHandle* mutable_nds_handle, + int num_mutable_nds, + EngineFnPropertyHandle prop_handle, + int priority, const char* opr_name) { API_BEGIN(); - NDArray** const_nds = reinterpret_cast(const_nds_handle); + NDArray** const_nds = reinterpret_cast(const_nds_handle); NDArray** mutable_nds = reinterpret_cast(mutable_nds_handle); std::vector const_var_vec(num_const_nds); - for (int i = 0; i < num_const_nds; ++i) const_var_vec[i] = const_nds[i]->var(); + for (int i = 0; i < num_const_nds; ++i) + const_var_vec[i] = const_nds[i]->var(); std::vector mutable_var_vec(num_mutable_nds); - for (int i = 0; i < num_mutable_nds; ++i) mutable_var_vec[i] = mutable_nds[i]->var(); - return MXEnginePushSync(sync_func, func_param, deleter, ctx_handle, - const_var_vec.data(), num_const_nds, - mutable_var_vec.data(), num_mutable_nds, - prop_handle, priority, opr_name); + for (int i = 0; i < num_mutable_nds; ++i) + mutable_var_vec[i] = mutable_nds[i]->var(); + return MXEnginePushSync(sync_func, + func_param, + deleter, + ctx_handle, + const_var_vec.data(), + num_const_nds, + mutable_var_vec.data(), + num_mutable_nds, + prop_handle, + priority, + opr_name); API_END(); } @@ -3656,7 +3935,7 @@ int MXShallowCopyNDArray(NDArrayHandle src_handle, NDArrayHandle* out) { NDArray* ret = nullptr; API_BEGIN(); NDArray* src_array = static_cast(src_handle); - ret = new NDArray(*src_array); - *out = ret; + ret = new NDArray(*src_array); + *out = ret; API_END_HANDLE_ERROR(delete ret); } diff --git a/src/c_api/c_api_common.h b/src/c_api/c_api_common.h index 7f673ee871b2..781b51f195c3 100644 --- a/src/c_api/c_api_common.h +++ b/src/c_api/c_api_common.h @@ -57,16 +57,16 @@ using namespace mxnet; /*! \brief entry to to easily hold returning information */ -template +template struct MXAPIThreadLocalEntry { /*! \brief result holder for returning string */ std::string ret_str; /*! \brief result holder for returning strings */ std::vector ret_vec_str; /*! \brief result holder for returning string pointers */ - std::vector ret_vec_charp; + std::vector ret_vec_charp; /*! \brief result holder for returning handles */ - std::vector ret_handles; + std::vector ret_handles; /*! \brief holder for NDArray handles */ std::vector ndinputs, ndoutputs; /*! \brief result holder for returning shapes */ @@ -91,29 +91,28 @@ struct MXAPIThreadLocalEntry { std::vector save_inputs, save_outputs; // DEPRECATED. Use SetupShapeArrayReturnWithBufferEx instead. // helper function to setup return value of shape array - inline static void SetupShapeArrayReturnWithBuffer( - const mxnet::ShapeVector &shapes, - std::vector *ndim, - std::vector *data, - std::vector *buffer) { + inline static void SetupShapeArrayReturnWithBuffer(const mxnet::ShapeVector& shapes, + std::vector* ndim, + std::vector* data, + std::vector* buffer) { ndim->resize(shapes.size()); data->resize(shapes.size()); size_t size = 0; - for (const auto& s : shapes) size += s.ndim(); + for (const auto& s : shapes) + size += s.ndim(); buffer->resize(size); - uint32_t *ptr = buffer->data(); + uint32_t* ptr = buffer->data(); for (size_t i = 0; i < shapes.size(); ++i) { ndim->at(i) = shapes[i].ndim(); data->at(i) = ptr; - ptr = nnvm::ShapeTypeCast(shapes[i].begin(), shapes[i].end(), ptr); + ptr = nnvm::ShapeTypeCast(shapes[i].begin(), shapes[i].end(), ptr); } } // helper function to setup return value of shape array - inline static void SetupShapeArrayReturnWithBufferEx( - const mxnet::ShapeVector &shapes, - std::vector *ndim, - std::vector *data, - std::vector *buffer) { + inline static void SetupShapeArrayReturnWithBufferEx(const mxnet::ShapeVector& shapes, + std::vector* ndim, + std::vector* data, + std::vector* buffer) { ndim->resize(shapes.size()); data->resize(shapes.size()); size_t size = 0; @@ -135,12 +134,12 @@ struct MXAPIThreadLocalEntry { }; // define the threadlocal store. -template +template using MXAPIThreadLocalStore = dmlc::ThreadLocalStore>; namespace mxnet { // copy attributes from inferred vector back to the vector of each type. -template +template inline void CopyAttr(const nnvm::IndexedGraph& idx, const std::vector& attr_vec, std::vector* in_attr, diff --git a/src/c_api/c_api_function.cc b/src/c_api/c_api_function.cc index f1dd8d98eaa1..2e3e84e8491a 100644 --- a/src/c_api/c_api_function.cc +++ b/src/c_api/c_api_function.cc @@ -21,7 +21,7 @@ * \file custom.cc * \brief * \author Junyuan Xie -*/ + */ #include #include #include @@ -41,15 +41,14 @@ struct CustomFunctionParam { std::vector out_dtypes; }; -std::vector Gradient( - const nnvm::ObjectPtr& n, - const std::vector& out_grads) { +std::vector Gradient(const nnvm::ObjectPtr& n, + const std::vector& out_grads) { const CustomFunctionParam& params = nnvm::get(n->attrs.parsed); nnvm::ObjectPtr g = nnvm::Node::Create(); - g->attrs.op = nnvm::Op::Get("_backward_CustomFunction"); - g->attrs.name = n->attrs.name + "_backward"; - g->attrs.parsed = params; + g->attrs.op = nnvm::Op::Get("_backward_CustomFunction"); + g->attrs.name = n->attrs.name + "_backward"; + g->attrs.parsed = params; g->control_deps.emplace_back(n); g->inputs = out_grads; @@ -107,19 +106,30 @@ void Backward(const OpStatePtr& state, } op::custom::CustomOperator::Get()->Push( - [=]() { - CHECK(reinterpret_cast( - params.info->callbacks[kCustomFunctionBackward])( - inputs.size(), outputs.size(), - const_cast(ptrs.data()), - reinterpret_cast(req.data()), ctx.is_train, - params.info->contexts[kCustomFunctionBackward])); - }, ctx, false, ctx.is_train, cpys, tags, output_tags, outputs); + [=]() { + CHECK(reinterpret_cast( + params.info->callbacks[kCustomFunctionBackward])( + inputs.size(), + outputs.size(), + const_cast(ptrs.data()), + reinterpret_cast(req.data()), + ctx.is_train, + params.info->contexts[kCustomFunctionBackward])); + }, + ctx, + false, + ctx.is_train, + cpys, + tags, + output_tags, + outputs); } -inline bool InferStorageType(const nnvm::NodeAttrs& attrs, const int dev_mask, +inline bool InferStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, DispatchMode* dispatch_mode, - std::vector* iattr, std::vector* oattr) { + std::vector* iattr, + std::vector* oattr) { using namespace op; for (size_t i = 0; i < iattr->size(); ++i) { @@ -133,71 +143,70 @@ inline bool InferStorageType(const nnvm::NodeAttrs& attrs, const int dev_mask, } NNVM_REGISTER_OP(_CustomFunction) -.set_num_inputs([](const NodeAttrs& attrs) { - const CustomFunctionParam& params = nnvm::get(attrs.parsed); - return params.num_args; - }) -.set_num_outputs([](const NodeAttrs& attrs) { - const CustomFunctionParam& params = nnvm::get(attrs.parsed); - return params.num_outs; - }) -.set_attr("FInferShape", - [](const NodeAttrs& attrs, mxnet::ShapeVector *in_shape, - mxnet::ShapeVector *out_shape) { - const CustomFunctionParam& params = nnvm::get(attrs.parsed); - *out_shape = params.out_shapes; - return true; - }) -.set_attr("FInferType", - [](const NodeAttrs& attrs, std::vector *in_type, - std::vector *out_type) { - const CustomFunctionParam& params = nnvm::get(attrs.parsed); - *out_type = params.out_dtypes; - return true; - }) -.set_attr("FCreateOpState", CreateState) -.set_attr("FGradient", Gradient) -.set_attr("FStatefulComputeEx", Forward) -.set_attr("FStatefulComputeEx", Forward) -.set_attr("FInferStorageType", InferStorageType); - + .set_num_inputs([](const NodeAttrs& attrs) { + const CustomFunctionParam& params = nnvm::get(attrs.parsed); + return params.num_args; + }) + .set_num_outputs([](const NodeAttrs& attrs) { + const CustomFunctionParam& params = nnvm::get(attrs.parsed); + return params.num_outs; + }) + .set_attr( + "FInferShape", + [](const NodeAttrs& attrs, mxnet::ShapeVector* in_shape, mxnet::ShapeVector* out_shape) { + const CustomFunctionParam& params = nnvm::get(attrs.parsed); + *out_shape = params.out_shapes; + return true; + }) + .set_attr( + "FInferType", + [](const NodeAttrs& attrs, std::vector* in_type, std::vector* out_type) { + const CustomFunctionParam& params = nnvm::get(attrs.parsed); + *out_type = params.out_dtypes; + return true; + }) + .set_attr("FCreateOpState", CreateState) + .set_attr("FGradient", Gradient) + .set_attr("FStatefulComputeEx", Forward) + .set_attr("FStatefulComputeEx", Forward) + .set_attr("FInferStorageType", InferStorageType); NNVM_REGISTER_OP(_backward_CustomFunction) -.set_num_inputs([](const NodeAttrs& attrs) { - const CustomFunctionParam& params = nnvm::get(attrs.parsed); - return params.num_outs; - }) -.set_num_outputs([](const NodeAttrs& attrs) { - const CustomFunctionParam& params = nnvm::get(attrs.parsed); - return params.num_args; - }) -.set_attr("TIsBackward", true) -.set_attr("TIsLayerOpBackward", true) -.set_attr("FExecType", [](const NodeAttrs& attrs) { - return ExecType::kAsync; - }) -.set_attr("FStatefulComputeEx", Backward) -.set_attr("FStatefulComputeEx", Backward) -.set_attr("FInferStorageType", InferStorageType); + .set_num_inputs([](const NodeAttrs& attrs) { + const CustomFunctionParam& params = nnvm::get(attrs.parsed); + return params.num_outs; + }) + .set_num_outputs([](const NodeAttrs& attrs) { + const CustomFunctionParam& params = nnvm::get(attrs.parsed); + return params.num_args; + }) + .set_attr("TIsBackward", true) + .set_attr("TIsLayerOpBackward", true) + .set_attr("FExecType", [](const NodeAttrs& attrs) { return ExecType::kAsync; }) + .set_attr("FStatefulComputeEx", Backward) + .set_attr("FStatefulComputeEx", Backward) + .set_attr("FInferStorageType", InferStorageType); } // namespace custom_function } // namespace mxnet -int MXCustomFunctionRecord(int num_inputs, NDArrayHandle *inputs, - int num_outputs, NDArrayHandle *outputs, - MXCallbackList *callbacks) { +int MXCustomFunctionRecord(int num_inputs, + NDArrayHandle* inputs, + int num_outputs, + NDArrayHandle* outputs, + MXCallbackList* callbacks) { using namespace mxnet; using namespace mxnet::custom_function; API_BEGIN(); CHECK(Imperative::Get()->is_recording()); - auto state = OpStatePtr::Create(); + auto state = OpStatePtr::Create(); CustomFunctionParam& params = state.get_state(); - params.num_args = num_inputs; - params.num_outs = num_outputs; - params.info.reset(callbacks, [](MXCallbackList* ptr){ - reinterpret_cast(ptr->callbacks[kCustomFunctionDelete])( + params.num_args = num_inputs; + params.num_outs = num_outputs; + params.info.reset(callbacks, [](MXCallbackList* ptr) { + reinterpret_cast(ptr->callbacks[kCustomFunctionDelete])( ptr->contexts[kCustomFunctionDelete]); - }); + }); std::vector ndinputs, ndoutputs; ndinputs.reserve(num_inputs); ndoutputs.reserve(num_outputs); @@ -213,10 +222,9 @@ int MXCustomFunctionRecord(int num_inputs, NDArrayHandle *inputs, params.out_dtypes.emplace_back(arr->dtype()); } nnvm::NodeAttrs attrs; - attrs.op = nnvm::Op::Get("_CustomFunction"); + attrs.op = nnvm::Op::Get("_CustomFunction"); attrs.parsed = params; - Imperative::Get()->RecordOp( - std::move(attrs), ndinputs, ndoutputs, state); + Imperative::Get()->RecordOp(std::move(attrs), ndinputs, ndoutputs, state); API_END(); } diff --git a/src/c_api/c_api_ndarray.cc b/src/c_api/c_api_ndarray.cc index a03868ad594a..d967ae6e12b3 100644 --- a/src/c_api/c_api_ndarray.cc +++ b/src/c_api/c_api_ndarray.cc @@ -46,11 +46,11 @@ void SetNDInputsOutputs(const nnvm::Op* op, std::vector* ndinputs, std::vector* ndoutputs, int num_inputs, - const NDArrayHandle *inputs, - int *num_outputs, + const NDArrayHandle* inputs, + int* num_outputs, int infered_num_outputs, int num_visible_outputs, - NDArrayHandle **outputs) { + NDArrayHandle** outputs) { NDArray** out_array = *reinterpret_cast(outputs); ndinputs->clear(); @@ -59,9 +59,9 @@ void SetNDInputsOutputs(const nnvm::Op* op, NDArray* inp = reinterpret_cast(inputs[i]); if (!features::is_enabled(features::INT64_TENSOR_SIZE)) { if (shape_is_known(inp->shape())) { // Shape may be unknown after dynamic shape operators - CHECK_LT(inp->shape().Size(), (int64_t{1} << 31) - 1) << - "[SetNDInputsOutputs] Size of tensor you are trying to allocate is larger than " - "2^31 elements. Please build with flag USE_INT64_TENSOR_SIZE=1"; + CHECK_LT(inp->shape().Size(), (int64_t{1} << 31) - 1) + << "[SetNDInputsOutputs] Size of tensor you are trying to allocate is larger than " + "2^31 elements. Please build with flag USE_INT64_TENSOR_SIZE=1"; } } ndinputs->emplace_back(inp); @@ -76,9 +76,8 @@ void SetNDInputsOutputs(const nnvm::Op* op, *num_outputs = num_visible_outputs; } else { CHECK(*num_outputs == infered_num_outputs || *num_outputs == num_visible_outputs) - << "Operator expects " << infered_num_outputs << " (all) or " - << num_visible_outputs << " (visible only) outputs, but got " - << *num_outputs << " instead."; + << "Operator expects " << infered_num_outputs << " (all) or " << num_visible_outputs + << " (visible only) outputs, but got " << *num_outputs << " instead."; for (int i = 0; i < *num_outputs; ++i) { ndoutputs->emplace_back(out_array[i]); } @@ -90,17 +89,17 @@ void SetNDInputsOutputs(const nnvm::Op* op, void MXImperativeInvokeImpl(AtomicSymbolCreator creator, int num_inputs, - NDArrayHandle *inputs, - int *num_outputs, - NDArrayHandle **outputs, + NDArrayHandle* inputs, + int* num_outputs, + NDArrayHandle** outputs, int num_params, - const char **param_keys, - const char **param_vals) { - const nnvm::Op* op = static_cast(creator); - MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get(); + const char** param_keys, + const char** param_vals) { + const nnvm::Op* op = static_cast(creator); + MXAPIThreadLocalEntry<>* ret = MXAPIThreadLocalStore<>::Get(); - nnvm::NodeAttrs attrs = imperative::ParseAttrs(op, num_inputs, num_params, - param_keys, param_vals); + nnvm::NodeAttrs attrs = + imperative::ParseAttrs(op, num_inputs, num_params, param_keys, param_vals); attrs.dict["__profiler_scope__"] = profiler::ProfilerScope::Get()->GetCurrentProfilerScope(); if (attrs.op) { attrs.name = attrs.op->name; @@ -111,8 +110,15 @@ void MXImperativeInvokeImpl(AtomicSymbolCreator creator, imperative::SetNumOutputs(op, attrs, num_inputs, &infered_num_outputs, &num_visible_outputs); std::vector ndinputs, ndoutputs; - SetNDInputsOutputs(op, &ndinputs, &ndoutputs, num_inputs, inputs, - num_outputs, infered_num_outputs, num_visible_outputs, outputs); + SetNDInputsOutputs(op, + &ndinputs, + &ndoutputs, + num_inputs, + inputs, + num_outputs, + infered_num_outputs, + num_visible_outputs, + outputs); if (Imperative::Get()->is_deferred_compute()) { Imperative::Get()->RecordDeferredCompute(std::move(attrs), ndinputs, ndoutputs); @@ -126,29 +132,31 @@ void MXImperativeInvokeImpl(AtomicSymbolCreator creator, } } - for (int i = *num_outputs; i < infered_num_outputs; ++i) delete ndoutputs[i]; + for (int i = *num_outputs; i < infered_num_outputs; ++i) + delete ndoutputs[i]; if (*outputs == nullptr) { ret->ret_handles.clear(); ret->ret_handles.reserve(*num_outputs); - for (int i = 0; i < *num_outputs; ++i) ret->ret_handles.push_back(ndoutputs[i]); + for (int i = 0; i < *num_outputs; ++i) + ret->ret_handles.push_back(ndoutputs[i]); *outputs = reinterpret_cast(dmlc::BeginPtr(ret->ret_handles)); } } int MXImperativeInvoke(AtomicSymbolCreator creator, int num_inputs, - NDArrayHandle *inputs, - int *num_outputs, - NDArrayHandle **outputs, + NDArrayHandle* inputs, + int* num_outputs, + NDArrayHandle** outputs, int num_params, - const char **param_keys, - const char **param_vals, - const int **out_stypes) { // outputs storage types - MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get(); + const char** param_keys, + const char** param_vals, + const int** out_stypes) { // outputs storage types + MXAPIThreadLocalEntry<>* ret = MXAPIThreadLocalStore<>::Get(); API_BEGIN(); - MXImperativeInvokeImpl(creator, num_inputs, inputs, num_outputs, outputs, - num_params, param_keys, param_vals); + MXImperativeInvokeImpl( + creator, num_inputs, inputs, num_outputs, outputs, num_params, param_keys, param_vals); if (out_stypes != nullptr) { NDArray** out_array = *reinterpret_cast(outputs); ret->out_types.clear(); @@ -165,7 +173,7 @@ int MXCreateCachedOp(SymbolHandle handle, int num_flags, const char** keys, const char** vals, - CachedOpHandle *out, + CachedOpHandle* out, bool thread_safe) { nnvm::Symbol* sym = static_cast(handle); API_BEGIN(); @@ -192,25 +200,24 @@ int MXFreeCachedOp(CachedOpHandle handle) { /*! * \brief get optimized graph from the cached op */ -int MXCachedOpGetOptimizedSymbol(CachedOpHandle handle, - SymbolHandle *out) { +int MXCachedOpGetOptimizedSymbol(CachedOpHandle handle, SymbolHandle* out) { auto s = new nnvm::Symbol(); API_BEGIN(); CachedOpPtr op = *static_cast(handle); - *s = op->GetOptimizedSymbol(); - *out = s; + *s = op->GetOptimizedSymbol(); + *out = s; API_END_HANDLE_ERROR(delete s); } int MXInvokeCachedOp(CachedOpHandle handle, int num_inputs, - NDArrayHandle *inputs, + NDArrayHandle* inputs, int default_dev_type, int default_dev_id, - int *num_outputs, - NDArrayHandle **outputs, - const int **out_stypes) { // outputs storage types - MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get(); + int* num_outputs, + NDArrayHandle** outputs, + const int** out_stypes) { // outputs storage types + MXAPIThreadLocalEntry<>* ret = MXAPIThreadLocalStore<>::Get(); API_BEGIN(); CachedOpPtr op_shared = *static_cast(handle); @@ -227,18 +234,17 @@ int MXInvokeCachedOp(CachedOpHandle handle, ndoutputs.reserve(op->num_outputs()); if (*outputs == nullptr) { *num_outputs = op->num_outputs(); - for (int i = 0; i < *num_outputs; ++i) ndoutputs.push_back(new NDArray()); + for (int i = 0; i < *num_outputs; ++i) + ndoutputs.push_back(new NDArray()); } else { - CHECK_EQ(*num_outputs, op->num_outputs()) - << "CachedOp expects " << op->num_outputs() << " outputs, but " - << *num_outputs << " was given."; + CHECK_EQ(*num_outputs, op->num_outputs()) << "CachedOp expects " << op->num_outputs() + << " outputs, but " << *num_outputs << " was given."; for (int i = 0; i < *num_outputs; ++i) { ndoutputs.push_back(reinterpret_cast((*outputs)[i])); } } // construct default context - Context ctx = Context::Create(static_cast(default_dev_type), - default_dev_id); + Context ctx = Context::Create(static_cast(default_dev_type), default_dev_id); op->Forward(op_shared, ndinputs, ndoutputs, ctx); if (*outputs == nullptr) { @@ -311,9 +317,9 @@ int MXSetIsNumpyDefaultDtype(bool default_dtype, bool* prev) { } int MXAutogradMarkVariables(uint32_t num_var, - NDArrayHandle *var_handles, - uint32_t *reqs_array, - NDArrayHandle *grad_handles) { + NDArrayHandle* var_handles, + uint32_t* reqs_array, + NDArrayHandle* grad_handles) { API_BEGIN(); std::vector variables, gradients; std::vector grad_reqs; @@ -329,31 +335,37 @@ int MXAutogradMarkVariables(uint32_t num_var, API_END(); } -int MXAutogradComputeGradient(uint32_t num_output, - NDArrayHandle *output_handles) { +int MXAutogradComputeGradient(uint32_t num_output, NDArrayHandle* output_handles) { return MXAutogradBackward(num_output, output_handles, nullptr, 0); } int MXAutogradBackward(uint32_t num_output, - NDArrayHandle *output_handles, - NDArrayHandle *ograd_handles, + NDArrayHandle* output_handles, + NDArrayHandle* ograd_handles, int retain_graph) { - return MXAutogradBackwardEx(num_output, output_handles, ograd_handles, - 0, nullptr, retain_graph, false, true, - nullptr, nullptr); + return MXAutogradBackwardEx(num_output, + output_handles, + ograd_handles, + 0, + nullptr, + retain_graph, + false, + true, + nullptr, + nullptr); } int MXAutogradBackwardEx(uint32_t num_output, - NDArrayHandle *output_handles, - NDArrayHandle *ograd_handles, + NDArrayHandle* output_handles, + NDArrayHandle* ograd_handles, uint32_t num_variables, - NDArrayHandle *var_handles, + NDArrayHandle* var_handles, int retain_graph, int create_graph, int is_train, - NDArrayHandle **grad_handles, - int **grad_stypes) { - MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get(); + NDArrayHandle** grad_handles, + int** grad_stypes) { + MXAPIThreadLocalEntry<>* ret = MXAPIThreadLocalStore<>::Get(); API_BEGIN(); std::vector outputs, ograds, variables; @@ -376,8 +388,8 @@ int MXAutogradBackwardEx(uint32_t num_output, variables.emplace_back(reinterpret_cast(var_handles[i])); } - auto grads = Imperative::Get()->Backward(outputs, ograds, variables, is_train, - retain_graph, create_graph); + auto grads = + Imperative::Get()->Backward(outputs, ograds, variables, is_train, retain_graph, create_graph); if (num_variables != 0) { ret->ret_handles.clear(); ret->out_types.clear(); @@ -388,16 +400,16 @@ int MXAutogradBackwardEx(uint32_t num_output, ret->out_types.push_back(i->storage_type()); } *grad_handles = dmlc::BeginPtr(ret->ret_handles); - *grad_stypes = dmlc::BeginPtr(ret->out_types); + *grad_stypes = dmlc::BeginPtr(ret->out_types); } API_END(); } -int MXAutogradGetSymbol(NDArrayHandle handle, SymbolHandle *out) { +int MXAutogradGetSymbol(NDArrayHandle handle, SymbolHandle* out) { API_BEGIN(); - NDArray *head = reinterpret_cast(handle); - auto sym = new nnvm::Symbol(head->get_autograd_symbol()); - *out = reinterpret_cast(sym); + NDArray* head = reinterpret_cast(handle); + auto sym = new nnvm::Symbol(head->get_autograd_symbol()); + *out = reinterpret_cast(sym); API_END(); } @@ -406,57 +418,57 @@ int MXCachedOpRegisterOpHook(CachedOpHandle handle, bool monitor_all) { API_BEGIN(); CachedOpMonitorCallback callback_temp = nullptr; - std::function clbk; + std::function clbk; if (callback) { callback_temp = callback; - clbk = [callback_temp](const char *name, const char *opr_name, - void *handle) { + clbk = [callback_temp](const char* name, const char* opr_name, void* handle) { callback_temp(name, opr_name, handle); }; } else { - clbk = nullptr; + clbk = nullptr; } - CachedOpPtr op = *static_cast(handle); + CachedOpPtr op = *static_cast(handle); op->RegisterOpHook(clbk, monitor_all); API_END(); } -int MXNDArrayIsDeferredCompute(int *curr) { +int MXNDArrayIsDeferredCompute(int* curr) { API_BEGIN(); *curr = Imperative::Get()->is_deferred_compute(); API_END(); } -int MXNDArraySetIsDeferredCompute(int deferred_compute, int *prev) { +int MXNDArraySetIsDeferredCompute(int deferred_compute, int* prev) { API_BEGIN(); *prev = Imperative::Get()->set_is_deferred_compute(static_cast(deferred_compute)); API_END(); } -int MXNDArraySetDeferredComputeVariable(NDArrayHandle *arrays, SymbolHandle *variables, int num) { +int MXNDArraySetDeferredComputeVariable(NDArrayHandle* arrays, SymbolHandle* variables, int num) { API_BEGIN(); Imperative::Get()->SetDeferredComputeVariable(arrays, variables, num); API_END(); } -int MXNDArrayClearDeferredCompute(NDArrayHandle *arrays, int num) { +int MXNDArrayClearDeferredCompute(NDArrayHandle* arrays, int num) { API_BEGIN(); Imperative::Get()->DeferredComputeClear(arrays, num); API_END(); } -int MXNDArrayGetDeferredComputeSymbol(NDArrayHandle *output_handles, int num_outputs, - SymbolHandle *out) { - nnvm::Symbol *s = new nnvm::Symbol(); +int MXNDArrayGetDeferredComputeSymbol(NDArrayHandle* output_handles, + int num_outputs, + SymbolHandle* out) { + nnvm::Symbol* s = new nnvm::Symbol(); API_BEGIN(); - std::vector outputs; + std::vector outputs; outputs.reserve(num_outputs); for (int i = 0; i < num_outputs; ++i) { - NDArray *array = reinterpret_cast(output_handles[i]); + NDArray* array = reinterpret_cast(output_handles[i]); outputs.emplace_back(array); } // Obtain Symbol - *s = Imperative::Get()->GetDeferredComputeSymbol(outputs); + *s = Imperative::Get()->GetDeferredComputeSymbol(outputs); *out = s; API_END_HANDLE_ERROR(delete s;); } diff --git a/src/c_api/c_api_profile.cc b/src/c_api/c_api_profile.cc index d7768e199a64..c5f719c0b618 100644 --- a/src/c_api/c_api_profile.cc +++ b/src/c_api/c_api_profile.cc @@ -39,13 +39,12 @@ namespace mxnet { static profiler::ProfileDomain api_domain("MXNET_C_API"); static profiler::ProfileCounter api_call_counter("MXNet C API Calls", &api_domain); -static profiler::ProfileCounter api_concurrency_counter("MXNet C API Concurrency", - &api_domain); +static profiler::ProfileCounter api_concurrency_counter("MXNet C API Concurrency", &api_domain); /*! \brief Per-API-call timing data */ struct APICallTimingData { - const char *name_; - profiler::ProfileTask *task_; + const char* name_; + profiler::ProfileTask* task_; }; /*! @@ -64,12 +63,14 @@ class ProfilingThreadData { * \param domain Domain of the task * \return Pointer to the stored or created ProfileTask object */ - profiler::ProfileTask *profile_task(const char *name, profiler::ProfileDomain *domain) { + profiler::ProfileTask* profile_task(const char* name, profiler::ProfileDomain* domain) { // Per-thread so no lock necessary auto iter = tasks_.find(name); if (iter == tasks_.end()) { - iter = tasks_.emplace(std::make_pair( - name, std::make_unique(name, domain))).first; + iter = + tasks_ + .emplace(std::make_pair(name, std::make_unique(name, domain))) + .first; } return iter->second.get(); } @@ -90,15 +91,13 @@ static thread_local ProfilingThreadData thread_profiling_data; static MX_THREAD_LOCAL ProfilingThreadData thread_profiling_data; #endif -extern void on_enter_api(const char *function) { +extern void on_enter_api(const char* function) { if (profiler::Profiler::Get()->IsProfiling(profiler::Profiler::kAPI)) { if (!thread_profiling_data.ignore_call_) { ++api_call_counter; ++api_concurrency_counter; - APICallTimingData data = { - function - , thread_profiling_data.profile_task(function, &api_domain) - }; + APICallTimingData data = {function, + thread_profiling_data.profile_task(function, &api_domain)}; thread_profiling_data.calls_.push(data); data.task_->start(); } @@ -120,7 +119,7 @@ extern void on_exit_api() { * \brief Don't profile calls in this scope using RAII */ struct IgnoreProfileCallScope { - IgnoreProfileCallScope() { + IgnoreProfileCallScope() { DCHECK_EQ(thread_profiling_data.ignore_call_, false); thread_profiling_data.ignore_call_ = true; } @@ -149,24 +148,19 @@ struct PythonProfileObjects { std::mutex cs_frames_; std::mutex cs_events_; std::list> domains_; - std::unordered_map> - counters_; - std::unordered_map> - tasks_; - std::unordered_map> - frames_; - std::unordered_map> - events_; + std::unordered_map> + counters_; + std::unordered_map> tasks_; + std::unordered_map> + frames_; + std::unordered_map> + events_; }; static PythonProfileObjects python_profile_objects; -enum class ProfileProcess { - kWorker, kServer -}; +enum class ProfileProcess { kWorker, kServer }; -enum class PrintFormat { - table, json -}; +enum class PrintFormat { table, json }; struct ProfileConfigParam : public dmlc::Parameter { bool profile_all; @@ -181,38 +175,48 @@ struct ProfileConfigParam : public dmlc::Parameter { bool aggregate_stats; int profile_process; DMLC_DECLARE_PARAMETER(ProfileConfigParam) { - DMLC_DECLARE_FIELD(profile_all).set_default(false) - .describe("Profile all. Default is False."); - DMLC_DECLARE_FIELD(profile_symbolic).set_default(true) - .describe("Profile symbolic operators. Default is True."); - DMLC_DECLARE_FIELD(profile_imperative).set_default(true) - .describe("Profile imperative operators. Default is True."); - DMLC_DECLARE_FIELD(profile_memory).set_default(true) - .describe("Profile memory. Default is True."); - DMLC_DECLARE_FIELD(profile_api).set_default(true) - .describe("Profile C API. Default is True."); - DMLC_DECLARE_FIELD(filename).set_default("profile.json") - .describe("File name to write profiling info."); + DMLC_DECLARE_FIELD(profile_all).set_default(false).describe("Profile all. Default is False."); + DMLC_DECLARE_FIELD(profile_symbolic) + .set_default(true) + .describe("Profile symbolic operators. Default is True."); + DMLC_DECLARE_FIELD(profile_imperative) + .set_default(true) + .describe("Profile imperative operators. Default is True."); + DMLC_DECLARE_FIELD(profile_memory) + .set_default(true) + .describe("Profile memory. Default is True."); + DMLC_DECLARE_FIELD(profile_api).set_default(true).describe("Profile C API. Default is True."); + DMLC_DECLARE_FIELD(filename) + .set_default("profile.json") + .describe("File name to write profiling info."); #if MXNET_USE_CUDA - DMLC_DECLARE_FIELD(gpu_memory_profile_filename_prefix).set_default("gpu_memory_profile") - .describe("File name prefix to write GPU memory profile info."); + DMLC_DECLARE_FIELD(gpu_memory_profile_filename_prefix) + .set_default("gpu_memory_profile") + .describe("File name prefix to write GPU memory profile info."); #endif // MXNET_USE_CUDA - DMLC_DECLARE_FIELD(continuous_dump).set_default(true) - .describe("Periodically dump (and append) profiling data to file while running. " - "Default is True."); - DMLC_DECLARE_FIELD(dump_period).set_default(1.0f) - .describe("When continuous dump is enabled, the period between subsequent " - "profile info dumping."); - DMLC_DECLARE_FIELD(aggregate_stats).set_default(false) - .describe("Maintain aggregate stats, required for MXDumpAggregateStats. Note that " - "this can have a negative performance impact. Default is False."); + DMLC_DECLARE_FIELD(continuous_dump) + .set_default(true) + .describe( + "Periodically dump (and append) profiling data to file while running. " + "Default is True."); + DMLC_DECLARE_FIELD(dump_period) + .set_default(1.0f) + .describe( + "When continuous dump is enabled, the period between subsequent " + "profile info dumping."); + DMLC_DECLARE_FIELD(aggregate_stats) + .set_default(false) + .describe( + "Maintain aggregate stats, required for MXDumpAggregateStats. Note that " + "this can have a negative performance impact. Default is False."); DMLC_DECLARE_FIELD(profile_process) - .add_enum("worker", static_cast(ProfileProcess::kWorker)) - .add_enum("server", static_cast(ProfileProcess::kServer)) - .set_default(static_cast(ProfileProcess::kWorker)) - .describe("Specifies which process to profile: " - "worker: this is default. for single node training it should always be worker." - "server: for distributed training, this profiles server process"); + .add_enum("worker", static_cast(ProfileProcess::kWorker)) + .add_enum("server", static_cast(ProfileProcess::kServer)) + .set_default(static_cast(ProfileProcess::kWorker)) + .describe( + "Specifies which process to profile: " + "worker: this is default. for single node training it should always be worker." + "server: for distributed training, this profiles server process"); } }; @@ -221,59 +225,70 @@ DMLC_REGISTER_PARAMETER(ProfileConfigParam); struct ProfileMarkerScopeParam : public dmlc::Parameter { int scope; DMLC_DECLARE_PARAMETER(ProfileMarkerScopeParam) { - DMLC_DECLARE_FIELD(scope).set_default(profiler::ProfileMarker::kProcess) - .add_enum("global", profiler::ProfileMarker::kGlobal) - .add_enum("process", profiler::ProfileMarker::kProcess) - .add_enum("thread", profiler::ProfileMarker::kThread) - .add_enum("task", profiler::ProfileMarker::kTask) - .add_enum("marker", profiler::ProfileMarker::kMarker) - .describe("Profile Instant-Marker scope."); + DMLC_DECLARE_FIELD(scope) + .set_default(profiler::ProfileMarker::kProcess) + .add_enum("global", profiler::ProfileMarker::kGlobal) + .add_enum("process", profiler::ProfileMarker::kProcess) + .add_enum("thread", profiler::ProfileMarker::kThread) + .add_enum("task", profiler::ProfileMarker::kTask) + .add_enum("marker", profiler::ProfileMarker::kMarker) + .describe("Profile Instant-Marker scope."); } }; DMLC_REGISTER_PARAMETER(ProfileMarkerScopeParam); -int MXSetProcessProfilerConfig(int num_params, const char* const* keys, const char* const* vals, +int MXSetProcessProfilerConfig(int num_params, + const char* const* keys, + const char* const* vals, KVStoreHandle kvstoreHandle) { - mxnet::IgnoreProfileCallScope ignore; + mxnet::IgnoreProfileCallScope ignore; API_BEGIN(); - std::vector> kwargs; - kwargs.reserve(num_params); + std::vector> kwargs; + kwargs.reserve(num_params); + for (int i = 0; i < num_params; ++i) { + CHECK_NOTNULL(keys[i]); + CHECK_NOTNULL(vals[i]); + kwargs.emplace_back(std::make_pair(keys[i], vals[i])); + } + ProfileConfigParam param; + param.Init(kwargs); + if (static_cast(param.profile_process) == ProfileProcess::kServer) { + std::ostringstream os; for (int i = 0; i < num_params; ++i) { - CHECK_NOTNULL(keys[i]); - CHECK_NOTNULL(vals[i]); - kwargs.emplace_back(std::make_pair(keys[i], vals[i])); + // this will be sent to the server now, those configs shouldn't have profile server again + if (strcmp(keys[i], "profile_process") == 0) + continue; + os << keys[i] << ":" << vals[i]; + if (i != num_params - 1) + os << ","; } - ProfileConfigParam param; - param.Init(kwargs); - if (static_cast(param.profile_process) == ProfileProcess::kServer) { - std::ostringstream os; - for (int i = 0; i < num_params; ++i) { - // this will be sent to the server now, those configs shouldn't have profile server again - if (strcmp(keys[i], "profile_process") == 0) continue; - os << keys[i] << ":" << vals[i]; - if (i != num_params - 1) os << ","; - } - CHECK(kvstoreHandle) << "KVStoreHandle passed to profiler is null"; - static_cast(kvstoreHandle)->SetServerProfilerCommand( - mxnet::KVStoreServerProfilerCommand::kSetConfig, os.str()); - } else { - int mode = 0; - if (param.profile_api || param.profile_all) { mode |= profiler::Profiler::kAPI; } - if (param.profile_symbolic || param.profile_all) { mode |= profiler::Profiler::kSymbolic; } - if (param.profile_imperative || - param.profile_all) { mode |= profiler::Profiler::kImperative; } - if (param.profile_memory || param.profile_all) { mode |= profiler::Profiler::kMemory; } - profiler::Profiler::Get()->SetConfig(profiler::Profiler::ProfilerMode(mode), - std::string(param.filename), - param.continuous_dump, - param.dump_period, - param.aggregate_stats); + CHECK(kvstoreHandle) << "KVStoreHandle passed to profiler is null"; + static_cast(kvstoreHandle) + ->SetServerProfilerCommand(mxnet::KVStoreServerProfilerCommand::kSetConfig, os.str()); + } else { + int mode = 0; + if (param.profile_api || param.profile_all) { + mode |= profiler::Profiler::kAPI; + } + if (param.profile_symbolic || param.profile_all) { + mode |= profiler::Profiler::kSymbolic; + } + if (param.profile_imperative || param.profile_all) { + mode |= profiler::Profiler::kImperative; + } + if (param.profile_memory || param.profile_all) { + mode |= profiler::Profiler::kMemory; + } + profiler::Profiler::Get()->SetConfig(profiler::Profiler::ProfilerMode(mode), + std::string(param.filename), + param.continuous_dump, + param.dump_period, + param.aggregate_stats); #if MXNET_USE_CUDA - profiler::GpuDeviceStorageProfiler::Get()->SetConfig( - param.gpu_memory_profile_filename_prefix); + profiler::GpuDeviceStorageProfiler::Get()->SetConfig(param.gpu_memory_profile_filename_prefix); #endif // MXNET_USE_CUDA - } + } API_END(); } @@ -281,30 +296,33 @@ int MXSetProfilerConfig(int num_params, const char* const* keys, const char* con return MXSetProcessProfilerConfig(num_params, keys, vals, nullptr); } -int MXAggregateProfileStatsPrint(const char **out_str, int reset, int format, int sort_by, +int MXAggregateProfileStatsPrint(const char** out_str, + int reset, + int format, + int sort_by, int ascending) { - MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get(); + MXAPIThreadLocalEntry<>* ret = MXAPIThreadLocalStore<>::Get(); API_BEGIN(); - CHECK_NOTNULL(out_str); - profiler::Profiler *profiler = profiler::Profiler::Get(); - if (profiler->IsEnableOutput()) { - // Register stats up until now - profiler->DumpProfile(false); - } - std::shared_ptr stats = profiler->GetAggregateStats(); - std::ostringstream os; - if (stats) { - if (static_cast(format) == PrintFormat::table) - stats->DumpTable(os, sort_by, ascending); - else if (static_cast(format) == PrintFormat::json) - stats->DumpJson(os, sort_by, ascending); - else - LOG(FATAL) << "Invalid value for parameter format"; - } - if (reset != 0) - stats->clear(); - ret->ret_str = os.str(); - *out_str = (ret->ret_str).c_str(); + CHECK_NOTNULL(out_str); + profiler::Profiler* profiler = profiler::Profiler::Get(); + if (profiler->IsEnableOutput()) { + // Register stats up until now + profiler->DumpProfile(false); + } + std::shared_ptr stats = profiler->GetAggregateStats(); + std::ostringstream os; + if (stats) { + if (static_cast(format) == PrintFormat::table) + stats->DumpTable(os, sort_by, ascending); + else if (static_cast(format) == PrintFormat::json) + stats->DumpJson(os, sort_by, ascending); + else + LOG(FATAL) << "Invalid value for parameter format"; + } + if (reset != 0) + stats->clear(); + ret->ret_str = os.str(); + *out_str = (ret->ret_str).c_str(); API_END(); } @@ -317,13 +335,13 @@ int MXDumpProcessProfile(int finished, int profile_process, KVStoreHandle kvStor API_BEGIN(); if (static_cast(profile_process) == ProfileProcess::kServer) { CHECK(kvStoreHandle) << "Kvstore Handle passed to profiler is null"; - static_cast(kvStoreHandle)->SetServerProfilerCommand( - mxnet::KVStoreServerProfilerCommand::kDump, - std::to_string(finished)); + static_cast(kvStoreHandle) + ->SetServerProfilerCommand(mxnet::KVStoreServerProfilerCommand::kDump, + std::to_string(finished)); } else { - profiler::Profiler *profiler = profiler::Profiler::Get(); + profiler::Profiler* profiler = profiler::Profiler::Get(); CHECK(profiler->IsEnableOutput()) - << "Profiler hasn't been run. Config and start profiler first"; + << "Profiler hasn't been run. Config and start profiler first"; profiler->DumpProfile(finished != 0); #if MXNET_USE_CUDA profiler::GpuDeviceStorageProfiler::Get()->DumpProfile(); @@ -348,9 +366,9 @@ int MXSetProcessProfilerState(int state, int profile_process, KVStoreHandle kvSt API_BEGIN(); if (static_cast(profile_process) == ProfileProcess::kServer) { CHECK(kvStoreHandle) << "Kvstore Handle passed to profiler is null"; - static_cast(kvStoreHandle)->SetServerProfilerCommand( - mxnet::KVStoreServerProfilerCommand::kState, - std::to_string(state)); + static_cast(kvStoreHandle) + ->SetServerProfilerCommand(mxnet::KVStoreServerProfilerCommand::kState, + std::to_string(state)); } else { switch (state) { case profiler::Profiler::kNotRunning: @@ -365,134 +383,127 @@ int MXSetProcessProfilerState(int state, int profile_process, KVStoreHandle kvSt API_END(); } -int MXProfileCreateDomain(const char *domain, ProfileHandle *out) { +int MXProfileCreateDomain(const char* domain, ProfileHandle* out) { mxnet::IgnoreProfileCallScope ignore; API_BEGIN(); - auto dom = std::make_shared(domain); - { - std::unique_lock lock(python_profile_objects.cs_domains_); - python_profile_objects.domains_.push_back(dom); - } - *out = dom.get(); + auto dom = std::make_shared(domain); + { + std::unique_lock lock(python_profile_objects.cs_domains_); + python_profile_objects.domains_.push_back(dom); + } + *out = dom.get(); API_END(); } -int MXProfileCreateTask(ProfileHandle domain, - const char *task_name, - ProfileHandle *out) { +int MXProfileCreateTask(ProfileHandle domain, const char* task_name, ProfileHandle* out) { mxnet::IgnoreProfileCallScope ignore; API_BEGIN(); - auto ctr = - std::make_shared(task_name, - static_cast(domain)); - { - std::unique_lock lock(python_profile_objects.cs_tasks_); - python_profile_objects.tasks_.emplace(std::make_pair(ctr.get(), ctr)); - } - *out = ctr.get(); + auto ctr = std::make_shared(task_name, + static_cast(domain)); + { + std::unique_lock lock(python_profile_objects.cs_tasks_); + python_profile_objects.tasks_.emplace(std::make_pair(ctr.get(), ctr)); + } + *out = ctr.get(); API_END(); } -int MXProfileCreateFrame(ProfileHandle domain, - const char *frame_name, - ProfileHandle *out) { +int MXProfileCreateFrame(ProfileHandle domain, const char* frame_name, ProfileHandle* out) { mxnet::IgnoreProfileCallScope ignore; API_BEGIN(); - auto ctr = - std::make_shared(frame_name, - static_cast(domain)); - { - std::unique_lock lock(python_profile_objects.cs_frames_); - python_profile_objects.frames_.emplace(std::make_pair(ctr.get(), ctr)); - } - *out = ctr.get(); + auto ctr = std::make_shared( + frame_name, static_cast(domain)); + { + std::unique_lock lock(python_profile_objects.cs_frames_); + python_profile_objects.frames_.emplace(std::make_pair(ctr.get(), ctr)); + } + *out = ctr.get(); API_END(); } -int MXProfileCreateEvent(const char *event_name, ProfileHandle *out) { +int MXProfileCreateEvent(const char* event_name, ProfileHandle* out) { mxnet::IgnoreProfileCallScope ignore; API_BEGIN(); - auto ctr = - std::make_shared(event_name); - { - std::unique_lock lock(python_profile_objects.cs_events_); - python_profile_objects.events_.emplace(std::make_pair(ctr.get(), ctr)); - } - *out = ctr.get(); + auto ctr = std::make_shared(event_name); + { + std::unique_lock lock(python_profile_objects.cs_events_); + python_profile_objects.events_.emplace(std::make_pair(ctr.get(), ctr)); + } + *out = ctr.get(); API_END(); } int MXProfileDestroyHandle(ProfileHandle object_handle) { mxnet::IgnoreProfileCallScope ignore; API_BEGIN(); - CHECK_NE(object_handle, static_cast(nullptr)) + CHECK_NE(object_handle, static_cast(nullptr)) << "Invalid NULL handle passed to MXProfileDestroyHandle"; - std::shared_ptr shared_object_ptr(nullptr); - { - auto object = static_cast(object_handle); - switch (object->type()) { - case profiler::kTask: { - auto p = static_cast(object_handle); - std::unique_lock lock(python_profile_objects.cs_tasks_); - auto iter = python_profile_objects.tasks_.find(p); - if (iter != python_profile_objects.tasks_.end()) { - shared_object_ptr = iter->second; - python_profile_objects.tasks_.erase(iter); - } - break; + std::shared_ptr shared_object_ptr(nullptr); + { + auto object = static_cast(object_handle); + switch (object->type()) { + case profiler::kTask: { + auto p = static_cast(object_handle); + std::unique_lock lock(python_profile_objects.cs_tasks_); + auto iter = python_profile_objects.tasks_.find(p); + if (iter != python_profile_objects.tasks_.end()) { + shared_object_ptr = iter->second; + python_profile_objects.tasks_.erase(iter); } - case profiler::kEvent: { - auto p = static_cast(object_handle); - std::unique_lock lock(python_profile_objects.cs_events_); - auto iter = python_profile_objects.events_.find(p); - if (iter != python_profile_objects.events_.end()) { - shared_object_ptr = iter->second; - python_profile_objects.events_.erase(iter); - } - break; + break; + } + case profiler::kEvent: { + auto p = static_cast(object_handle); + std::unique_lock lock(python_profile_objects.cs_events_); + auto iter = python_profile_objects.events_.find(p); + if (iter != python_profile_objects.events_.end()) { + shared_object_ptr = iter->second; + python_profile_objects.events_.erase(iter); } - case profiler::kFrame: { - auto p = static_cast(object_handle); - std::unique_lock lock(python_profile_objects.cs_frames_); - auto iter = python_profile_objects.frames_.find(p); - if (iter != python_profile_objects.frames_.end()) { - shared_object_ptr = iter->second; - python_profile_objects.frames_.erase(iter); - } - break; + break; + } + case profiler::kFrame: { + auto p = static_cast(object_handle); + std::unique_lock lock(python_profile_objects.cs_frames_); + auto iter = python_profile_objects.frames_.find(p); + if (iter != python_profile_objects.frames_.end()) { + shared_object_ptr = iter->second; + python_profile_objects.frames_.erase(iter); } - case profiler::kCounter: { - auto p = static_cast(object_handle); - std::unique_lock lock(python_profile_objects.cs_counters_); - auto iter = python_profile_objects.counters_.find(p); - if (iter != python_profile_objects.counters_.end()) { - shared_object_ptr = iter->second; - python_profile_objects.counters_.erase(iter); - } - break; + break; + } + case profiler::kCounter: { + auto p = static_cast(object_handle); + std::unique_lock lock(python_profile_objects.cs_counters_); + auto iter = python_profile_objects.counters_.find(p); + if (iter != python_profile_objects.counters_.end()) { + shared_object_ptr = iter->second; + python_profile_objects.counters_.erase(iter); } - case profiler::kDomain: - // Not destroyed - break; + break; } + case profiler::kDomain: + // Not destroyed + break; } - shared_object_ptr.reset(); // Destroy out of lock scope + } + shared_object_ptr.reset(); // Destroy out of lock scope API_END(); } int MXProfileDurationStart(ProfileHandle duration_handle) { mxnet::IgnoreProfileCallScope ignore; API_BEGIN(); - CHECK_NOTNULL(duration_handle); - static_cast(duration_handle)->start(); + CHECK_NOTNULL(duration_handle); + static_cast(duration_handle)->start(); API_END(); } int MXProfileDurationStop(ProfileHandle duration_handle) { mxnet::IgnoreProfileCallScope ignore; API_BEGIN(); - CHECK_NOTNULL(duration_handle); - static_cast(duration_handle)->stop(); + CHECK_NOTNULL(duration_handle); + static_cast(duration_handle)->stop(); API_END(); } @@ -505,9 +516,9 @@ int MXProcessProfilePause(int paused, int profile_process, KVStoreHandle kvStore API_BEGIN(); if (static_cast(profile_process) == ProfileProcess::kServer) { CHECK(kvStoreHandle) << "Kvstore Handle passed to profiler is null"; - static_cast(kvStoreHandle)->SetServerProfilerCommand( - mxnet::KVStoreServerProfilerCommand::kPause, - std::to_string(paused)); + static_cast(kvStoreHandle) + ->SetServerProfilerCommand(mxnet::KVStoreServerProfilerCommand::kPause, + std::to_string(paused)); } else { if (paused) { profiler::vtune::vtune_pause(); @@ -520,48 +531,42 @@ int MXProcessProfilePause(int paused, int profile_process, KVStoreHandle kvStore API_END(); } -int MXProfileCreateCounter(ProfileHandle domain, - const char *counter_name, - ProfileHandle *out) { +int MXProfileCreateCounter(ProfileHandle domain, const char* counter_name, ProfileHandle* out) { mxnet::IgnoreProfileCallScope ignore; API_BEGIN(); - auto ctr = - std::make_shared(counter_name, - static_cast(domain)); - { - std::unique_lock lock(python_profile_objects.cs_counters_); - python_profile_objects.counters_.emplace(std::make_pair(ctr.get(), ctr)); - } - *out = ctr.get(); + auto ctr = std::make_shared( + counter_name, static_cast(domain)); + { + std::unique_lock lock(python_profile_objects.cs_counters_); + python_profile_objects.counters_.emplace(std::make_pair(ctr.get(), ctr)); + } + *out = ctr.get(); API_END(); } int MXProfileSetCounter(ProfileHandle counter_handle, uint64_t value) { mxnet::IgnoreProfileCallScope ignore; API_BEGIN(); - static_cast(counter_handle)->operator=(value); + static_cast(counter_handle)->operator=(value); API_END(); } int MXProfileAdjustCounter(ProfileHandle counter_handle, int64_t by_value) { mxnet::IgnoreProfileCallScope ignore; API_BEGIN(); - static_cast(counter_handle)->operator+=(by_value); + static_cast(counter_handle)->operator+=(by_value); API_END(); } -int MXProfileSetMarker(ProfileHandle domain, - const char *instant_marker_name, - const char *scope) { +int MXProfileSetMarker(ProfileHandle domain, const char* instant_marker_name, const char* scope) { mxnet::IgnoreProfileCallScope ignore; API_BEGIN(); - ProfileMarkerScopeParam param; - std::vector> kwargs = {{ "scope", scope }}; - param.Init(kwargs); - profiler::ProfileMarker marker(instant_marker_name, - static_cast(domain), - static_cast( - param.scope)); - marker.mark(); + ProfileMarkerScopeParam param; + std::vector> kwargs = {{"scope", scope}}; + param.Init(kwargs); + profiler::ProfileMarker marker(instant_marker_name, + static_cast(domain), + static_cast(param.scope)); + marker.mark(); API_END(); } diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc index 8990e8ac6b29..69b46e436e5e 100644 --- a/src/c_api/c_api_symbolic.cc +++ b/src/c_api/c_api_symbolic.cc @@ -39,36 +39,27 @@ namespace mxnet { namespace op { void RegisterLegacyOpProp(); void RegisterLegacyNDFunc(); -} -const std::vector kHiddenKeys = { - "ctx_group", - "lr_mult", - "wd_mult", - "force_mirroring", - "mirror_stage", - "profiler_scope" -}; -const std::vector kReplacedHiddenKeys = { - "__ctx_group__", - "__lr_mult__", - "__wd_mult__", - "__force_mirroring__", - "__mirror_stage__", - "__profiler_scope__" -}; -const char *kNamespaceSeparator = "$"; - +} // namespace op +const std::vector kHiddenKeys = + {"ctx_group", "lr_mult", "wd_mult", "force_mirroring", "mirror_stage", "profiler_scope"}; +const std::vector kReplacedHiddenKeys = {"__ctx_group__", + "__lr_mult__", + "__wd_mult__", + "__force_mirroring__", + "__mirror_stage__", + "__profiler_scope__"}; +const char* kNamespaceSeparator = "$"; DMLC_JSON_ENABLE_ANY(int, int); // convert nnvm symbol to a nnvm graph. -nnvm::Graph Symbol2Graph(const nnvm::Symbol &s) { +nnvm::Graph Symbol2Graph(const nnvm::Symbol& s) { nnvm::Graph g; - g.outputs = s.outputs; + g.outputs = s.outputs; g.attrs["mxnet_version"] = std::make_shared(static_cast(MXNET_VERSION)); if (Imperative::Get()->is_np_shape()) { - g.attrs["is_np_shape"] = std::make_shared( - static_cast(Imperative::Get()->is_np_shape())); + g.attrs["is_np_shape"] = + std::make_shared(static_cast(Imperative::Get()->is_np_shape())); } return g; } @@ -88,32 +79,30 @@ std::vector ReadOnlyArgIndices(const nnvm::IndexedGraph& idx) { // symbolic configuration generation API. // Redirect to NNVM's C API -int MXListAllOpNames(nn_uint *out_size, - const char ***out_array) { +int MXListAllOpNames(nn_uint* out_size, const char*** out_array) { mxnet::op::RegisterLegacyOpProp(); mxnet::op::RegisterLegacyNDFunc(); return NNListAllOpNames(out_size, out_array); } -int MXSymbolListAtomicSymbolCreators(uint32_t *out_size, - AtomicSymbolCreator **out_array) { +int MXSymbolListAtomicSymbolCreators(uint32_t* out_size, AtomicSymbolCreator** out_array) { mxnet::op::RegisterLegacyOpProp(); mxnet::op::RegisterLegacyNDFunc(); return NNListUniqueOps(out_size, out_array); } int MXSymbolGetAtomicSymbolInfo(AtomicSymbolCreator creator, - const char **name, - const char **description, - uint32_t *num_args, - const char ***arg_names, - const char ***arg_type_infos, - const char ***arg_descriptions, - const char **key_var_num_args, - const char **return_type) { + const char** name, + const char** description, + uint32_t* num_args, + const char*** arg_names, + const char*** arg_type_infos, + const char*** arg_descriptions, + const char** key_var_num_args, + const char** return_type) { static auto& map_key_var_args = nnvm::Op::GetAttr("key_var_num_args"); - const Op* op = static_cast(creator); - MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get(); + const Op* op = static_cast(creator); + MXAPIThreadLocalEntry<>* ret = MXAPIThreadLocalStore<>::Get(); ret->ret_str.resize(0); if (map_key_var_args.count(op) != 0) { @@ -121,24 +110,28 @@ int MXSymbolGetAtomicSymbolInfo(AtomicSymbolCreator creator, } else { *key_var_num_args = ret->ret_str.c_str(); } - return NNGetOpInfo( - creator, name, description, - num_args, arg_names, arg_type_infos, - arg_descriptions, return_type); + return NNGetOpInfo(creator, + name, + description, + num_args, + arg_names, + arg_type_infos, + arg_descriptions, + return_type); } int MXSymbolCreateAtomicSymbol(AtomicSymbolCreator creator, uint32_t num_param, - const char **keys, - const char **vals, - SymbolHandle *out) { - nnvm::Symbol *s = new nnvm::Symbol(); + const char** keys, + const char** vals, + SymbolHandle* out) { + nnvm::Symbol* s = new nnvm::Symbol(); API_BEGIN(); const nnvm::Op* op = static_cast(creator); std::unordered_map kwargs; for (nn_uint i = 0; i < num_param; ++i) { bool flag = false; - for (const auto &k : kHiddenKeys) { + for (const auto& k : kHiddenKeys) { std::string tmp(keys[i]); size_t pos = tmp.rfind(k); if (pos == 0) { @@ -149,62 +142,55 @@ int MXSymbolCreateAtomicSymbol(AtomicSymbolCreator creator, std::ostringstream os; os << "setting variable attributes with " << keys[i] << " is deprecated. " << "please instead use\nw = Variable(" << k << "=" << vals[i] << ")\n" - << "sym = YourSymbolName(" << tmp.substr(0, pos-1) << "=w)"; + << "sym = YourSymbolName(" << tmp.substr(0, pos - 1) << "=w)"; throw dmlc::Error(os.str()); } } if (!flag) kwargs.insert({std::string(keys[i]), std::string(vals[i])}); } - *s = nnvm::Symbol::CreateFunctor(op, std::move(kwargs)); + *s = nnvm::Symbol::CreateFunctor(op, std::move(kwargs)); *out = s; API_END_HANDLE_ERROR(delete s;); } -int MXSymbolCreateVariable(const char *name, SymbolHandle *out) { +int MXSymbolCreateVariable(const char* name, SymbolHandle* out) { return NNSymbolCreateVariable(name, out); } -int MXSymbolCreateGroup(uint32_t num_symbols, - SymbolHandle *symbols, - SymbolHandle *out) { +int MXSymbolCreateGroup(uint32_t num_symbols, SymbolHandle* symbols, SymbolHandle* out) { return NNSymbolCreateGroup(num_symbols, symbols, out); } -int MXSymbolGetOutput(SymbolHandle symbol, - uint32_t index, - SymbolHandle *out) { +int MXSymbolGetOutput(SymbolHandle symbol, uint32_t index, SymbolHandle* out) { return NNSymbolGetOutput(symbol, index, out); } -int MXSymbolGetInputs(SymbolHandle symbol, - SymbolHandle *out) { - nnvm::Symbol *s = new nnvm::Symbol(); - API_BEGIN(); - std::vector inputs = static_cast(symbol)->ListInputs( - nnvm::Symbol::ListInputOption(0)); - for (const nnvm::ObjectPtr &o : inputs) { - nnvm::NodeEntry e(o); - s->outputs.push_back(e); - } - *out = s; - API_END_HANDLE_ERROR(delete s); +int MXSymbolGetInputs(SymbolHandle symbol, SymbolHandle* out) { + nnvm::Symbol* s = new nnvm::Symbol(); + API_BEGIN(); + std::vector inputs = + static_cast(symbol)->ListInputs(nnvm::Symbol::ListInputOption(0)); + for (const nnvm::ObjectPtr& o : inputs) { + nnvm::NodeEntry e(o); + s->outputs.push_back(e); + } + *out = s; + API_END_HANDLE_ERROR(delete s); } -int MXSymbolGetInternals(SymbolHandle symbol, - SymbolHandle *out) { - nnvm::Symbol *s = new nnvm::Symbol(); +int MXSymbolGetInternals(SymbolHandle symbol, SymbolHandle* out) { + nnvm::Symbol* s = new nnvm::Symbol(); API_BEGIN(); - *s = static_cast(symbol)->GetInternals(); + *s = static_cast(symbol)->GetInternals(); *out = s; API_END_HANDLE_ERROR(delete s); } -int MXSymbolGetChildren(SymbolHandle symbol, - SymbolHandle *out) { - nnvm::Symbol *s = new nnvm::Symbol(); +int MXSymbolGetChildren(SymbolHandle symbol, SymbolHandle* out) { + nnvm::Symbol* s = new nnvm::Symbol(); API_BEGIN(); - *s = static_cast(symbol)->GetChildren(); + *s = static_cast(symbol)->GetChildren(); *out = s; API_END_HANDLE_ERROR(delete s); } @@ -213,37 +199,32 @@ int MXSymbolFree(SymbolHandle symbol) { return NNSymbolFree(symbol); } -int MXSymbolCopy(SymbolHandle symbol, SymbolHandle *out) { +int MXSymbolCopy(SymbolHandle symbol, SymbolHandle* out) { return NNSymbolCopy(symbol, out); } -int MXSymbolPrint(SymbolHandle symbol, const char **out_str) { +int MXSymbolPrint(SymbolHandle symbol, const char** out_str) { return NNSymbolPrint(symbol, out_str); } -int MXSymbolGetName(SymbolHandle symbol, - const char** out, - int* success) { +int MXSymbolGetName(SymbolHandle symbol, const char** out, int* success) { return NNSymbolGetAttr(symbol, "name", out, success); } -int MXSymbolGetAttr(SymbolHandle symbol, - const char* key, - const char** out, - int* success) { - nnvm::Symbol *s = static_cast(symbol); - MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get(); +int MXSymbolGetAttr(SymbolHandle symbol, const char* key, const char** out, int* success) { + nnvm::Symbol* s = static_cast(symbol); + MXAPIThreadLocalEntry<>* ret = MXAPIThreadLocalStore<>::Get(); API_BEGIN(); if (s->GetAttr(key, &(ret->ret_str))) { - *out = (ret->ret_str).c_str(); + *out = (ret->ret_str).c_str(); *success = 1; } else { - *out = nullptr; + *out = nullptr; *success = 0; if (std::find(kHiddenKeys.begin(), kHiddenKeys.end(), key) != kHiddenKeys.end()) { std::string skey = "__" + std::string(key) + "__"; if (s->GetAttr(skey, &(ret->ret_str))) { - *out = (ret->ret_str).c_str(); + *out = (ret->ret_str).c_str(); *success = 1; } } @@ -251,14 +232,12 @@ int MXSymbolGetAttr(SymbolHandle symbol, API_END(); } -int MXSymbolSetAttr(SymbolHandle symbol, - const char* key, - const char* value) { - nnvm::Symbol *s = static_cast(symbol); +int MXSymbolSetAttr(SymbolHandle symbol, const char* key, const char* value) { + nnvm::Symbol* s = static_cast(symbol); API_BEGIN(); - std::vector > kwargs; + std::vector> kwargs; std::string skey(key), sval(value); - for (const auto &k : kHiddenKeys) { + for (const auto& k : kHiddenKeys) { size_t pos = skey.rfind(k); if (pos == 0 && k.length() == skey.length()) { skey = "__" + skey + "__"; @@ -267,7 +246,7 @@ int MXSymbolSetAttr(SymbolHandle symbol, std::ostringstream os; os << "setting variable attributes with " << key << " is deprecated. " << "please instead use\nw = Variable(" << k << "=" << value << ")\n" - << "sym = YourSymbolName(" << skey.substr(0, pos-1) << "=w)"; + << "sym = YourSymbolName(" << skey.substr(0, pos - 1) << "=w)"; throw dmlc::Error(os.str()); } } @@ -276,28 +255,25 @@ int MXSymbolSetAttr(SymbolHandle symbol, API_END(); } -int MXSymbolListAttr(SymbolHandle symbol, - uint32_t *out_size, - const char*** out) { - nnvm::Symbol *s = static_cast(symbol); - MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get(); +int MXSymbolListAttr(SymbolHandle symbol, uint32_t* out_size, const char*** out) { + nnvm::Symbol* s = static_cast(symbol); + MXAPIThreadLocalEntry<>* ret = MXAPIThreadLocalStore<>::Get(); API_BEGIN(); - std::vector > attr = - s->ListAttrsRecursive(); + std::vector> attr = s->ListAttrsRecursive(); std::vector& attr_list = ret->ret_vec_str; attr_list.clear(); for (const auto& tp : attr) { attr_list.emplace_back(std::get<0>(tp) + kNamespaceSeparator + std::get<1>(tp)); attr_list.emplace_back(std::get<2>(tp)); - if (find(kReplacedHiddenKeys.begin(), kReplacedHiddenKeys.end(), std::get<1>(tp)) - != kReplacedHiddenKeys.end()) { + if (find(kReplacedHiddenKeys.begin(), kReplacedHiddenKeys.end(), std::get<1>(tp)) != + kReplacedHiddenKeys.end()) { attr_list.push_back(std::get<0>(tp) + kNamespaceSeparator + std::get<1>(tp).substr(2, std::get<1>(tp).length() - 4)); attr_list.push_back(std::get<2>(tp)); } } - *out_size = attr_list.size()/2; + *out_size = attr_list.size() / 2; ret->ret_vec_charp.clear(); for (const auto& attr : attr_list) { ret->ret_vec_charp.push_back(attr.c_str()); @@ -306,11 +282,9 @@ int MXSymbolListAttr(SymbolHandle symbol, API_END(); } -int MXSymbolListAttrShallow(SymbolHandle symbol, - uint32_t *out_size, - const char*** out) { - nnvm::Symbol *s = static_cast(symbol); - MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get(); +int MXSymbolListAttrShallow(SymbolHandle symbol, uint32_t* out_size, const char*** out) { + nnvm::Symbol* s = static_cast(symbol); + MXAPIThreadLocalEntry<>* ret = MXAPIThreadLocalStore<>::Get(); API_BEGIN(); std::unordered_map attr = s->ListAttrs(static_cast(1)); // NOLINT(*) @@ -320,34 +294,31 @@ int MXSymbolListAttrShallow(SymbolHandle symbol, for (const auto& kv : attr) { attr_list.push_back(kv.first); attr_list.push_back(kv.second); - if (find(kReplacedHiddenKeys.begin(), kReplacedHiddenKeys.end(), kv.first) - != kReplacedHiddenKeys.end()) { + if (find(kReplacedHiddenKeys.begin(), kReplacedHiddenKeys.end(), kv.first) != + kReplacedHiddenKeys.end()) { attr_list.push_back(kv.first.substr(2, kv.first.length() - 4)); attr_list.push_back(kv.second); } } - *out_size = attr_list.size()/2; + *out_size = attr_list.size() / 2; ret->ret_vec_charp.clear(); - for (auto &attr : attr_list) { + for (auto& attr : attr_list) { ret->ret_vec_charp.push_back(attr.c_str()); } *out = dmlc::BeginPtr(ret->ret_vec_charp); API_END(); } -int MXSymbolListOutputs(SymbolHandle symbol, - uint32_t *out_size, - const char ***out_str_array) { +int MXSymbolListOutputs(SymbolHandle symbol, uint32_t* out_size, const char*** out_str_array) { return NNSymbolListOutputNames(symbol, out_size, out_str_array); } -int MXSymbolGetNumOutputs(SymbolHandle symbol, - uint32_t *output_count) { +int MXSymbolGetNumOutputs(SymbolHandle symbol, uint32_t* output_count) { return NNSymbolGetNumOutputs(symbol, output_count); } int MXSymbolCompose(SymbolHandle sym, - const char *name, + const char* name, uint32_t num_args, const char** keys, SymbolHandle* args) { @@ -355,62 +326,59 @@ int MXSymbolCompose(SymbolHandle sym, } // adapter functions that re-implements the functions. -int MXSymbolListArguments(SymbolHandle symbol, - uint32_t *out_size, - const char ***out_str_array) { +int MXSymbolListArguments(SymbolHandle symbol, uint32_t* out_size, const char*** out_str_array) { return NNSymbolListInputNames(symbol, 1, out_size, out_str_array); } int MXSymbolListAuxiliaryStates(SymbolHandle symbol, - uint32_t *out_size, - const char ***out_str_array) { + uint32_t* out_size, + const char*** out_str_array) { return NNSymbolListInputNames(symbol, 2, out_size, out_str_array); } -int MXSymbolGetAtomicSymbolName(AtomicSymbolCreator creator, - const char **out) { +int MXSymbolGetAtomicSymbolName(AtomicSymbolCreator creator, const char** out) { API_BEGIN(); - Op *e = static_cast(creator); - *out = e->name.c_str(); + Op* e = static_cast(creator); + *out = e->name.c_str(); API_END(); } namespace mxnet { -extern std::vector GetInputSymbols(const nnvm::Symbol &sym); -extern bool CutGraphInputs(const std::vector &input_entries, - bool skip_var, std::vector *orig_entries); +extern std::vector GetInputSymbols(const nnvm::Symbol& sym); +extern bool CutGraphInputs(const std::vector& input_entries, + bool skip_var, + std::vector* orig_entries); -} +} // namespace mxnet -int MXSymbolGetInputSymbols(SymbolHandle sym, SymbolHandle **input_arr, int *input_size) { +int MXSymbolGetInputSymbols(SymbolHandle sym, SymbolHandle** input_arr, int* input_size) { API_BEGIN(); - nnvm::Symbol *s = static_cast(sym); - std::vector input_syms = mxnet::GetInputSymbols(*s); - *input_size = input_syms.size(); + nnvm::Symbol* s = static_cast(sym); + std::vector input_syms = mxnet::GetInputSymbols(*s); + *input_size = input_syms.size(); - MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get(); + MXAPIThreadLocalEntry<>* ret = MXAPIThreadLocalStore<>::Get(); ret->ret_handles.clear(); ret->ret_handles.reserve(*input_size); - for (int i = 0; i < *input_size; ++i) ret->ret_handles.push_back(input_syms[i]); + for (int i = 0; i < *input_size; ++i) + ret->ret_handles.push_back(input_syms[i]); *input_arr = reinterpret_cast(dmlc::BeginPtr(ret->ret_handles)); API_END_HANDLE_ERROR(); } -int MXSymbolCutSubgraph(SymbolHandle sym, SymbolHandle **input_symbols, - int *input_size) { +int MXSymbolCutSubgraph(SymbolHandle sym, SymbolHandle** input_symbols, int* input_size) { // Given a graph, we want to fetch the nodes that have been marked as part of // a subgraph. API_BEGIN(); - nnvm::Symbol *s = static_cast(sym); + nnvm::Symbol* s = static_cast(sym); const std::string subg_attr = "__subgraph_name__"; - auto out_node = s->outputs[0].node; - auto it = out_node->attrs.dict.find(subg_attr); + auto out_node = s->outputs[0].node; + auto it = out_node->attrs.dict.find(subg_attr); if (it != out_node->attrs.dict.end()) { - const std::string &subg_name = it->second; - std::vector input_entries; - DFSVisit(s->outputs, [&subg_attr, &subg_name, &input_entries] - (nnvm::ObjectPtr n) { + const std::string& subg_name = it->second; + std::vector input_entries; + DFSVisit(s->outputs, [&subg_attr, &subg_name, &input_entries](nnvm::ObjectPtr n) { // If the node itself isn't in the subgraph, we ignore it. auto it = n->attrs.dict.find(subg_attr); if (it == n->attrs.dict.end() || it->second != subg_name) @@ -419,7 +387,7 @@ int MXSymbolCutSubgraph(SymbolHandle sym, SymbolHandle **input_symbols, // We search for nodes whose node entries aren't in the subgraph. for (size_t j = 0; j < n->inputs.size(); j++) { auto in_node = n->inputs[j].node; - auto it = in_node->attrs.dict.find(subg_attr); + auto it = in_node->attrs.dict.find(subg_attr); if (it == in_node->attrs.dict.end() || it->second != subg_name) input_entries.push_back(&n->inputs[j]); } @@ -427,17 +395,18 @@ int MXSymbolCutSubgraph(SymbolHandle sym, SymbolHandle **input_symbols, std::vector orig_entries; CutGraphInputs(input_entries, false, &orig_entries); - std::vector input_syms(orig_entries.size()); + std::vector input_syms(orig_entries.size()); for (size_t i = 0; i < input_syms.size(); i++) { input_syms[i] = new nnvm::Symbol(); input_syms[i]->outputs.push_back(orig_entries[i]); } *input_size = input_syms.size(); - MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get(); + MXAPIThreadLocalEntry<>* ret = MXAPIThreadLocalStore<>::Get(); ret->ret_handles.clear(); ret->ret_handles.reserve(*input_size); - for (int i = 0; i < *input_size; ++i) ret->ret_handles.push_back(input_syms[i]); + for (int i = 0; i < *input_size; ++i) + ret->ret_handles.push_back(input_syms[i]); *input_symbols = reinterpret_cast(dmlc::BeginPtr(ret->ret_handles)); } else { *input_size = 0; @@ -446,15 +415,14 @@ int MXSymbolCutSubgraph(SymbolHandle sym, SymbolHandle **input_symbols, API_END_HANDLE_ERROR(); } - /*! * \brief Convert shape attr in graph nodes to comply with NumPy semantics for * legacy models (before 1.6.0) if global flag is_np_shape has been turned on, * i.e., use -1 to indicate unknown number of dimensions and unknown dimension sizes. */ void ConvertShapeAttrToNumPyCompatible(nnvm::Graph* g) { - if (Imperative::Get()->is_np_shape() - && (!g->HasAttr("is_np_shape") || !g->GetAttr("is_np_shape"))) { + if (Imperative::Get()->is_np_shape() && + (!g->HasAttr("is_np_shape") || !g->GetAttr("is_np_shape"))) { DFSVisit(g->outputs, [](nnvm::ObjectPtr n) { if (n->is_variable()) { auto it = n->attrs.dict.find("__shape__"); @@ -472,46 +440,46 @@ void ConvertShapeAttrToNumPyCompatible(nnvm::Graph* g) { } } -int MXSymbolCreateFromFile(const char *fname, SymbolHandle *out) { - nnvm::Symbol *s = new nnvm::Symbol(); +int MXSymbolCreateFromFile(const char* fname, SymbolHandle* out) { + nnvm::Symbol* s = new nnvm::Symbol(); API_BEGIN(); std::unique_ptr fi(dmlc::Stream::Create(fname, "r")); dmlc::istream is(fi.get()); nnvm::Graph g; g.attrs["json"] = std::make_shared( - std::string(std::istreambuf_iterator(is), std::istreambuf_iterator())); + std::string(std::istreambuf_iterator(is), std::istreambuf_iterator())); g = nnvm::ApplyPass(g, "LoadLegacyJSON"); ConvertShapeAttrToNumPyCompatible(&g); s->outputs = g.outputs; - *out = s; + *out = s; is.set_stream(nullptr); API_END_HANDLE_ERROR(delete s); } -int MXSymbolCreateFromJSON(const char *json, SymbolHandle *out) { - nnvm::Symbol *s = new nnvm::Symbol(); +int MXSymbolCreateFromJSON(const char* json, SymbolHandle* out) { + nnvm::Symbol* s = new nnvm::Symbol(); API_BEGIN(); nnvm::Graph g; g.attrs["json"] = std::make_shared(std::string(json)); - g = nnvm::ApplyPass(g, "LoadLegacyJSON"); + g = nnvm::ApplyPass(g, "LoadLegacyJSON"); ConvertShapeAttrToNumPyCompatible(&g); s->outputs = g.outputs; - *out = s; + *out = s; API_END_HANDLE_ERROR(delete s); } int MXSymbolRemoveAmpCast(SymbolHandle sym_handle, SymbolHandle* ret_sym_handle) { nnvm::Symbol* s = new nnvm::Symbol(); API_BEGIN(); - nnvm::Symbol *source = static_cast(sym_handle); - *s = source->Copy(); - s->outputs = nnvm::ApplyPass(Symbol2Graph(*s), "RemoveAmpCast").outputs; - *ret_sym_handle = s; + nnvm::Symbol* source = static_cast(sym_handle); + *s = source->Copy(); + s->outputs = nnvm::ApplyPass(Symbol2Graph(*s), "RemoveAmpCast").outputs; + *ret_sym_handle = s; API_END_HANDLE_ERROR(delete s); } -int MXSymbolSaveToFile(SymbolHandle symbol, const char *fname) { - nnvm::Symbol *s = static_cast(symbol); +int MXSymbolSaveToFile(SymbolHandle symbol, const char* fname) { + nnvm::Symbol* s = static_cast(symbol); API_BEGIN(); std::unique_ptr fo(dmlc::Stream::Create(fname, "w")); dmlc::ostream os(fo.get()); @@ -521,29 +489,28 @@ int MXSymbolSaveToFile(SymbolHandle symbol, const char *fname) { API_END(); } -int MXSymbolSaveToJSON(SymbolHandle symbol, const char **out_json) { - nnvm::Symbol *s = static_cast(symbol); - MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get(); +int MXSymbolSaveToJSON(SymbolHandle symbol, const char** out_json) { + nnvm::Symbol* s = static_cast(symbol); + MXAPIThreadLocalEntry<>* ret = MXAPIThreadLocalStore<>::Get(); API_BEGIN(); ret->ret_str = nnvm::pass::SaveJSON(Symbol2Graph(*s)); - *out_json = ret->ret_str.c_str(); + *out_json = ret->ret_str.c_str(); API_END(); } namespace mxnet { -template -void MatchArguments( - const nnvm::IndexedGraph& idx, - const std::unordered_map& known_arg_attrs, - std::vector* arg_attrs, - const char* source) { +template +void MatchArguments(const nnvm::IndexedGraph& idx, + const std::unordered_map& known_arg_attrs, + std::vector* arg_attrs, + const char* source) { auto& arg_nodes = idx.input_nodes(); CHECK_EQ(arg_attrs->size(), arg_nodes.size()); size_t nmatched = 0; for (size_t i = 0; i < arg_nodes.size(); ++i) { const std::string& name = idx[arg_nodes[i]].source->attrs.name; - auto it = known_arg_attrs.find(name); + auto it = known_arg_attrs.find(name); if (it != known_arg_attrs.end()) { arg_attrs->at(i) = it->second; ++nmatched; @@ -561,9 +528,7 @@ void MatchArguments( for (const auto& kv : known_arg_attrs) { const std::string& key = kv.first; if (keys.count(key) == 0) { - LOG(FATAL) << source - << "Keyword argument name " << key << " not found." - << msg.str(); + LOG(FATAL) << source << "Keyword argument name " << key << " not found." << msg.str(); } } } @@ -571,7 +536,7 @@ void MatchArguments( } // namespace mxnet -template +template inline void SymbolInferShape(const char** keys, uint32_t num_args, const dtype* arg_shape_data, @@ -591,7 +556,7 @@ inline void SymbolInferShape(const char** keys, nnvm::Graph g = Symbol2Graph(*s); mxnet::ShapeVector arg_shapes(g.indexed_graph().input_nodes().size(), mxnet::TShape()); if (keys == nullptr && num_args != 0) { - std::vector < uint32_t > read_only_args = mxnet::ReadOnlyArgIndices(g.indexed_graph()); + std::vector read_only_args = mxnet::ReadOnlyArgIndices(g.indexed_graph()); CHECK_LE(num_args, read_only_args.size()); for (uint32_t i = 0; i < num_args; ++i) { arg_shapes[read_only_args[i]] = mxnet::ShapeTypeCast(arg_shape_data + arg_ind_ptr[i], @@ -630,9 +595,9 @@ inline void SymbolInferShape(const char** keys, &(ret->aux_shape_ndim_ex), &(ret->aux_shape_data_ex), &(ret->aux_shape_buffer_ex)); - *in_shape_size = static_cast(ret->arg_shapes.size()); - *in_shape_ndim = dmlc::BeginPtr(ret->arg_shape_ndim_ex); - *in_shape_data = dmlc::BeginPtr(ret->arg_shape_data_ex); + *in_shape_size = static_cast(ret->arg_shapes.size()); + *in_shape_ndim = dmlc::BeginPtr(ret->arg_shape_ndim_ex); + *in_shape_data = dmlc::BeginPtr(ret->arg_shape_data_ex); *out_shape_size = static_cast(ret->out_shapes.size()); *out_shape_ndim = dmlc::BeginPtr(ret->out_shape_ndim_ex); *out_shape_data = dmlc::BeginPtr(ret->out_shape_data_ex); @@ -667,37 +632,37 @@ inline void SymbolInferShape(const char** keys, int MXSymbolInferShape(SymbolHandle sym, uint32_t num_args, const char** keys, - const uint32_t *arg_ind_ptr, - const int *arg_shape_data, - uint32_t *in_shape_size, - const int **in_shape_ndim, - const int ***in_shape_data, - uint32_t *out_shape_size, - const int **out_shape_ndim, - const int ***out_shape_data, - uint32_t *aux_shape_size, - const int **aux_shape_ndim, - const int ***aux_shape_data, - int *complete) { - nnvm::Symbol *s = static_cast(sym); - MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get(); + const uint32_t* arg_ind_ptr, + const int* arg_shape_data, + uint32_t* in_shape_size, + const int** in_shape_ndim, + const int*** in_shape_data, + uint32_t* out_shape_size, + const int** out_shape_ndim, + const int*** out_shape_data, + uint32_t* aux_shape_size, + const int** aux_shape_ndim, + const int*** aux_shape_data, + int* complete) { + nnvm::Symbol* s = static_cast(sym); + MXAPIThreadLocalEntry<>* ret = MXAPIThreadLocalStore<>::Get(); API_BEGIN(); SymbolInferShape(keys, - num_args, - arg_shape_data, - arg_ind_ptr, - in_shape_ndim, - in_shape_data, - out_shape_ndim, - out_shape_data, - aux_shape_ndim, - aux_shape_data, - s, - ret, - in_shape_size, - out_shape_size, - aux_shape_size, - complete); + num_args, + arg_shape_data, + arg_ind_ptr, + in_shape_ndim, + in_shape_data, + out_shape_ndim, + out_shape_data, + aux_shape_ndim, + aux_shape_data, + s, + ret, + in_shape_size, + out_shape_size, + aux_shape_size, + complete); API_END(); } @@ -725,37 +690,37 @@ int MXSymbolInferShape(SymbolHandle sym, int MXSymbolInferShape64(SymbolHandle sym, uint32_t num_args, const char** keys, - const int64_t *arg_ind_ptr, - const int64_t *arg_shape_data, - size_t *in_shape_size, - const int **in_shape_ndim, - const int64_t ***in_shape_data, - size_t *out_shape_size, - const int **out_shape_ndim, - const int64_t ***out_shape_data, - size_t *aux_shape_size, - const int **aux_shape_ndim, - const int64_t ***aux_shape_data, - int *complete) { - nnvm::Symbol *s = static_cast(sym); - MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); + const int64_t* arg_ind_ptr, + const int64_t* arg_shape_data, + size_t* in_shape_size, + const int** in_shape_ndim, + const int64_t*** in_shape_data, + size_t* out_shape_size, + const int** out_shape_ndim, + const int64_t*** out_shape_data, + size_t* aux_shape_size, + const int** aux_shape_ndim, + const int64_t*** aux_shape_data, + int* complete) { + nnvm::Symbol* s = static_cast(sym); + MXAPIThreadLocalEntry* ret = MXAPIThreadLocalStore::Get(); API_BEGIN(); SymbolInferShape(keys, - num_args, - arg_shape_data, - arg_ind_ptr, - in_shape_ndim, - in_shape_data, - out_shape_ndim, - out_shape_data, - aux_shape_ndim, - aux_shape_data, - s, - ret, - in_shape_size, - out_shape_size, - aux_shape_size, - complete); + num_args, + arg_shape_data, + arg_ind_ptr, + in_shape_ndim, + in_shape_data, + out_shape_ndim, + out_shape_data, + aux_shape_ndim, + aux_shape_data, + s, + ret, + in_shape_size, + out_shape_size, + aux_shape_size, + complete); API_END(); } @@ -783,25 +748,34 @@ int MXSymbolInferShape64(SymbolHandle sym, int MXSymbolInferShapePartial(SymbolHandle sym, uint32_t num_args, const char** keys, - const uint32_t *arg_ind_ptr, - const int *arg_shape_data, - uint32_t *in_shape_size, - const int **in_shape_ndim, - const int ***in_shape_data, - uint32_t *out_shape_size, - const int **out_shape_ndim, - const int ***out_shape_data, - uint32_t *aux_shape_size, - const int **aux_shape_ndim, - const int ***aux_shape_data, - int *complete) { - int succ = 0; + const uint32_t* arg_ind_ptr, + const int* arg_shape_data, + uint32_t* in_shape_size, + const int** in_shape_ndim, + const int*** in_shape_data, + uint32_t* out_shape_size, + const int** out_shape_ndim, + const int*** out_shape_data, + uint32_t* aux_shape_size, + const int** aux_shape_ndim, + const int*** aux_shape_data, + int* complete) { + int succ = 0; *complete = 1; - return MXSymbolInferShape(sym, num_args, keys, - arg_ind_ptr, arg_shape_data, - in_shape_size, in_shape_ndim, in_shape_data, - out_shape_size, out_shape_ndim, out_shape_data, - aux_shape_size, aux_shape_ndim, aux_shape_data, + return MXSymbolInferShape(sym, + num_args, + keys, + arg_ind_ptr, + arg_shape_data, + in_shape_size, + in_shape_ndim, + in_shape_data, + out_shape_size, + out_shape_ndim, + out_shape_data, + aux_shape_size, + aux_shape_ndim, + aux_shape_data, &succ); } @@ -829,41 +803,50 @@ int MXSymbolInferShapePartial(SymbolHandle sym, int MXSymbolInferShapePartial64(SymbolHandle sym, uint32_t num_args, const char** keys, - const int64_t *arg_ind_ptr, - const int64_t *arg_shape_data, - size_t *in_shape_size, - const int **in_shape_ndim, - const int64_t ***in_shape_data, - size_t *out_shape_size, - const int **out_shape_ndim, - const int64_t ***out_shape_data, - size_t *aux_shape_size, - const int **aux_shape_ndim, - const int64_t ***aux_shape_data, - int *complete) { - int succ = 0; + const int64_t* arg_ind_ptr, + const int64_t* arg_shape_data, + size_t* in_shape_size, + const int** in_shape_ndim, + const int64_t*** in_shape_data, + size_t* out_shape_size, + const int** out_shape_ndim, + const int64_t*** out_shape_data, + size_t* aux_shape_size, + const int** aux_shape_ndim, + const int64_t*** aux_shape_data, + int* complete) { + int succ = 0; *complete = 1; - return MXSymbolInferShape64(sym, num_args, keys, - arg_ind_ptr, arg_shape_data, - in_shape_size, in_shape_ndim, in_shape_data, - out_shape_size, out_shape_ndim, out_shape_data, - aux_shape_size, aux_shape_ndim, aux_shape_data, + return MXSymbolInferShape64(sym, + num_args, + keys, + arg_ind_ptr, + arg_shape_data, + in_shape_size, + in_shape_ndim, + in_shape_data, + out_shape_size, + out_shape_ndim, + out_shape_data, + aux_shape_size, + aux_shape_ndim, + aux_shape_data, &succ); } int MXSymbolInferType(SymbolHandle sym, uint32_t num_args, const char** keys, - const int *arg_type_data, - uint32_t *in_type_size, - const int **in_type_data, - uint32_t *out_type_size, - const int **out_type_data, - uint32_t *aux_type_size, - const int **aux_type_data, - int *complete) { - nnvm::Symbol *s = static_cast(sym); - MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get(); + const int* arg_type_data, + uint32_t* in_type_size, + const int** in_type_data, + uint32_t* out_type_size, + const int** out_type_data, + uint32_t* aux_type_size, + const int** aux_type_data, + int* complete) { + nnvm::Symbol* s = static_cast(sym); + MXAPIThreadLocalEntry<>* ret = MXAPIThreadLocalStore<>::Get(); API_BEGIN(); nnvm::Graph g = Symbol2Graph(*s); nnvm::DTypeVector arg_types(g.indexed_graph().input_nodes().size(), -1); @@ -883,38 +866,46 @@ int MXSymbolInferType(SymbolHandle sym, g = mxnet::exec::InferType(std::move(g), std::move(arg_types), "__dtype__"); // copy back - CopyAttr(g.indexed_graph(), g.GetAttr("dtype"), - &(ret->arg_types), &(ret->out_types), &(ret->aux_types)); - - *in_type_size = static_cast(ret->arg_types.size()); - *in_type_data = dmlc::BeginPtr(ret->arg_types); + CopyAttr(g.indexed_graph(), + g.GetAttr("dtype"), + &(ret->arg_types), + &(ret->out_types), + &(ret->aux_types)); + + *in_type_size = static_cast(ret->arg_types.size()); + *in_type_data = dmlc::BeginPtr(ret->arg_types); *out_type_size = static_cast(ret->out_types.size()); *out_type_data = dmlc::BeginPtr(ret->out_types); *aux_type_size = static_cast(ret->aux_types.size()); *aux_type_data = dmlc::BeginPtr(ret->aux_types); - *complete = (g.GetAttr("dtype_num_unknown_nodes") == 0); + *complete = (g.GetAttr("dtype_num_unknown_nodes") == 0); API_END(); } int MXSymbolInferTypePartial(SymbolHandle sym, uint32_t num_args, const char** keys, - const int *arg_type_data, - uint32_t *in_type_size, - const int **in_type_data, - uint32_t *out_type_size, - const int **out_type_data, - uint32_t *aux_type_size, - const int **aux_type_data, - int *complete) { - int succ = 0; + const int* arg_type_data, + uint32_t* in_type_size, + const int** in_type_data, + uint32_t* out_type_size, + const int** out_type_data, + uint32_t* aux_type_size, + const int** aux_type_data, + int* complete) { + int succ = 0; *complete = 1; - return MXSymbolInferType(sym, num_args, keys, - arg_type_data, - in_type_size, in_type_data, - out_type_size, out_type_data, - aux_type_size, aux_type_data, - &succ); + return MXSymbolInferType(sym, + num_args, + keys, + arg_type_data, + in_type_size, + in_type_data, + out_type_size, + out_type_data, + aux_type_size, + aux_type_data, + &succ); } int MXSymbolGrad(SymbolHandle sym, uint32_t num_wrt, const char** wrt, SymbolHandle* out) { @@ -924,25 +915,25 @@ int MXSymbolGrad(SymbolHandle sym, uint32_t num_wrt, const char** wrt, SymbolHan } int MXQuantizeSymbol(SymbolHandle sym_handle, - SymbolHandle *ret_sym_handle, + SymbolHandle* ret_sym_handle, const int* dev_type, const uint32_t num_excluded_sym_names, - const char **excluded_sym_names, + const char** excluded_sym_names, const uint32_t num_excluded_op_names, - const char **excluded_op_names, + const char** excluded_op_names, const uint32_t num_offline, - const char **offline_params, - const char *quantized_dtype, + const char** offline_params, + const char* quantized_dtype, const bool calib_quantize, - const char *quantize_mode, - const char *quantize_granularity, + const char* quantize_mode, + const char* quantize_granularity, mx_uint* out_num_calib_names, - const char ***out_calib_names) { - nnvm::Symbol *s = new nnvm::Symbol(); + const char*** out_calib_names) { + nnvm::Symbol* s = new nnvm::Symbol(); API_BEGIN(); - nnvm::Symbol *sym = static_cast(sym_handle); - nnvm::Graph g = Symbol2Graph(*sym); - int target_dev = *dev_type; + nnvm::Symbol* sym = static_cast(sym_handle); + nnvm::Graph g = Symbol2Graph(*sym); + int target_dev = *dev_type; std::unordered_set excluded_node_names; for (size_t i = 0; i < num_excluded_sym_names; ++i) { excluded_node_names.emplace(excluded_sym_names[i]); @@ -958,39 +949,38 @@ int MXQuantizeSymbol(SymbolHandle sym_handle, std::string quantized_type(quantized_dtype); std::string quantized_mode(quantize_mode); std::string quantized_granularity(quantize_granularity); - g.attrs["excluded_nodes"] = std::make_shared(std::move(excluded_node_names)); - g.attrs["excluded_ops"] = std::make_shared(std::move(excluded_op)); - g.attrs["offline_params"] = std::make_shared(std::move(offline)); - g.attrs["quantized_dtype"] = std::make_shared(std::move(quantized_type)); - g.attrs["target_ctx"] = std::make_shared(target_dev); - g.attrs["quantize_mode"] = std::make_shared(std::move(quantized_mode)); + g.attrs["excluded_nodes"] = std::make_shared(std::move(excluded_node_names)); + g.attrs["excluded_ops"] = std::make_shared(std::move(excluded_op)); + g.attrs["offline_params"] = std::make_shared(std::move(offline)); + g.attrs["quantized_dtype"] = std::make_shared(std::move(quantized_type)); + g.attrs["target_ctx"] = std::make_shared(target_dev); + g.attrs["quantize_mode"] = std::make_shared(std::move(quantized_mode)); g.attrs["quantize_granularity"] = std::make_shared(std::move(quantized_granularity)); - g = ApplyPass(std::move(g), "QuantizeGraph"); - const auto& calib_nodes = g.GetAttr>("calib_nodes"); - MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get(); - ret->ret_vec_str = calib_nodes; - *out_num_calib_names = ret->ret_vec_str.size(); + g = ApplyPass(std::move(g), "QuantizeGraph"); + const auto& calib_nodes = g.GetAttr>("calib_nodes"); + MXAPIThreadLocalEntry<>* ret = MXAPIThreadLocalStore<>::Get(); + ret->ret_vec_str = calib_nodes; + *out_num_calib_names = ret->ret_vec_str.size(); ret->ret_vec_charp.clear(); ret->ret_vec_charp.reserve(ret->ret_vec_str.size()); - for (const auto &str : ret->ret_vec_str) { + for (const auto& str : ret->ret_vec_str) { ret->ret_vec_charp.push_back(str.c_str()); } *out_calib_names = dmlc::BeginPtr(ret->ret_vec_charp); - s->outputs = g.outputs; - *ret_sym_handle = s; + s->outputs = g.outputs; + *ret_sym_handle = s; API_END_HANDLE_ERROR(delete s); } // helper function to add mapping of node_name -> dtype map // for the given indexed graph and inferred_dtypes -static void _SetInputDTypes( - const nnvm::IndexedGraph& idx, - const nnvm::DTypeVector& inferred_dtypes, - std::unordered_map* node_name_dtype_map, - std::unordered_map* node_without_dtype_map) { +static void _SetInputDTypes(const nnvm::IndexedGraph& idx, + const nnvm::DTypeVector& inferred_dtypes, + std::unordered_map* node_name_dtype_map, + std::unordered_map* node_without_dtype_map) { const std::string dtype_keyword = "__dtype__"; for (uint32_t nid : idx.input_nodes()) { - const auto& node = idx[nid].source; + const auto& node = idx[nid].source; const auto& node_with_dtype = node->attrs.dict.find(dtype_keyword); // input nodes classified into nodes_with_dtype, nodes_without_dtype // This classification required because if param_names not provided @@ -1001,15 +991,13 @@ static void _SetInputDTypes( if (inferred_dtypes[idx.entry_id(nid, 0)] == -1) { (*node_name_dtype_map)[node->attrs.name] = 0; } else { - (*node_name_dtype_map)[node->attrs.name] = - inferred_dtypes[idx.entry_id(nid, 0)]; + (*node_name_dtype_map)[node->attrs.name] = inferred_dtypes[idx.entry_id(nid, 0)]; } } else { if (inferred_dtypes[idx.entry_id(nid, 0)] == -1) { (*node_without_dtype_map)[node->attrs.name] = 0; } else { - (*node_without_dtype_map)[node->attrs.name] = - inferred_dtypes[idx.entry_id(nid, 0)]; + (*node_without_dtype_map)[node->attrs.name] = inferred_dtypes[idx.entry_id(nid, 0)]; } } } @@ -1022,43 +1010,39 @@ static void _SetInputDTypes( // a prior dtype set. // args is a const_reference vector of ObjectPtrs. ObjectPtrs are immutable but // the Nodes they are pointing will be mutated in this function -static void _UpdateSymDTypeAttrs( - const std::unordered_map& node_name_dtype_map, - const std::unordered_map& node_without_dtype_map, - const std::unordered_set& model_params, - const std::vector& args) { +static void _UpdateSymDTypeAttrs(const std::unordered_map& node_name_dtype_map, + const std::unordered_map& node_without_dtype_map, + const std::unordered_set& model_params, + const std::vector& args) { const std::string dtype_keyword = "__dtype__"; // Update args to have the right dtype attrs if (model_params.size() > 0) { // if model params provided, set dtype only for model params - for (const auto & arg : args) { + for (const auto& arg : args) { const std::string& node_name = arg->attrs.name; - auto it_model_params = model_params.find(node_name); - auto it_with_dtype = node_name_dtype_map.find(node_name); - auto it_without_dtype = node_without_dtype_map.find(node_name); + auto it_model_params = model_params.find(node_name); + auto it_with_dtype = node_name_dtype_map.find(node_name); + auto it_without_dtype = node_without_dtype_map.find(node_name); if (it_model_params != model_params.end()) { // need to update __dtype__ attribute if already set, else set it if (it_with_dtype != node_name_dtype_map.end()) { - arg->attrs.dict[dtype_keyword] = - std::to_string(it_with_dtype->second); + arg->attrs.dict[dtype_keyword] = std::to_string(it_with_dtype->second); } else { CHECK(it_without_dtype != node_without_dtype_map.end()) << "make sure all nodes without dtype have properly been added " "in node_without_dtype_map"; - arg->attrs.dict[dtype_keyword] = - std::to_string(it_without_dtype->second); + arg->attrs.dict[dtype_keyword] = std::to_string(it_without_dtype->second); } } } } else { // if model params not provided, update __dtype__ for all inputs, // which already had it set, don't touch the rest - for (const auto & arg : args) { + for (const auto& arg : args) { auto it = node_name_dtype_map.find(arg->attrs.name); if (it != node_name_dtype_map.end()) { - if (arg->attrs.dict.find(dtype_keyword) != - arg->attrs.dict.end()) { + if (arg->attrs.dict.find(dtype_keyword) != arg->attrs.dict.end()) { arg->attrs.dict[dtype_keyword] = std::to_string(it->second); } } @@ -1067,9 +1051,9 @@ static void _UpdateSymDTypeAttrs( } int MXReducePrecisionSymbol(SymbolHandle sym_handle, - SymbolHandle *ret_sym_handle, + SymbolHandle* ret_sym_handle, uint32_t num_args, - const int *arg_type_data, + const int* arg_type_data, uint32_t num_ind_ptr, const int* ind_ptr, const int* target_dtype, @@ -1080,19 +1064,19 @@ int MXReducePrecisionSymbol(SymbolHandle sym_handle, const uint32_t num_conditional_fp32_op_names, const uint32_t num_excluded_symbols, const uint32_t num_model_params, - const char **target_dtype_op_names, - const char **fp32_op_names, - const char **widest_dtype_op_names, - const char **conditional_fp32_op_names, - const char **excluded_symbols, - const char **param_names, - const char **param_vals, - const char **model_param_names, - const char **arg_names) { - nnvm::Symbol *result_sym = new nnvm::Symbol(); + const char** target_dtype_op_names, + const char** fp32_op_names, + const char** widest_dtype_op_names, + const char** conditional_fp32_op_names, + const char** excluded_symbols, + const char** param_names, + const char** param_vals, + const char** model_param_names, + const char** arg_names) { + nnvm::Symbol* result_sym = new nnvm::Symbol(); API_BEGIN(); - nnvm::Symbol *sym = static_cast(sym_handle); - nnvm::Graph g = Symbol2Graph(*sym); + nnvm::Symbol* sym = static_cast(sym_handle); + nnvm::Graph g = Symbol2Graph(*sym); std::unordered_set target_dtype_ops; std::unordered_set fp32_ops; std::unordered_set widest_dtype_ops; @@ -1101,9 +1085,8 @@ int MXReducePrecisionSymbol(SymbolHandle sym_handle, // conditional_fp32_ops contains the mapping of op_name -> (map of param_name -> param_values) // which need to be conditionally selected to be casted to FP32 - std::unordered_map>> conditional_fp32_ops; + std::unordered_map>> + conditional_fp32_ops; int target_dt = *target_dtype; for (size_t i = 0; i < num_target_dtype_op_names; ++i) { @@ -1124,8 +1107,8 @@ int MXReducePrecisionSymbol(SymbolHandle sym_handle, for (size_t i = 0; i < num_ind_ptr - 1; ++i) { for (int j = ind_ptr[i]; j < ind_ptr[i + 1]; ++j) { - conditional_fp32_ops[conditional_fp32_op_names[i]][param_names[i]] - .emplace_back(std::string(param_vals[j])); + conditional_fp32_ops[conditional_fp32_op_names[i]][param_names[i]].emplace_back( + std::string(param_vals[j])); } } @@ -1133,51 +1116,45 @@ int MXReducePrecisionSymbol(SymbolHandle sym_handle, std::unordered_map node_name_dtype_map, node_without_dtype_map; nnvm::DTypeVector arg_types(g.indexed_graph().input_nodes().size(), -1); for (uint32_t i = 0; i < num_args; ++i) { - kwargs[arg_names[i]] = arg_type_data[i]; + kwargs[arg_names[i]] = arg_type_data[i]; node_name_dtype_map[arg_names[i]] = arg_type_data[i]; } mxnet::MatchArguments(g.indexed_graph(), kwargs, &arg_types, "InferType"); - g.attrs["target_dtype_ops"] = - std::make_shared(std::move(target_dtype_ops)); - g.attrs["fp32_ops"] = std::make_shared(std::move(fp32_ops)); - g.attrs["widest_dtype_ops"] = - std::make_shared(std::move(widest_dtype_ops)); - g.attrs["conditional_fp32_ops"] = - std::make_shared(std::move(conditional_fp32_ops)); - g.attrs["excluded_syms"] = - std::make_shared(std::move(excluded_syms)); - g.attrs["target_dtype"] = std::make_shared(target_dt); - g.attrs["data_name_types"] = std::make_shared(kwargs); + g.attrs["target_dtype_ops"] = std::make_shared(std::move(target_dtype_ops)); + g.attrs["fp32_ops"] = std::make_shared(std::move(fp32_ops)); + g.attrs["widest_dtype_ops"] = std::make_shared(std::move(widest_dtype_ops)); + g.attrs["conditional_fp32_ops"] = std::make_shared(std::move(conditional_fp32_ops)); + g.attrs["excluded_syms"] = std::make_shared(std::move(excluded_syms)); + g.attrs["target_dtype"] = std::make_shared(target_dt); + g.attrs["data_name_types"] = std::make_shared(kwargs); g.attrs["cast_optional_params"] = std::make_shared(cast_optional_params); g = ApplyPass(std::move(g), "ReducePrecision"); // Need to run type inference since it is possible that inferred // type of some inputs has changed g = mxnet::exec::InferType(std::move(g), std::move(arg_types), ""); - const nnvm::DTypeVector &inferred_dtypes = - g.GetAttr("dtype"); + const nnvm::DTypeVector& inferred_dtypes = g.GetAttr("dtype"); g.attrs["inferred_dtypes"] = std::make_shared(inferred_dtypes); - g.attrs["target_dtype"] = std::make_shared(target_dt); + g.attrs["target_dtype"] = std::make_shared(target_dt); if (cast_optional_params) { g = ApplyPass(std::move(g), "AMPInferUnknown"); - const nnvm::DTypeVector &inferred_dtype_result = + const nnvm::DTypeVector& inferred_dtype_result = g.GetAttr("inferred_dtype_result"); - const nnvm::IndexedGraph &idx = g.indexed_graph(); + const nnvm::IndexedGraph& idx = g.indexed_graph(); // set node name -> input dtype mapping using infer dtype _SetInputDTypes(idx, inferred_dtype_result, &node_name_dtype_map, &node_without_dtype_map); } else { - const nnvm::IndexedGraph &idx = g.indexed_graph(); + const nnvm::IndexedGraph& idx = g.indexed_graph(); // set node name -> input dtype mapping using infer dtype _SetInputDTypes(idx, inferred_dtypes, &node_name_dtype_map, &node_without_dtype_map); } - - result_sym->outputs = g.outputs; - *ret_sym_handle = result_sym; - nnvm::Symbol *ret_sym = static_cast(*ret_sym_handle); + result_sym->outputs = g.outputs; + *ret_sym_handle = result_sym; + nnvm::Symbol* ret_sym = static_cast(*ret_sym_handle); const std::vector& args = ret_sym->ListInputs(nnvm::Symbol::kAll); // update symbol dtype attrs using the node name -> dtype mapping, if dtype is already set @@ -1196,25 +1173,26 @@ int MXSetCalibTableToQuantizedSymbol(SymbolHandle qsym_handle, nnvm::Symbol* s = new nnvm::Symbol(); API_BEGIN(); nnvm::Symbol* sym = static_cast(qsym_handle); - nnvm::Graph g = Symbol2Graph(*sym); + nnvm::Graph g = Symbol2Graph(*sym); std::unordered_map> calib_table; for (size_t i = 0; i < num_layers; ++i) { calib_table.emplace(layer_names[i], std::make_pair(min_ranges[i], max_ranges[i])); } g.attrs["calib_table"] = std::make_shared(std::move(calib_table)); - g = ApplyPass(std::move(g), "SetCalibTableToQuantizedGraph"); - s->outputs = g.outputs; - *ret_qsym_handle = s; + g = ApplyPass(std::move(g), "SetCalibTableToQuantizedGraph"); + s->outputs = g.outputs; + *ret_qsym_handle = s; API_END_HANDLE_ERROR(delete s); } -int MXGenBackendSubgraph(SymbolHandle sym_handle, const char *backend_name, - SymbolHandle *ret_sym_handle) { - nnvm::Symbol *s = new nnvm::Symbol(); +int MXGenBackendSubgraph(SymbolHandle sym_handle, + const char* backend_name, + SymbolHandle* ret_sym_handle) { + nnvm::Symbol* s = new nnvm::Symbol(); API_BEGIN(); - nnvm::Symbol *sym = static_cast(sym_handle); - *s = sym->Copy(); - auto backend = mxnet::op::SubgraphBackendRegistry::Get()->GetSubgraphBackend(backend_name); + nnvm::Symbol* sym = static_cast(sym_handle); + *s = sym->Copy(); + auto backend = mxnet::op::SubgraphBackendRegistry::Get()->GetSubgraphBackend(backend_name); const auto& subgraph_prop_list = backend->GetSubgraphProperties(); for (auto property : subgraph_prop_list) { if (property->HasAttr("disable") && property->GetAttr("disable") == true) { @@ -1228,7 +1206,7 @@ int MXGenBackendSubgraph(SymbolHandle sym_handle, const char *backend_name, nnvm::Graph g = Symbol2Graph(*s); property->SetAttr("graph", g); g.attrs["subgraph_property"] = std::make_shared(property); - g = ApplyPass(std::move(g), "BuildSubgraph"); + g = ApplyPass(std::move(g), "BuildSubgraph"); property->RemoveAttr("graph"); g.attrs.erase("subgraph_property"); s->outputs = g.outputs; @@ -1237,22 +1215,21 @@ int MXGenBackendSubgraph(SymbolHandle sym_handle, const char *backend_name, API_END_HANDLE_ERROR(delete s); } -int MXGenAtomicSymbolFromSymbol(SymbolHandle sym_handle, SymbolHandle *ret_sym_handle) { - nnvm::Symbol *s = new nnvm::Symbol(); +int MXGenAtomicSymbolFromSymbol(SymbolHandle sym_handle, SymbolHandle* ret_sym_handle) { + nnvm::Symbol* s = new nnvm::Symbol(); API_BEGIN(); - nnvm::Symbol *source = static_cast(sym_handle); + nnvm::Symbol* source = static_cast(sym_handle); CHECK_GE(source->outputs.size(), 1) << "Input symbol does not have outputs."; - const auto &node = source->outputs[0].node; - for (const auto &other_node : source->outputs) { + const auto& node = source->outputs[0].node; + for (const auto& other_node : source->outputs) { if (node.get() != other_node.node.get()) { - LOG(FATAL) - << "Generating atomic symbol from other symbol only works for nongrouped symbol."; + LOG(FATAL) << "Generating atomic symbol from other symbol only works for nongrouped symbol."; } } - const auto *op = node->op(); + const auto* op = node->op(); const auto attrs = source->ListAttrs(nnvm::Symbol::ListAttrOption::kShallow); - *s = nnvm::Symbol::CreateFunctor(op, attrs); - *ret_sym_handle = s; + *s = nnvm::Symbol::CreateFunctor(op, attrs); + *ret_sym_handle = s; API_END_HANDLE_ERROR(delete s); } @@ -1260,8 +1237,8 @@ int MXShallowCopySymbol(SymbolHandle src, SymbolHandle* out) { nnvm::Symbol* out_sym = new nnvm::Symbol; API_BEGIN(); nnvm::Symbol* src_sym = static_cast(src); - *out_sym = *src_sym; - *out = out_sym; + *out_sym = *src_sym; + *out = out_sym; API_END_HANDLE_ERROR(delete out_sym); } @@ -1294,27 +1271,27 @@ int MXOptimizeForBackend(SymbolHandle sym_handle, NDArrayHandle** new_aux_handle, char*** new_aux_names_handle) { // create copy of input symbol - nnvm::Symbol *s = new nnvm::Symbol(); + nnvm::Symbol* s = new nnvm::Symbol(); API_BEGIN(); - nnvm::Symbol *sym = static_cast(sym_handle); - *s = sym->Copy(); + nnvm::Symbol* sym = static_cast(sym_handle); + *s = sym->Copy(); // create a data structure from pointer array std::unordered_map options_map; for (mx_uint i = 0; i < num_options; ++i) options_map.emplace(keys[i], vals[i]); - NDArray ***new_args_ptr = reinterpret_cast(new_args_handle); - NDArray ***new_aux_ptr = reinterpret_cast(new_aux_handle); - NDArray **in_args_ptr = reinterpret_cast(in_args_handle); - NDArray **in_aux_ptr = reinterpret_cast(in_aux_handle); + NDArray*** new_args_ptr = reinterpret_cast(new_args_handle); + NDArray*** new_aux_ptr = reinterpret_cast(new_aux_handle); + NDArray** in_args_ptr = reinterpret_cast(in_args_handle); + NDArray** in_aux_ptr = reinterpret_cast(in_aux_handle); auto init_graph = [&](auto s) { - nnvm::Graph g = Symbol2Graph(*s); - const auto& indexed_graph = g.indexed_graph(); - const auto& mutable_nodes = indexed_graph.mutable_input_nodes(); + nnvm::Graph g = Symbol2Graph(*s); + const auto& indexed_graph = g.indexed_graph(); + const auto& mutable_nodes = indexed_graph.mutable_input_nodes(); std::vector input_names = s->ListInputNames(nnvm::Symbol::kAll); - size_t num_forward_inputs = input_names.size(); + size_t num_forward_inputs = input_names.size(); if (args_len || aux_len) { if (!skip_infer) { @@ -1327,8 +1304,8 @@ int MXOptimizeForBackend(SymbolHandle sym_handle, std::unordered_map input_shape_map(num_input_shapes); for (uint32_t i = 0; i < num_input_shapes; ++i) { input_shape_map.emplace(input_shape_names[i], - mxnet::TShape(input_shape_data + input_shape_idx[i], - input_shape_data + input_shape_idx[i+1])); + mxnet::TShape(input_shape_data + input_shape_idx[i], + input_shape_data + input_shape_idx[i + 1])); } std::unordered_map input_dtype_map(num_input_dtypes); for (uint32_t i = 0; i < num_input_dtypes; ++i) { @@ -1345,23 +1322,23 @@ int MXOptimizeForBackend(SymbolHandle sym_handle, const uint32_t nid = indexed_graph.input_nodes().at(i); if (mutable_nodes.count(nid)) { CHECK_LT(aux_top, aux_len) - << "Cannot find aux '" << input_names[i] << "' in provided aux to optimize_for"; + << "Cannot find aux '" << input_names[i] << "' in provided aux to optimize_for"; if (in_aux_ptr[aux_top] != nullptr) { - const auto &in_arg = *(in_aux_ptr[aux_top]); - arg_shapes[i] = in_arg.shape(); - arg_dtypes[i] = in_arg.dtype(); - arg_stypes[i] = in_arg.storage_type(); + const auto& in_arg = *(in_aux_ptr[aux_top]); + arg_shapes[i] = in_arg.shape(); + arg_dtypes[i] = in_arg.dtype(); + arg_stypes[i] = in_arg.storage_type(); } aux_top++; } else { auto name = input_names[i]; CHECK_LT(args_top, args_len) - << "Cannot find arg '" << name << "' in provided args to optimize_for"; + << "Cannot find arg '" << name << "' in provided args to optimize_for"; if (in_args_ptr[args_top] != nullptr) { - const auto &in_arg = *(in_args_ptr[args_top]); - arg_shapes[i] = in_arg.shape(); - arg_dtypes[i] = in_arg.dtype(); - arg_stypes[i] = in_arg.storage_type(); + const auto& in_arg = *(in_args_ptr[args_top]); + arg_shapes[i] = in_arg.shape(); + arg_dtypes[i] = in_arg.dtype(); + arg_stypes[i] = in_arg.storage_type(); } else { // input_names[i] is not in args but can be in the optional // shape/type/stype attribute dicts. @@ -1383,7 +1360,7 @@ int MXOptimizeForBackend(SymbolHandle sym_handle, } g.attrs["context"] = std::make_shared( - exec::ContextVector(indexed_graph.num_nodes(), default_ctx)); + exec::ContextVector(indexed_graph.num_nodes(), default_ctx)); // infer shapes g = exec::InferShape(std::move(g), std::move(arg_shapes), "__shape__"); @@ -1394,22 +1371,22 @@ int MXOptimizeForBackend(SymbolHandle sym_handle, } // set args/aux as attributes on graph so that subgraph property can use them std::vector arg_names = s->ListInputNames(nnvm::Symbol::kReadOnlyArgs); - g.attrs["in_args"] = std::make_shared(in_args_ptr); - g.attrs["in_arg_names"] = std::make_shared(arg_names); + g.attrs["in_args"] = std::make_shared(in_args_ptr); + g.attrs["in_arg_names"] = std::make_shared(arg_names); std::vector aux_names = s->ListInputNames(nnvm::Symbol::kAuxiliaryStates); - g.attrs["in_aux"] = std::make_shared(in_aux_ptr); - g.attrs["in_aux_names"] = std::make_shared(aux_names); + g.attrs["in_aux"] = std::make_shared(in_aux_ptr); + g.attrs["in_aux_names"] = std::make_shared(aux_names); } else { // args/aux were not specified, so set nullptr/empty-lists - NDArray **in_args_ptr = static_cast(nullptr); + NDArray** in_args_ptr = static_cast(nullptr); std::vector arg_names; - g.attrs["in_args"] = std::make_shared(in_args_ptr); + g.attrs["in_args"] = std::make_shared(in_args_ptr); g.attrs["in_arg_names"] = std::make_shared(arg_names); - NDArray **in_aux_ptr = static_cast(nullptr); + NDArray** in_aux_ptr = static_cast(nullptr); std::vector aux_names; - g.attrs["in_aux"] = std::make_shared(in_aux_ptr); + g.attrs["in_aux"] = std::make_shared(in_aux_ptr); g.attrs["in_aux_names"] = std::make_shared(aux_names); } @@ -1422,27 +1399,27 @@ int MXOptimizeForBackend(SymbolHandle sym_handle, if (mxnet::op::SubgraphBackendRegistry::Get()->backend_map_.count(backend_name) > 0) { // use subgraph backend - const auto backend = mxnet::op::SubgraphBackendRegistry - ::Get()->GetSubgraphBackend(backend_name); + const auto backend = + mxnet::op::SubgraphBackendRegistry ::Get()->GetSubgraphBackend(backend_name); const auto& subgraph_prop_list = backend->GetSubgraphProperties(); for (auto property : subgraph_prop_list) { nnvm::Graph g = init_graph(s); property->PrePartition(g, options_map); g.attrs["subgraph_property"] = std::make_shared(property); - g = ApplyPass(std::move(g), "BuildSubgraph"); + g = ApplyPass(std::move(g), "BuildSubgraph"); g.attrs.erase("subgraph_property"); property->PostPartition(g); s->outputs = g.outputs; } } else if (dmlc::Registry::Find(backend_name) != nullptr) { // use graph pass - nnvm::Graph g = init_graph(s); + nnvm::Graph g = init_graph(s); g.attrs["options_map"] = std::make_shared(options_map); - g.attrs["pass_name"] = std::make_shared(backend_name); - g = ApplyPass(std::move(g), backend_name); + g.attrs["pass_name"] = std::make_shared(backend_name); + g = ApplyPass(std::move(g), backend_name); - std::vector new_args = g.GetAttr>("new_args"); - std::vector new_aux = g.GetAttr>("new_aux"); + std::vector new_args = g.GetAttr>("new_args"); + std::vector new_aux = g.GetAttr>("new_aux"); std::vector new_arg_names = g.GetAttr>("new_arg_names"); std::vector new_aux_names = g.GetAttr>("new_aux_names"); g.attrs.erase("new_args"); @@ -1453,12 +1430,12 @@ int MXOptimizeForBackend(SymbolHandle sym_handle, NDArray** new_arg_arr = new NDArray*[new_arg_names.size()]; NDArray** new_aux_arr = new NDArray*[new_aux_names.size()]; - char** new_arg_cstr = new char*[new_arg_names.size()]; - char** new_aux_cstr = new char*[new_aux_names.size()]; + char** new_arg_cstr = new char*[new_arg_names.size()]; + char** new_aux_cstr = new char*[new_aux_names.size()]; for (unsigned i = 0; i < new_arg_names.size(); i++) { new_arg_arr[i] = new_args[i]; std::string& s = new_arg_names[i]; - char* tmp = new char[s.length()+1]; + char* tmp = new char[s.length() + 1]; s.copy(tmp, s.length()); tmp[s.length()] = '\0'; new_arg_cstr[i] = tmp; @@ -1466,17 +1443,17 @@ int MXOptimizeForBackend(SymbolHandle sym_handle, for (unsigned i = 0; i < new_aux_names.size(); i++) { new_aux_arr[i] = new_aux[i]; std::string& s = new_aux_names[i]; - char* tmp = new char[s.length()+1]; + char* tmp = new char[s.length() + 1]; s.copy(tmp, s.length()); tmp[s.length()] = '\0'; new_aux_cstr[i] = tmp; } - *new_args_cnt = new_arg_names.size(); - *new_aux_cnt = new_aux_names.size(); + *new_args_cnt = new_arg_names.size(); + *new_aux_cnt = new_aux_names.size(); *new_arg_names_handle = new_arg_cstr; *new_aux_names_handle = new_aux_cstr; - *new_args_ptr = new_arg_arr; - *new_aux_ptr = new_aux_arr; + *new_args_ptr = new_arg_arr; + *new_aux_ptr = new_aux_arr; } else { // cannot find graph pass or subgraph backend registered in this name LOG(ERROR) << "Error optimizing for backend '" << backend_name << "' cannot be found"; @@ -1486,18 +1463,18 @@ int MXOptimizeForBackend(SymbolHandle sym_handle, API_END_HANDLE_ERROR(delete s); } -int MXCheckDynamicShapeOp(SymbolHandle sym_handle, - bool* has_dynamic_shape) { - nnvm::Symbol *s = new nnvm::Symbol(); +int MXCheckDynamicShapeOp(SymbolHandle sym_handle, bool* has_dynamic_shape) { + nnvm::Symbol* s = new nnvm::Symbol(); API_BEGIN(); *has_dynamic_shape = false; // traverse the symbol and check if any dynamic shape is present - nnvm::Symbol *sym = static_cast(sym_handle); - *s = sym->Copy(); - nnvm::Graph g = Symbol2Graph(*s); + nnvm::Symbol* sym = static_cast(sym_handle); + *s = sym->Copy(); + nnvm::Graph g = Symbol2Graph(*s); const auto& infershape = nnvm::Op::GetAttr("FInferShape"); DFSVisit(g.outputs, [infershape, has_dynamic_shape](const nnvm::ObjectPtr n) { - if (*has_dynamic_shape) return; + if (*has_dynamic_shape) + return; if (!n->is_variable() && !infershape.count(n->op())) { *has_dynamic_shape = true; return; diff --git a/src/c_api/c_api_test.cc b/src/c_api/c_api_test.cc index ac691234fcf0..9295c13b84a0 100644 --- a/src/c_api/c_api_test.cc +++ b/src/c_api/c_api_test.cc @@ -29,10 +29,10 @@ #include "../common/cuda/rtc.h" int MXBuildSubgraphByOpNames(SymbolHandle sym_handle, - const char* prop_name, - const uint32_t num_ops, - const char** op_names, - SymbolHandle* ret_sym_handle) { + const char* prop_name, + const uint32_t num_ops, + const char** op_names, + SymbolHandle* ret_sym_handle) { nnvm::Symbol* s = new nnvm::Symbol(); API_BEGIN(); std::unordered_set op_name_set; @@ -40,10 +40,9 @@ int MXBuildSubgraphByOpNames(SymbolHandle sym_handle, op_name_set.emplace(op_names[i]); } nnvm::Symbol* sym = static_cast(sym_handle); - *s = sym->Copy(); + *s = sym->Copy(); if (!op_name_set.empty()) { - auto& backend = - mxnet::op::SubgraphBackendRegistry::Get()->GetSubgraphBackend(prop_name); + auto& backend = mxnet::op::SubgraphBackendRegistry::Get()->GetSubgraphBackend(prop_name); LOG(INFO) << "Subgraph backend " << backend->GetName() << " is activated."; const auto& subgraph_prop_list = backend->GetSubgraphProperties(); for (auto property : subgraph_prop_list) { @@ -52,7 +51,7 @@ int MXBuildSubgraphByOpNames(SymbolHandle sym_handle, property->SetAttr("graph", g); property->SetAttr("op_names", op_name_set); g.attrs["subgraph_property"] = std::make_shared(property); - g = nnvm::ApplyPass(std::move(g), "BuildSubgraph"); + g = nnvm::ApplyPass(std::move(g), "BuildSubgraph"); property->RemoveAttr("graph"); g.attrs.erase("subgraph_property"); s->outputs = g.outputs; @@ -75,15 +74,14 @@ int MXSetSubgraphPropertyOpNames(const char* prop_name, } int MXSetSubgraphPropertyOpNamesV2(const char* prop_name, - const uint32_t num_ops, - const char** op_names) { + const uint32_t num_ops, + const char** op_names) { API_BEGIN(); std::unordered_set op_name_set; for (size_t i = 0; i < num_ops; ++i) { op_name_set.emplace(op_names[i]); } - auto& backend = - mxnet::op::SubgraphBackendRegistry::Get()->GetSubgraphBackend(prop_name); + auto& backend = mxnet::op::SubgraphBackendRegistry::Get()->GetSubgraphBackend(prop_name); const auto& subgraph_prop_list = backend->GetSubgraphProperties(); for (auto& property : subgraph_prop_list) { property->SetAttr("op_names", op_name_set); @@ -99,8 +97,7 @@ int MXRemoveSubgraphPropertyOpNames(const char* prop_name) { int MXRemoveSubgraphPropertyOpNamesV2(const char* prop_name) { API_BEGIN(); - auto& backend = - mxnet::op::SubgraphBackendRegistry::Get()->GetSubgraphBackend(prop_name); + auto& backend = mxnet::op::SubgraphBackendRegistry::Get()->GetSubgraphBackend(prop_name); const auto& subgraph_prop_list = backend->GetSubgraphProperties(); for (auto& property : subgraph_prop_list) { property->RemoveAttr("op_names"); @@ -108,15 +105,13 @@ int MXRemoveSubgraphPropertyOpNamesV2(const char* prop_name) { API_END(); } -int MXGetEnv(const char* name, - const char** value) { +int MXGetEnv(const char* name, const char** value) { API_BEGIN(); *value = getenv(name); API_END(); } -int MXSetEnv(const char* name, - const char* value) { +int MXSetEnv(const char* name, const char* value) { API_BEGIN(); #ifdef _WIN32 auto value_arg = (value == nullptr) ? "" : value; @@ -130,7 +125,7 @@ int MXSetEnv(const char* name, API_END(); } -int MXGetMaxSupportedArch(uint32_t *max_arch) { +int MXGetMaxSupportedArch(uint32_t* max_arch) { API_BEGIN(); #if MXNET_USE_CUDA *max_arch = static_cast(mxnet::common::cuda::rtc::GetMaxSupportedArch()); diff --git a/src/common/cuda/rtc.cc b/src/common/cuda/rtc.cc index 2294feaa9e2f..45bccea718fc 100644 --- a/src/common/cuda/rtc.cc +++ b/src/common/cuda/rtc.cc @@ -42,18 +42,26 @@ #include "rtc/reducer-inl.h" #include "utils.h" -typedef CUresult (*cuDeviceGetPtr) (CUdevice* device, int ordinal); -typedef CUresult (*cuDevicePrimaryCtxRetainPtr) (CUcontext* pctx, CUdevice dev); -typedef CUresult (*cuModuleLoadDataExPtr) (CUmodule* module, const void* image, - unsigned int numOptions, CUjit_option* options, void** optionValues); -typedef CUresult (*cuModuleGetFunctionPtr) (CUfunction* hfunc, CUmodule hmod, - const char* name); -typedef CUresult (*cuLaunchKernelPtr) (CUfunction f, unsigned int gridDimX, - unsigned int gridDimY, unsigned int gridDimZ, - unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, - unsigned int sharedMemBytes, CUstream hStream, void** kernelParams, - void** extra); -typedef CUresult (*cuGetErrorStringPtr) (CUresult error, const char** pStr); +typedef CUresult (*cuDeviceGetPtr)(CUdevice* device, int ordinal); +typedef CUresult (*cuDevicePrimaryCtxRetainPtr)(CUcontext* pctx, CUdevice dev); +typedef CUresult (*cuModuleLoadDataExPtr)(CUmodule* module, + const void* image, + unsigned int numOptions, + CUjit_option* options, + void** optionValues); +typedef CUresult (*cuModuleGetFunctionPtr)(CUfunction* hfunc, CUmodule hmod, const char* name); +typedef CUresult (*cuLaunchKernelPtr)(CUfunction f, + unsigned int gridDimX, + unsigned int gridDimY, + unsigned int gridDimZ, + unsigned int blockDimX, + unsigned int blockDimY, + unsigned int blockDimZ, + unsigned int sharedMemBytes, + CUstream hStream, + void** kernelParams, + void** extra); +typedef CUresult (*cuGetErrorStringPtr)(CUresult error, const char** pStr); namespace mxnet { namespace common { @@ -61,9 +69,9 @@ namespace cuda { namespace rtc { #if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) - const char cuda_lib_name[] = "nvcuda.dll"; +const char cuda_lib_name[] = "nvcuda.dll"; #else - const char cuda_lib_name[] = "libcuda.so.1"; +const char cuda_lib_name[] = "libcuda.so.1"; #endif std::mutex lock; @@ -131,8 +139,8 @@ std::string GetCompiledCode(nvrtcProgram program, bool use_cubin) { const auto getSize = use_cubin ? nvrtcGetCUBINSize : nvrtcGetPTXSize; const auto getFunc = use_cubin ? nvrtcGetCUBIN : nvrtcGetPTX; #else - const auto getSize = nvrtcGetPTXSize; - const auto getFunc = nvrtcGetPTX; + const auto getSize = nvrtcGetPTXSize; + const auto getFunc = nvrtcGetPTX; #endif size_t ptx_size_including_null; NVRTC_CALL(getSize(program, &ptx_size_including_null)); @@ -145,8 +153,7 @@ std::string GetCompiledCode(nvrtcProgram program, bool use_cubin) { std::tuple GetArchString(const int sm_arch) { const int sm_arch_as_used = std::min(sm_arch, GetMaxSupportedArch()); // Always use PTX for CUDA <= 11.0 - const bool known_arch = (CUDA_VERSION > 11000) && - (sm_arch == sm_arch_as_used); + const bool known_arch = (CUDA_VERSION > 11000) && (sm_arch == sm_arch_as_used); if (known_arch) { return {known_arch, "sm_" + std::to_string(sm_arch_as_used)}; } else { @@ -156,9 +163,9 @@ std::tuple GetArchString(const int sm_arch) { } // namespace -CUfunction get_function(const std::string ¶meters, - const std::string &kernel_name, - const std::string &code, +CUfunction get_function(const std::string& parameters, + const std::string& kernel_name, + const std::string& code, int dev_id) { constexpr int CACHESIZE_WARN_THRESHOLD = 10000; std::lock_guard l(lock); @@ -182,19 +189,11 @@ CUfunction get_function(const std::string ¶meters, if (kinfo.ptx.size() == 0) { // It's the first time we've seen this kernel, so we need to generate the ptx and mangled_name. static std::string common_header = - std::string(fp16_support_string) + "\n" + - type_support_string + "\n" + - util_string + "\n" + - limits + "\n" + - special_functions_definitions + '\n' + - vectorization_support_string + "\n" + - function_definitions_util + "\n" + - function_definitions_binary + "\n" + - function_definitions_unary + "\n" + - backward_function_definitions + "\n" + - grad_function_definitions + "\n" + - reducer + "\n" + - logic_reducer + "\n"; + std::string(fp16_support_string) + "\n" + type_support_string + "\n" + util_string + "\n" + + limits + "\n" + special_functions_definitions + '\n' + vectorization_support_string + "\n" + + function_definitions_util + "\n" + function_definitions_binary + "\n" + + function_definitions_unary + "\n" + backward_function_definitions + "\n" + + grad_function_definitions + "\n" + reducer + "\n" + logic_reducer + "\n"; std::string code_with_header = common_header + parameters + code; // If verbose mode, output kernel source, though not including the common header if (dmlc::GetEnv("MXNET_RTC_VERBOSE", false)) { @@ -207,25 +206,27 @@ CUfunction get_function(const std::string ¶meters, << ". Set MXNET_RTC_SIZE_WARNING=0 to quiet this warning."; } nvrtcProgram program; - NVRTC_CALL(nvrtcCreateProgram(&program, // prog - &code_with_header[0], // buffer - (kernel_name + "_kernel.cu").c_str(), // name - 0, // num headers - nullptr, // headers - nullptr)); // include names - const auto [use_cubin, gpu_arch] = GetArchString(sm_arch); // NOLINT(*) - std::string gpu_arch_arg = "--gpu-architecture=" + gpu_arch; - const char *opts[] = {gpu_arch_arg.c_str(), + NVRTC_CALL(nvrtcCreateProgram(&program, // prog + &code_with_header[0], // buffer + (kernel_name + "_kernel.cu").c_str(), // name + 0, // num headers + nullptr, // headers + nullptr)); // include names + const auto [use_cubin, gpu_arch] = GetArchString(sm_arch); // NOLINT(*) + std::string gpu_arch_arg = "--gpu-architecture=" + gpu_arch; + const char* opts[] = { + gpu_arch_arg.c_str(), #if NDEBUG == 0 - "-G", + "-G", #endif - "--std=c++14"}; + "--std=c++14" + }; const std::string& kernel_name_demangled = kernel_name; NVRTC_CALL(nvrtcAddNameExpression(program, (kernel_name_demangled).c_str())); - nvrtcResult compileResult = nvrtcCompileProgram(program, // prog + nvrtcResult compileResult = nvrtcCompileProgram(program, // prog sizeof(opts) / sizeof(opts[0]), // num options - opts); // options + opts); // options static const std::string dump_file = "mxnet_rtc_debug_code.log"; if (compileResult != NVRTC_SUCCESS) { std::ofstream f(dump_file); @@ -238,10 +239,8 @@ CUfunction get_function(const std::string ¶meters, << GetCompileLog(program); kinfo.ptx = GetCompiledCode(program, use_cubin); - const char *mangled_name; - NVRTC_CALL(nvrtcGetLoweredName(program, - kernel_name_demangled.c_str(), - &mangled_name)); + const char* mangled_name; + NVRTC_CALL(nvrtcGetLoweredName(program, kernel_name_demangled.c_str(), &mangled_name)); kinfo.mangled_name = mangled_name; // Destroy the program. NVRTC_CALL(nvrtcDestroyProgram(&program)); @@ -257,7 +256,7 @@ CUfunction get_function(const std::string ¶meters, cuDeviceGetPtr device_get_ptr = get_func(cuda_lib_handle, "cuDeviceGet"); CUDA_DRIVER_CALL((*device_get_ptr)(&cu_device, dev_id)); cuDevicePrimaryCtxRetainPtr device_primary_ctx_retain_ptr = - get_func(cuda_lib_handle, "cuDevicePrimaryCtxRetain"); + get_func(cuda_lib_handle, "cuDevicePrimaryCtxRetain"); CUDA_DRIVER_CALL((*device_primary_ctx_retain_ptr)(&context, cu_device)); // Jit-compile ptx for the driver's current context @@ -265,25 +264,24 @@ CUfunction get_function(const std::string ¶meters, #if NDEBUG == 0 intptr_t debug_info = 1; - intptr_t line_info = 1; + intptr_t line_info = 1; #else intptr_t debug_info = 0; - intptr_t line_info = 0; + intptr_t line_info = 0; #endif CUjit_option jit_opts[] = {CU_JIT_GENERATE_DEBUG_INFO, CU_JIT_GENERATE_LINE_INFO}; - void* jit_opt_values[] = {reinterpret_cast(debug_info), + void* jit_opt_values[] = {reinterpret_cast(debug_info), reinterpret_cast(line_info)}; cuModuleLoadDataExPtr module_load_data_ex_ptr = - get_func(cuda_lib_handle, "cuModuleLoadDataEx"); - CUDA_DRIVER_CALL((*module_load_data_ex_ptr)(&module, kinfo.ptx.c_str(), 2, - jit_opts, jit_opt_values)); + get_func(cuda_lib_handle, "cuModuleLoadDataEx"); + CUDA_DRIVER_CALL( + (*module_load_data_ex_ptr)(&module, kinfo.ptx.c_str(), 2, jit_opts, jit_opt_values)); cuModuleGetFunctionPtr module_get_function_ptr = - get_func(cuda_lib_handle, "cuModuleGetFunction"); - CUDA_DRIVER_CALL((*module_get_function_ptr)(&kinfo.functions[dev_id], - module, - kinfo.mangled_name.c_str())); + get_func(cuda_lib_handle, "cuModuleGetFunction"); + CUDA_DRIVER_CALL( + (*module_get_function_ptr)(&kinfo.functions[dev_id], module, kinfo.mangled_name.c_str())); } return kinfo.functions[dev_id]; } @@ -292,32 +290,33 @@ void launch(CUfunction function, const dim3 grid_dim, const dim3 block_dim, unsigned int shared_mem_bytes, - mshadow::Stream *stream, - std::vector *args) { - CHECK(args->size() != 0) << - "Empty argument list passed to a kernel."; + mshadow::Stream* stream, + std::vector* args) { + CHECK(args->size() != 0) << "Empty argument list passed to a kernel."; void* cuda_lib_handle = LibraryInitializer::Get()->lib_load(cuda_lib_name); cuLaunchKernelPtr launch_kernel_ptr = - get_func(cuda_lib_handle, "cuLaunchKernel"); + get_func(cuda_lib_handle, "cuLaunchKernel"); CUresult err = (*launch_kernel_ptr)(function, // function to launch - grid_dim.x, grid_dim.y, grid_dim.z, // grid dim - block_dim.x, block_dim.y, block_dim.z, // block dim - shared_mem_bytes, // shared memory - mshadow::Stream::GetStream(stream), // stream - const_cast(args->data()), // arguments - nullptr); // ); + grid_dim.x, + grid_dim.y, + grid_dim.z, // grid dim + block_dim.x, + block_dim.y, + block_dim.z, // block dim + shared_mem_bytes, // shared memory + mshadow::Stream::GetStream(stream), // stream + const_cast(args->data()), // arguments + nullptr); // ); if (err != CUDA_SUCCESS) { const char* error_string; cuGetErrorStringPtr get_error_string_ptr = - get_func(cuda_lib_handle, "cuGetErrorString"); + get_func(cuda_lib_handle, "cuGetErrorString"); (*get_error_string_ptr)(err, &error_string); - LOG(FATAL) << "cuLaunchKernel failed: " - << err << " " << error_string << ": " + LOG(FATAL) << "cuLaunchKernel failed: " << err << " " << error_string << ": " << reinterpret_cast(function) << " " << "(" << grid_dim.x << ", " << grid_dim.y << ", " << grid_dim.z << ") " << "(" << block_dim.x << ", " << block_dim.y << ", " << block_dim.z << ") " - << shared_mem_bytes << " " - << args->size(); + << shared_mem_bytes << " " << args->size(); } } diff --git a/src/common/cuda/rtc.h b/src/common/cuda/rtc.h index 8c36aa161927..449fb9f0a3ad 100644 --- a/src/common/cuda/rtc.h +++ b/src/common/cuda/rtc.h @@ -64,9 +64,9 @@ extern std::mutex lock; * \param code used for compilation of the kernel if not found in cache * \param dev_id id of the device which the kernel will be launched on */ -CUfunction get_function(const std::string ¶meters, - const std::string &kernel_name, - const std::string &code, +CUfunction get_function(const std::string& parameters, + const std::string& kernel_name, + const std::string& code, int dev_id); /*! \brief Launch a GPU kernel. @@ -81,8 +81,8 @@ void launch(CUfunction function, const dim3 grid_dim, const dim3 block_dim, unsigned int shared_mem_bytes, - mshadow::Stream *stream, - std::vector *args); + mshadow::Stream* stream, + std::vector* args); } // namespace rtc } // namespace cuda diff --git a/src/common/cuda/rtc/reducer-inl.h b/src/common/cuda/rtc/reducer-inl.h index f5b70d832594..b25cc55e5c7e 100644 --- a/src/common/cuda/rtc/reducer-inl.h +++ b/src/common/cuda/rtc/reducer-inl.h @@ -27,7 +27,6 @@ namespace common { namespace cuda { namespace rtc { - const char reducer[] = R"code( namespace red { @@ -617,4 +616,3 @@ struct argmin { #endif // MXNET_USE_CUDA #endif // MXNET_COMMON_CUDA_RTC_REDUCER_INL_H_ - diff --git a/src/common/cuda/rtc/util-inl.h b/src/common/cuda/rtc/util-inl.h index bafa8cf3f7e5..f294aa0ef2eb 100644 --- a/src/common/cuda/rtc/util-inl.h +++ b/src/common/cuda/rtc/util-inl.h @@ -48,11 +48,11 @@ static_assert(sizeof(int64) == 8, "Size of int64 is expected to be 8B"); )code" #if MSHADOW_INT64_TENSOR_SIZE == 1 -"typedef int64 index_t;\n" + "typedef int64 index_t;\n" #else -"typedef int32 index_t;\n" + "typedef int32 index_t;\n" #endif -R"code( + R"code( // bool and int8 need to be accumulated in index_t // but bool needs to be treated in the special way // for ops like bitwise_not diff --git a/src/common/cuda/rtc/vectorization-inl.h b/src/common/cuda/rtc/vectorization-inl.h index 96205fceab3e..f5feab5fe2eb 100644 --- a/src/common/cuda/rtc/vectorization-inl.h +++ b/src/common/cuda/rtc/vectorization-inl.h @@ -265,20 +265,22 @@ class VectorizedStorer : public VectorizedAccessor { namespace { -inline index_t get_num_aligned_elements(const void *ptr, const index_t lead_dim, - const int nvec, const int size) { +inline index_t get_num_aligned_elements(const void* ptr, + const index_t lead_dim, + const int nvec, + const int size) { size_t ptr_as_number = reinterpret_cast(ptr); - int alignment = (ptr_as_number % (nvec * size)) / size; + int alignment = (ptr_as_number % (nvec * size)) / size; return (lead_dim + alignment + nvec - 1) / nvec; } enum class Alignment { - SAME_ALIGNED, // All tensors aligned + SAME_ALIGNED, // All tensors aligned SAME_UNALIGNED, // All tensors have the same misalignment - DIFFERENT // Tensors have different alignment + DIFFERENT // Tensors have different alignment }; -inline int CalcAlignment(const void *ptr, const int size) { +inline int CalcAlignment(const void* ptr, const int size) { size_t ptr_as_number = reinterpret_cast(ptr); return ptr_as_number % size; } @@ -292,18 +294,19 @@ inline int CalcAlignment(const void *ptr, const int size) { \param outputs Outputs of the operator. */ template -Alignment CheckAlignment(const Params& params, const index_t lead_dim, - const index_t other_dim, const int nvec, - const std::vector &inputs, - const std::vector &outputs) { +Alignment CheckAlignment(const Params& params, + const index_t lead_dim, + const index_t other_dim, + const int nvec, + const std::vector& inputs, + const std::vector& outputs) { using namespace common; int align = -1; size_t i = 0; - for (const void *ptr : params.inputs) { + for (const void* ptr : params.inputs) { if (ptr != nullptr) { - int new_align = CalcAlignment(ptr, - mshadow_type_info(inputs[i].type_flag_).size * nvec); + int new_align = CalcAlignment(ptr, mshadow_type_info(inputs[i].type_flag_).size * nvec); if (align == -1) { align = new_align; } else { @@ -316,10 +319,9 @@ Alignment CheckAlignment(const Params& params, const index_t lead_dim, } i = 0; - for (const void *ptr : params.outputs) { + for (const void* ptr : params.outputs) { if (ptr != nullptr) { - int new_align = CalcAlignment(ptr, - mshadow_type_info(outputs[i].type_flag_).size * nvec); + int new_align = CalcAlignment(ptr, mshadow_type_info(outputs[i].type_flag_).size * nvec); if (align == -1) { align = new_align; } else { @@ -331,13 +333,11 @@ Alignment CheckAlignment(const Params& params, const index_t lead_dim, ++i; } - if ((other_dim != 1) && - (lead_dim % nvec != 0)) { + if ((other_dim != 1) && (lead_dim % nvec != 0)) { return Alignment::DIFFERENT; } - if ((align == 0) && - (lead_dim % nvec == 0)) { + if ((align == 0) && (lead_dim % nvec == 0)) { return Alignment::SAME_ALIGNED; } else { return Alignment::SAME_UNALIGNED; @@ -366,24 +366,23 @@ constexpr int vectorized_kernel_thread_num = 512; * Default is 0. */ template -void VectorizedKernelRTCLauncher(const std::string ¶meters, - const std::string &kernel_name, - const std::string &code, +void VectorizedKernelRTCLauncher(const std::string& parameters, + const std::string& kernel_name, + const std::string& code, int nvec, const index_t lead_dim, const index_t other_dim, - mshadow::Stream *s, + mshadow::Stream* s, const Params params, - const std::vector &inputs, - const std::vector &outputs, + const std::vector& inputs, + const std::vector& outputs, const int dev_id, const int lead_input_num = 0, - const index_t blocks = 0) { + const index_t blocks = 0) { const index_t N = lead_dim * other_dim; - nvec = std::min(nvec, 4); // Use at most 4-wide vectors + nvec = std::min(nvec, 4); // Use at most 4-wide vectors if (N != 0) { - auto align = CheckAlignment(params, lead_dim, other_dim, - nvec, inputs, outputs); + auto align = CheckAlignment(params, lead_dim, other_dim, nvec, inputs, outputs); std::string kernel_builder; kernel_builder.reserve(2560); @@ -413,21 +412,24 @@ void VectorizedKernelRTCLauncher(const std::string ¶meters, switch (align) { case Alignment::SAME_ALIGNED: - kernel_builder += "const bool aligned = true;\n" - "const int nvec = "; + kernel_builder += + "const bool aligned = true;\n" + "const int nvec = "; kernel_builder += std::to_string(nvec); kernel_builder += ";\n"; break; case Alignment::SAME_UNALIGNED: - kernel_builder += "const bool aligned = false;\n" - "const int nvec = "; + kernel_builder += + "const bool aligned = false;\n" + "const int nvec = "; kernel_builder += std::to_string(nvec); kernel_builder += ";\n"; break; case Alignment::DIFFERENT: { // If the pointers are aligned differently we cannot vectorize - kernel_builder += "const bool aligned = true;\n" - "const int nvec = 1;\n"; + kernel_builder += + "const bool aligned = true;\n" + "const int nvec = 1;\n"; nvec = 1; break; } @@ -435,36 +437,33 @@ void VectorizedKernelRTCLauncher(const std::string ¶meters, kernel_builder += parameters; - index_t num_aligned_elements = get_num_aligned_elements( - params.inputs[lead_input_num], - lead_dim, nvec, - common::mshadow_type_info( - inputs[lead_input_num].type_flag_).size); + index_t num_aligned_elements = + get_num_aligned_elements(params.inputs[lead_input_num], + lead_dim, + nvec, + common::mshadow_type_info(inputs[lead_input_num].type_flag_).size); constexpr int threads = vectorized_kernel_thread_num; index_t num_blocks; if (blocks != 0) { num_blocks = blocks; } else { - size_t num_elements = other_dim * num_aligned_elements; - num_blocks = (num_elements + threads - 1) / threads; + size_t num_elements = other_dim * num_aligned_elements; + num_blocks = (num_elements + threads - 1) / threads; constexpr int max_blocks = 65535; - num_blocks = std::min(static_cast(num_blocks), max_blocks); + num_blocks = std::min(static_cast(num_blocks), max_blocks); } - std::vector args = {¶ms, &lead_dim, &other_dim, - &N, &num_aligned_elements}; - auto function = common::cuda::rtc::get_function(kernel_builder, - kernel_name, - code, - dev_id); + std::vector args = {¶ms, &lead_dim, &other_dim, &N, &num_aligned_elements}; + auto function = common::cuda::rtc::get_function(kernel_builder, kernel_name, code, dev_id); common::cuda::rtc::launch(function, {static_cast(num_blocks), 1, 1}, {static_cast(threads), 1, 1}, - 0, s, &args); + 0, + s, + &args); } } - } // namespace rtc } // namespace cuda } // namespace common diff --git a/src/common/cuda/utils.cc b/src/common/cuda/utils.cc index 7aa936dc9d4d..d04097bf1c57 100644 --- a/src/common/cuda/utils.cc +++ b/src/common/cuda/utils.cc @@ -53,13 +53,13 @@ int get_load_type(size_t N) { int get_rows_per_block(size_t row_size, int num_threads_per_block) { const int warp_size = 32; CHECK(IsPower2(num_threads_per_block)) - << "Number of threads in a block must be power of 2 to use get_rows_per_block function"; + << "Number of threads in a block must be power of 2 to use get_rows_per_block function"; // How many read instructions should 1 thread at least do - const int read_instructions = 2; + const int read_instructions = 2; const int desired_num_threads_per_row = (row_size + read_instructions - 1) / read_instructions; - int desired_num_warps_per_row = (desired_num_threads_per_row + warp_size - 1) / warp_size; - int actual_num_warps_per_row = std::min(desired_num_warps_per_row, - num_threads_per_block / warp_size); + int desired_num_warps_per_row = (desired_num_threads_per_row + warp_size - 1) / warp_size; + int actual_num_warps_per_row = + std::min(desired_num_warps_per_row, num_threads_per_block / warp_size); // actual number of warps needs to be power of 2 actual_num_warps_per_row = RoundToPower2(actual_num_warps_per_row); return num_threads_per_block / (warp_size * actual_num_warps_per_row); diff --git a/src/common/cuda/utils.h b/src/common/cuda/utils.h index a203ba55a773..63df34bf07d2 100644 --- a/src/common/cuda/utils.h +++ b/src/common/cuda/utils.h @@ -41,14 +41,21 @@ #define __shared__ inline void __syncthreads() {} inline void __threadfence_block() {} -template inline T __clz(const T val) { return val; } -struct __cuda_fake_struct { int x; int y; int z; }; +template +inline T __clz(const T val) { + return val; +} +struct __cuda_fake_struct { + int x; + int y; + int z; +}; extern __cuda_fake_struct blockDim; extern __cuda_fake_struct threadIdx; extern __cuda_fake_struct blockIdx; #endif -#define QUOTE(x) #x +#define QUOTE(x) #x #define QUOTEVALUE(x) QUOTE(x) #if MXNET_USE_CUDA @@ -98,11 +105,10 @@ inline __device__ bool __is_supported_cuda_architecture() { * * It checks for CUDA errors after invocation of the expression. */ -#define CUDA_CALL(func) \ - { \ - cudaError_t e = (func); \ - CHECK(e == cudaSuccess || e == cudaErrorCudartUnloading) \ - << "CUDA: " << cudaGetErrorString(e); \ +#define CUDA_CALL(func) \ + { \ + cudaError_t e = (func); \ + CHECK(e == cudaSuccess || e == cudaErrorCudartUnloading) << "CUDA: " << cudaGetErrorString(e); \ } /*! @@ -111,10 +117,10 @@ inline __device__ bool __is_supported_cuda_architecture() { * * It checks for cuBLAS errors after invocation of the expression. */ -#define CUBLAS_CALL(func) \ - { \ - cublasStatus_t e = (func); \ - CHECK_EQ(e, CUBLAS_STATUS_SUCCESS) \ +#define CUBLAS_CALL(func) \ + { \ + cublasStatus_t e = (func); \ + CHECK_EQ(e, CUBLAS_STATUS_SUCCESS) \ << "cuBLAS: " << mxnet::common::cuda::CublasGetErrorString(e); \ } @@ -124,10 +130,10 @@ inline __device__ bool __is_supported_cuda_architecture() { * * It checks for cuSolver errors after invocation of the expression. */ -#define CUSOLVER_CALL(func) \ - { \ - cusolverStatus_t e = (func); \ - CHECK_EQ(e, CUSOLVER_STATUS_SUCCESS) \ +#define CUSOLVER_CALL(func) \ + { \ + cusolverStatus_t e = (func); \ + CHECK_EQ(e, CUSOLVER_STATUS_SUCCESS) \ << "cuSolver: " << mxnet::common::cuda::CusolverGetErrorString(e); \ } @@ -137,10 +143,10 @@ inline __device__ bool __is_supported_cuda_architecture() { * * It checks for cuRAND errors after invocation of the expression. */ -#define CURAND_CALL(func) \ - { \ - curandStatus_t e = (func); \ - CHECK_EQ(e, CURAND_STATUS_SUCCESS) \ +#define CURAND_CALL(func) \ + { \ + curandStatus_t e = (func); \ + CHECK_EQ(e, CURAND_STATUS_SUCCESS) \ << "cuRAND: " << mxnet::common::cuda::CurandGetErrorString(e); \ } @@ -150,12 +156,10 @@ inline __device__ bool __is_supported_cuda_architecture() { * * It checks for NVRTC errors after invocation of the expression. */ -#define NVRTC_CALL(x) \ - { \ - nvrtcResult result = x; \ - CHECK_EQ(result, NVRTC_SUCCESS) \ - << #x " failed with error " \ - << nvrtcGetErrorString(result); \ +#define NVRTC_CALL(x) \ + { \ + nvrtcResult result = x; \ + CHECK_EQ(result, NVRTC_SUCCESS) << #x " failed with error " << nvrtcGetErrorString(result); \ } /*! @@ -164,20 +168,19 @@ inline __device__ bool __is_supported_cuda_architecture() { * * It checks for CUDA driver errors after invocation of the expression. */ -#define CUDA_DRIVER_CALL(func) \ - { \ - CUresult e = (func); \ - if (e != CUDA_SUCCESS) { \ - char const * err_msg = nullptr; \ - if (cuGetErrorString(e, &err_msg) == CUDA_ERROR_INVALID_VALUE) { \ - LOG(FATAL) << "CUDA Driver: Unknown error " << e; \ - } else { \ - LOG(FATAL) << "CUDA Driver: " << e << " " << err_msg; \ - } \ - } \ +#define CUDA_DRIVER_CALL(func) \ + { \ + CUresult e = (func); \ + if (e != CUDA_SUCCESS) { \ + char const* err_msg = nullptr; \ + if (cuGetErrorString(e, &err_msg) == CUDA_ERROR_INVALID_VALUE) { \ + LOG(FATAL) << "CUDA Driver: Unknown error " << e; \ + } else { \ + LOG(FATAL) << "CUDA Driver: " << e << " " << err_msg; \ + } \ + } \ } - #if MXNET_USE_NVML /*! * \brief Protected NVML call. @@ -185,17 +188,15 @@ inline __device__ bool __is_supported_cuda_architecture() { * * It checks for NVML errors after invocation of the expression. */ -#define NVML_CALL(func) \ - { \ - nvmlReturn_t result = (func); \ - CHECK_EQ(result, NVML_SUCCESS) \ - << #func " failed with error " \ - << nvmlErrorString(result); \ +#define NVML_CALL(func) \ + { \ + nvmlReturn_t result = (func); \ + CHECK_EQ(result, NVML_SUCCESS) << #func " failed with error " << nvmlErrorString(result); \ } #endif // MXNET_USE_NVML #if !defined(_MSC_VER) -#define CUDA_UNROLL _Pragma("unroll") +#define CUDA_UNROLL _Pragma("unroll") #define CUDA_NOUNROLL _Pragma("nounroll") #else #define CUDA_UNROLL @@ -209,7 +210,7 @@ namespace cuda { /*! * \brief Converts between C++ datatypes and enums/constants needed by cuBLAS. */ -template +template struct CublasType; // With CUDA v8, cuBLAS adopted use of cudaDataType_t instead of its own @@ -218,7 +219,7 @@ struct CublasType; // call cublasGemmEx(), burdening the class with the legacy type values // was not needed. -template<> +template <> struct CublasType { static const int kFlag = mshadow::kFloat32; #if CUDA_VERSION >= 8000 @@ -228,7 +229,7 @@ struct CublasType { static const float one; static const float zero; }; -template<> +template <> struct CublasType { static const int kFlag = mshadow::kFloat64; #if CUDA_VERSION >= 8000 @@ -238,7 +239,7 @@ struct CublasType { static const double one; static const double zero; }; -template<> +template <> struct CublasType { static const int kFlag = mshadow::kFloat16; #if CUDA_VERSION >= 8000 @@ -248,24 +249,24 @@ struct CublasType { static const mshadow::half::half_t one; static const mshadow::half::half_t zero; }; -template<> +template <> struct CublasType { static const int kFlag = mshadow::kUint8; #if CUDA_VERSION >= 8000 static const cudaDataType_t kCudaFlag = CUDA_R_8I; #endif typedef uint8_t ScaleType; - static const uint8_t one = 1; + static const uint8_t one = 1; static const uint8_t zero = 0; }; -template<> +template <> struct CublasType { static const int kFlag = mshadow::kInt32; #if CUDA_VERSION >= 8000 static const cudaDataType_t kCudaFlag = CUDA_R_32I; #endif typedef int32_t ScaleType; - static const int32_t one = 1; + static const int32_t one = 1; static const int32_t zero = 0; }; @@ -276,26 +277,26 @@ struct CublasType { */ inline const char* CublasGetErrorString(cublasStatus_t error) { switch (error) { - case CUBLAS_STATUS_SUCCESS: - return "CUBLAS_STATUS_SUCCESS"; - case CUBLAS_STATUS_NOT_INITIALIZED: - return "CUBLAS_STATUS_NOT_INITIALIZED"; - case CUBLAS_STATUS_ALLOC_FAILED: - return "CUBLAS_STATUS_ALLOC_FAILED"; - case CUBLAS_STATUS_INVALID_VALUE: - return "CUBLAS_STATUS_INVALID_VALUE"; - case CUBLAS_STATUS_ARCH_MISMATCH: - return "CUBLAS_STATUS_ARCH_MISMATCH"; - case CUBLAS_STATUS_MAPPING_ERROR: - return "CUBLAS_STATUS_MAPPING_ERROR"; - case CUBLAS_STATUS_EXECUTION_FAILED: - return "CUBLAS_STATUS_EXECUTION_FAILED"; - case CUBLAS_STATUS_INTERNAL_ERROR: - return "CUBLAS_STATUS_INTERNAL_ERROR"; - case CUBLAS_STATUS_NOT_SUPPORTED: - return "CUBLAS_STATUS_NOT_SUPPORTED"; - default: - break; + case CUBLAS_STATUS_SUCCESS: + return "CUBLAS_STATUS_SUCCESS"; + case CUBLAS_STATUS_NOT_INITIALIZED: + return "CUBLAS_STATUS_NOT_INITIALIZED"; + case CUBLAS_STATUS_ALLOC_FAILED: + return "CUBLAS_STATUS_ALLOC_FAILED"; + case CUBLAS_STATUS_INVALID_VALUE: + return "CUBLAS_STATUS_INVALID_VALUE"; + case CUBLAS_STATUS_ARCH_MISMATCH: + return "CUBLAS_STATUS_ARCH_MISMATCH"; + case CUBLAS_STATUS_MAPPING_ERROR: + return "CUBLAS_STATUS_MAPPING_ERROR"; + case CUBLAS_STATUS_EXECUTION_FAILED: + return "CUBLAS_STATUS_EXECUTION_FAILED"; + case CUBLAS_STATUS_INTERNAL_ERROR: + return "CUBLAS_STATUS_INTERNAL_ERROR"; + case CUBLAS_STATUS_NOT_SUPPORTED: + return "CUBLAS_STATUS_NOT_SUPPORTED"; + default: + break; } return "Unknown cuBLAS status"; } @@ -318,24 +319,24 @@ inline cublasOperation_t CublasTransposeOp(bool transpose) { */ inline const char* CusolverGetErrorString(cusolverStatus_t error) { switch (error) { - case CUSOLVER_STATUS_SUCCESS: - return "CUSOLVER_STATUS_SUCCESS"; - case CUSOLVER_STATUS_NOT_INITIALIZED: - return "CUSOLVER_STATUS_NOT_INITIALIZED"; - case CUSOLVER_STATUS_ALLOC_FAILED: - return "CUSOLVER_STATUS_ALLOC_FAILED"; - case CUSOLVER_STATUS_INVALID_VALUE: - return "CUSOLVER_STATUS_INVALID_VALUE"; - case CUSOLVER_STATUS_ARCH_MISMATCH: - return "CUSOLVER_STATUS_ARCH_MISMATCH"; - case CUSOLVER_STATUS_EXECUTION_FAILED: - return "CUSOLVER_STATUS_EXECUTION_FAILED"; - case CUSOLVER_STATUS_INTERNAL_ERROR: - return "CUSOLVER_STATUS_INTERNAL_ERROR"; - case CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED: - return "CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED"; - default: - break; + case CUSOLVER_STATUS_SUCCESS: + return "CUSOLVER_STATUS_SUCCESS"; + case CUSOLVER_STATUS_NOT_INITIALIZED: + return "CUSOLVER_STATUS_NOT_INITIALIZED"; + case CUSOLVER_STATUS_ALLOC_FAILED: + return "CUSOLVER_STATUS_ALLOC_FAILED"; + case CUSOLVER_STATUS_INVALID_VALUE: + return "CUSOLVER_STATUS_INVALID_VALUE"; + case CUSOLVER_STATUS_ARCH_MISMATCH: + return "CUSOLVER_STATUS_ARCH_MISMATCH"; + case CUSOLVER_STATUS_EXECUTION_FAILED: + return "CUSOLVER_STATUS_EXECUTION_FAILED"; + case CUSOLVER_STATUS_INTERNAL_ERROR: + return "CUSOLVER_STATUS_INTERNAL_ERROR"; + case CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED: + return "CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED"; + default: + break; } return "Unknown cuSOLVER status"; } @@ -347,53 +348,51 @@ inline const char* CusolverGetErrorString(cusolverStatus_t error) { */ inline const char* CurandGetErrorString(curandStatus_t status) { switch (status) { - case CURAND_STATUS_SUCCESS: - return "CURAND_STATUS_SUCCESS"; - case CURAND_STATUS_VERSION_MISMATCH: - return "CURAND_STATUS_VERSION_MISMATCH"; - case CURAND_STATUS_NOT_INITIALIZED: - return "CURAND_STATUS_NOT_INITIALIZED"; - case CURAND_STATUS_ALLOCATION_FAILED: - return "CURAND_STATUS_ALLOCATION_FAILED"; - case CURAND_STATUS_TYPE_ERROR: - return "CURAND_STATUS_TYPE_ERROR"; - case CURAND_STATUS_OUT_OF_RANGE: - return "CURAND_STATUS_OUT_OF_RANGE"; - case CURAND_STATUS_LENGTH_NOT_MULTIPLE: - return "CURAND_STATUS_LENGTH_NOT_MULTIPLE"; - case CURAND_STATUS_DOUBLE_PRECISION_REQUIRED: - return "CURAND_STATUS_DOUBLE_PRECISION_REQUIRED"; - case CURAND_STATUS_LAUNCH_FAILURE: - return "CURAND_STATUS_LAUNCH_FAILURE"; - case CURAND_STATUS_PREEXISTING_FAILURE: - return "CURAND_STATUS_PREEXISTING_FAILURE"; - case CURAND_STATUS_INITIALIZATION_FAILED: - return "CURAND_STATUS_INITIALIZATION_FAILED"; - case CURAND_STATUS_ARCH_MISMATCH: - return "CURAND_STATUS_ARCH_MISMATCH"; - case CURAND_STATUS_INTERNAL_ERROR: - return "CURAND_STATUS_INTERNAL_ERROR"; + case CURAND_STATUS_SUCCESS: + return "CURAND_STATUS_SUCCESS"; + case CURAND_STATUS_VERSION_MISMATCH: + return "CURAND_STATUS_VERSION_MISMATCH"; + case CURAND_STATUS_NOT_INITIALIZED: + return "CURAND_STATUS_NOT_INITIALIZED"; + case CURAND_STATUS_ALLOCATION_FAILED: + return "CURAND_STATUS_ALLOCATION_FAILED"; + case CURAND_STATUS_TYPE_ERROR: + return "CURAND_STATUS_TYPE_ERROR"; + case CURAND_STATUS_OUT_OF_RANGE: + return "CURAND_STATUS_OUT_OF_RANGE"; + case CURAND_STATUS_LENGTH_NOT_MULTIPLE: + return "CURAND_STATUS_LENGTH_NOT_MULTIPLE"; + case CURAND_STATUS_DOUBLE_PRECISION_REQUIRED: + return "CURAND_STATUS_DOUBLE_PRECISION_REQUIRED"; + case CURAND_STATUS_LAUNCH_FAILURE: + return "CURAND_STATUS_LAUNCH_FAILURE"; + case CURAND_STATUS_PREEXISTING_FAILURE: + return "CURAND_STATUS_PREEXISTING_FAILURE"; + case CURAND_STATUS_INITIALIZATION_FAILED: + return "CURAND_STATUS_INITIALIZATION_FAILED"; + case CURAND_STATUS_ARCH_MISMATCH: + return "CURAND_STATUS_ARCH_MISMATCH"; + case CURAND_STATUS_INTERNAL_ERROR: + return "CURAND_STATUS_INTERNAL_ERROR"; } return "Unknown cuRAND status"; } template inline DType __device__ CudaMax(DType a, DType b) { - return a > b ? a : b; + return a > b ? a : b; } template inline DType __device__ CudaMin(DType a, DType b) { - return a < b ? a : b; + return a < b ? a : b; } class DeviceStore { public: /*! \brief default constructor- only optionally restores previous device */ - explicit DeviceStore(int requested_device = -1, bool restore = true) : - restore_device_(-1), - current_device_(requested_device), - restore_(restore) { + explicit DeviceStore(int requested_device = -1, bool restore = true) + : restore_device_(-1), current_device_(requested_device), restore_(restore) { if (restore_) CUDA_CALL(cudaGetDevice(&restore_device_)); if (requested_device != restore_device_) { @@ -402,9 +401,7 @@ class DeviceStore { } ~DeviceStore() { - if (restore_ && - current_device_ != restore_device_ && - current_device_ != -1 && + if (restore_ && current_device_ != restore_device_ && current_device_ != -1 && restore_device_ != -1) CUDA_CALL(cudaSetDevice(restore_device_)); } @@ -462,8 +459,10 @@ constexpr size_t kMaxNumGpus = 64; * \param attr_name A string representation of the attribute, for error messages. * \return the gpu's attribute value. */ -inline int cudaAttributeLookup(int device_id, std::vector *cached_values, - cudaDeviceAttr attr, const char *attr_name) { +inline int cudaAttributeLookup(int device_id, + std::vector* cached_values, + cudaDeviceAttr attr, + const char* attr_name) { if (device_id < 0 || device_id >= static_cast(cached_values->size())) { LOG(FATAL) << attr_name << "(device_id) called with invalid id: " << device_id; } else if ((*cached_values)[device_id] < 0) { @@ -481,8 +480,8 @@ inline int cudaAttributeLookup(int device_id, std::vector *cached_value */ inline int ComputeCapabilityMajor(int device_id) { static std::vector capability_major(kMaxNumGpus, -1); - return cudaAttributeLookup(device_id, &capability_major, - cudaDevAttrComputeCapabilityMajor, "ComputeCapabilityMajor"); + return cudaAttributeLookup( + device_id, &capability_major, cudaDevAttrComputeCapabilityMajor, "ComputeCapabilityMajor"); } /*! @@ -492,8 +491,8 @@ inline int ComputeCapabilityMajor(int device_id) { */ inline int ComputeCapabilityMinor(int device_id) { static std::vector capability_minor(kMaxNumGpus, -1); - return cudaAttributeLookup(device_id, &capability_minor, - cudaDevAttrComputeCapabilityMinor, "ComputeCapabilityMinor"); + return cudaAttributeLookup( + device_id, &capability_minor, cudaDevAttrComputeCapabilityMinor, "ComputeCapabilityMinor"); } /*! @@ -514,8 +513,8 @@ inline int SMArch(int device_id) { */ inline int MultiprocessorCount(int device_id) { static std::vector sm_counts(kMaxNumGpus, -1); - return cudaAttributeLookup(device_id, &sm_counts, - cudaDevAttrMultiProcessorCount, "MultiprocessorCount"); + return cudaAttributeLookup( + device_id, &sm_counts, cudaDevAttrMultiProcessorCount, "MultiprocessorCount"); } /*! @@ -525,7 +524,8 @@ inline int MultiprocessorCount(int device_id) { */ inline int MaxSharedMemoryPerMultiprocessor(int device_id) { static std::vector max_smem_per_mutiprocessor(kMaxNumGpus, -1); - return cudaAttributeLookup(device_id, &max_smem_per_mutiprocessor, + return cudaAttributeLookup(device_id, + &max_smem_per_mutiprocessor, cudaDevAttrMaxSharedMemoryPerMultiprocessor, "MaxSharedMemoryPerMultiprocessor"); } @@ -537,8 +537,8 @@ inline int MaxSharedMemoryPerMultiprocessor(int device_id) { */ inline bool SupportsCooperativeLaunch(int device_id) { static std::vector coop_launch(kMaxNumGpus, -1); - return cudaAttributeLookup(device_id, &coop_launch, - cudaDevAttrCooperativeLaunch, "SupportsCooperativeLaunch"); + return cudaAttributeLookup( + device_id, &coop_launch, cudaDevAttrCooperativeLaunch, "SupportsCooperativeLaunch"); } /*! @@ -566,8 +566,7 @@ inline bool SupportsFloat16Compute(int device_id) { */ inline bool SupportsTensorCore(int device_id) { // Volta (sm_70) supports TensorCore algos - return device_id >= 0 && - ComputeCapabilityMajor(device_id) >=7; + return device_id >= 0 && ComputeCapabilityMajor(device_id) >= 7; } // The policy if the user hasn't set the environment variable MXNET_CUDA_ALLOW_TENSOR_CORE @@ -582,12 +581,12 @@ inline bool GetEnvAllowTensorCore() { // separately in each compilation unit. Not ideal, but cleaner than creating a // cuda_utils.cc solely to have a single instance and initialization. static bool allow_tensor_core = false; - static bool is_set = false; + static bool is_set = false; if (!is_set) { // Use of optional here permits: "0", "1", "true" and "false" to all be legal. bool default_value = MXNET_CUDA_ALLOW_TENSOR_CORE_DEFAULT; - allow_tensor_core = dmlc::GetEnv("MXNET_CUDA_ALLOW_TENSOR_CORE", - dmlc::optional(default_value)).value(); + allow_tensor_core = + dmlc::GetEnv("MXNET_CUDA_ALLOW_TENSOR_CORE", dmlc::optional(default_value)).value(); is_set = true; } return allow_tensor_core; @@ -630,19 +629,22 @@ inline cublasMath_t SetCublasMathMode(cublasHandle_t blas_handle, cublasMath_t n static_assert(CUDNN_PATCHLEVEL < 100 && CUDNN_MINOR < 10, "CUDNN_VERSION_AS_STRING macro assumptions violated."); #if CUDNN_PATCHLEVEL >= 10 -#define CUDNN_VERSION_AS_STRING QUOTEVALUE(CUDNN_MAJOR) \ - QUOTEVALUE(CUDNN_MINOR) \ - QUOTEVALUE(CUDNN_PATCHLEVEL) +#define CUDNN_VERSION_AS_STRING \ + QUOTEVALUE(CUDNN_MAJOR) \ + QUOTEVALUE(CUDNN_MINOR) \ + QUOTEVALUE(CUDNN_PATCHLEVEL) #else -#define CUDNN_VERSION_AS_STRING QUOTEVALUE(CUDNN_MAJOR) \ - QUOTEVALUE(CUDNN_MINOR) \ - "0" QUOTEVALUE(CUDNN_PATCHLEVEL) +#define CUDNN_VERSION_AS_STRING \ + QUOTEVALUE(CUDNN_MAJOR) \ + QUOTEVALUE(CUDNN_MINOR) \ + "0" QUOTEVALUE(CUDNN_PATCHLEVEL) #endif -#define STATIC_ASSERT_CUDNN_VERSION_GE(min_version) \ - static_assert(CUDNN_VERSION >= min_version, "Compiled-against cuDNN version " \ - CUDNN_VERSION_AS_STRING " is too old, please upgrade system to version " \ - QUOTEVALUE(min_version) " or later.") +#define STATIC_ASSERT_CUDNN_VERSION_GE(min_version) \ + static_assert( \ + CUDNN_VERSION >= min_version, \ + "Compiled-against cuDNN version " CUDNN_VERSION_AS_STRING \ + " is too old, please upgrade system to version " QUOTEVALUE(min_version) " or later.") #define CUDNN_CALL(func) \ { \ @@ -703,17 +705,16 @@ inline int MaxBackwardDataAlgos(cudnnHandle_t cudnn_handle) { // Overload atomicAdd to work for floats on all architectures #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 600 // From CUDA Programming Guide -static inline __device__ void atomicAdd(double *address, double val) { - unsigned long long* address_as_ull = // NOLINT(*) - reinterpret_cast(address); // NOLINT(*) - unsigned long long old = *address_as_ull; // NOLINT(*) - unsigned long long assumed; // NOLINT(*) +static inline __device__ void atomicAdd(double* address, double val) { + unsigned long long* address_as_ull = // NOLINT(*) + reinterpret_cast(address); // NOLINT(*) + unsigned long long old = *address_as_ull; // NOLINT(*) + unsigned long long assumed; // NOLINT(*) do { assumed = old; - old = atomicCAS(address_as_ull, assumed, - __double_as_longlong(val + - __longlong_as_double(assumed))); + old = atomicCAS( + address_as_ull, assumed, __double_as_longlong(val + __longlong_as_double(assumed))); // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN) } while (assumed != old); @@ -724,68 +725,65 @@ static inline __device__ void atomicAdd(double *address, double val) { // Taken from: // https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh #ifdef __CUDACC__ -static inline __device__ void atomicAdd(mshadow::half::half_t *address, - mshadow::half::half_t val) { - unsigned int *address_as_ui = - reinterpret_cast(reinterpret_cast(address) - - (reinterpret_cast(address) & 2)); +static inline __device__ void atomicAdd(mshadow::half::half_t* address, mshadow::half::half_t val) { + unsigned int* address_as_ui = reinterpret_cast( + reinterpret_cast(address) - (reinterpret_cast(address) & 2)); unsigned int old = *address_as_ui; unsigned int assumed; do { assumed = old; mshadow::half::half_t hsum; - hsum.half_ = - reinterpret_cast(address) & 2 ? (old >> 16) : (old & 0xffff); + hsum.half_ = reinterpret_cast(address) & 2 ? (old >> 16) : (old & 0xffff); hsum += val; - old = reinterpret_cast(address) & 2 - ? (old & 0xffff) | (hsum.half_ << 16) - : (old & 0xffff0000) | hsum.half_; + old = reinterpret_cast(address) & 2 ? (old & 0xffff) | (hsum.half_ << 16) + : (old & 0xffff0000) | hsum.half_; old = atomicCAS(address_as_ui, assumed, old); } while (assumed != old); } -static inline __device__ void atomicAdd(uint8_t *address, uint8_t val) { - unsigned int * address_as_ui = (unsigned int *) (address - ((size_t)address & 0x3)); - unsigned int old = *address_as_ui; - unsigned int shift = (((size_t)address & 0x3) << 3); +static inline __device__ void atomicAdd(uint8_t* address, uint8_t val) { + unsigned int* address_as_ui = (unsigned int*)(address - ((size_t)address & 0x3)); + unsigned int old = *address_as_ui; + unsigned int shift = (((size_t)address & 0x3) << 3); unsigned int sum; unsigned int assumed; do { assumed = old; - sum = val + static_cast((old >> shift) & 0xff); - old = (old & ~(0x000000ff << shift)) | (sum << shift); - old = atomicCAS(address_as_ui, assumed, old); + sum = val + static_cast((old >> shift) & 0xff); + old = (old & ~(0x000000ff << shift)) | (sum << shift); + old = atomicCAS(address_as_ui, assumed, old); } while (assumed != old); } -static inline __device__ void atomicAdd(int8_t *address, int8_t val) { - unsigned int * address_as_ui = (unsigned int *) (address - ((size_t)address & 0x3)); - unsigned int old = *address_as_ui; - unsigned int shift = (((size_t)address & 0x3) << 3); +static inline __device__ void atomicAdd(int8_t* address, int8_t val) { + unsigned int* address_as_ui = (unsigned int*)(address - ((size_t)address & 0x3)); + unsigned int old = *address_as_ui; + unsigned int shift = (((size_t)address & 0x3) << 3); unsigned int sum; unsigned int assumed; do { assumed = old; - sum = val + static_cast((old >> shift) & 0xff); - old = (old & ~(0x000000ff << shift)) | (sum << shift); - old = atomicCAS(address_as_ui, assumed, old); + sum = val + static_cast((old >> shift) & 0xff); + old = (old & ~(0x000000ff << shift)) | (sum << shift); + old = atomicCAS(address_as_ui, assumed, old); } while (assumed != old); } // Overload atomicAdd to work for signed int64 on all architectures -static inline __device__ void atomicAdd(int64_t *address, int64_t val) { - atomicAdd(reinterpret_cast(address), static_cast(val)); // NOLINT +static inline __device__ void atomicAdd(int64_t* address, int64_t val) { + atomicAdd(reinterpret_cast(address), // NOLINT + static_cast(val)); // NOLINT } template __device__ inline DType ldg(const DType* address) { #if __CUDA_ARCH__ >= 350 - return __ldg(address); + return __ldg(address); #else - return *address; + return *address; #endif } @@ -806,7 +804,8 @@ template __device__ inline T warp_reduce(T value, OP redfun) { #pragma unroll for (int i = warp_size / 2; i >= 1; i /= 2) { - if (NVALUES > i) value = redfun(value, __shfl_down_sync(0xffffffff, value, i)); + if (NVALUES > i) + value = redfun(value, __shfl_down_sync(0xffffffff, value, i)); } return value; } @@ -824,7 +823,8 @@ __device__ inline mshadow::half::half_t warp_reduce(mshadow::half::half_t value, float v = static_cast(value); #pragma unroll for (int i = warp_size / 2; i >= 1; i /= 2) { - if (NValues > i) v = redfun(v, __shfl_down_sync(0xffffffff, v, i)); + if (NValues > i) + v = redfun(v, __shfl_down_sync(0xffffffff, v, i)); } return mshadow::half::half_t(v); } @@ -843,12 +843,11 @@ __device__ inline mshadow::half::half_t warp_reduce(mshadow::half::half_t value, */ template __device__ inline T reduce(const T& value, OP redfun) { - static_assert(NTHREADS <= warp_size * warp_size, - "Number of threads too large for reduction"); + static_assert(NTHREADS <= warp_size * warp_size, "Number of threads too large for reduction"); __shared__ T scratch[NTHREADS / warp_size]; const int thread_idx_in_warp = threadIdx.x % warp_size; - const int warp_id = threadIdx.x / warp_size; - const T my_val = warp_reduce(value, redfun); + const int warp_id = threadIdx.x / warp_size; + const T my_val = warp_reduce(value, redfun); if (thread_idx_in_warp == 0) { scratch[warp_id] = my_val; } @@ -856,7 +855,7 @@ __device__ inline T reduce(const T& value, OP redfun) { T ret = 0; if (warp_id == 0) { const T prev_val = threadIdx.x < (NTHREADS / warp_size) ? scratch[threadIdx.x] : 0; - const T my_val = warp_reduce(prev_val, redfun); + const T my_val = warp_reduce(prev_val, redfun); if (all_reduce) { scratch[threadIdx.x] = my_val; } else { diff --git a/src/common/exec_utils.cc b/src/common/exec_utils.cc index 601d1c0b6d96..bbc11e12a708 100644 --- a/src/common/exec_utils.cc +++ b/src/common/exec_utils.cc @@ -30,26 +30,26 @@ namespace mxnet { namespace common { -void CopyGraph(nnvm::Graph *dst, const nnvm::Graph &src, bool copy_variables) { +void CopyGraph(nnvm::Graph* dst, const nnvm::Graph& src, bool copy_variables) { using nnvm::Node; - using nnvm::ObjectPtr; using nnvm::NodeEntry; + using nnvm::ObjectPtr; std::unordered_map old_new; // use DFSVisit to copy all the nodes DFSVisit(src.outputs, [&old_new, copy_variables](const ObjectPtr& node) { - ObjectPtr np; - if (copy_variables || !node->is_variable()) { - np = Node::Create(); - np->attrs = node->attrs; - } else { - np = node; - } - old_new[node.get()] = std::move(np); - }); + ObjectPtr np; + if (copy_variables || !node->is_variable()) { + np = Node::Create(); + np->attrs = node->attrs; + } else { + np = node; + } + old_new[node.get()] = std::move(np); + }); // connect nodes of new graph - for (const auto &kv : old_new) { + for (const auto& kv : old_new) { for (const NodeEntry& e : kv.first->inputs) { - Node *ptr = e.node.get(); + Node* ptr = e.node.get(); kv.second->inputs.emplace_back(NodeEntry{old_new[ptr], e.index, e.version}); } for (const ObjectPtr& p : kv.first->control_deps) { @@ -57,15 +57,15 @@ void CopyGraph(nnvm::Graph *dst, const nnvm::Graph &src, bool copy_variables) { } } // set the head - for (const NodeEntry &e : src.outputs) { + for (const NodeEntry& e : src.outputs) { (*dst).outputs.emplace_back(NodeEntry{old_new[e.node.get()], e.index, e.version}); } } -bool CheckForInputNameDuplicates(const nnvm::IndexedGraph &idx) { +bool CheckForInputNameDuplicates(const nnvm::IndexedGraph& idx) { std::unordered_set names; for (const auto& nid : idx.input_nodes()) { - const std::string &name = idx[nid].source->attrs.name; + const std::string& name = idx[nid].source->attrs.name; if (names.count(name)) { LOG(WARNING) << "Variable name " << name << " is used more than once!"; return false; diff --git a/src/common/exec_utils.h b/src/common/exec_utils.h index 80936a916b4d..ec2aa7cb6975 100644 --- a/src/common/exec_utils.h +++ b/src/common/exec_utils.h @@ -37,10 +37,10 @@ namespace mxnet { namespace common { #if MXNET_USE_ONEDNN == 1 - // We have to make sure it's default storage and default layout. -#define DEFAULT_DATA(x) x.IsDefaultData() +// We have to make sure it's default storage and default layout. +#define DEFAULT_DATA(x) x.IsDefaultData() #else -#define DEFAULT_DATA(x) (x.storage_type() == kDefaultStorage) +#define DEFAULT_DATA(x) (x.storage_type() == kDefaultStorage) #endif /* @@ -57,18 +57,18 @@ namespace common { * \return true if any source NDArray need to cast storage */ inline bool SetupDefaultBlobsIn(const std::vector& src, - const std::vector *bufs, - std::vector *blobs, - std::vector *temp_src, - std::vector *temp_dst, - std::unordered_map *idx_map) { + const std::vector* bufs, + std::vector* blobs, + std::vector* temp_src, + std::vector* temp_dst, + std::unordered_map* idx_map) { bool require_cast = false; for (size_t i = 0; i < src.size(); i++) { const auto& nd = src[i]; if (!DEFAULT_DATA(nd)) { (*idx_map)[i] = temp_dst->size(); - NDArray temp = bufs != nullptr ? bufs->at(i) : NDArray(nd.shape(), nd.ctx(), - true, nd.dtype()); + NDArray temp = + bufs != nullptr ? bufs->at(i) : NDArray(nd.shape(), nd.ctx(), true, nd.dtype()); #if MXNET_USE_ONEDNN == 1 CHECK(temp.IsDefaultData()); #endif @@ -84,11 +84,11 @@ inline bool SetupDefaultBlobsIn(const std::vector& src, } inline bool SetupDefaultBlobsOut(const std::vector& src, - const std::vector *bufs, - std::vector *req, - std::vector *blobs, - std::vector *temp_src, - std::vector *temp_dst) { + const std::vector* bufs, + std::vector* req, + std::vector* blobs, + std::vector* temp_src, + std::vector* temp_dst) { bool require_cast = false; for (size_t i = 0; i < src.size(); i++) { const auto& nd = src[i]; @@ -100,7 +100,7 @@ inline bool SetupDefaultBlobsOut(const std::vector& src, // the input array and the output array are no longer the same array. // we should change the request type. req->at(i) = kWriteTo; - // We have to make sure it's default storage and default layout. + // We have to make sure it's default storage and default layout. #endif if (!DEFAULT_DATA(nd)) { #if MXNET_USE_ONEDNN == 1 @@ -108,14 +108,14 @@ inline bool SetupDefaultBlobsOut(const std::vector& src, if (bufs != nullptr) { temp = bufs->at(i); } else if (kAddTo == req->at(i)) { - temp = nd.IsMKLDNNData()? nd.Reorder2Default() : nd; + temp = nd.IsMKLDNNData() ? nd.Reorder2Default() : nd; } else { temp = NDArray(nd.shape(), nd.ctx(), true, nd.dtype()); } CHECK(temp.IsDefaultData()); #else - NDArray temp = bufs != nullptr ? bufs->at(i) : NDArray(nd.shape(), nd.ctx(), - true, nd.dtype()); + NDArray temp = + bufs != nullptr ? bufs->at(i) : NDArray(nd.shape(), nd.ctx(), true, nd.dtype()); #endif temp_src->emplace_back(nd); temp_dst->emplace_back(temp); @@ -135,25 +135,23 @@ inline bool SetupDefaultBlobsOut(const std::vector& src, * function also records the indices of non-default source NDArrays and the indices of * their corresponding temporary NDArrays in the temp array. */ -inline void SetupDefaultBlobsInOut(const std::vector &ndinputs, - const std::vector &ndoutputs, - const std::vector *in_bufs, - const std::vector *out_bufs, - std::vector *req, - std::vector *input_blobs, - std::vector *output_blobs, - std::vector *pre_temp_src, - std::vector *pre_temp_dst, - std::vector *post_temp_src, - std::vector *post_temp_dst, - std::unordered_map *in_temp_idx_map, - const std::vector &mutate_idx) { +inline void SetupDefaultBlobsInOut(const std::vector& ndinputs, + const std::vector& ndoutputs, + const std::vector* in_bufs, + const std::vector* out_bufs, + std::vector* req, + std::vector* input_blobs, + std::vector* output_blobs, + std::vector* pre_temp_src, + std::vector* pre_temp_dst, + std::vector* post_temp_src, + std::vector* post_temp_dst, + std::unordered_map* in_temp_idx_map, + const std::vector& mutate_idx) { // populate input blobs - SetupDefaultBlobsIn(ndinputs, in_bufs, input_blobs, pre_temp_src, pre_temp_dst, - in_temp_idx_map); + SetupDefaultBlobsIn(ndinputs, in_bufs, input_blobs, pre_temp_src, pre_temp_dst, in_temp_idx_map); // populate output blobs - SetupDefaultBlobsOut(ndoutputs, out_bufs, req, output_blobs, post_temp_dst, - post_temp_src); + SetupDefaultBlobsOut(ndoutputs, out_bufs, req, output_blobs, post_temp_dst, post_temp_src); // add mutable inputs to post temp list for (const auto idx : mutate_idx) { auto map_iter = in_temp_idx_map->find(idx); @@ -193,22 +191,25 @@ inline void CastNonDefaultStorage(const std::vector& src, * types to the same type of one of the inputs or outputs. */ inline bool SameType(const nnvm::NodeAttrs& attrs, - std::vector *iattr, - std::vector *oattr) { + std::vector* iattr, + std::vector* oattr) { int def_v = -1; for (int v : *oattr) { if (v != -1) { - def_v = v; break; + def_v = v; + break; } } if (def_v == -1) { for (int v : *iattr) { if (v != -1) { - def_v = v; break; + def_v = v; + break; } } } - if (def_v == -1) return false; + if (def_v == -1) + return false; for (int& v : *oattr) { v = def_v; } @@ -218,7 +219,6 @@ inline bool SameType(const nnvm::NodeAttrs& attrs, return true; } - /*! \brief The default storage type inference function, which assigns all undefined * storage types to kDefaultStorage. If all of input and output storage types * are kDefaultStorage, DispatchMode::kFCompute is assigned to dispatch_mode. Otherwise, @@ -227,16 +227,20 @@ inline bool SameType(const nnvm::NodeAttrs& attrs, inline bool DefaultStorageType(const nnvm::NodeAttrs& attrs, const int dev_mask, DispatchMode* dispatch_mode, - std::vector *iattr, - std::vector *oattr) { + std::vector* iattr, + std::vector* oattr) { bool fallback = false; for (int& v : *oattr) { - if (v == -1) v = kDefaultStorage; - if (v != kDefaultStorage) fallback = true; + if (v == -1) + v = kDefaultStorage; + if (v != kDefaultStorage) + fallback = true; } for (int& v : *iattr) { - if (v == -1) v = kDefaultStorage; - if (v != kDefaultStorage) fallback = true; + if (v == -1) + v = kDefaultStorage; + if (v != kDefaultStorage) + fallback = true; } if (*dispatch_mode == DispatchMode::kUndefined) { if (fallback) { @@ -282,15 +286,15 @@ inline std::string storage_str(int storage_id) { ... */ inline void LogMemoryPlan(const nnvm::Graph& g) { - const auto &idx = g.indexed_graph(); + const auto& idx = g.indexed_graph(); const auto& vshape = g.GetAttr("shape"); - const auto& vtype = g.GetAttr("dtype"); + const auto& vtype = g.GetAttr("dtype"); // find node range uint32_t node_start = 0, node_end = idx.num_nodes(); if (g.attrs.count("node_range")) { const auto& range = g.GetAttr >("node_range"); - node_start = range.first; - node_end = range.second; + node_start = range.first; + node_end = range.second; } for (uint32_t nid = node_start; nid < node_end; ++nid) { const auto& inode = idx[nid]; @@ -299,16 +303,14 @@ inline void LogMemoryPlan(const nnvm::Graph& g) { } else { LOG(INFO) << "node " << nid << " " << inode.source->attrs.op->name; for (const auto& e : inode.inputs) { - auto eid = idx.entry_id(e); + auto eid = idx.entry_id(e); size_t kilo_bytes = vshape[eid].Size() * mshadow::mshadow_sizeof(vtype[eid]) / 1024; - LOG(INFO) << "\t\tinput " << eid << ": " << vshape[eid] << " (" - << kilo_bytes << " KB)"; + LOG(INFO) << "\t\tinput " << eid << ": " << vshape[eid] << " (" << kilo_bytes << " KB)"; } for (uint32_t index = 0; index < inode.source->num_outputs(); ++index) { - uint32_t eid = idx.entry_id(nid, index); + uint32_t eid = idx.entry_id(nid, index); size_t kilo_bytes = vshape[eid].Size() * mshadow::mshadow_sizeof(vtype[eid]) / 1024; - LOG(INFO) << "\t\toutput " << eid << ": " << vshape[eid] << " (" - << kilo_bytes << " KB)"; + LOG(INFO) << "\t\toutput " << eid << ": " << vshape[eid] << " (" << kilo_bytes << " KB)"; } } } @@ -340,22 +342,22 @@ inline void LogMemoryPlan(const nnvm::Graph& g) { ... */ inline void LogInferStorage(const nnvm::Graph& g) { - const auto &idx = g.indexed_graph(); - const auto& vstorage_type = g.GetAttr("storage_type"); + const auto& idx = g.indexed_graph(); + const auto& vstorage_type = g.GetAttr("storage_type"); const auto& dispatch_modes = g.GetAttr("dispatch_mode"); uint32_t node_start = 0, node_end = idx.num_nodes(); if (g.attrs.count("node_range")) { const auto& range = g.GetAttr >("node_range"); - node_start = range.first; - node_end = range.second; + node_start = range.first; + node_end = range.second; } for (uint32_t nid = node_start; nid < node_end; ++nid) { const auto& inode = idx[nid]; if (inode.source->is_variable()) { LOG(INFO) << "node " << nid << " var"; } else { - LOG(INFO) << "node " << nid << " " << inode.source->attrs.op->name - << ": " << dispatch_mode_string(dispatch_modes[nid]); + LOG(INFO) << "node " << nid << " " << inode.source->attrs.op->name << ": " + << dispatch_mode_string(dispatch_modes[nid]); for (const auto& e : inode.inputs) { auto eid = idx.entry_id(e); LOG(INFO) << "\t\tinput " << eid << ": " << stype_string(vstorage_type[eid]); @@ -432,39 +434,39 @@ inline nnvm::Graph AssignContext(nnvm::Graph g, const std::vector& grad_req_types, size_t num_forward_inputs, size_t num_forward_outputs) { - const auto& idx = g.indexed_graph(); + const auto& idx = g.indexed_graph(); const auto& mutable_nodes = idx.mutable_input_nodes(); // default use default context. if (ctx_map.size() == 0) { - g.attrs["context"] = std::make_shared( - exec::ContextVector(idx.num_nodes(), default_ctx)); + g.attrs["context"] = + std::make_shared(exec::ContextVector(idx.num_nodes(), default_ctx)); for (const auto& x : in_arg_ctxes) { - CHECK(x == default_ctx) - << "Input array is in " << x << " while binding with ctx=" << default_ctx - << ". All arguments must be in global context (" << default_ctx - << ") unless group2ctx is specified for cross-device graph."; + CHECK(x == default_ctx) << "Input array is in " << x + << " while binding with ctx=" << default_ctx + << ". All arguments must be in global context (" << default_ctx + << ") unless group2ctx is specified for cross-device graph."; } for (const auto& x : arg_grad_ctxes) { - CHECK(x == default_ctx) - << "Gradient array is in " << x << " while binding with ctx=" - << default_ctx << ". All gradients must be in global context (" << default_ctx - << ") unless group2ctx is specified for cross-device graph."; + CHECK(x == default_ctx) << "Gradient array is in " << x + << " while binding with ctx=" << default_ctx + << ". All gradients must be in global context (" << default_ctx + << ") unless group2ctx is specified for cross-device graph."; } return g; } // otherwise, use context assignment. - std::map ctx2id; // map ctx to device id - std::vector ctx_list; // index is device id + std::map ctx2id; // map ctx to device id + std::vector ctx_list; // index is device id nnvm::DeviceVector device(idx.num_nodes(), -1); // index is node id - nnvm::DeviceAssignMap device_map; // map arg name to device id + nnvm::DeviceAssignMap device_map; // map arg name to device id // loop through the user input ctx_map and // populate maps and lists - for (auto &kv : ctx_map) { + for (auto& kv : ctx_map) { if (ctx2id.count(kv.second) == 0) { // if context has no device id, create one ctx2id[kv.second] = static_cast(ctx_list.size()); // assign device id to ctx - ctx_list.push_back(kv.second); // save ctx to the list + ctx_list.push_back(kv.second); // save ctx to the list } // assign device id to to the arg name with the corresponding ctx device_map[kv.first] = ctx2id.at(kv.second); @@ -487,7 +489,7 @@ inline nnvm::Graph AssignContext(nnvm::Graph g, } if (ctx2id.count(ctx) == 0) { // if the current ctx is not in the map of ctx and device id ctx2id[ctx] = static_cast(ctx_list.size()); // assign the current ctx with device id - ctx_list.push_back(ctx); // save the current ctx in the list + ctx_list.push_back(ctx); // save the current ctx in the list } device[nid] = ctx2id.at(ctx); // assign device id to the current node } @@ -500,9 +502,10 @@ inline nnvm::Graph AssignContext(nnvm::Graph g, CHECK_GE(grad_req_types.size(), g.outputs.size() - num_forward_outputs) << "insufficient number of grad_reqs"; for (size_t i = num_forward_outputs; i < g.outputs.size(); ++i, ++arg_grad_offset) { - while (grad_req_types[arg_grad_offset] == kNullOp) ++arg_grad_offset; + while (grad_req_types[arg_grad_offset] == kNullOp) + ++arg_grad_offset; const uint32_t nid = idx.outputs()[i].node_id; - Context ctx = arg_grad_ctxes[arg_grad_offset]; + Context ctx = arg_grad_ctxes[arg_grad_offset]; if (ctx2id.count(ctx) == 0) { ctx2id[ctx] = static_cast(ctx_list.size()); ctx_list.push_back(ctx); @@ -516,7 +519,7 @@ inline nnvm::Graph AssignContext(nnvm::Graph g, } g.attrs["device"] = std::make_shared(std::move(device)); - g = nnvm::pass::PlaceDevice(g, "__ctx_group__", device_map, "_CrossDeviceCopy"); + g = nnvm::pass::PlaceDevice(g, "__ctx_group__", device_map, "_CrossDeviceCopy"); const auto& assigned_devices = g.GetAttr("device"); exec::ContextVector vcontext; @@ -531,17 +534,17 @@ inline nnvm::Graph AssignContext(nnvm::Graph g, // after device planning, we should check again // if the assigned device of gradient node // corresponds to storage of grads - auto &new_idx = g.indexed_graph(); + auto& new_idx = g.indexed_graph(); arg_grad_offset = 0; for (size_t i = num_forward_outputs; i < g.outputs.size(); ++i, ++arg_grad_offset) { - while (grad_req_types[arg_grad_offset] == kNullOp) ++arg_grad_offset; + while (grad_req_types[arg_grad_offset] == kNullOp) + ++arg_grad_offset; const uint32_t nid = new_idx.outputs()[i].node_id; - Context ctx = arg_grad_ctxes[arg_grad_offset]; - CHECK(ctx == vcontext[nid]) - << "Trying to save gradient to " << ctx - << " while its source node \"" << new_idx[nid].source->attrs.name - << "\" computes it on " << vcontext[nid] - << ". Check your ctx in NDArray allocation."; + Context ctx = arg_grad_ctxes[arg_grad_offset]; + CHECK(ctx == vcontext[nid]) << "Trying to save gradient to " << ctx + << " while its source node \"" << new_idx[nid].source->attrs.name + << "\" computes it on " << vcontext[nid] + << ". Check your ctx in NDArray allocation."; } g.attrs["context"] = std::make_shared(std::move(vcontext)); @@ -556,7 +559,7 @@ inline nnvm::Graph AssignContext(nnvm::Graph g, * \param copy_variable whether to copy or reuse Variable nodes from the * source graph */ -void CopyGraph(nnvm::Graph *dst, const nnvm::Graph &src, bool copy_variables); +void CopyGraph(nnvm::Graph* dst, const nnvm::Graph& src, bool copy_variables); /*! * \brief Check whether graph contains any duplicated names in its inputs. @@ -565,9 +568,8 @@ void CopyGraph(nnvm::Graph *dst, const nnvm::Graph &src, bool copy_variables); * * \return true if there are no duplicates, false otherwise */ -bool CheckForInputNameDuplicates(const nnvm::IndexedGraph &idx); +bool CheckForInputNameDuplicates(const nnvm::IndexedGraph& idx); } // namespace common } // namespace mxnet #endif // MXNET_COMMON_EXEC_UTILS_H_ - diff --git a/src/common/lazy_alloc_array.h b/src/common/lazy_alloc_array.h index 0fd5acd63d59..f6b4dac2d87d 100644 --- a/src/common/lazy_alloc_array.h +++ b/src/common/lazy_alloc_array.h @@ -36,7 +36,7 @@ namespace mxnet { namespace common { -template +template class LazyAllocArray { public: LazyAllocArray(); @@ -46,23 +46,22 @@ class LazyAllocArray { * \param index the array index position * \param creator a lambda function to create new element when needed. */ - template + template inline std::shared_ptr Get(int index, FCreate creator); /*! * \brief for each not null element of the array, call fvisit * \param fvisit a function of (size_t, TElem*) */ - template + template inline void ForEach(FVisit fvisit); /*! \brief clear all the allocated elements in array */ inline void Clear(); private: - template + template class unique_unlock { public: - explicit unique_unlock(std::unique_lock *lock) - : lock_(lock) { + explicit unique_unlock(std::unique_lock* lock) : lock_(lock) { if (lock_) { lock_->unlock(); } @@ -72,8 +71,9 @@ class LazyAllocArray { lock_->lock(); } } + private: - std::unique_lock *lock_; + std::unique_lock* lock_; }; /*! \brief the initial size of the array */ @@ -88,14 +88,12 @@ class LazyAllocArray { std::atomic is_clearing_; }; -template -inline LazyAllocArray::LazyAllocArray() - : is_clearing_(false) { -} +template +inline LazyAllocArray::LazyAllocArray() : is_clearing_(false) {} // implementations -template -template +template +template inline std::shared_ptr LazyAllocArray::Get(int index, FCreate creator) { CHECK_GE(index, 0); size_t idx = static_cast(index); @@ -135,7 +133,7 @@ inline std::shared_ptr LazyAllocArray::Get(int index, FCreate crea return nullptr; } -template +template inline void LazyAllocArray::Clear() { std::unique_lock lock(create_mutex_); is_clearing_.store(true); @@ -144,13 +142,13 @@ inline void LazyAllocArray::Clear() { // any growth which might happen when create_mutex_ is unlocked for (size_t i = 0; i < head_.size(); ++i) { std::shared_ptr p = head_[i]; - head_[i] = std::shared_ptr(nullptr); + head_[i] = std::shared_ptr(nullptr); unique_unlock unlocker(&lock); p = std::shared_ptr(nullptr); } for (size_t i = 0; i < more_.size(); ++i) { std::shared_ptr p = more_[i]; - more_[i] = std::shared_ptr(nullptr); + more_[i] = std::shared_ptr(nullptr); unique_unlock unlocker(&lock); p = std::shared_ptr(nullptr); } @@ -158,8 +156,8 @@ inline void LazyAllocArray::Clear() { is_clearing_.store(false); } -template -template +template +template inline void LazyAllocArray::ForEach(FVisit fvisit) { std::lock_guard lock(create_mutex_); for (size_t i = 0; i < head_.size(); ++i) { diff --git a/src/common/object_pool.h b/src/common/object_pool.h index f0a651182431..f822604ce912 100644 --- a/src/common/object_pool.h +++ b/src/common/object_pool.h @@ -150,7 +150,7 @@ T* ObjectPool::New(Args&&... args) { if (head_->next == nullptr) { AllocateChunk(); } - ret = head_; + ret = head_; head_ = head_->next; } return new (static_cast(ret)) T(std::forward(args)...); @@ -163,7 +163,7 @@ void ObjectPool::Delete(T* ptr) { { std::lock_guard lock{m_}; linked_list_ptr->next = head_; - head_ = linked_list_ptr; + head_ = linked_list_ptr; } } @@ -199,12 +199,12 @@ void ObjectPool::AllocateChunk() { #endif allocated_.emplace_back(new_chunk_ptr); auto new_chunk = static_cast(new_chunk_ptr); - auto size = kPageSize / sizeof(LinkedList); + auto size = kPageSize / sizeof(LinkedList); for (std::size_t i = 0; i < size - 1; ++i) { new_chunk[i].next = &new_chunk[i + 1]; } new_chunk[size - 1].next = head_; - head_ = new_chunk; + head_ = new_chunk; } template diff --git a/src/common/random_generator.cu b/src/common/random_generator.cu index 8f7b95985d02..caced2b95b00 100644 --- a/src/common/random_generator.cu +++ b/src/common/random_generator.cu @@ -31,46 +31,43 @@ namespace mxnet { namespace common { namespace random { -template<> +template <> const int RandGenerator::kMinNumRandomPerThread = 64; -template<> +template <> const int RandGenerator::kNumRandomStates = 32768; -__global__ void rand_generator_seed_kernel(curandStatePhilox4_32_10_t *states_, +__global__ void rand_generator_seed_kernel(curandStatePhilox4_32_10_t* states_, const int size, uint32_t seed) { int id = blockIdx.x * blockDim.x + threadIdx.x; - if (id < size) curand_init(seed, id, 0, states_ + id); + if (id < size) + curand_init(seed, id, 0, states_ + id); } -template<> -void RandGenerator::Seed(mshadow::Stream *s, uint32_t seed) { +template <> +void RandGenerator::Seed(mshadow::Stream* s, uint32_t seed) { using namespace mshadow::cuda; - int ngrid = std::min(kMaxGridNum, - (RandGenerator::kNumRandomStates + kBaseThreadNum - 1) / - kBaseThreadNum); - rand_generator_seed_kernel - <<::GetStream(s)>>>( - states_, - RandGenerator::kNumRandomStates, - seed); + int ngrid = + std::min(kMaxGridNum, + (RandGenerator::kNumRandomStates + kBaseThreadNum - 1) / kBaseThreadNum); + rand_generator_seed_kernel<<::GetStream(s)>>>( + states_, RandGenerator::kNumRandomStates, seed); MSHADOW_CUDA_POST_KERNEL_CHECK(rand_generator_seed_kernel); s->Wait(); } -template<> -void RandGenerator::AllocState(RandGenerator *inst) { - CUDA_CALL(cudaMalloc(&inst->states_, - kNumRandomStates * sizeof(curandStatePhilox4_32_10_t))); +template <> +void RandGenerator::AllocState(RandGenerator* inst) { + CUDA_CALL(cudaMalloc(&inst->states_, kNumRandomStates * sizeof(curandStatePhilox4_32_10_t))); } -template<> -void RandGenerator::FreeState(RandGenerator *inst) { +template <> +void RandGenerator::FreeState(RandGenerator* inst) { CUDA_CALL(cudaFree(inst->states_)); } -template<> +template <> void* RandGenerator::GetStates() { return static_cast(states_); } diff --git a/src/common/rtc.cc b/src/common/rtc.cc index ece9c0566acd..683cc7c8faf0 100644 --- a/src/common/rtc.cc +++ b/src/common/rtc.cc @@ -28,12 +28,12 @@ namespace mxnet { namespace rtc { -CudaModule::Chunk::Chunk( - const char* source, - const std::vector& options, - const std::vector& exports) { +CudaModule::Chunk::Chunk(const char* source, + const std::vector& options, + const std::vector& exports) { NVRTC_CALL(nvrtcCreateProgram(&prog_, source, "source.cu", 0, nullptr, nullptr)); - for (const auto& i : exports) exports_.insert(i); + for (const auto& i : exports) + exports_.insert(i); #if CUDA_VERSION >= 8000 for (const auto& func : exports) { NVRTC_CALL(nvrtcAddNameExpression(prog_, func.c_str())); @@ -45,7 +45,8 @@ CudaModule::Chunk::Chunk( << "with extern \"C\" instead."; #endif std::vector c_options; - for (const auto& i : options) c_options.push_back(i.c_str()); + for (const auto& i : options) + c_options.push_back(i.c_str()); nvrtcResult compile_res = nvrtcCompileProgram(prog_, c_options.size(), c_options.data()); if (compile_res != NVRTC_SUCCESS) { size_t err_size; @@ -81,7 +82,6 @@ CudaModule::Chunk::Chunk( } } - CudaModule::Chunk::~Chunk() { for (const auto& kv : mod_) { CUDA_DRIVER_CALL(cuModuleUnload(kv.second)); @@ -89,12 +89,8 @@ CudaModule::Chunk::~Chunk() { NVRTC_CALL(nvrtcDestroyProgram(&prog_)); } - -CUfunction CudaModule::Chunk::GetFunction( - const std::string& mangled_name, - const Context& ctx) { - CHECK_EQ(ctx.dev_mask(), Context::kGPU) - << "CUDA Runtime compilation only supports Nvidia GPU."; +CUfunction CudaModule::Chunk::GetFunction(const std::string& mangled_name, const Context& ctx) { + CHECK_EQ(ctx.dev_mask(), Context::kGPU) << "CUDA Runtime compilation only supports Nvidia GPU."; auto iter = mod_.find(ctx.dev_id); mxnet::common::cuda::DeviceStore device_store; CUmodule module; @@ -117,13 +113,12 @@ CUfunction CudaModule::Chunk::GetFunction( return function; } - -std::shared_ptr CudaModule::GetKernel( - const std::string& name, const std::vector& signature) { +std::shared_ptr CudaModule::GetKernel(const std::string& name, + const std::vector& signature) { std::string mangled_name = name; #if CUDA_VERSION >= 8000 if (ptr_->exports_.count(name)) { - const char * c_mangled_name; + const char* c_mangled_name; NVRTC_CALL(nvrtcGetLoweredName(ptr_->prog_, name.c_str(), &c_mangled_name)); mangled_name = c_mangled_name; } @@ -131,23 +126,23 @@ std::shared_ptr CudaModule::GetKernel( return std::shared_ptr(new Kernel(ptr_, mangled_name, signature)); } - -CudaModule::Kernel::Kernel( - const std::shared_ptr& mod, - const std::string& mangled_name, - const std::vector& signature) - : mangled_name_(mangled_name), signature_(signature), mod_(mod) { -} - -void CudaModule::Kernel::Launch( - const Context& ctx, const std::vector& args, - uint32_t grid_dim_x, uint32_t grid_dim_y, uint32_t grid_dim_z, - uint32_t block_dim_x, uint32_t block_dim_y, uint32_t block_dim_z, - uint32_t shared_mem) { - CHECK_EQ(ctx.dev_mask(), Context::kGPU) - << "CUDA Runtime compilation only supports Nvidia GPU."; - - auto mod = mod_; +CudaModule::Kernel::Kernel(const std::shared_ptr& mod, + const std::string& mangled_name, + const std::vector& signature) + : mangled_name_(mangled_name), signature_(signature), mod_(mod) {} + +void CudaModule::Kernel::Launch(const Context& ctx, + const std::vector& args, + uint32_t grid_dim_x, + uint32_t grid_dim_y, + uint32_t grid_dim_z, + uint32_t block_dim_x, + uint32_t block_dim_y, + uint32_t block_dim_z, + uint32_t shared_mem) { + CHECK_EQ(ctx.dev_mask(), Context::kGPU) << "CUDA Runtime compilation only supports Nvidia GPU."; + + auto mod = mod_; auto arg_types = signature(); CUfunction function; @@ -155,13 +150,14 @@ void CudaModule::Kernel::Launch( if (iter != func_.end()) { function = iter->second; } else { - function = mod_->GetFunction(mangled_name_, ctx); + function = mod_->GetFunction(mangled_name_, ctx); func_[ctx.dev_id] = function; } std::vector read_vars, write_vars; for (size_t i = 0; i < arg_types.size(); ++i) { - if (!arg_types[i].is_ndarray) continue; + if (!arg_types[i].is_ndarray) + continue; const auto& array = dmlc::get(args[i]); CHECK_EQ(array.dtype(), arg_types[i].dtype) << "The i-th argument is expected to be an NDArray of " @@ -175,33 +171,52 @@ void CudaModule::Kernel::Launch( } Engine::Get()->PushSync( - [function, mod, args, arg_types, grid_dim_x, grid_dim_y, grid_dim_z, - block_dim_x, block_dim_y, block_dim_z, shared_mem](RunContext rctx) { - std::vector p_args; - for (size_t i = 0; i < arg_types.size(); ++i) { - if (arg_types[i].is_ndarray) { - const auto& array = dmlc::get(args[i]); - p_args.push_back(reinterpret_cast(const_cast(&array.data().dptr_))); - } else { - MSHADOW_TYPE_SWITCH(arg_types[i].dtype, DType, { - const auto& number = dmlc::get(args[i]); - p_args.push_back(const_cast(&number)); - }); - } - } - - mshadow::Stream *s = rctx.get_stream(); - CUDA_DRIVER_CALL(cuLaunchKernel( - function, grid_dim_x, grid_dim_y, grid_dim_z, - block_dim_x, block_dim_y, block_dim_z, - shared_mem, s->stream_, - p_args.data(), nullptr)); - CUDA_CALL(cudaStreamSynchronize(s->stream_)); - }, ctx, read_vars, write_vars, FnProperty::kNormal, 0, - mangled_name_.c_str()); + [function, + mod, + args, + arg_types, + grid_dim_x, + grid_dim_y, + grid_dim_z, + block_dim_x, + block_dim_y, + block_dim_z, + shared_mem](RunContext rctx) { + std::vector p_args; + for (size_t i = 0; i < arg_types.size(); ++i) { + if (arg_types[i].is_ndarray) { + const auto& array = dmlc::get(args[i]); + p_args.push_back(reinterpret_cast(const_cast(&array.data().dptr_))); + } else { + MSHADOW_TYPE_SWITCH(arg_types[i].dtype, DType, { + const auto& number = dmlc::get(args[i]); + p_args.push_back(const_cast(&number)); + }); + } + } + + mshadow::Stream* s = rctx.get_stream(); + CUDA_DRIVER_CALL(cuLaunchKernel(function, + grid_dim_x, + grid_dim_y, + grid_dim_z, + block_dim_x, + block_dim_y, + block_dim_z, + shared_mem, + s->stream_, + p_args.data(), + nullptr)); + CUDA_CALL(cudaStreamSynchronize(s->stream_)); + }, + ctx, + read_vars, + write_vars, + FnProperty::kNormal, + 0, + mangled_name_.c_str()); } - } // namespace rtc } // namespace mxnet diff --git a/src/common/static_array.h b/src/common/static_array.h index 8d51967b172d..238560e1916e 100644 --- a/src/common/static_array.h +++ b/src/common/static_array.h @@ -36,7 +36,7 @@ namespace common { * \tparam T element type of the array, must be copyable between CPU and GPU * \tparam num number of elements in the array */ -template +template struct StaticArray { static const int kNum = num; @@ -47,7 +47,7 @@ struct StaticArray { /*! \brief constructor, fill in the array with the input value */ MSHADOW_XINLINE StaticArray(const T& val) { - #pragma unroll +#pragma unroll for (int i = 0; i < num; ++i) { this->array_[i] = val; } @@ -55,7 +55,7 @@ struct StaticArray { /*! \brief constuctor */ MSHADOW_XINLINE StaticArray(const StaticArray& sa) { - #pragma unroll +#pragma unroll for (int i = 0; i < num; ++i) { this->array_[i] = sa[i]; } diff --git a/src/common/tensor_inspector.h b/src/common/tensor_inspector.h index d1ff595f5c91..aa7feba80ebe 100644 --- a/src/common/tensor_inspector.h +++ b/src/common/tensor_inspector.h @@ -71,27 +71,27 @@ struct InspectorManager { enum CheckerType { NegativeChecker, // check if is negative PositiveChecker, // check if is positive - ZeroChecker, // check if is zero - NaNChecker, // check if is NaN, will always return false if DType is not a float type - InfChecker, // check if is infinity, will always return false if DType is not a float type + ZeroChecker, // check if is zero + NaNChecker, // check if is NaN, will always return false if DType is not a float type + InfChecker, // check if is infinity, will always return false if DType is not a float type PositiveInfChecker, // check if is positive infinity, // will always return false if DType is not a float type NegativeInfChecker, // check if is nagative infinity, // will always return false if DType is not a float type - FiniteChecker, // check if is finite, will always return false if DType is not a float type - NormalChecker, // check if is neither infinity nor NaN - AbnormalChecker, // chekck if is infinity or nan + FiniteChecker, // check if is finite, will always return false if DType is not a float type + NormalChecker, // check if is neither infinity nor NaN + AbnormalChecker, // chekck if is infinity or nan }; /** - * _______ _____ _ - * |__ __| |_ _| | | - * | | ___ _ __ ___ ___ _ __| | _ __ ___ _ __ ___ ___| |_ ___ _ __ + * _______ _____ _ + * |__ __| |_ _| | | + * | | ___ _ __ ___ ___ _ __| | _ __ ___ _ __ ___ ___| |_ ___ _ __ * | |/ _ \ '_ \/ __|/ _ \| '__| | | '_ \/ __| '_ \ / _ \/ __| __/ _ \| '__| - * | | __/ | | \__ \ (_) | | _| |_| | | \__ \ |_) | __/ (__| || (_) | | - * |_|\___|_| |_|___/\___/|_||_____|_| |_|___/ .__/ \___|\___|\__\___/|_| - * | | - * |_| + * | | __/ | | \__ \ (_) | | _| |_| | | \__ \ |_) | __/ (__| || (_) | | + * |_|\___|_| |_|___/\___/|_||_____|_| |_|___/ .__/ \___|\___|\__\___/|_| + * | | + * |_| */ /*! @@ -103,12 +103,12 @@ enum CheckerType { class TensorInspector { private: /*! - * \brief generate the tensor info, including data type and shape + * \brief generate the tensor info, including data type and shape * \tparam DType the data type * \tparam StreamType the type of the stream object * \param os stream object to output to */ - template + template void tensor_info_to_string(StreamType* os) { const int dimension = tb_.ndim(); *os << "<" << infer_type_string(typeid(DType)) << " Tensor "; @@ -120,13 +120,13 @@ class TensorInspector { } /*! - * \brief output the tensor info, including data type and shape + * \brief output the tensor info, including data type and shape * \tparam DType the data type * \tparam StreamType the type of the stream object * \param os stream object to output to * \param shape the shape of the tensor */ - template + template void tensor_info_to_string(StreamType* os, const std::vector& shape) { const int dimension = shape.size(); *os << "<" << infer_type_string(typeid(DType)) << " Tensor "; @@ -138,17 +138,16 @@ class TensorInspector { } /*! - * \brief output the tensor in a structured format + * \brief output the tensor in a structured format * \tparam DType the data type * \tparam StreamType the type of the stream object * \param os stream object to output to */ - template + template void to_string_helper(StreamType* os) { #if MXNET_USE_CUDA if (tb_.dev_mask() == gpu::kDevMask) { - TensorInspector(test::CAccessAsCPU(ctx_, tb_, false)(), ctx_) - .to_string_helper(os); + TensorInspector(test::CAccessAsCPU(ctx_, tb_, false)(), ctx_).to_string_helper(os); return; } #endif // MXNET_USE_CUDA @@ -167,8 +166,8 @@ class TensorInspector { n += (i % off == 0); } if (n) { - *os << std::string(n, ']') << ", " << std::string(n, '['); - } else { + *os << std::string(n, ']') << ", " << std::string(n, '['); + } else { *os << ", "; } *os << tb_.dptr()[i]; @@ -178,13 +177,13 @@ class TensorInspector { } /*! - * \brief output the tensor in a structured format + * \brief output the tensor in a structured format * \tparam DType the data type * \tparam StreamType the type of the stream object * \param os stream object to output to * \param dptr the data pointer */ - template + template void to_string_helper(StreamType* os, const DType* dptr) { #if MXNET_USE_CUDA if (tb_.dev_mask() == gpu::kDevMask) { @@ -198,14 +197,14 @@ class TensorInspector { } /*! - * \brief output a part of the tensor in a structed format + * \brief output a part of the tensor in a structed format * \tparam DType the data type * \tparam StreamType the type of the stream object * \param os stream object to output to * \param sub_shape the sub-shape of the desired part of the tensor * \param offset the position of the first value of the desired part of the tensor */ - template + template void to_string_helper(StreamType* os, const std::vector& sub_shape, index_t offset) { #if MXNET_USE_CUDA if (tb_.dev_mask() == gpu::kDevMask) { @@ -235,8 +234,8 @@ class TensorInspector { n += (i % off == 0); } if (n) { - *os << std::string(n, ']') << ", " << std::string(n, '['); - } else { + *os << std::string(n, ']') << ", " << std::string(n, '['); + } else { *os << ", "; } *os << dptr[i]; @@ -246,16 +245,17 @@ class TensorInspector { } /*! - * \brief helper function to calculate the sub_shape and offset for the desired part of the tensor, - * given its coordinates in the original tensor - * \param pos the coordinates of the desired part of the tensor - * \param sub_shape the sub-shape of the desired part of the tensor; calculated here - * \param offset the position of the first value of the desired part of the tensor; calculated here + * \brief helper function to calculate the sub_shape and offset for the desired part of the + * tensor, given its coordinates in the original tensor \param pos the coordinates of the desired + * part of the tensor \param sub_shape the sub-shape of the desired part of the tensor; calculated + * here \param offset the position of the first value of the desired part of the tensor; + * calculated here */ - void print_locator(const std::vector& pos, std::vector* sub_shape, - index_t* offset) { + void print_locator(const std::vector& pos, + std::vector* sub_shape, + index_t* offset) { const int dimension = tb_.ndim(); - const int sub_dim = dimension - pos.size(); + const int sub_dim = dimension - pos.size(); sub_shape->resize(sub_dim); index_t multiple = 1; for (size_t i = pos.size(), j = 0; i < static_cast(dimension); ++i, ++j) { @@ -263,7 +263,7 @@ class TensorInspector { multiple *= tb_.shape_[i]; } index_t sum = 0; - index_t m = 1; + index_t m = 1; for (index_t i = pos.size() - 1; i >= 0; --i) { sum += pos[i] * m; m *= tb_.shape_[i]; @@ -272,9 +272,8 @@ class TensorInspector { } /*! - * \brief parse the coordinate of the desired part of the tensor, given a string that represents that - * coordinate - * \param pos the coordinates of the desired part of the tensor, calculated here + * \brief parse the coordinate of the desired part of the tensor, given a string that represents + * that coordinate \param pos the coordinates of the desired part of the tensor, calculated here * \param str the string that represents the coordinate */ bool parse_position(std::vector* pos, const std::string& str) { @@ -303,7 +302,7 @@ class TensorInspector { * \tparam DType the data type * \param tag the name given to this call */ - template + template void interactive_print_helper(std::string tag) { #if MXNET_USE_CUDA if (tb_.dev_mask() == gpu::kDevMask) { @@ -317,16 +316,17 @@ class TensorInspector { while (!InspectorManager::get()->interactive_print_skip_all_) { std::cout << "----------Interactive Print----------" << std::endl; if (tag != "") { - std::cout << "Tag: " << tag << " Visit: " << - InspectorManager::get()->interactive_print_tag_counter_[tag] << std::endl; + std::cout << "Tag: " << tag + << " Visit: " << InspectorManager::get()->interactive_print_tag_counter_[tag] + << std::endl; } tensor_info_to_string(&std::cout); - std::cout << "To print a part of the tensor, " << - "please specify a position, seperated by \",\"" << std::endl; - std::cout << "\"e\" for the entire tensor, " << - "\"d\" to dump value to file, " << - "\"b\" to break, " << - "\"s\" to skip all: "; + std::cout << "To print a part of the tensor, " + << "please specify a position, seperated by \",\"" << std::endl; + std::cout << "\"e\" for the entire tensor, " + << "\"d\" to dump value to file, " + << "\"b\" to break, " + << "\"s\" to skip all: "; std::string str; std::cin >> str; if (str == "b") { @@ -367,106 +367,84 @@ class TensorInspector { * \tparam DType the data type * \param ct the type of the checker */ - template + template std::function get_checker(CheckerType ct) { switch (ct) { case NegativeChecker: - return [] (DType x) { - return x < 0; - }; + return [](DType x) { return x < 0; }; case PositiveChecker: - return [] (DType x) { - return x > 0; - }; + return [](DType x) { return x > 0; }; case ZeroChecker: - return [] (DType x) { - return x == 0; - }; + return [](DType x) { return x == 0; }; case NaNChecker: if (std::is_same::value || std::is_same::value || std::is_same::value) { - return [] (DType x) { - return x != x; - }; + return [](DType x) { return x != x; }; } else { - LOG(WARNING) << "NaNChecker only applies to float types. " << - "Lambda will always return false."; + LOG(WARNING) << "NaNChecker only applies to float types. " + << "Lambda will always return false."; } break; case InfChecker: if (std::is_same::value || std::is_same::value || std::is_same::value) { - return [] (DType x) { - return x == (DType)1.0 / 0.0f || x == -(DType)1.0 / 0.0f; - }; + return [](DType x) { return x == (DType)1.0 / 0.0f || x == -(DType)1.0 / 0.0f; }; } else { - LOG(WARNING) << "InfChecker only applies to float types. " << - "Lambda will always return false."; + LOG(WARNING) << "InfChecker only applies to float types. " + << "Lambda will always return false."; } break; case PositiveInfChecker: if (std::is_same::value || std::is_same::value || std::is_same::value) { - return [] (DType x) { - return x == (DType)1.0 / 0.0f; - }; + return [](DType x) { return x == (DType)1.0 / 0.0f; }; } else { - LOG(WARNING) << "PositiveInfChecker only applies to float types. " << - "Lambda will always return false."; + LOG(WARNING) << "PositiveInfChecker only applies to float types. " + << "Lambda will always return false."; } break; case NegativeInfChecker: if (std::is_same::value || std::is_same::value || std::is_same::value) { - return [] (DType x) { - return x == -(DType)1.0 / 0.0f; - }; + return [](DType x) { return x == -(DType)1.0 / 0.0f; }; } else { - LOG(WARNING) << "NegativeInfChecker only applies to float types. " << - "Lambda will always return false."; + LOG(WARNING) << "NegativeInfChecker only applies to float types. " + << "Lambda will always return false."; } break; case FiniteChecker: if (std::is_same::value || std::is_same::value || std::is_same::value) { - return [] (DType x) { - return x != (DType)1.0 / 0.0f && x != -(DType)1.0 / 0.0f; - }; + return [](DType x) { return x != (DType)1.0 / 0.0f && x != -(DType)1.0 / 0.0f; }; } else { - LOG(WARNING) << "FiniteChecker only applies to float types. " << - "Lambda will always return false."; + LOG(WARNING) << "FiniteChecker only applies to float types. " + << "Lambda will always return false."; } break; case NormalChecker: if (std::is_same::value || std::is_same::value || std::is_same::value) { - return [] (DType x) { - return x != (DType)1.0 / 0.0f && x != -(DType)1.0 / 0.0f && - x == x; - }; + return + [](DType x) { return x != (DType)1.0 / 0.0f && x != -(DType)1.0 / 0.0f && x == x; }; } else { - LOG(WARNING) << "NormalChecker only applies to float types. " << - "Lambda will always return false."; + LOG(WARNING) << "NormalChecker only applies to float types. " + << "Lambda will always return false."; } break; case AbnormalChecker: if (std::is_same::value || std::is_same::value || std::is_same::value) { - return [] (DType x) { - return x == (DType)1.0 / 0.0f || x == -(DType)1.0 / 0.0f || - x != x; - }; + return + [](DType x) { return x == (DType)1.0 / 0.0f || x == -(DType)1.0 / 0.0f || x != x; }; } else { - LOG(WARNING) << "AbnormalChecker only applies to float types. " << - "Lambda will always return false."; + LOG(WARNING) << "AbnormalChecker only applies to float types. " + << "Lambda will always return false."; } break; default: - return [] (DType x) { - return false; - }; + return [](DType x) { return false; }; } - return [] (DType x) {return false;}; + return [](DType x) { return false; }; } /*! @@ -493,9 +471,11 @@ class TensorInspector { * \param interactive wherether to allow the user to interactively check the coordinates * \param tag the name given to this call */ - template + template void check_value_helper(std::vector>* ret, - const std::function& checker, bool interactive, std::string tag) { + const std::function& checker, + bool interactive, + std::string tag) { #if MXNET_USE_CUDA if (tb_.dev_mask() == gpu::kDevMask) { return TensorInspector(test::CAccessAsCPU(ctx_, tb_, false)(), ctx_) @@ -506,13 +486,13 @@ class TensorInspector { std::stringstream ss; ss << "["; bool first_pass = true; - for (index_t i = 0; i (tb_.shape_.Size()); ++i) { + for (index_t i = 0; i < static_cast(tb_.shape_.Size()); ++i) { if (checker(tb_.dptr()[i])) { ++count; if (!first_pass) { - ss << ", "; + ss << ", "; } - first_pass = false; + first_pass = false; std::vector coords = index_to_coordinates(i); ss << "(" << coords[0]; for (size_t i = 1; i < coords.size(); ++i) { @@ -525,21 +505,22 @@ class TensorInspector { ss << "]" << std::endl; if (interactive) { std::lock_guard lock(InspectorManager::get()->mutex_); - InspectorManager::get()->check_value_tag_counter_[tag] += 1; + InspectorManager::get()->check_value_tag_counter_[tag] += 1; while (!InspectorManager::get()->check_value_skip_all_) { std::cout << "----------Value Check----------" << std::endl; tensor_info_to_string(&std::cout); if (tag != "") { - std::cout << "Tag: " << tag << " Visit: " << - InspectorManager::get()->check_value_tag_counter_[tag] << std::endl; + std::cout << "Tag: " << tag + << " Visit: " << InspectorManager::get()->check_value_tag_counter_[tag] + << std::endl; } std::cout << count << " value(s) found." << std::endl; - std::cout << "To print a part of the tensor," << - " please specify a position, seperated by \",\"" << std::endl; - std::cout << "\"e\" for the entire tensor, " << - "\"p\" to print the coordinates of the values found, " << - "\"b\" to break, " << - "\"s\" to skip all: "; + std::cout << "To print a part of the tensor," + << " please specify a position, seperated by \",\"" << std::endl; + std::cout << "\"e\" for the entire tensor, " + << "\"p\" to print the coordinates of the values found, " + << "\"b\" to break, " + << "\"s\" to skip all: "; std::string str; std::cin >> str; if (str == "b") { @@ -569,30 +550,42 @@ class TensorInspector { /*! * \brief infer the python type, given the c++ type - * \tparam ti the type info + * \tparam ti the type info */ inline char infer_type(const std::type_info& ti) { - if (ti == typeid(float)) return 'f'; - else if (ti == typeid(double)) return 'f'; - else if (ti == typeid(mshadow::half::half_t) ) return 'f'; - else if (ti == typeid(uint8_t)) return 'u'; - else if (ti == typeid(int32_t)) return 'i'; - else if (ti == typeid(int64_t)) return 'i'; + if (ti == typeid(float)) + return 'f'; + else if (ti == typeid(double)) + return 'f'; + else if (ti == typeid(mshadow::half::half_t)) + return 'f'; + else if (ti == typeid(uint8_t)) + return 'u'; + else if (ti == typeid(int32_t)) + return 'i'; + else if (ti == typeid(int64_t)) + return 'i'; else return '?'; } /*! * \brief infer the python type, given the c++ type - * \tparam ti the type info + * \tparam ti the type info */ inline std::string infer_type_string(const std::type_info& ti) { - if (ti == typeid(float)) return "float"; - else if (ti == typeid(double)) return "double"; - else if (ti == typeid(mshadow::half::half_t) ) return "mshasow::half::half_t"; - else if (ti == typeid(uint8_t)) return "uint8_t"; - else if (ti == typeid(int32_t)) return "int32_t"; - else if (ti == typeid(int64_t)) return "int64_t"; + if (ti == typeid(float)) + return "float"; + else if (ti == typeid(double)) + return "double"; + else if (ti == typeid(mshadow::half::half_t)) + return "mshasow::half::half_t"; + else if (ti == typeid(uint8_t)) + return "uint8_t"; + else if (ti == typeid(int32_t)) + return "int32_t"; + else if (ti == typeid(int64_t)) + return "int64_t"; else return "unknown tyoe"; } @@ -609,7 +602,7 @@ class TensorInspector { * \brief generate the header following npy 1.0 format * \tparam DType the data type */ - template + template std::string get_header() { const int dimension = tb_.ndim(); std::string dict; @@ -624,7 +617,7 @@ class TensorInspector { dict += std::to_string(tb_.shape_[i]); } if (dimension == 1) { - dict += ","; + dict += ","; } dict += ")} "; int padding_size = 64 - ((10 + dict.size()) % 64); @@ -647,7 +640,7 @@ class TensorInspector { * \param header the header of the file * \param filename the file name */ - template + template void write_npy(const std::string& header, const std::string& filename) { std::ofstream file; file.exceptions(std::ofstream::failbit | std::ofstream::badbit); @@ -668,29 +661,28 @@ class TensorInspector { * \tparam DType the data type * \param tag the name given to this call */ - template + template void dump_to_file_helper(const std::string& tag) { #if MXNET_USE_CUDA if (tb_.dev_mask() == gpu::kDevMask) { - TensorInspector(test::CAccessAsCPU(ctx_, tb_, false)(), ctx_) - .dump_to_file_helper(tag); + TensorInspector(test::CAccessAsCPU(ctx_, tb_, false)(), ctx_).dump_to_file_helper(tag); return; } #endif // MXNET_USE_CUDA std::string header = get_header(); InspectorManager::get()->dump_to_file_tag_counter_[tag] += 1; - const int visit = InspectorManager::get()->dump_to_file_tag_counter_[tag]; + const int visit = InspectorManager::get()->dump_to_file_tag_counter_[tag]; std::string filename = tag + "_" + std::to_string(visit) + ".npy"; write_npy(header, filename); } - /*! + /*! * \brief validate that the shape */ inline void validate_shape() { const int dimension = tb_.ndim(); - CHECK(dimension > 0) << "Tensor Inspector does not support empty tensors " << - "or tensors of unknow shape."; + CHECK(dimension > 0) << "Tensor Inspector does not support empty tensors " + << "or tensors of unknow shape."; for (int i = 0; i < dimension; ++i) { CHECK(tb_.shape_[i] != 0) << "Invalid tensor shape: shape_[" << i << "] is 0"; } @@ -702,7 +694,7 @@ class TensorInspector { const RunContext& ctx_; public: - /*! + /*! * \brief construct from Tensor object * \tparam Device the device the tensor resides in * \tparam dimension the dimension of the tensor @@ -710,9 +702,9 @@ class TensorInspector { * \param ts the source tensor object * \param ctx the run context of the tensor */ - template - TensorInspector(const mshadow::Tensor& ts, const RunContext& ctx): - tb_(ts), ctx_(ctx) { + template + TensorInspector(const mshadow::Tensor& ts, const RunContext& ctx) + : tb_(ts), ctx_(ctx) { validate_shape(); } @@ -721,8 +713,7 @@ class TensorInspector { * \param tb the source tblob object * \param ctx the run context of the tensor */ - TensorInspector(const TBlob& tb, const RunContext& ctx): - tb_(tb), ctx_(ctx) { + TensorInspector(const TBlob& tb, const RunContext& ctx) : tb_(tb), ctx_(ctx) { validate_shape(); } @@ -731,8 +722,7 @@ class TensorInspector { * \param arr the source ndarray object * \param ctx the run context of the tensor */ - TensorInspector(const NDArray& arr, const RunContext& ctx): - tb_(arr.data()), ctx_(ctx) { + TensorInspector(const NDArray& arr, const RunContext& ctx) : tb_(arr.data()), ctx_(ctx) { validate_shape(); } @@ -748,9 +738,7 @@ class TensorInspector { */ std::string to_string() { std::stringstream ss; - MSHADOW_TYPE_SWITCH(tb_.type_flag_, DType, { - to_string_helper(&ss); - }); + MSHADOW_TYPE_SWITCH(tb_.type_flag_, DType, { to_string_helper(&ss); }); return ss.str(); } @@ -759,9 +747,7 @@ class TensorInspector { * \param tag the name given to this call */ void interactive_print(std::string tag = "") { - MSHADOW_TYPE_SWITCH(tb_.type_flag_, DType, { - interactive_print_helper(tag); - }); + MSHADOW_TYPE_SWITCH(tb_.type_flag_, DType, { interactive_print_helper(tag); }); } /*! @@ -772,9 +758,10 @@ class TensorInspector { * \param interactive wherether to allow the user to interactively check the coordinates * \param tag the name given to this call */ - template + template std::vector> check_value(const ValueChecker& checker, - bool interactive = false, std::string tag = "") { + bool interactive = false, + std::string tag = "") { std::vector> ret; MSHADOW_TYPE_SWITCH(tb_.type_flag_, DType, { check_value_helper(&ret, checker, ret, interactive, tag); @@ -789,8 +776,9 @@ class TensorInspector { * \param interactive wherether to allow the user to interactively check the coordinates * \param tag the name given to this call */ - std::vector> check_value(CheckerType ct, bool interactive = false, - std::string tag = "") { + std::vector> check_value(CheckerType ct, + bool interactive = false, + std::string tag = "") { std::vector> ret; MSHADOW_TYPE_SWITCH(tb_.type_flag_, DType, { check_value_helper(&ret, get_checker(ct), interactive, tag); @@ -803,9 +791,7 @@ class TensorInspector { * \param tag the name given to this call */ void dump_to_file(std::string tag) { - MSHADOW_TYPE_SWITCH(tb_.type_flag_, DType, { - dump_to_file_helper(tag); - }); + MSHADOW_TYPE_SWITCH(tb_.type_flag_, DType, { dump_to_file_helper(tag); }); } }; diff --git a/src/common/utils.cc b/src/common/utils.cc index 67f1f3137c9f..f400093cc9b5 100644 --- a/src/common/utils.cc +++ b/src/common/utils.cc @@ -29,14 +29,16 @@ namespace mxnet { namespace common { -template<> -void CheckFormatWrapper(const RunContext &rctx, const NDArray &input, - const TBlob &err_cpu, const bool full_check) { +template <> +void CheckFormatWrapper(const RunContext& rctx, + const NDArray& input, + const TBlob& err_cpu, + const bool full_check) { CheckFormatImpl(rctx, input, err_cpu, full_check); } -template<> -void SparseRetainOpForwardRspWrapper(mshadow::Stream *s, +template <> +void SparseRetainOpForwardRspWrapper(mshadow::Stream* s, const NDArray& input_nd, const TBlob& idx_data, const OpReqType req, @@ -44,22 +46,20 @@ void SparseRetainOpForwardRspWrapper(mshadow::Stream *s, mxnet::op::SparseRetainOpForwardRspImpl(s, input_nd, idx_data, req, output_nd); } -template<> -void CastStorageDispatch(const OpContext& ctx, - const NDArray& input, - const NDArray& output) { +template <> +void CastStorageDispatch(const OpContext& ctx, const NDArray& input, const NDArray& output) { mxnet::op::CastStorageComputeImpl(ctx, input, output); } void ExecuteMonInputCallback( - const nnvm::IndexedGraph &idx, const std::vector &state_arrays, - size_t nid, const std::function - &monitor_callback) { - static const auto &flist_inputs = - nnvm::Op::GetAttr("FListInputNames"); + const nnvm::IndexedGraph& idx, + const std::vector& state_arrays, + size_t nid, + const std::function& monitor_callback) { + static const auto& flist_inputs = nnvm::Op::GetAttr("FListInputNames"); std::vector input_names; - const nnvm::IndexedGraph::Node &inode = idx[nid]; - const nnvm::Node *node = inode.source; + const nnvm::IndexedGraph::Node& inode = idx[nid]; + const nnvm::Node* node = inode.source; if (flist_inputs.count(node->op())) { input_names = flist_inputs[node->op()](node->attrs); } else { @@ -69,26 +69,25 @@ void ExecuteMonInputCallback( } for (size_t i = 0; i < node->num_inputs(); ++i) { - const nnvm::NodeEntry &input = node->inputs[i]; + const nnvm::NodeEntry& input = node->inputs[i]; if (state_arrays[idx.entry_id(input)]->is_none()) { continue; } - NDArray *cpy = new NDArray(*state_arrays[idx.entry_id(input)]); + NDArray* cpy = new NDArray(*state_arrays[idx.entry_id(input)]); std::string name = inode.source->attrs.name + "_" + input_names[i]; - monitor_callback(name.c_str(), inode.source->op()->name.c_str(), - reinterpret_cast(cpy)); + monitor_callback(name.c_str(), inode.source->op()->name.c_str(), reinterpret_cast(cpy)); } } void ExecuteMonOutputCallback( - const nnvm::IndexedGraph &idx, const std::vector &state_arrays, - size_t nid, const std::function - &monitor_callback) { - static const auto &flist_outputs = - nnvm::Op::GetAttr("FListOutputNames"); + const nnvm::IndexedGraph& idx, + const std::vector& state_arrays, + size_t nid, + const std::function& monitor_callback) { + static const auto& flist_outputs = nnvm::Op::GetAttr("FListOutputNames"); std::vector output_names; - const nnvm::IndexedGraph::Node &inode = idx[nid]; - const nnvm::Node *node = inode.source; + const nnvm::IndexedGraph::Node& inode = idx[nid]; + const nnvm::Node* node = inode.source; if (flist_outputs.count(node->op())) { output_names = flist_outputs[node->op()](node->attrs); } else { @@ -101,10 +100,9 @@ void ExecuteMonOutputCallback( if (state_arrays[idx.entry_id(nid, i)]->is_none()) { continue; } - NDArray *cpy = new NDArray(*state_arrays[idx.entry_id(nid, i)]); + NDArray* cpy = new NDArray(*state_arrays[idx.entry_id(nid, i)]); std::string name = inode.source->attrs.name + "_" + output_names[i]; - monitor_callback(name.c_str(), inode.source->op()->name.c_str(), - reinterpret_cast(cpy)); + monitor_callback(name.c_str(), inode.source->op()->name.c_str(), reinterpret_cast(cpy)); } } diff --git a/src/common/utils.cu b/src/common/utils.cu index 0937d7aa5145..e8f65d531096 100644 --- a/src/common/utils.cu +++ b/src/common/utils.cu @@ -29,14 +29,16 @@ namespace mxnet { namespace common { -template<> -void CheckFormatWrapper(const RunContext &rctx, const NDArray &input, - const TBlob &err_cpu, const bool full_check) { +template <> +void CheckFormatWrapper(const RunContext& rctx, + const NDArray& input, + const TBlob& err_cpu, + const bool full_check) { CheckFormatImpl(rctx, input, err_cpu, full_check); } -template<> -void SparseRetainOpForwardRspWrapper(mshadow::Stream *s, +template <> +void SparseRetainOpForwardRspWrapper(mshadow::Stream* s, const NDArray& input_nd, const TBlob& idx_data, const OpReqType req, @@ -44,10 +46,8 @@ void SparseRetainOpForwardRspWrapper(mshadow::Stream *s, mxnet::op::SparseRetainOpForwardRspImpl(s, input_nd, idx_data, req, output_nd); } -template<> -void CastStorageDispatch(const OpContext& ctx, - const NDArray& input, - const NDArray& output) { +template <> +void CastStorageDispatch(const OpContext& ctx, const NDArray& input, const NDArray& output) { mxnet::op::CastStorageComputeImpl(ctx, input, output); } diff --git a/src/common/utils.h b/src/common/utils.h index 40376e993a0b..b3eb98f3f320 100644 --- a/src/common/utils.h +++ b/src/common/utils.h @@ -59,25 +59,30 @@ #include #endif - namespace mxnet { namespace common { #if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) -inline size_t current_process_id() { return ::GetCurrentProcessId(); } +inline size_t current_process_id() { + return ::GetCurrentProcessId(); +} #else -inline size_t current_process_id() { return getpid(); } +inline size_t current_process_id() { + return getpid(); +} #endif /*! * \brief IndPtr should be non-negative, in non-decreasing order, start with 0 * and end with value equal with size of indices. */ struct csr_indptr_check { - template - MSHADOW_XINLINE static void Map(int i, DType* out, const IType* indptr, - const nnvm::dim_t end, const nnvm::dim_t idx_size) { - if (indptr[i+1] < 0 || indptr[i+1] < indptr[i] || - (i == 0 && indptr[i] != 0) || + template + MSHADOW_XINLINE static void Map(int i, + DType* out, + const IType* indptr, + const nnvm::dim_t end, + const nnvm::dim_t idx_size) { + if (indptr[i + 1] < 0 || indptr[i + 1] < indptr[i] || (i == 0 && indptr[i] != 0) || (i == end - 1 && indptr[end] != idx_size)) *out = kCSRIndPtrErr; } @@ -88,12 +93,14 @@ struct csr_indptr_check { * and in ascending order per row. */ struct csr_idx_check { - template - MSHADOW_XINLINE static void Map(int i, DType* out, const IType* idx, - const RType* indptr, const nnvm::dim_t ncols) { - for (RType j = indptr[i]; j < indptr[i+1]; j++) { - if (idx[j] >= ncols || idx[j] < 0 || - (j < indptr[i+1] - 1 && idx[j] >= idx[j+1])) { + template + MSHADOW_XINLINE static void Map(int i, + DType* out, + const IType* idx, + const RType* indptr, + const nnvm::dim_t ncols) { + for (RType j = indptr[i]; j < indptr[i + 1]; j++) { + if (idx[j] >= ncols || idx[j] < 0 || (j < indptr[i + 1] - 1 && idx[j] >= idx[j + 1])) { *out = kCSRIdxErr; break; } @@ -106,18 +113,22 @@ struct csr_idx_check { * less than the size of first dimension and in ascending order */ struct rsp_idx_check { - template - MSHADOW_XINLINE static void Map(int i, DType* out, const IType* idx, - const nnvm::dim_t end, const nnvm::dim_t nrows) { - if ((i < end && idx[i+1] <= idx[i]) - || idx[i] < 0 || idx[i] >= nrows) + template + MSHADOW_XINLINE static void Map(int i, + DType* out, + const IType* idx, + const nnvm::dim_t end, + const nnvm::dim_t nrows) { + if ((i < end && idx[i + 1] <= idx[i]) || idx[i] < 0 || idx[i] >= nrows) *out = kRSPIdxErr; } }; -template -void CheckFormatWrapper(const RunContext &rctx, const NDArray &input, - const TBlob &err_cpu, const bool full_check); +template +void CheckFormatWrapper(const RunContext& rctx, + const NDArray& input, + const TBlob& err_cpu, + const bool full_check); /*! * \brief Check the validity of CSRNDArray. @@ -127,46 +138,50 @@ void CheckFormatWrapper(const RunContext &rctx, const NDArray &input, * \param full_check If true, rigorous check, O(N) operations, * otherwise basic check, O(1) operations. */ -template -void CheckFormatCSRImpl(const RunContext &rctx, const NDArray &input, - const TBlob &err_cpu, const bool full_check) { +template +void CheckFormatCSRImpl(const RunContext& rctx, + const NDArray& input, + const TBlob& err_cpu, + const bool full_check) { using namespace op::mxnet_op; - CHECK_EQ(input.storage_type(), kCSRStorage) - << "CheckFormatCSRImpl is for CSRNDArray"; - const mxnet::TShape shape = input.shape(); - const mxnet::TShape idx_shape = input.aux_shape(csr::kIdx); - const mxnet::TShape indptr_shape = input.aux_shape(csr::kIndPtr); + CHECK_EQ(input.storage_type(), kCSRStorage) << "CheckFormatCSRImpl is for CSRNDArray"; + const mxnet::TShape shape = input.shape(); + const mxnet::TShape idx_shape = input.aux_shape(csr::kIdx); + const mxnet::TShape indptr_shape = input.aux_shape(csr::kIndPtr); const mxnet::TShape storage_shape = input.storage_shape(); if ((shape.ndim() != 2) || (idx_shape.ndim() != 1 || indptr_shape.ndim() != 1 || storage_shape.ndim() != 1) || - (indptr_shape[0] != shape[0] + 1) || - (idx_shape[0] != storage_shape[0])) { - MSHADOW_TYPE_SWITCH(err_cpu.type_flag_, DType, { - DType* err = err_cpu.dptr(); - *err = kCSRShapeErr; - }); - return; + (indptr_shape[0] != shape[0] + 1) || (idx_shape[0] != storage_shape[0])) { + MSHADOW_TYPE_SWITCH(err_cpu.type_flag_, DType, { + DType* err = err_cpu.dptr(); + *err = kCSRShapeErr; + }); + return; } if (full_check) { MSHADOW_TYPE_SWITCH(err_cpu.type_flag_, DType, { MSHADOW_IDX_TYPE_SWITCH(input.aux_type(csr::kIndPtr), RType, { MSHADOW_IDX_TYPE_SWITCH(input.aux_type(csr::kIdx), IType, { - mshadow::Stream *s = rctx.get_stream(); - NDArray ret_xpu = NDArray(mshadow::Shape1(1), - rctx.get_ctx(), false, err_cpu.type_flag_); - TBlob val_xpu = ret_xpu.data(); + mshadow::Stream* s = rctx.get_stream(); + NDArray ret_xpu = NDArray(mshadow::Shape1(1), rctx.get_ctx(), false, err_cpu.type_flag_); + TBlob val_xpu = ret_xpu.data(); Kernel, xpu>::Launch(s, val_xpu.Size(), val_xpu.dptr()); - Kernel::Launch(s, indptr_shape[0] - 1, val_xpu.dptr(), - input.aux_data(csr::kIndPtr).dptr(), - indptr_shape[0] - 1, idx_shape[0]); + Kernel::Launch(s, + indptr_shape[0] - 1, + val_xpu.dptr(), + input.aux_data(csr::kIndPtr).dptr(), + indptr_shape[0] - 1, + idx_shape[0]); // no need to check indices if indices are empty if (idx_shape[0] != 0) { - Kernel::Launch(s, indptr_shape[0] - 1, val_xpu.dptr(), - input.aux_data(csr::kIdx).dptr(), - input.aux_data(csr::kIndPtr).dptr(), shape[1]); + Kernel::Launch(s, + indptr_shape[0] - 1, + val_xpu.dptr(), + input.aux_data(csr::kIdx).dptr(), + input.aux_data(csr::kIndPtr).dptr(), + shape[1]); } - mshadow::Copy(err_cpu.get(), - val_xpu.get(s), s); + mshadow::Copy(err_cpu.get(), val_xpu.get(s), s); }); }); }); @@ -181,17 +196,18 @@ void CheckFormatCSRImpl(const RunContext &rctx, const NDArray &input, * \param full_check If true, rigorous check, O(N) operations, * otherwise basic check, O(1) operations. */ -template -void CheckFormatRSPImpl(const RunContext &rctx, const NDArray &input, - const TBlob &err_cpu, const bool full_check) { +template +void CheckFormatRSPImpl(const RunContext& rctx, + const NDArray& input, + const TBlob& err_cpu, + const bool full_check) { using namespace op::mxnet_op; - CHECK_EQ(input.storage_type(), kRowSparseStorage) - << "CheckFormatRSPImpl is for RSPNDArray"; + CHECK_EQ(input.storage_type(), kRowSparseStorage) << "CheckFormatRSPImpl is for RSPNDArray"; const mxnet::TShape idx_shape = input.aux_shape(rowsparse::kIdx); if (idx_shape[0] != input.storage_shape()[0]) { MSHADOW_TYPE_SWITCH(err_cpu.type_flag_, DType, { DType* err = err_cpu.dptr(); - *err = kRSPShapeErr; + *err = kRSPShapeErr; }); return; } @@ -201,25 +217,28 @@ void CheckFormatRSPImpl(const RunContext &rctx, const NDArray &input, if (full_check) { MSHADOW_TYPE_SWITCH(err_cpu.type_flag_, DType, { MSHADOW_IDX_TYPE_SWITCH(input.aux_type(rowsparse::kIdx), IType, { - mshadow::Stream *s = rctx.get_stream(); - NDArray ret_xpu = NDArray(mshadow::Shape1(1), - rctx.get_ctx(), false, err_cpu.type_flag_); - TBlob val_xpu = ret_xpu.data(); + mshadow::Stream* s = rctx.get_stream(); + NDArray ret_xpu = NDArray(mshadow::Shape1(1), rctx.get_ctx(), false, err_cpu.type_flag_); + TBlob val_xpu = ret_xpu.data(); Kernel, xpu>::Launch(s, val_xpu.Size(), val_xpu.dptr()); - Kernel::Launch(s, idx_shape[0], - val_xpu.dptr(), input.aux_data(rowsparse::kIdx).dptr(), - idx_shape[0] - 1, input.shape()[0]); - mshadow::Copy(err_cpu.get(), - val_xpu.get(s), s); + Kernel::Launch(s, + idx_shape[0], + val_xpu.dptr(), + input.aux_data(rowsparse::kIdx).dptr(), + idx_shape[0] - 1, + input.shape()[0]); + mshadow::Copy(err_cpu.get(), val_xpu.get(s), s); }); }); } } -template -void CheckFormatImpl(const RunContext &rctx, const NDArray &input, - const TBlob &err_cpu, const bool full_check) { +template +void CheckFormatImpl(const RunContext& rctx, + const NDArray& input, + const TBlob& err_cpu, + const bool full_check) { int stype = input.storage_type(); if (stype == kCSRStorage) { CheckFormatCSRImpl(rctx, input, err_cpu, full_check); @@ -235,8 +254,8 @@ void CheckFormatImpl(const RunContext &rctx, const NDArray &input, /*! \brief Pick rows specified by user input index array from a row sparse ndarray * and save them in the output sparse ndarray. */ -template -void SparseRetainOpForwardRspWrapper(mshadow::Stream *s, +template +void SparseRetainOpForwardRspWrapper(mshadow::Stream* s, const NDArray& input_nd, const TBlob& idx_data, const OpReqType req, @@ -244,17 +263,17 @@ void SparseRetainOpForwardRspWrapper(mshadow::Stream *s, /* \brief Casts tensor storage type to the new type. */ -template +template void CastStorageDispatch(const OpContext& ctx, const NDArray& input, const NDArray& output); /*! \brief returns true if all storage types in `vstorage` are the same as target `stype`. * false is returned for empty inputs. */ -inline bool ContainsOnlyStorage(const StorageTypeVector& vstorage, - const NDArrayStorageType stype) { +inline bool ContainsOnlyStorage(const StorageTypeVector& vstorage, const NDArrayStorageType stype) { if (!vstorage.empty()) { for (const auto& i : vstorage) { - if (i != stype) return false; + if (i != stype) + return false; } return true; } @@ -268,7 +287,7 @@ inline bool ContainsOnlyStorage(const StorageTypeVector& vstorage, inline bool ContainsOnlyStorage(const StorageTypeVector& vstorage, const NDArrayStorageType stype1, const NDArrayStorageType stype2, - bool *has_both) { + bool* has_both) { if (has_both) { *has_both = false; } @@ -313,7 +332,7 @@ inline bool ContainsOnlyStorage(const std::vector& ndarrays, inline bool ContainsOnlyStorage(const std::vector& ndarrays, const NDArrayStorageType stype1, const NDArrayStorageType stype2, - bool *has_both) { + bool* has_both) { if (has_both) { *has_both = false; } @@ -355,8 +374,7 @@ inline bool ContainsStorageType(const std::vector& ndarrays, /*! \brief returns true if any storage type `ndstype` in `ndstypes` * is the same as the target `stype`. false is returned for empty inputs. */ -inline bool ContainsStorageType(const std::vector& ndstypes, - const NDArrayStorageType stype) { +inline bool ContainsStorageType(const std::vector& ndstypes, const NDArrayStorageType stype) { if (!ndstypes.empty()) { for (const auto& ndstype : ndstypes) { if (ndstype == stype) { @@ -384,7 +402,6 @@ inline std::string dispatch_mode_string(const DispatchMode x) { return "unknown"; } - /*! \brief get string representation of storage_type */ inline std::string stype_string(const int x) { switch (x) { @@ -428,8 +445,7 @@ inline std::string operator_stype_string(const nnvm::NodeAttrs& attrs, const std::vector& in_attrs, const std::vector& out_attrs) { std::ostringstream os; - os << "operator = " << attrs.op->name - << "\ninput storage types = ["; + os << "operator = " << attrs.op->name << "\ninput storage types = ["; for (const int attr : in_attrs) { os << stype_string(attr) << ", "; } @@ -483,25 +499,31 @@ inline void LogStorageFallback(const nnvm::NodeAttrs& attrs, const std::vector* in_attrs, const std::vector* out_attrs) { static bool log = dmlc::GetEnv("MXNET_STORAGE_FALLBACK_LOG_VERBOSE", true); - if (!log) return; + if (!log) + return; const std::string op_str = operator_stype_string(attrs, dev_mask, *in_attrs, *out_attrs); std::ostringstream os; - const char* warning = "\nThe operator with default storage type will be dispatched " - "for execution. You're seeing this warning message because the operator above is unable " - "to process the given ndarrays with specified storage types, context and parameter. " - "Temporary dense ndarrays are generated in order to execute the operator. " - "This does not affect the correctness of the programme. " - "You can set environment variable MXNET_STORAGE_FALLBACK_LOG_VERBOSE to " - "0 to suppress this warning."; + const char* warning = + "\nThe operator with default storage type will be dispatched " + "for execution. You're seeing this warning message because the operator above is unable " + "to process the given ndarrays with specified storage types, context and parameter. " + "Temporary dense ndarrays are generated in order to execute the operator. " + "This does not affect the correctness of the programme. " + "You can set environment variable MXNET_STORAGE_FALLBACK_LOG_VERBOSE to " + "0 to suppress this warning."; os << "\nStorage type fallback detected:\n" << op_str << warning; LogOnce(os.str()); #if MXNET_USE_ONEDNN == 1 - if (!MKLDNNEnvSet()) common::LogOnce("MXNET_ONEDNN_ENABLED flag is off. " - "You can re-enable by setting MXNET_ONEDNN_ENABLED=1"); - if (GetMKLDNNCacheSize() != -1) common::LogOnce("MXNET_ONEDNN_CACHE_NUM is set." - "Should only be set if " - "your model has variable input shapes, " - "as cache size may grow unbounded"); + if (!MKLDNNEnvSet()) + common::LogOnce( + "MXNET_ONEDNN_ENABLED flag is off. " + "You can re-enable by setting MXNET_ONEDNN_ENABLED=1"); + if (GetMKLDNNCacheSize() != -1) + common::LogOnce( + "MXNET_ONEDNN_CACHE_NUM is set." + "Should only be set if " + "your model has variable input shapes, " + "as cache size may grow unbounded"); #endif } @@ -519,10 +541,10 @@ inline int GetExecNumMatchColor() { return std::min(num_match_color, GetNumThreadsPerGPU()); } -template +template V ParallelAccumulate(const T* a, const int n, V start) { V sum = start; -#pragma omp parallel for reduction(+:sum) +#pragma omp parallel for reduction(+ : sum) for (int i = 0; i < n; ++i) { sum += a[i]; } @@ -536,16 +558,15 @@ V ParallelAccumulate(const T* a, const int n, V start) { * Use the interface ParallelSort instead. * Ref: https://github.com/dmlc/difacto/blob/master/src/common/parallel_sort.h */ -template -void ParallelSortHelper(RandomIt first, size_t len, - size_t grainsize, const Compare& comp) { +template +void ParallelSortHelper(RandomIt first, size_t len, size_t grainsize, const Compare& comp) { if (len < grainsize) { - std::sort(first, first+len, comp); + std::sort(first, first + len, comp); } else { - std::thread thr(ParallelSortHelper, first, len/2, grainsize, comp); - ParallelSortHelper(first+len/2, len - len/2, grainsize, comp); + std::thread thr(ParallelSortHelper, first, len / 2, grainsize, comp); + ParallelSortHelper(first + len / 2, len - len / 2, grainsize, comp); thr.join(); - std::inplace_merge(first, first+len/2, first+len, comp); + std::inplace_merge(first, first + len / 2, first + len, comp); } } @@ -558,10 +579,10 @@ void ParallelSortHelper(RandomIt first, size_t len, * to sort each half range. * Ref: https://github.com/dmlc/difacto/blob/master/src/common/parallel_sort.h */ -template +template void ParallelSort(RandomIt first, RandomIt last, size_t num_threads, Compare comp) { - const auto num = std::distance(first, last); - size_t grainsize = std::max(num / num_threads + 5, static_cast(1024*16)); + const auto num = std::distance(first, last); + size_t grainsize = std::max(num / num_threads + 5, static_cast(1024 * 16)); ParallelSortHelper(first, num, grainsize, comp); } @@ -574,10 +595,10 @@ void ParallelSort(RandomIt first, RandomIt last, size_t num_threads, Compare com * to sort each half range. * Ref: https://github.com/dmlc/difacto/blob/master/src/common/parallel_sort.h */ -template +template void ParallelSort(RandomIt first, RandomIt last, size_t num_threads) { - ParallelSort(first, last, num_threads, - std::less::value_type>()); + ParallelSort( + first, last, num_threads, std::less::value_type>()); } /*! @@ -667,9 +688,8 @@ typename helper::UniqueIf::UnknownBound MakeUnique(size_t n) { template typename helper::UniqueIf::KnownBound MakeUnique(Args&&... args) = delete; -template -FCompType GetFCompute(const nnvm::Op* op, const std::string& name, - const Context& ctx) { +template +FCompType GetFCompute(const nnvm::Op* op, const std::string& name, const Context& ctx) { static auto& fcompute_cpu = nnvm::Op::GetAttr(name + ""); static auto& fcompute_gpu = nnvm::Op::GetAttr(name + ""); @@ -688,9 +708,8 @@ FCompType GetFCompute(const nnvm::Op* op, const std::string& name, */ template constexpr size_t MaxIntegerValue() { - return std::is_integral::value ? - std::numeric_limits::max(): - size_t(2) << (std::numeric_limits::digits - 1); + return std::is_integral::value ? std::numeric_limits::max() + : size_t(2) << (std::numeric_limits::digits - 1); } template <> @@ -705,21 +724,25 @@ constexpr size_t MaxIntegerValue() { MSHADOW_XINLINE int ilog2ul(size_t a) { int k = 1; - while (a >>= 1) ++k; + while (a >>= 1) + ++k; return k; } MSHADOW_XINLINE int ilog2ui(unsigned int a) { int k = 1; - while (a >>= 1) ++k; + while (a >>= 1) + ++k; return k; } /*! * \brief Return an NDArray of all zeros. */ -inline NDArray InitZeros(const NDArrayStorageType stype, const mxnet::TShape &shape, - const Context &ctx, const int dtype) { +inline NDArray InitZeros(const NDArrayStorageType stype, + const mxnet::TShape& shape, + const Context& ctx, + const int dtype) { // NDArray with default storage if (stype == kDefaultStorage) { NDArray ret(shape, ctx, false, dtype); @@ -734,10 +757,10 @@ inline NDArray InitZeros(const NDArrayStorageType stype, const mxnet::TShape &sh * \brief Helper to add a NDArray of zeros to a std::vector. */ inline void EmplaceBackZeros(const NDArrayStorageType stype, - const mxnet::TShape &shape, - const Context &ctx, + const mxnet::TShape& shape, + const Context& ctx, const int dtype, - std::vector *vec) { + std::vector* vec) { // NDArray with default storage if (stype == kDefaultStorage) { vec->emplace_back(shape, ctx, false, dtype); @@ -748,15 +771,14 @@ inline void EmplaceBackZeros(const NDArrayStorageType stype, } } - /*! * \brief parallelize copy by OpenMP. */ -template +template inline void ParallelCopy(DType* dst, const DType* src, index_t size) { static index_t copy_block_size = dmlc::GetEnv("MXNET_CPU_PARALLEL_SIZE", 200000); if (size >= copy_block_size) { - #pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount()) +#pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount()) for (index_t i = 0; i < size; ++i) { dst[i] = src[i]; } @@ -773,11 +795,11 @@ inline void ParallelCopy(DType* dst, const DType* src, index_t size) { /*! * \breif parallelize add by OpenMP */ -template +template inline void ParallelAdd(DType* dst, const DType* src, index_t size) { static index_t add_block_size = dmlc::GetEnv("MXNET_CPU_PARALLEL_SIZE", 200000); if (size >= add_block_size) { - #pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount()) +#pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount()) for (index_t i = 0; i < size; ++i) { dst[i] += src[i]; } @@ -807,12 +829,12 @@ inline void ParallelAdd(DType* dst, const DType* src, index_t size) { * compatible shapes. */ inline void ConvertToNumpyShape(mxnet::TShape* shape) { - if (shape->ndim() == 0) { // legacy shape ndim = 0 means unknown + if (shape->ndim() == 0) { // legacy shape ndim = 0 means unknown *shape = mxnet::TShape(); // unknown shape ndim = -1 } else { for (int j = 0; j < shape->ndim(); ++j) { if ((*shape)[j] == 0) { // legacy shape dim_size = 0 means unknown - (*shape)[j] = -1; // unknown dim size = -1 + (*shape)[j] = -1; // unknown dim size = -1 } } } @@ -846,26 +868,27 @@ inline void ConvertToLegacyShape(mxnet::ShapeVector* shapes) { } } void ExecuteMonInputCallback( - const nnvm::IndexedGraph &idx, const std::vector &state_arrays, - size_t nid, const std::function - &monitor_callback); + const nnvm::IndexedGraph& idx, + const std::vector& state_arrays, + size_t nid, + const std::function& monitor_callback); void ExecuteMonOutputCallback( - const nnvm::IndexedGraph &idx, const std::vector &state_arrays, - size_t nid, const std::function - &monitor_callback); + const nnvm::IndexedGraph& idx, + const std::vector& state_arrays, + size_t nid, + const std::function& monitor_callback); inline mxnet::TShape CanonicalizeAxes(const mxnet::TShape& src) { // convert negative axes to positive values - const int ndim = src.ndim(); + const int ndim = src.ndim(); mxnet::TShape axes = src; for (int i = 0; i < ndim; ++i) { if (axes[i] < 0) { axes[i] += ndim; } - CHECK(axes[i] >= 0 && axes[i] < ndim) << "axes[" << i << "]=" - << axes[i] << " exceeds the range [" - << 0 << ", " << ndim << ")"; + CHECK(axes[i] >= 0 && axes[i] < ndim) + << "axes[" << i << "]=" << axes[i] << " exceeds the range [" << 0 << ", " << ndim << ")"; } return axes; } @@ -875,12 +898,13 @@ inline bool is_float(const int dtype) { } inline bool is_int(const int dtype) { - return dtype == mshadow::kUint8 || dtype == mshadow::kInt8 || - dtype == mshadow::kInt32 || dtype == mshadow::kInt64; + return dtype == mshadow::kUint8 || dtype == mshadow::kInt8 || dtype == mshadow::kInt32 || + dtype == mshadow::kInt64; } inline int get_more_precise_type(const int type1, const int type2) { - if (type1 == type2) return type1; + if (type1 == type2) + return type1; if (is_float(type1) && is_float(type2)) { if (type1 == mshadow::kFloat64 || type2 == mshadow::kFloat64) { return mshadow::kFloat64; @@ -900,7 +924,7 @@ inline int get_more_precise_type(const int type1, const int type2) { } CHECK(!((type1 == mshadow::kUint8 && type2 == mshadow::kInt8) || (type1 == mshadow::kInt8 && type2 == mshadow::kUint8))) - << "1 is UInt8 and 1 is Int8 should not get here"; + << "1 is UInt8 and 1 is Int8 should not get here"; if (type1 == mshadow::kUint8 || type2 == mshadow::kUint8) { return mshadow::kUint8; } @@ -915,13 +939,12 @@ inline int np_binary_out_infer_type(const int type1, const int type2) { return get_more_precise_type(type1, type2); } -inline const std::string -NodeAttrsGetProfilerScope(const nnvm::NodeAttrs& attrs) { +inline const std::string NodeAttrsGetProfilerScope(const nnvm::NodeAttrs& attrs) { // obtain the profiler scope name, if assigned previously std::string profiler_scope = MXNET_STORAGE_DEFAULT_PROFILER_SCOPE_CSTR; const std::unordered_map& node_attrs_dict = attrs.dict; - const std::unordered_map::const_iterator - profiler_scope_iter = node_attrs_dict.find("__profiler_scope__"); + const std::unordered_map::const_iterator profiler_scope_iter = + node_attrs_dict.find("__profiler_scope__"); if (profiler_scope_iter != node_attrs_dict.end()) { profiler_scope = profiler_scope_iter->second; } @@ -929,16 +952,13 @@ NodeAttrsGetProfilerScope(const nnvm::NodeAttrs& attrs) { } inline int GetDefaultDtype() { - return Imperative::Get()->is_np_default_dtype() ? - mshadow::kFloat64 : - mshadow::kFloat32; + return Imperative::Get()->is_np_default_dtype() ? mshadow::kFloat64 : mshadow::kFloat32; } inline int GetDefaultDtype(int dtype) { - if (dtype != -1) return dtype; - return Imperative::Get()->is_np_default_dtype() ? - mshadow::kFloat64 : - mshadow::kFloat32; + if (dtype != -1) + return dtype; + return Imperative::Get()->is_np_default_dtype() ? mshadow::kFloat64 : mshadow::kFloat32; } struct MShadowTypeInfo { @@ -946,11 +966,10 @@ struct MShadowTypeInfo { int size; int acc_size; - MShadowTypeInfo(const std::string name, const int size, const int acc_size) : - name(std::move(name)), size(size), acc_size(acc_size) {} + MShadowTypeInfo(const std::string name, const int size, const int acc_size) + : name(std::move(name)), size(size), acc_size(acc_size) {} - MShadowTypeInfo(const std::string name, const int size) : - MShadowTypeInfo(name, size, size) {} + MShadowTypeInfo(const std::string name, const int size) : MShadowTypeInfo(name, size, size) {} }; MShadowTypeInfo mshadow_type_info(const int type_flag); @@ -976,7 +995,6 @@ inline void AlignedMemFree(void* ptr) { #endif } - inline index_t div_round(const index_t a, const index_t b) { return (a + b - 1) / b; } @@ -986,7 +1004,7 @@ inline bool IsPower2(size_t N) { } inline size_t RoundToPower2(size_t N) { - size_t ret = 1; + size_t ret = 1; size_t copyN = N; while (N >= 2) { ret *= 2; diff --git a/src/engine/engine.cc b/src/engine/engine.cc index a33f0b2c1442..5725fa4fd7df 100644 --- a/src/engine/engine.cc +++ b/src/engine/engine.cc @@ -30,13 +30,14 @@ namespace mxnet { namespace engine { inline Engine* CreateEngine() { - const char *type = getenv("MXNET_ENGINE_TYPE"); + const char* type = getenv("MXNET_ENGINE_TYPE"); const bool default_engine = (type == nullptr); - if (type == nullptr) type = "ThreadedEnginePerDevice"; + if (type == nullptr) + type = "ThreadedEnginePerDevice"; std::string stype = type; - Engine *ret = nullptr; - #if MXNET_PREDICT_ONLY == 0 + Engine* ret = nullptr; +#if MXNET_PREDICT_ONLY == 0 if (stype == "NaiveEngine") { ret = CreateNaiveEngine(); } else if (stype == "ThreadedEngine") { @@ -44,9 +45,9 @@ inline Engine* CreateEngine() { } else if (stype == "ThreadedEnginePerDevice") { ret = CreateThreadedEnginePerDevice(); } - #else +#else ret = CreateNaiveEngine(); - #endif +#endif if (ret == nullptr) { LOG(FATAL) << "Cannot find Engine " << type; @@ -64,7 +65,7 @@ std::shared_ptr Engine::_GetSharedRef() { } Engine* Engine::Get() { - static Engine *inst = _GetSharedRef().get(); + static Engine* inst = _GetSharedRef().get(); return inst; } } // namespace mxnet diff --git a/src/engine/engine_impl.h b/src/engine/engine_impl.h index f15141f4e7a2..d83eea4ab662 100644 --- a/src/engine/engine_impl.h +++ b/src/engine/engine_impl.h @@ -50,8 +50,7 @@ struct Opr { // implementation of the inline functions template inline T* Var::Cast() { - static_assert(std::is_base_of::value, - "must inherit `mxnet::engine::Var`"); + static_assert(std::is_base_of::value, "must inherit `mxnet::engine::Var`"); #if ENGINE_DEBUG return dynamic_cast(this); #else @@ -61,8 +60,7 @@ inline T* Var::Cast() { template inline T* Opr::Cast() { - static_assert(std::is_base_of::value, - "must inherit `mxnet::engine::Opr`"); + static_assert(std::is_base_of::value, "must inherit `mxnet::engine::Opr`"); #if ENGINE_DEBUG return dynamic_cast(this); #else @@ -75,12 +73,12 @@ static constexpr std::size_t kMaxNumGPUs = 16; // predeclare factory function for each type of engine /*! \return NaiveEngine instance */ -Engine *CreateNaiveEngine(); +Engine* CreateNaiveEngine(); #if MXNET_PREDICT_ONLY == 0 /*! \return ThreadedEnginePooled instance */ -Engine *CreateThreadedEnginePooled(); +Engine* CreateThreadedEnginePooled(); /*! \return ThreadedEnginePerDevie instance */ -Engine *CreateThreadedEnginePerDevice(); +Engine* CreateThreadedEnginePerDevice(); #endif } // namespace engine } // namespace mxnet diff --git a/src/engine/naive_engine.cc b/src/engine/naive_engine.cc index 95528a934c6c..e1ab240bbde4 100644 --- a/src/engine/naive_engine.cc +++ b/src/engine/naive_engine.cc @@ -40,15 +40,13 @@ namespace engine { * \brief var used in Naive Engine for tracking the version * of the objects it is associated with. */ -class NaiveVar final - : public Var, public common::ObjectPoolAllocatable { +class NaiveVar final : public Var, public common::ObjectPoolAllocatable { public: inline static NaiveVar* CastFromBase(Var* ptr) { return ptr->Cast(); } }; // class NaiveVar - // implement naive engine class NaiveEngine final : public Engine { public: @@ -87,11 +85,9 @@ class NaiveEngine final : public Engine { ~NaiveEngine() override = default; #endif - void Stop() override { - } + void Stop() override {} - void Start() override { - } + void Start() override {} // new variables VarHandle NewVariable() override { @@ -101,86 +97,84 @@ class NaiveEngine final : public Engine { OprHandle NewOperator(AsyncFn fn, std::vector const& const_vars, std::vector const& mutable_vars, - FnProperty prop = FnProperty::kNormal, + FnProperty prop = FnProperty::kNormal, const char* opr_name = nullptr, - bool wait = false) override { - NaiveOpr *opr = new NaiveOpr(); - opr->fn = fn; - opr->const_vars = const_vars; + bool wait = false) override { + NaiveOpr* opr = new NaiveOpr(); + opr->fn = fn; + opr->const_vars = const_vars; opr->mutable_vars = mutable_vars; - opr->prop = prop; - opr->opr_name = opr_name ? std::string(opr_name) : std::string(); + opr->prop = prop; + opr->opr_name = opr_name ? std::string(opr_name) : std::string(); return opr; } void DeleteOperator(OprHandle op) override { - NaiveOpr *opr = op->Cast(); + NaiveOpr* opr = op->Cast(); delete opr; } void Push(OprHandle op, Context exec_ctx, int priority = 0, bool profiling = false) override { - profiler::Profiler *profiler = profiler::Profiler::Get(); - NaiveOpr *opr = op->Cast(); + profiler::Profiler* profiler = profiler::Profiler::Get(); + NaiveOpr* opr = op->Cast(); opr->profiling = profiling && profiler->IsProfiling(profiler::Profiler::kSymbolic); - this->PushAsync([&](RunContext ctx, CallbackOnComplete on_complete) { - if (opr->profiling) { - std::unique_ptr attrs; - if (profiler->AggregateEnabled()) { - attrs = std::make_unique(); + this->PushAsync( + [&](RunContext ctx, CallbackOnComplete on_complete) { + if (opr->profiling) { + std::unique_ptr attrs; + if (profiler->AggregateEnabled()) { + attrs = std::make_unique(); + } + opr->opr_profile = + std::make_unique(opr->opr_name.c_str(), attrs.release()); + opr->opr_profile->startForDevice(exec_ctx.dev_type, exec_ctx.dev_id); } - opr->opr_profile = std::make_unique(opr->opr_name.c_str(), - attrs.release()); - opr->opr_profile->startForDevice(exec_ctx.dev_type, exec_ctx.dev_id); - } - opr->fn(ctx, on_complete); - if (opr->profiling) { - opr->opr_profile->stop(); - } - }, - exec_ctx, - opr->const_vars, - opr->mutable_vars, - opr->prop, - priority, - opr->opr_name.c_str()); + opr->fn(ctx, on_complete); + if (opr->profiling) { + opr->opr_profile->stop(); + } + }, + exec_ctx, + opr->const_vars, + opr->mutable_vars, + opr->prop, + priority, + opr->opr_name.c_str()); } -/*! - * \brief NaiveEngine's PushAsync was intentionally synchronous. - * User should not make any assumption about execution order when using async interface of any engine. - */ + /*! + * \brief NaiveEngine's PushAsync was intentionally synchronous. + * User should not make any assumption about execution order when using async interface of any + * engine. + */ void PushAsync(AsyncFn exec_fun, Context exec_ctx, std::vector const& const_vars, std::vector const& mutable_vars, - FnProperty prop = FnProperty::kNormal, - int priority = 0, + FnProperty prop = FnProperty::kNormal, + int priority = 0, const char* opr_name = nullptr, - bool wait = false) override { + bool wait = false) override { std::promise promise; - std::future future = promise.get_future(); - CallbackOnComplete callback = CreateCallback( - NaiveEngine::OnComplete, &promise); - profiler::Profiler *profiler = profiler::Profiler::Get(); - auto opr_deleter = [this](NaiveOpr* p) { - this->DeleteOperator(p); - }; + std::future future = promise.get_future(); + CallbackOnComplete callback = CreateCallback(NaiveEngine::OnComplete, &promise); + profiler::Profiler* profiler = profiler::Profiler::Get(); + auto opr_deleter = [this](NaiveOpr* p) { this->DeleteOperator(p); }; std::unique_ptr opr(nullptr, opr_deleter); const bool profiling = opr_name && profiler->IsProfiling(profiler::Profiler::kImperative); // GenerateDisplayName() will return a pointer to the correct name of the operator - const char* display_name = profiling ? - profiler::CustomOpProfiler::Get()->GenerateDisplayName(opr_name) : - opr_name; + const char* display_name = + profiling ? profiler::CustomOpProfiler::Get()->GenerateDisplayName(opr_name) : opr_name; if (profiling) { - opr.reset(NewOperator(exec_fun, const_vars, mutable_vars, - prop, display_name)->Cast()); + opr.reset( + NewOperator(exec_fun, const_vars, mutable_vars, prop, display_name)->Cast()); opr->profiling = profiling; std::unique_ptr attrs; if (profiler->AggregateEnabled()) { attrs = std::make_unique(); } - opr->opr_profile = std::make_unique(opr->opr_name.c_str(), - attrs.release()); + opr->opr_profile = + std::make_unique(opr->opr_name.c_str(), attrs.release()); opr->opr_profile->startForDevice(exec_ctx.dev_type, exec_ctx.dev_id); } if (exec_ctx.dev_mask() == gpu::kDevMask) { @@ -193,7 +187,7 @@ class NaiveEngine final : public Engine { aux_streams_.resize(dev_id + 1, nullptr); } if (streams_[dev_id] == nullptr) { - streams_[dev_id] = mshadow::NewStream(true, MXNET_USE_CUDNN != 0, dev_id); + streams_[dev_id] = mshadow::NewStream(true, MXNET_USE_CUDNN != 0, dev_id); aux_streams_[dev_id] = new GPUAuxStream(streams_[dev_id]); } exec_fun(RunContext{exec_ctx, streams_[dev_id], aux_streams_[dev_id], false}, callback); @@ -215,21 +209,25 @@ class NaiveEngine final : public Engine { void DeleteVariable(SyncFn delete_fn, Context exec_ctx, VarHandle var) override { NaiveVar* naive_var = NaiveVar::CastFromBase(var); - this->PushAsync([delete_fn, naive_var](RunContext ctx, CallbackOnComplete on_complete) mutable { - delete_fn(ctx); - NaiveVar::Delete(naive_var); - on_complete(); - }, exec_ctx, {}, {var}, FnProperty::kDeleteVar, 0, "DeleteVariable"); + this->PushAsync( + [delete_fn, naive_var](RunContext ctx, CallbackOnComplete on_complete) mutable { + delete_fn(ctx); + NaiveVar::Delete(naive_var); + on_complete(); + }, + exec_ctx, + {}, + {var}, + FnProperty::kDeleteVar, + 0, + "DeleteVariable"); } - void WaitForVar(VarHandle var) override { - } + void WaitForVar(VarHandle var) override {} - void WaitForAll() override { - } + void WaitForAll() override {} - void Throw(VarHandle var) override { - } + void Throw(VarHandle var) override {} void NotifyShutdown() override { shutdown_phase_.store(true); @@ -237,8 +235,7 @@ class NaiveEngine final : public Engine { private: // callback to oncomplete - static void OnComplete(Engine *engine, void *param, - const dmlc::Error* error) { + static void OnComplete(Engine* engine, void* param, const dmlc::Error* error) { static_cast*>(param)->set_value(); } /*! \brief whether it is during shutdown phase*/ @@ -251,16 +248,17 @@ class NaiveEngine final : public Engine { // GPU auxiliary streams std::vector aux_streams_; #endif -/*! - * \brief Holding a shared_ptr to the object pool to prevent it from being destructed too early - * See also #309 (https://github.com/dmlc/mxnet/issues/309) and similar fix in threaded_engine.h. - * Without this, segfaults seen on CentOS7 in test_operator_gpu.py:test_convolution_multiple_streams - */ + /*! + * \brief Holding a shared_ptr to the object pool to prevent it from being destructed too early + * See also #309 (https://github.com/dmlc/mxnet/issues/309) and similar fix in threaded_engine.h. + * Without this, segfaults seen on CentOS7 in + * test_operator_gpu.py:test_convolution_multiple_streams + */ std::shared_ptr > objpool_opr_ref_; std::shared_ptr > objpool_var_ref_; }; // class NaiveEngine -Engine *CreateNaiveEngine() { +Engine* CreateNaiveEngine() { return new NaiveEngine(); } diff --git a/src/engine/openmp.cc b/src/engine/openmp.cc index 0d31f71aa9a3..c031b089b61a 100644 --- a/src/engine/openmp.cc +++ b/src/engine/openmp.cc @@ -29,17 +29,16 @@ namespace engine { #define ARCH_IS_INTEL_X86 #endif -static inline bool is_env_set(const char *var) { +static inline bool is_env_set(const char* var) { return dmlc::GetEnv(var, INT_MIN) != INT_MIN; } -OpenMP *OpenMP::Get() { +OpenMP* OpenMP::Get() { static OpenMP openMP; return &openMP; } -OpenMP::OpenMP() - : omp_num_threads_set_in_environment_(is_env_set("OMP_NUM_THREADS")) { +OpenMP::OpenMP() : omp_num_threads_set_in_environment_(is_env_set("OMP_NUM_THREADS")) { #ifdef _OPENMP initialize_process(); const int max = dmlc::GetEnv("MXNET_OMP_MAX_THREADS", INT_MIN); @@ -57,12 +56,12 @@ OpenMP::OpenMP() } } #else - enabled_ = false; + enabled_ = false; omp_thread_max_ = 1; #endif } -void OpenMP:: initialize_process() { +void OpenMP::initialize_process() { #ifdef _OPENMP omp_get_num_procs(); // will force OpenMP to be initialized #endif @@ -116,8 +115,7 @@ int OpenMP::GetRecommendedOMPThreadCount(bool exclude_reserved) const { #endif } -OpenMP *__init_omp__ = OpenMP::Get(); +OpenMP* __init_omp__ = OpenMP::Get(); } // namespace engine } // namespace mxnet - diff --git a/src/engine/openmp.h b/src/engine/openmp.h index 94b83e3aa25b..83f22a4bf42c 100644 --- a/src/engine/openmp.h +++ b/src/engine/openmp.h @@ -42,19 +42,27 @@ class OpenMP { * \brief Set whether clients of this class receive pro-OMP behavior guidance * \param enabled Set to 'true' if this class should provide OMP behavior */ - void set_enabled(bool enabled) { enabled_ = enabled; } - bool enabled() const { return enabled_; } + void set_enabled(bool enabled) { + enabled_ = enabled; + } + bool enabled() const { + return enabled_; + } /*! * \brief Set maximum number of threads to be used in an OMP region * \param thread_max Maximum number of threads to be used in an OMP region */ - void set_thread_max(int thread_max) { omp_thread_max_ = thread_max; } + void set_thread_max(int thread_max) { + omp_thread_max_ = thread_max; + } /*! * \brief Maximum number of threads to be used in an OMP region * \return Maximum number of threads */ - int thread_max() const { return omp_thread_max_; } + int thread_max() const { + return omp_thread_max_; + } /*! * \brief Reserve cores to be excluded from OMP regions @@ -65,7 +73,9 @@ class OpenMP { * \brief Get number of cores to be excluded from OMP regions * \return Number of cores to be excluded from OMP regions */ - int reserve_cores() const { return reserve_cores_; } + int reserve_cores() const { + return reserve_cores_; + } /*! * \brief Call at the beginning of a worker thread's life. This will set the omp_num_threads @@ -85,7 +95,7 @@ class OpenMP { * \brief Get the OpenMP object's singleton pointer * \return Singleton OpenMP object pointer */ - static OpenMP *Get(); + static OpenMP* Get(); private: /*! diff --git a/src/engine/stream_manager.h b/src/engine/stream_manager.h index da1e4bc436ab..342e47519557 100644 --- a/src/engine/stream_manager.h +++ b/src/engine/stream_manager.h @@ -50,13 +50,12 @@ class StreamManager { RunContext GetRunContext(Context const& ctx); RunContext GetIORunContext(Context const& ctx); void Finalize(); + private: std::mutex mutex_; #if MXNET_USE_CUDA - std::array*, kStreams>, kNumGpus> - gpu_streams_; - std::array, kNumGpus> - gpu_aux_streams_; + std::array*, kStreams>, kNumGpus> gpu_streams_; + std::array, kNumGpus> gpu_aux_streams_; std::array*, kNumGpus> gpu_io_streams_; std::array gpu_cnt_; #endif // MXNET_USE_CUDA @@ -64,8 +63,7 @@ class StreamManager { }; // class StreamManager template -RunContext StreamManager::GetRunContext( - Context const& ctx) { +RunContext StreamManager::GetRunContext(Context const& ctx) { RunContext ret; switch (ctx.dev_mask()) { case cpu::kDevMask: @@ -85,12 +83,12 @@ RunContext StreamManager::GetRunContext( int idx = 0; for (auto&& aux_stream : gpu_aux_streams_.at(ctx.dev_id)) { auto primary_stream = gpu_streams_.at(ctx.dev_id).at(idx++); - aux_stream = new GPUAuxStream(primary_stream); + aux_stream = new GPUAuxStream(primary_stream); } counter = 0; } use_counter = counter; - counter = (counter + 1) % kStreams; + counter = (counter + 1) % kStreams; } ret = RunContext{ctx, gpu_streams_.at(ctx.dev_id).at(use_counter), @@ -100,16 +98,15 @@ RunContext StreamManager::GetRunContext( #else LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; #endif // MXNET_USE_CUDA - default: - LOG(FATAL) << "Not Reached"; + default: + LOG(FATAL) << "Not Reached"; } } return ret; } template -RunContext StreamManager::GetIORunContext( - Context const& ctx) { +RunContext StreamManager::GetIORunContext(Context const& ctx) { RunContext ret; switch (ctx.dev_mask()) { case cpu::kDevMask: @@ -129,8 +126,8 @@ RunContext StreamManager::GetIORunContext( #else LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; #endif // MXNET_USE_CUDA - default: - LOG(FATAL) << "Not Reached"; + default: + LOG(FATAL) << "Not Reached"; } } return ret; diff --git a/src/engine/thread_pool.h b/src/engine/thread_pool.h index a48ac1bb1555..580f22820843 100644 --- a/src/engine/thread_pool.h +++ b/src/engine/thread_pool.h @@ -43,14 +43,13 @@ class ThreadPool { /*! \brief Signal event upon destruction, even for exceptions (RAII) */ struct SetReadyOnDestroy { explicit inline SetReadyOnDestroy(const std::shared_ptr& event) - : event_(event) { - } + : event_(event) {} inline ~SetReadyOnDestroy() { if (event_) { event_->signal(); } } - std::shared_ptr event_; + std::shared_ptr event_; }; /*! @@ -58,8 +57,7 @@ class ThreadPool { * \param size size of the thread pool. * \param func the function to run on the thread pool. */ - explicit ThreadPool(size_t size, std::function func) - : worker_threads_(size) { + explicit ThreadPool(size_t size, std::function func) : worker_threads_(size) { CHECK_GT(size, 0); for (auto& i : worker_threads_) { i = std::thread(func); diff --git a/src/engine/threaded_engine.cc b/src/engine/threaded_engine.cc index 3eda2c8712f7..0e206c89a201 100644 --- a/src/engine/threaded_engine.cc +++ b/src/engine/threaded_engine.cc @@ -63,9 +63,9 @@ inline void ThreadedVar::AppendReadDependency(OprBlock* opr_block) { assert(head_->trigger == nullptr); assert(head_->write == false); // append things to next. - head_->next = new_var_block; + head_->next = new_var_block; head_->trigger = opr_block; - head_ = new_var_block; + head_ = new_var_block; } } @@ -77,9 +77,9 @@ inline void ThreadedVar::AppendWriteDependency(OprBlock* opr_block) { assert(head_->trigger == nullptr); assert(head_->write == false); // attach to head. - head_->next = new_var_block; + head_->next = new_var_block; head_->trigger = opr_block; - head_->write = true; + head_->write = true; // check if it is ready to write if (pending_write_ == nullptr) { @@ -99,7 +99,7 @@ inline void ThreadedVar::AppendWriteDependency(OprBlock* opr_block) { template inline void ThreadedVar::CompleteReadDependency(Dispatcher dispatcher) { - OprBlock *trigger = nullptr; + OprBlock* trigger = nullptr; { // this is lock scope std::lock_guard lock{mutex_}; @@ -108,7 +108,7 @@ inline void ThreadedVar::CompleteReadDependency(Dispatcher dispatcher) { if (--num_pending_reads_ == 0) { if (pending_write_ != nullptr) { // STATE CHANGE - trigger = pending_write_->trigger; + trigger = pending_write_->trigger; num_pending_reads_ = kWriteTriggered; } } @@ -135,7 +135,7 @@ inline bool ThreadedVar::CompleteWriteDependency(Dispatcher dispatcher) { // really delete if (to_delete_) { - VersionedVarBlock *head = pending_write_->next; + VersionedVarBlock* head = pending_write_->next; VersionedVarBlock::Delete(pending_write_); assert(head_ == head); VersionedVarBlock::Delete(head); @@ -147,8 +147,7 @@ inline bool ThreadedVar::CompleteWriteDependency(Dispatcher dispatcher) { end_of_read_chain = old_pending_write->next; // reset to 0 pending reads num_pending_reads_ = 0; - while (end_of_read_chain != head_ && - end_of_read_chain->write == false) { + while (end_of_read_chain != head_ && end_of_read_chain->write == false) { ++num_pending_reads_; end_of_read_chain = end_of_read_chain->next; } @@ -161,7 +160,7 @@ inline bool ThreadedVar::CompleteWriteDependency(Dispatcher dispatcher) { if (num_pending_reads_ == 0) { // mark write as already activated in this var num_pending_reads_ = kWriteTriggered; - trigger_write = end_of_read_chain->trigger; + trigger_write = end_of_read_chain->trigger; } } } @@ -171,7 +170,7 @@ inline bool ThreadedVar::CompleteWriteDependency(Dispatcher dispatcher) { // The linked list \in [old_pending_write, end_of_read_chain) // is already detached from this Var. // So it is safe to modify these - VersionedVarBlock *cur_head = old_pending_write->next; + VersionedVarBlock* cur_head = old_pending_write->next; VersionedVarBlock::Delete(old_pending_write); // dispatch all the events while (cur_head != end_of_read_chain) { @@ -179,7 +178,7 @@ inline bool ThreadedVar::CompleteWriteDependency(Dispatcher dispatcher) { dispatcher(cur_head->trigger); } auto prev = cur_head; - cur_head = cur_head->next; + cur_head = cur_head->next; assert(cur_head != nullptr); VersionedVarBlock::Delete(prev); } @@ -209,24 +208,25 @@ ThreadedVar* ThreadedEngine::NewVariable() { return ThreadedVar::New(VersionedVarBlock::New()); } -ThreadedOpr* ThreadedEngine::NewOperator( - ThreadedEngine::AsyncFn fn, - std::vector const& const_vars, - std::vector const& mutable_vars, - FnProperty prop, - const char* opr_name, - bool wait) { - auto ret = ThreadedOpr::New(); +ThreadedOpr* ThreadedEngine::NewOperator(ThreadedEngine::AsyncFn fn, + std::vector const& const_vars, + std::vector const& mutable_vars, + FnProperty prop, + const char* opr_name, + bool wait) { + auto ret = ThreadedOpr::New(); ret->opr_name = opr_name ? std::string(opr_name) : std::string(); - ret->fn = std::move(fn); - ret->prop = prop; + ret->fn = std::move(fn); + ret->prop = prop; ret->const_vars.resize(const_vars.size()); ret->mutable_vars.resize(mutable_vars.size()); ret->wait = wait; - std::transform(const_vars.begin(), const_vars.end(), - ret->const_vars.begin(), ThreadedVar::CastFromBase); - std::transform(mutable_vars.begin(), mutable_vars.end(), - ret->mutable_vars.begin(), ThreadedVar::CastFromBase); + std::transform( + const_vars.begin(), const_vars.end(), ret->const_vars.begin(), ThreadedVar::CastFromBase); + std::transform(mutable_vars.begin(), + mutable_vars.end(), + ret->mutable_vars.begin(), + ThreadedVar::CastFromBase); if (ENGINE_DEBUG != 0) { CheckDuplicate(const_vars, mutable_vars); } @@ -236,9 +236,9 @@ ThreadedOpr* ThreadedEngine::NewOperator( void ThreadedEngine::CheckDuplicate(std::vector const& const_vars, std::vector const& mutable_vars) { // Check for duplicates. - auto use = const_vars; - auto mutate = mutable_vars; - const size_t use_size = use.size(); + auto use = const_vars; + auto mutate = mutable_vars; + const size_t use_size = use.size(); const size_t mutate_size = mutate.size(); std::sort(use.begin(), use.end()); std::sort(mutate.begin(), mutate.end()); @@ -261,8 +261,7 @@ void ThreadedEngine::CheckDuplicate(std::vector const& const_vars, break; } if (mutate.at(j) == use.at(i)) { - LOG(FATAL) - << "duplicate items found between `const_vars` and `mutable_vars`"; + LOG(FATAL) << "duplicate items found between `const_vars` and `mutable_vars`"; } } } @@ -270,19 +269,20 @@ void ThreadedEngine::CheckDuplicate(std::vector const& const_vars, void ThreadedEngine::DeleteOperator(OprHandle op) { ThreadedOpr* threaded_opr = ThreadedOpr::CastFromBase(op); std::vector deps; - deps.reserve(threaded_opr->const_vars.size() + - threaded_opr->mutable_vars.size()); - deps.insert(deps.end(), - threaded_opr->const_vars.begin(), - threaded_opr->const_vars.end()); - deps.insert(deps.end(), - threaded_opr->mutable_vars.begin(), - threaded_opr->mutable_vars.end()); - this->PushAsync([threaded_opr](RunContext, CallbackOnComplete on_complete) { - ThreadedOpr::Delete(threaded_opr); - on_complete(); - }, Context::CPU(), {}, deps, FnProperty::kDeleteVar, 0, - "DeleteOperator"); + deps.reserve(threaded_opr->const_vars.size() + threaded_opr->mutable_vars.size()); + deps.insert(deps.end(), threaded_opr->const_vars.begin(), threaded_opr->const_vars.end()); + deps.insert(deps.end(), threaded_opr->mutable_vars.begin(), threaded_opr->mutable_vars.end()); + this->PushAsync( + [threaded_opr](RunContext, CallbackOnComplete on_complete) { + ThreadedOpr::Delete(threaded_opr); + on_complete(); + }, + Context::CPU(), + {}, + deps, + FnProperty::kDeleteVar, + 0, + "DeleteOperator"); } void ThreadedEngine::Push(OprHandle op, Context exec_ctx, int priority, bool profiling) { @@ -293,13 +293,12 @@ void ThreadedEngine::Push(OprHandle op, Context exec_ctx, int priority, bool pro profiler::CustomOpProfiler::Get()->GenerateDisplayName(threaded_opr->opr_name.c_str()); } OprBlock* opr_block = OprBlock::New(); - opr_block->opr = threaded_opr; + opr_block->opr = threaded_opr; - opr_block->wait.store(static_cast( - threaded_opr->const_vars.size() + - threaded_opr->mutable_vars.size() + 1)); - opr_block->ctx = exec_ctx; - opr_block->priority = priority; + opr_block->wait.store( + static_cast(threaded_opr->const_vars.size() + threaded_opr->mutable_vars.size() + 1)); + opr_block->ctx = exec_ctx; + opr_block->priority = priority; opr_block->profiling = profiling; ++pending_; // Add read dependencies. @@ -315,7 +314,8 @@ void ThreadedEngine::Push(OprHandle op, Context exec_ctx, int priority, bool pro } } -void ThreadedEngine::PushAsync(AsyncFn fn, Context exec_ctx, +void ThreadedEngine::PushAsync(AsyncFn fn, + Context exec_ctx, std::vector const& const_vars, std::vector const& mutable_vars, FnProperty prop, @@ -332,48 +332,59 @@ void ThreadedEngine::PushAsync(AsyncFn fn, Context exec_ctx, } CHECK_LT(exec_ctx.dev_id, device_count_) << "Invalid GPU Id: " << exec_ctx.dev_id - << ", Valid device id should be less than device_count: " - << device_count_; + << ", Valid device id should be less than device_count: " << device_count_; } #endif const bool profiling = profiler_->IsProfiling(profiler::Profiler::kImperative); - ThreadedOpr *opr = NewOperator(std::move(fn), const_vars, mutable_vars, - prop, opr_name, wait); - opr->temporary = true; + ThreadedOpr* opr = NewOperator(std::move(fn), const_vars, mutable_vars, prop, opr_name, wait); + opr->temporary = true; Push(opr, exec_ctx, priority, profiling); } -void ThreadedEngine::PushSync(SyncFn exec_fn, Context exec_ctx, +void ThreadedEngine::PushSync(SyncFn exec_fn, + Context exec_ctx, std::vector const& const_vars, std::vector const& mutable_vars, FnProperty prop, int priority, const char* opr_name) { if (!bulk_size() || prop != FnProperty::kNormal || priority) { - this->PushAsync([exec_fn](RunContext ctx, CallbackOnComplete on_complete) { - exec_fn(ctx); - on_complete(); - }, exec_ctx, const_vars, mutable_vars, prop, priority, opr_name); + this->PushAsync( + [exec_fn](RunContext ctx, CallbackOnComplete on_complete) { + exec_fn(ctx); + on_complete(); + }, + exec_ctx, + const_vars, + mutable_vars, + prop, + priority, + opr_name); return; } const BulkStatus& bulk_status = *BulkStatusStore::Get(); - if (bulk_status.count && exec_ctx != bulk_status.ctx) BulkFlush(); + if (bulk_status.count && exec_ctx != bulk_status.ctx) + BulkFlush(); BulkAppend(exec_fn, exec_ctx, const_vars, mutable_vars); } -void ThreadedEngine::DeleteVariable(SyncFn delete_fn, - Context exec_ctx, - VarHandle var) { +void ThreadedEngine::DeleteVariable(SyncFn delete_fn, Context exec_ctx, VarHandle var) { ThreadedVar* threaded_var = ThreadedVar::CastFromBase(var); - this->PushAsync([delete_fn, threaded_var](RunContext ctx, CallbackOnComplete on_complete) { - // Mark variable as orphan, - // so during `ThreadedEngine::OnComplete` it could be recycled. - threaded_var->SetToDelete(); - delete_fn(ctx); - on_complete(); - }, exec_ctx, {}, {var}, FnProperty::kDeleteVar, 0, - "DeleteVariable"); + this->PushAsync( + [delete_fn, threaded_var](RunContext ctx, CallbackOnComplete on_complete) { + // Mark variable as orphan, + // so during `ThreadedEngine::OnComplete` it could be recycled. + threaded_var->SetToDelete(); + delete_fn(ctx); + on_complete(); + }, + exec_ctx, + {}, + {var}, + FnProperty::kDeleteVar, + 0, + "DeleteVariable"); } void ThreadedEngine::WaitForVar(VarHandle var) { @@ -388,26 +399,31 @@ void ThreadedEngine::WaitForVar(VarHandle var) { debug_wait_var_ = threaded_var; } std::atomic done{false}; - this->PushAsync([this, &done](RunContext, CallbackOnComplete on_complete) { - if (engine_info_) { - LOG(INFO) << "Sync is executed"; - } - { - std::unique_lock lock{finished_m_}; - done.store(true); - } - finished_cv_.notify_all(); - if (engine_info_) { - LOG(INFO) << "Sync is notified"; - } - on_complete(); - }, Context::CPU(), {var}, {}, FnProperty::kNormal, 0, - "WaitForVar", true); + this->PushAsync( + [this, &done](RunContext, CallbackOnComplete on_complete) { + if (engine_info_) { + LOG(INFO) << "Sync is executed"; + } + { + std::unique_lock lock{finished_m_}; + done.store(true); + } + finished_cv_.notify_all(); + if (engine_info_) { + LOG(INFO) << "Sync is notified"; + } + on_complete(); + }, + Context::CPU(), + {var}, + {}, + FnProperty::kNormal, + 0, + "WaitForVar", + true); { std::unique_lock lock{finished_m_}; - finished_cv_.wait(lock, [this, &done]() { - return done.load() || kill_.load(); - }); + finished_cv_.wait(lock, [this, &done]() { return done.load() || kill_.load(); }); } ThrowException(threaded_var); @@ -416,9 +432,7 @@ void ThreadedEngine::WaitForVar(VarHandle var) { void ThreadedEngine::WaitForAll() { BulkFlush(); std::unique_lock lock{finished_m_}; - finished_cv_.wait(lock, [this]() { - return pending_.load() == 0 || kill_.load(); - }); + finished_cv_.wait(lock, [this]() { return pending_.load() == 0 || kill_.load(); }); std::exception_ptr exception_to_rethrow = nullptr; if (!global_exception_refs_.empty()) { // iterate through all exception refs @@ -442,8 +456,7 @@ inline void ThreadedEngine::OnComplete(ThreadedOpr* threaded_opr) { bool is_temporary_opr = threaded_opr->temporary; // Mark complete for read variables for (auto&& i : threaded_opr->const_vars) { - i->CompleteReadDependency( - [this](OprBlock* opr) { this->PushToExecute(opr, false); }); + i->CompleteReadDependency([this](OprBlock* opr) { this->PushToExecute(opr, false); }); } // Mark complete for write variables. for (auto&& i : threaded_opr->mutable_vars) { @@ -457,17 +470,16 @@ inline void ThreadedEngine::OnComplete(ThreadedOpr* threaded_opr) { if (debug_info) { LOG(INFO) << "Complete write dep for " << i; } - const bool to_delete = - i->CompleteWriteDependency([this, debug_info](OprBlock* opr) { - if (debug_info) { - LOG(INFO) << "PushToExecute " << opr; - debug_push_opr_ = opr; - } - this->PushToExecute(opr, false); - if (debug_info) { - LOG(INFO) << "Fin PushToExecute " << opr; - } - }); + const bool to_delete = i->CompleteWriteDependency([this, debug_info](OprBlock* opr) { + if (debug_info) { + LOG(INFO) << "PushToExecute " << opr; + debug_push_opr_ = opr; + } + this->PushToExecute(opr, false); + if (debug_info) { + LOG(INFO) << "Fin PushToExecute " << opr; + } + }); if (to_delete) { ThreadedVar::Delete(i); } @@ -495,7 +507,7 @@ inline void ThreadedEngine::OnComplete(ThreadedOpr* threaded_opr) { inline void ThreadedEngine::ThrowException(ThreadedVar* threaded_var) { if (threaded_var->var_exception && *threaded_var->var_exception) { - std::exception_ptr tmp = *threaded_var->var_exception; + std::exception_ptr tmp = *threaded_var->var_exception; *threaded_var->var_exception = nullptr; std::rethrow_exception(tmp); } @@ -503,16 +515,15 @@ inline void ThreadedEngine::ThrowException(ThreadedVar* threaded_var) { } void ThreadedEngine::Throw(VarHandle var) { - ThreadedVar *threaded_var = ThreadedVar::CastFromBase(var); + ThreadedVar* threaded_var = ThreadedVar::CastFromBase(var); ThrowException(threaded_var); } -void ThreadedEngine::OnCompleteStatic(Engine *engine, void *opr_block_, - const dmlc::Error* error) { - OprBlock *opr_block = static_cast(opr_block_); - ThreadedOpr *threaded_opr = opr_block->opr; +void ThreadedEngine::OnCompleteStatic(Engine* engine, void* opr_block_, const dmlc::Error* error) { + OprBlock* opr_block = static_cast(opr_block_); + ThreadedOpr* threaded_opr = opr_block->opr; if (error != nullptr) { - auto ex_p = std::make_exception_ptr(*error); + auto ex_p = std::make_exception_ptr(*error); threaded_opr->opr_exception = std::make_shared(ex_p); } if (opr_block->profiling && threaded_opr->opr_name.size()) { diff --git a/src/engine/threaded_engine.h b/src/engine/threaded_engine.h index aa0e5a22fb1e..45a02a57a931 100644 --- a/src/engine/threaded_engine.h +++ b/src/engine/threaded_engine.h @@ -50,10 +50,14 @@ namespace engine { // Define helper macros for debug information. #if ENGINE_DEBUG -#define DEFINE_ENGINE_DEBUG_INFO(Type) \ - static std::atomic counter; \ - Type() { LOG(INFO) << __func__ << " " << ++counter; } \ - ~Type() { LOG(INFO) << __func__ << " " << --counter; } +#define DEFINE_ENGINE_DEBUG_INFO(Type) \ + static std::atomic counter; \ + Type() { \ + LOG(INFO) << __func__ << " " << ++counter; \ + } \ + ~Type() { \ + LOG(INFO) << __func__ << " " << --counter; \ + } #else #define DEFINE_ENGINE_DEBUG_INFO(Type) #endif @@ -101,8 +105,7 @@ struct OprBlock : public common::ObjectPoolAllocatable { * \brief VersionedVarBlock that corresponding to a variable version. * This is a basic unit of LinkedList in the ThreadedVar. */ -struct VersionedVarBlock - : public common::ObjectPoolAllocatable { +struct VersionedVarBlock : public common::ObjectPoolAllocatable { /*! \brief next block in the LinkedList */ VersionedVarBlock* next{nullptr}; /*! \brief the operation this block triggers */ @@ -117,8 +120,7 @@ struct VersionedVarBlock * \brief Variable implementation. * Each ThreadedVar is a linked list(queue) of operations to be performed. */ -class ThreadedVar final - : public Var, public common::ObjectPoolAllocatable { +class ThreadedVar final : public Var, public common::ObjectPoolAllocatable { public: /*! * \brief constructor @@ -179,7 +181,9 @@ class ThreadedVar final // code for debug. #if ENGINE_DEBUG static std::atomic counter; - ~ThreadedVar() { LOG(INFO) << __func__ << " " << --counter; } + ~ThreadedVar() { + LOG(INFO) << __func__ << " " << --counter; + } #endif // ENGINE_DEBUG /*! * \brief exception_ptr associated with the ThreadedOpr @@ -231,8 +235,7 @@ class ThreadedVar final /*! * \brief Operator used in ThreadedEngine. */ -struct ThreadedOpr final : public Opr, - public common::ObjectPoolAllocatable { +struct ThreadedOpr final : public Opr, public common::ObjectPoolAllocatable { /*! \brief The function to be invoked each time. */ Engine::AsyncFn fn; /*! \brief The variable this operation will read from. */ @@ -286,23 +289,25 @@ class ThreadedEngine : public Engine { ThreadedOpr* NewOperator(AsyncFn fn, std::vector const& const_vars, std::vector const& mutable_vars, - FnProperty prop = FnProperty::kNormal, + FnProperty prop = FnProperty::kNormal, const char* opr_name = nullptr, - bool wait = false) override; + bool wait = false) override; void DeleteOperator(OprHandle op) override; void Push(OprHandle op, Context exec_ctx, int priority = 0, bool profiling = false) override; - void PushAsync(AsyncFn exec_fun, Context exec_ctx, + void PushAsync(AsyncFn exec_fun, + Context exec_ctx, std::vector const& const_vars, std::vector const& mutable_vars, - FnProperty prop = FnProperty::kNormal, - int priority = 0, + FnProperty prop = FnProperty::kNormal, + int priority = 0, const char* opr_name = nullptr, - bool wait = false) override; - void PushSync(SyncFn exec_fn, Context exec_ctx, + bool wait = false) override; + void PushSync(SyncFn exec_fn, + Context exec_ctx, std::vector const& const_vars, std::vector const& mutable_vars, - FnProperty prop = FnProperty::kNormal, - int priority = 0, + FnProperty prop = FnProperty::kNormal, + int priority = 0, const char* opr_name = nullptr) override; void DeleteVariable(SyncFn delete_fn, Context exec_ctx, VarHandle var) override; void WaitForVar(VarHandle var) override; @@ -357,16 +362,14 @@ class ThreadedEngine : public Engine { attrs.reset(new profiler::ProfileOperator::Attributes()); } const Context& ctx = opr_block->ctx; - opr_block->opr_profile.reset(new profiler::ProfileOperator(threaded_opr->opr_name.c_str(), - attrs.release())); + opr_block->opr_profile.reset( + new profiler::ProfileOperator(threaded_opr->opr_name.c_str(), attrs.release())); opr_block->opr_profile->startForDevice(ctx.dev_type, ctx.dev_id); } - CallbackOnComplete callback = - this->CreateCallback(ThreadedEngine::OnCompleteStatic, opr_block); - const bool debug_info = (engine_info_ && debug_push_opr_ == opr_block); + CallbackOnComplete callback = this->CreateCallback(ThreadedEngine::OnCompleteStatic, opr_block); + const bool debug_info = (engine_info_ && debug_push_opr_ == opr_block); if (debug_info) { - LOG(INFO) << "ExecuteOprBlock " << opr_block - << "shutdown_phase=" << shutdown_phase_; + LOG(INFO) << "ExecuteOprBlock " << opr_block << "shutdown_phase=" << shutdown_phase_; } // still run cleanup in shutdown_phase if (!shutdown_phase_ || threaded_opr->prop == FnProperty::kDeleteVar) { @@ -377,7 +380,8 @@ class ThreadedEngine : public Engine { } try { if ((!(threaded_opr->opr_exception && *threaded_opr->opr_exception) || - threaded_opr->prop == FnProperty::kNoSkip) || threaded_opr->wait) { + threaded_opr->prop == FnProperty::kNoSkip) || + threaded_opr->wait) { threaded_opr->fn(run_ctx, callback); } else { callback(); @@ -392,18 +396,16 @@ class ThreadedEngine : public Engine { } } catch (std::exception& e) { std::string what = e.what(); - if (what.find("driver shutting down") == std::string::npos && - !shutdown_phase_) { - LOG(FATAL) - << e.what() << "\n" - << "A fatal error occurred in asynchronous engine operation. " - "If you do not know what caused this error, " - "you can try set environment variable MXNET_ENGINE_TYPE " - "to NaiveEngine and run with debugger (i.e. gdb). " - "This will force all operations to be synchronous and " - "backtrace will give you the series of calls that lead " - "to this error. Remember to set MXNET_ENGINE_TYPE back to " - "empty after debugging."; + if (what.find("driver shutting down") == std::string::npos && !shutdown_phase_) { + LOG(FATAL) << e.what() << "\n" + << "A fatal error occurred in asynchronous engine operation. " + "If you do not know what caused this error, " + "you can try set environment variable MXNET_ENGINE_TYPE " + "to NaiveEngine and run with debugger (i.e. gdb). " + "This will force all operations to be synchronous and " + "backtrace will give you the series of calls that lead " + "to this error. Remember to set MXNET_ENGINE_TYPE back to " + "empty after debugging."; } } } else { @@ -412,14 +414,15 @@ class ThreadedEngine : public Engine { } int bulk_size() const override { - const profiler::Profiler *prof = profiler::Profiler::Get(); - return (prof && prof->AggregateRunning()) ? 0 : BulkStatusStore::Get()->bulk_size; + const profiler::Profiler* prof = profiler::Profiler::Get(); + return (prof && prof->AggregateRunning()) ? 0 : BulkStatusStore::Get()->bulk_size; } int set_bulk_size(int bulk_size) override { BulkStatus& bulk_status = *BulkStatusStore::Get(); std::swap(bulk_status.bulk_size, bulk_size); - if (bulk_status.count >= bulk_status.bulk_size) BulkFlush(); + if (bulk_status.count >= bulk_status.bulk_size) + BulkFlush(); if (!bulk_status.functions) { bulk_status.functions.reset(new std::vector()); } @@ -489,22 +492,22 @@ class ThreadedEngine : public Engine { } } - static void OnCompleteStatic(Engine *engine, void *threaded_opr, - const dmlc::Error* error); + static void OnCompleteStatic(Engine* engine, void* threaded_opr, const dmlc::Error* error); /*! * \brief find exception in global_exception_refs and add it if missing * \param opr_exception the exception to be added to global_exception_refs */ inline void AddToGlobalExceptions(const ExceptionRef& opr_exception) { - auto it = std::find(global_exception_refs_.begin(), - global_exception_refs_.end(), opr_exception); + auto it = + std::find(global_exception_refs_.begin(), global_exception_refs_.end(), opr_exception); if (it == global_exception_refs_.end()) { global_exception_refs_.push_back(opr_exception); } return; } /*! \brief append an operator to bulk */ - inline void BulkAppend(SyncFn exec_fn, Context exec_ctx, + inline void BulkAppend(SyncFn exec_fn, + Context exec_ctx, std::vector const& const_vars, std::vector const& mutable_vars) { BulkStatus& bulk_status = *BulkStatusStore::Get(); @@ -522,28 +525,36 @@ class ThreadedEngine : public Engine { bulk_status.mutable_vars.insert( bulk_status.mutable_vars.end(), mutable_vars.begin(), mutable_vars.end()); - if (bulk_status.count >= bulk_status.bulk_size) BulkFlush(); + if (bulk_status.count >= bulk_status.bulk_size) + BulkFlush(); } /*! \brief flush current bulk to execution */ inline void BulkFlush() { BulkStatus& bulk_status = *BulkStatusStore::Get(); - if (!bulk_status.count) return; + if (!bulk_status.count) + return; bulk_status.count = 0; DeduplicateVarHandle(&bulk_status.const_vars, &bulk_status.mutable_vars); auto functions = bulk_status.functions; - this->PushAsync([functions](RunContext ctx, CallbackOnComplete on_complete) { - ctx.is_bulk = true; - for (auto& fn : *functions) { - fn(ctx); - } - ctx.is_bulk = false; - bool is_gpu = ctx.ctx.dev_mask() == gpu::kDevMask; - if (is_gpu) { - ctx.get_stream()->Wait(); - } - on_complete(); - }, bulk_status.ctx, bulk_status.const_vars, bulk_status.mutable_vars, - FnProperty::kNormal, 0, "ImperativeBulk"); + this->PushAsync( + [functions](RunContext ctx, CallbackOnComplete on_complete) { + ctx.is_bulk = true; + for (auto& fn : *functions) { + fn(ctx); + } + ctx.is_bulk = false; + bool is_gpu = ctx.ctx.dev_mask() == gpu::kDevMask; + if (is_gpu) { + ctx.get_stream()->Wait(); + } + on_complete(); + }, + bulk_status.ctx, + bulk_status.const_vars, + bulk_status.mutable_vars, + FnProperty::kNormal, + 0, + "ImperativeBulk"); bulk_status.functions.reset(new std::vector()); bulk_status.functions->reserve(bulk_status.bulk_size); @@ -577,10 +588,10 @@ class ThreadedEngine : public Engine { * \brief Holding a shared_ptr to the object pool to prevent it from being destructed too early * See also #309 (https://github.com/dmlc/mxnet/issues/309) */ - std::shared_ptr > objpool_opr_ref_; - std::shared_ptr > objpool_blk_ref_; - std::shared_ptr > objpool_varblk_ref_; - std::shared_ptr > objpool_var_ref_; + std::shared_ptr> objpool_opr_ref_; + std::shared_ptr> objpool_blk_ref_; + std::shared_ptr> objpool_varblk_ref_; + std::shared_ptr> objpool_var_ref_; /*! * \brief Async destruction of some objects is relied on storage, diff --git a/src/engine/threaded_engine_perdevice.cc b/src/engine/threaded_engine_perdevice.cc index 4c5d1befb8b3..8faec911d3d8 100644 --- a/src/engine/threaded_engine_perdevice.cc +++ b/src/engine/threaded_engine_perdevice.cc @@ -48,11 +48,11 @@ namespace engine { */ class ThreadedEnginePerDevice : public ThreadedEngine { public: - static auto constexpr kFIFO = dmlc::ConcurrentQueueType::kFIFO; - static auto constexpr kPriority = dmlc::ConcurrentQueueType::kPriority; - static auto constexpr kCopyQueue = kPriority; + static auto constexpr kFIFO = dmlc::ConcurrentQueueType::kFIFO; + static auto constexpr kPriority = dmlc::ConcurrentQueueType::kPriority; + static auto constexpr kCopyQueue = kPriority; static auto constexpr kPriorityQueue = kPriority; - static auto constexpr kWorkerQueue = kFIFO; + static auto constexpr kWorkerQueue = kFIFO; ThreadedEnginePerDevice() noexcept(false) { this->Start(); @@ -71,37 +71,41 @@ class ThreadedEnginePerDevice : public ThreadedEngine { } void Stop() override { - if (is_worker_) return; + if (is_worker_) + return; WaitForAll(); StopNoWait(); } void Start() override { - if (is_worker_) return; + if (is_worker_) + return; gpu_worker_nthreads_ = common::GetNumThreadsPerGPU(); // MXNET_CPU_WORKER_NTHREADS cpu_worker_nthreads_ = LibraryInitializer::Get()->cpu_worker_nthreads_; - gpu_copy_nthreads_ = dmlc::GetEnv("MXNET_GPU_COPY_NTHREADS", 2); + gpu_copy_nthreads_ = dmlc::GetEnv("MXNET_GPU_COPY_NTHREADS", 2); // create CPU task - int cpu_priority_nthreads = dmlc::GetEnv("MXNET_CPU_PRIORITY_NTHREADS", 4); - cpu_priority_worker_ = std::make_unique>(); + int cpu_priority_nthreads = dmlc::GetEnv("MXNET_CPU_PRIORITY_NTHREADS", 4); + cpu_priority_worker_ = std::make_unique>(); cpu_priority_worker_->pool = std::make_unique( cpu_priority_nthreads, [this](std::shared_ptr ready_event) { this->CPUWorker(Context(), cpu_priority_worker_.get(), ready_event); - }, true); + }, + true); // GPU tasks will be created lazily } protected: - void PushToExecute(OprBlock *opr_block, bool pusher_thread) override { + void PushToExecute(OprBlock* opr_block, bool pusher_thread) override { const Context& ctx = opr_block->ctx; if ((opr_block->opr->prop == FnProperty::kAsync || - opr_block->opr->prop == FnProperty::kDeleteVar) && pusher_thread) { + opr_block->opr->prop == FnProperty::kDeleteVar) && + pusher_thread) { if (ctx.dev_mask() == Context::kGPU) { - #if MXNET_USE_CUDA +#if MXNET_USE_CUDA MSHADOW_CATCH_ERROR(mshadow::SetDevice(ctx.dev_id)); - #endif +#endif } this->ExecuteOprBlock(RunContext{ctx, nullptr, nullptr, false}, opr_block); } else { @@ -110,15 +114,16 @@ class ThreadedEnginePerDevice : public ThreadedEngine { if (opr_block->opr->prop == FnProperty::kCPUPrioritized) { cpu_priority_worker_->task_queue.Push(opr_block, opr_block->priority); } else { - int dev_id = ctx.dev_id; + int dev_id = ctx.dev_id; int nthread = cpu_worker_nthreads_; - auto ptr = - cpu_normal_workers_.Get(dev_id, [this, ctx, nthread]() { - auto blk = new ThreadWorkerBlock(); - blk->pool = std::make_unique(nthread, - [this, ctx, blk](std::shared_ptr ready_event) { - this->CPUWorker(ctx, blk, ready_event); - }, true); + auto ptr = cpu_normal_workers_.Get(dev_id, [this, ctx, nthread]() { + auto blk = new ThreadWorkerBlock(); + blk->pool = std::make_unique( + nthread, + [this, ctx, blk](std::shared_ptr ready_event) { + this->CPUWorker(ctx, blk, ready_event); + }, + true); return blk; }); if (ptr) { @@ -133,22 +138,21 @@ class ThreadedEnginePerDevice : public ThreadedEngine { CHECK_EQ(ctx.dev_mask(), Context::kGPU); // GPU execution. const FnProperty prop = opr_block->opr->prop; - const bool is_copy = (prop == FnProperty::kCopyFromGPU || - prop == FnProperty::kCopyToGPU); + const bool is_copy = (prop == FnProperty::kCopyFromGPU || prop == FnProperty::kCopyToGPU); if (is_copy) { const size_t nthread = gpu_copy_nthreads_; - auto ptr = gpu_copy_workers_.Get(ctx.dev_id, [this, ctx, is_copy, nthread]() { + auto ptr = gpu_copy_workers_.Get(ctx.dev_id, [this, ctx, is_copy, nthread]() { // Signify to kernel that GPU is being used, so reserve cores as necessary OpenMP::Get()->set_reserve_cores(GetReserveCoreCount(true)); - auto blk = new ThreadWorkerBlock(); - blk->pool = std::make_unique( + auto blk = new ThreadWorkerBlock(); + blk->pool = std::make_unique( nthread, - [this, ctx, is_copy, blk] - (std::shared_ptr ready_event) { - this->GPUWorker(ctx, is_copy, blk, ready_event); - }, true); - return blk; - }); + [this, ctx, is_copy, blk](std::shared_ptr ready_event) { + this->GPUWorker(ctx, is_copy, blk, ready_event); + }, + true); + return blk; + }); if (ptr) { if (opr_block->opr->prop == FnProperty::kDeleteVar) { ptr->task_queue.PushFront(opr_block, opr_block->priority); @@ -163,14 +167,14 @@ class ThreadedEnginePerDevice : public ThreadedEngine { auto ptr = gpu_priority_workers_.Get(ctx.dev_id, [this, ctx, is_copy, nthread]() { // Signify to kernel that GPU is being used, so reserve cores as necessary OpenMP::Get()->set_reserve_cores(GetReserveCoreCount(true)); - auto blk = new ThreadWorkerBlock(); - blk->pool = std::make_unique( + auto blk = new ThreadWorkerBlock(); + blk->pool = std::make_unique( nthread, - [this, ctx, is_copy, blk] - (std::shared_ptr ready_event) { - this->GPUWorker(ctx, is_copy, blk, ready_event); - }, true); - return blk; + [this, ctx, is_copy, blk](std::shared_ptr ready_event) { + this->GPUWorker(ctx, is_copy, blk, ready_event); + }, + true); + return blk; }); if (ptr) { ptr->task_queue.Push(opr_block, opr_block->priority); @@ -180,14 +184,14 @@ class ThreadedEnginePerDevice : public ThreadedEngine { auto ptr = gpu_normal_workers_.Get(ctx.dev_id, [this, ctx, is_copy, nthread]() { // Signify to kernel that GPU is being used, so reserve cores as necessary OpenMP::Get()->set_reserve_cores(GetReserveCoreCount(true)); - auto blk = new ThreadWorkerBlock(); - blk->pool = std::make_unique( + auto blk = new ThreadWorkerBlock(); + blk->pool = std::make_unique( nthread, - [this, ctx, is_copy, blk] - (std::shared_ptr ready_event) { - this->GPUWorker(ctx, is_copy, blk, ready_event); - }, true); - return blk; + [this, ctx, is_copy, blk](std::shared_ptr ready_event) { + this->GPUWorker(ctx, is_copy, blk, ready_event); + }, + true); + return blk; }); if (ptr) { if (opr_block->opr->prop == FnProperty::kDeleteVar) { @@ -204,10 +208,10 @@ class ThreadedEnginePerDevice : public ThreadedEngine { private: // working unit for each of the task. - template + template struct ThreadWorkerBlock { // task queue on this task - dmlc::ConcurrentBlockingQueue task_queue; + dmlc::ConcurrentBlockingQueue task_queue; // thread pool that works on this task std::unique_ptr pool; // constructor @@ -225,31 +229,31 @@ class ThreadedEnginePerDevice : public ThreadedEngine { /*! \brief number of concurrent thread each gpu copy worker uses */ size_t gpu_copy_nthreads_; // cpu worker - common::LazyAllocArray > cpu_normal_workers_; + common::LazyAllocArray> cpu_normal_workers_; // cpu priority worker - std::unique_ptr > cpu_priority_worker_; + std::unique_ptr> cpu_priority_worker_; // workers doing normal works on GPU - common::LazyAllocArray > gpu_normal_workers_; + common::LazyAllocArray> gpu_normal_workers_; // workers doing copy works from/to GPU - common::LazyAllocArray > gpu_copy_workers_; + common::LazyAllocArray> gpu_copy_workers_; // gpu priority workers - common::LazyAllocArray > gpu_priority_workers_; + common::LazyAllocArray> gpu_priority_workers_; /*! * \brief GPU worker that performs operations on a certain device. * \param dev_id The device id of the worker. * \param is_copy_worker whether the worker only do copy job * \param block The task block of the worker. */ - template + template inline void GPUWorker(Context ctx, bool is_copy_worker, - ThreadWorkerBlock *block, + ThreadWorkerBlock* block, const std::shared_ptr& ready_event) { this->is_worker_ = true; #if MXNET_USE_CUDA CHECK(block != nullptr); - mshadow::Stream *stream = nullptr; - GPUAuxStream *aux_stream = nullptr; + mshadow::Stream* stream = nullptr; + GPUAuxStream* aux_stream = nullptr; do { ThreadPool::SetReadyOnDestroy setReady(ready_event); // allocate stream @@ -257,7 +261,7 @@ class ThreadedEnginePerDevice : public ThreadedEngine { if (is_copy_worker) { stream = mshadow::NewStream(false, false, ctx.dev_id); } else { - stream = mshadow::NewStream(true, MXNET_USE_CUDNN != 0, ctx.dev_id); + stream = mshadow::NewStream(true, MXNET_USE_CUDNN != 0, ctx.dev_id); aux_stream = new GPUAuxStream(stream); } } while (false); @@ -281,9 +285,9 @@ class ThreadedEnginePerDevice : public ThreadedEngine { * \brief CPU worker that performs operations on CPU. * \param block The task block of the worker. */ - template + template inline void CPUWorker(Context ctx, - ThreadWorkerBlock *block, + ThreadWorkerBlock* block, const std::shared_ptr& ready_event) { this->is_worker_ = true; auto* task_queue = &(block->task_queue); @@ -323,11 +327,9 @@ class ThreadedEnginePerDevice : public ThreadedEngine { } /*! \brief Signal a single queue for shutdown */ - template - static inline void SignalQueueForKill(common::LazyAllocArray *array) { - array->ForEach([](size_t i, Object *block) { - block->task_queue.SignalForKill(); - }); + template + static inline void SignalQueueForKill(common::LazyAllocArray* array) { + array->ForEach([](size_t i, Object* block) { block->task_queue.SignalForKill(); }); } /*! Signal all queues for shutdown */ @@ -342,7 +344,7 @@ class ThreadedEnginePerDevice : public ThreadedEngine { } }; -Engine *CreateThreadedEnginePerDevice() { +Engine* CreateThreadedEnginePerDevice() { return new ThreadedEnginePerDevice(); } diff --git a/src/engine/threaded_engine_pooled.cc b/src/engine/threaded_engine_pooled.cc index dde16bc8fe5d..99b7726fc4e4 100644 --- a/src/engine/threaded_engine_pooled.cc +++ b/src/engine/threaded_engine_pooled.cc @@ -59,11 +59,11 @@ class ThreadedEnginePooled : public ThreadedEngine { streams_->Finalize(); task_queue_->SignalForKill(); io_task_queue_->SignalForKill(); - task_queue_ = nullptr; - io_task_queue_ = nullptr; - thread_pool_ = nullptr; + task_queue_ = nullptr; + io_task_queue_ = nullptr; + thread_pool_ = nullptr; io_thread_pool_ = nullptr; - streams_ = nullptr; + streams_ = nullptr; } void Stop() override { @@ -75,18 +75,22 @@ class ThreadedEnginePooled : public ThreadedEngine { streams_ = std::make_unique>(); task_queue_.reset(new dmlc::ConcurrentBlockingQueue()); io_task_queue_.reset(new dmlc::ConcurrentBlockingQueue()); - thread_pool_ = std::make_unique(kNumWorkingThreads, - [this](std::shared_ptr ready_event) { - ThreadWorker(task_queue_, ready_event); }, - true); - io_thread_pool_ = std::make_unique(1, - [this](std::shared_ptr ready_event) { - ThreadWorker(io_task_queue_, ready_event); }, - true); + thread_pool_ = std::make_unique( + kNumWorkingThreads, + [this](std::shared_ptr ready_event) { + ThreadWorker(task_queue_, ready_event); + }, + true); + io_thread_pool_ = std::make_unique( + 1, + [this](std::shared_ptr ready_event) { + ThreadWorker(io_task_queue_, ready_event); + }, + true); } protected: - void PushToExecute(OprBlock *opr_block, bool pusher_thread) override { + void PushToExecute(OprBlock* opr_block, bool pusher_thread) override { if (opr_block->opr->prop == FnProperty::kAsync && pusher_thread) { DoExecute(opr_block); } else { @@ -139,17 +143,16 @@ class ThreadedEnginePooled : public ThreadedEngine { #endif assert(opr_block->wait.load() == 0); if (opr_block->ctx.dev_mask() == gpu::kDevMask) { - #if MXNET_USE_CUDA +#if MXNET_USE_CUDA device_store.SetDevice(opr_block->ctx.dev_id); - #else // MXNET_USE_CUDA +#else // MXNET_USE_CUDA LOG(FATAL) << "Please compile with CUDA enabled"; - #endif // MXNET_USE_CUDA +#endif // MXNET_USE_CUDA } bool is_copy = (opr_block->opr->prop == FnProperty::kCopyFromGPU || opr_block->opr->prop == FnProperty::kCopyToGPU); - auto&& rctx = is_copy - ? streams_->GetIORunContext(opr_block->ctx) - : streams_->GetRunContext(opr_block->ctx); + auto&& rctx = is_copy ? streams_->GetIORunContext(opr_block->ctx) + : streams_->GetRunContext(opr_block->ctx); this->ExecuteOprBlock(rctx, opr_block); } /*! @@ -171,7 +174,7 @@ class ThreadedEnginePooled : public ThreadedEngine { } }; -Engine *CreateThreadedEnginePooled() { +Engine* CreateThreadedEnginePooled() { return new ThreadedEnginePooled(); } } // namespace engine diff --git a/src/imperative/attach_op_execs_pass.cc b/src/imperative/attach_op_execs_pass.cc index 30e67f44a80b..57be3d5a1001 100644 --- a/src/imperative/attach_op_execs_pass.cc +++ b/src/imperative/attach_op_execs_pass.cc @@ -37,13 +37,12 @@ namespace mxnet { namespace exec { #if MXNET_USE_ONEDNN == 1 -#define CREATE_DEFAULT_INPUTS_MKLDNN(in_array, in_array_fallback, attrs) \ - CREATE_DEFAULT_INPUTS(true, attrs, CreateDefaultInputs(in_array, in_array_fallback)) +#define CREATE_DEFAULT_INPUTS_MKLDNN(in_array, in_array_fallback, attrs) \ + CREATE_DEFAULT_INPUTS(true, attrs, CreateDefaultInputs(in_array, in_array_fallback)) #else #define CREATE_DEFAULT_INPUTS_MKLDNN(in_array, in_array_fallback, attrs) // empty macro #endif - // abstract OpExecutor which provides storage fallback procedure on // non-default inputs and outputs // FComputeExecutor and FStatefulComputeExecutor inherit from this class @@ -76,16 +75,27 @@ class StorageFallbackOpExecutor : public OpExecutor { void PreFCompute(bool is_gpu) { using namespace common; InitBlobs(); - in_data_.clear(); out_data_.clear(); - pre_temp_src_.clear(); pre_temp_dst_.clear(); - post_temp_src_.clear(); post_temp_dst_.clear(); + in_data_.clear(); + out_data_.clear(); + pre_temp_src_.clear(); + pre_temp_dst_.clear(); + post_temp_src_.clear(); + post_temp_dst_.clear(); in_temp_idx_map_.clear(); tmp_req = req; - SetupDefaultBlobsInOut(in_array, out_array, &pre_temp_buf_, &post_temp_buf_, &req, - &in_data_, &out_data_, - &pre_temp_src_, &pre_temp_dst_, - &post_temp_src_, &post_temp_dst_, - &in_temp_idx_map_, mutate_idx_); + SetupDefaultBlobsInOut(in_array, + out_array, + &pre_temp_buf_, + &post_temp_buf_, + &req, + &in_data_, + &out_data_, + &pre_temp_src_, + &pre_temp_dst_, + &post_temp_src_, + &post_temp_dst_, + &in_temp_idx_map_, + mutate_idx_); common::CastNonDefaultStorage(pre_temp_src_, pre_temp_dst_, op_ctx, is_gpu); } @@ -137,12 +147,14 @@ class StatefulComputeExecutor : public StorageFallbackOpExecutor { return state_; } - explicit StatefulComputeExecutor(OpStatePtr state, - FStatefulCompute fcompute, + explicit StatefulComputeExecutor(OpStatePtr state, + FStatefulCompute fcompute, ExecType exec_type, - const std::vector &mutate_idx) + const std::vector& mutate_idx) : StorageFallbackOpExecutor(mutate_idx), - state_(std::move(state)), fcompute_(std::move(fcompute)), exec_type_(exec_type) {} + state_(std::move(state)), + fcompute_(std::move(fcompute)), + exec_type_(exec_type) {} private: OpStatePtr state_; @@ -150,14 +162,13 @@ class StatefulComputeExecutor : public StorageFallbackOpExecutor { ExecType exec_type_; }; - // stateful compute_ex executor class StatefulComputeExExecutor : public OpExecutor { public: void Run(RunContext rctx, bool is_gpu) override { op_ctx.run_ctx = rctx; INVALIDATE_OUTPUTS(out_array, req); - std::vector *pInArray = &in_array; + std::vector* pInArray = &in_array; CREATE_DEFAULT_INPUTS_MKLDNN(in_array, pInArray = &in_array_fallback, attrs_); fcompute_(state_, op_ctx, *pInArray, req, out_array); } @@ -176,11 +187,13 @@ class StatefulComputeExExecutor : public OpExecutor { return state_; } - explicit StatefulComputeExExecutor(NodeAttrs attrs, - OpStatePtr state, - FStatefulComputeEx fcompute, + explicit StatefulComputeExExecutor(NodeAttrs attrs, + OpStatePtr state, + FStatefulComputeEx fcompute, ExecType exec_type) - : attrs_(std::move(attrs)), state_(std::move(state)), fcompute_(std::move(fcompute)), + : attrs_(std::move(attrs)), + state_(std::move(state)), + fcompute_(std::move(fcompute)), exec_type_(exec_type) {} private: @@ -190,7 +203,6 @@ class StatefulComputeExExecutor : public OpExecutor { ExecType exec_type_; }; - // fcompute executor class FComputeExecutor : public StorageFallbackOpExecutor { public: @@ -207,11 +219,14 @@ class FComputeExecutor : public StorageFallbackOpExecutor { return exec_type_; } - explicit FComputeExecutor(NodeAttrs attrs, FCompute fcompute, - ExecType exec_type, const std::vector &mutate_idx) + explicit FComputeExecutor(NodeAttrs attrs, + FCompute fcompute, + ExecType exec_type, + const std::vector& mutate_idx) : StorageFallbackOpExecutor(mutate_idx), - attrs_(std::move(attrs)), fcompute_(std::move(fcompute)), exec_type_(exec_type) { - } + attrs_(std::move(attrs)), + fcompute_(std::move(fcompute)), + exec_type_(exec_type) {} private: NodeAttrs attrs_; @@ -225,7 +240,7 @@ class FComputeExExecutor : public OpExecutor { void Run(RunContext rctx, bool is_gpu) override { op_ctx.run_ctx = rctx; INVALIDATE_OUTPUTS(out_array, req); - std::vector *pInArray = &in_array; + std::vector* pInArray = &in_array; CREATE_DEFAULT_INPUTS_MKLDNN(in_array, pInArray = &in_array_fallback, attrs_); fcompute_(attrs_, op_ctx, *pInArray, req, out_array); } @@ -236,10 +251,8 @@ class FComputeExExecutor : public OpExecutor { return exec_type_; } - explicit FComputeExExecutor(NodeAttrs attrs, FComputeEx fcompute, - ExecType exec_type) - : attrs_(std::move(attrs)), fcompute_(std::move(fcompute)), exec_type_(exec_type) { - } + explicit FComputeExExecutor(NodeAttrs attrs, FComputeEx fcompute, ExecType exec_type) + : attrs_(std::move(attrs)), fcompute_(std::move(fcompute)), exec_type_(exec_type) {} private: NodeAttrs attrs_; @@ -248,27 +261,28 @@ class FComputeExExecutor : public OpExecutor { }; void CreateOpExecs(const Graph& g, OpExecVector* p_ret, OpStateVector* p_state, size_t i) { - using nnvm::DTypeVector; using mxnet::ShapeVector; + using nnvm::DTypeVector; using nnvm::FMutateInputs; - static auto& fcreate_op_state = nnvm::Op::GetAttr("FCreateOpState"); - static auto& fmutate_inputs = nnvm::Op::GetAttr("FMutateInputs"); - static auto& fexec_type = nnvm::Op::GetAttr("FExecType"); + static auto& fcreate_op_state = nnvm::Op::GetAttr("FCreateOpState"); + static auto& fmutate_inputs = nnvm::Op::GetAttr("FMutateInputs"); + static auto& fexec_type = nnvm::Op::GetAttr("FExecType"); static auto& is_layer_backward = nnvm::Op::GetAttr("TIsLayerOpBackward"); - const auto& vdtype = g.GetAttr("dtype"); - const auto& vshape = g.GetAttr("shape"); - const auto& vctx = g.GetAttr("context"); + const auto& vdtype = g.GetAttr("dtype"); + const auto& vshape = g.GetAttr("shape"); + const auto& vctx = g.GetAttr("context"); const auto& dispatch_modes = g.GetAttr("dispatch_mode"); // get the graph - const auto& idx = g.indexed_graph(); + const auto& idx = g.indexed_graph(); OpExecVector& ret = *p_ret; // initialize the nodes const auto& inode = idx[i]; - if (inode.source->is_variable()) return; - const nnvm::Op *op = inode.source->op(); + if (inode.source->is_variable()) + return; + const nnvm::Op* op = inode.source->op(); ExecType exec_type = ExecType::kSync; std::vector mutate_index; if (fmutate_inputs.count(op)) { @@ -286,42 +300,39 @@ void CreateOpExecs(const Graph& g, OpExecVector* p_ret, OpStateVector* p_state, itype.emplace_back(vdtype[idx.entry_id(e)]); } - OpStatePtr state = fcreate_op_state[op]( - inode.source->attrs, vctx[i], ishape, itype); + OpStatePtr state = fcreate_op_state[op](inode.source->attrs, vctx[i], ishape, itype); if (p_state) { CHECK_GT(p_state->size(), i); p_state->at(i) = state; } - FStatefulComputeEx fcompute_ex = common::GetFCompute( - op, "FStatefulComputeEx", vctx[i]); + FStatefulComputeEx fcompute_ex = + common::GetFCompute(op, "FStatefulComputeEx", vctx[i]); // FStatefulComputeEx is dispatched only when dispatch_mode is DispatchMode::kFComputeEx if (fcompute_ex != nullptr && dispatch_modes[i] == DispatchMode::kFComputeEx) { - ret[i] = std::make_shared(inode.source->attrs, state, - fcompute_ex, exec_type); + ret[i] = std::make_shared( + inode.source->attrs, state, fcompute_ex, exec_type); } else { - FStatefulCompute fcompute = common::GetFCompute( - op, "FStatefulCompute", vctx[i]); + FStatefulCompute fcompute = + common::GetFCompute(op, "FStatefulCompute", vctx[i]); CHECK(fcompute != nullptr) << "One of FStatefulCompute and FStatefulComputeEx must be registered " << "for stateful operator " << op->name; - ret[i] = std::make_shared(state, fcompute, - exec_type, mutate_index); + ret[i] = std::make_shared(state, fcompute, exec_type, mutate_index); } } else if (is_layer_backward.get(op, false)) { CHECK_GE(inode.control_deps.size(), 1); uint32_t fwd_id = inode.control_deps[0]; CHECK(vctx[fwd_id] == vctx[i]); CHECK(ret[fwd_id] != nullptr); - FStatefulComputeEx fcompute_ex = common::GetFCompute( - op, "FStatefulComputeEx", vctx[i]); + FStatefulComputeEx fcompute_ex = + common::GetFCompute(op, "FStatefulComputeEx", vctx[i]); // FStatefulComputeEx is dispatched only when dispatch_mode is DispatchMode::kFComputeEx if (fcompute_ex != nullptr && dispatch_modes[i] == DispatchMode::kFComputeEx) { ret[i] = std::make_shared( - inode.source->attrs, ret[fwd_id].get()->state(), fcompute_ex, - exec_type); + inode.source->attrs, ret[fwd_id].get()->state(), fcompute_ex, exec_type); } else { - FStatefulCompute fcompute = common::GetFCompute( - op, "FStatefulCompute", vctx[i]); + FStatefulCompute fcompute = + common::GetFCompute(op, "FStatefulCompute", vctx[i]); CHECK(fcompute != nullptr) << "One of FStatefulCompute and FStatefulComputeEx must be registered " << "for stateful operator " << op->name; @@ -329,11 +340,10 @@ void CreateOpExecs(const Graph& g, OpExecVector* p_ret, OpStateVector* p_state, ret[fwd_id].get()->state(), fcompute, exec_type, mutate_index); } } else { - FCompute fcompute = common::GetFCompute(op, "FCompute", vctx[i]); + FCompute fcompute = common::GetFCompute(op, "FCompute", vctx[i]); FComputeEx fcomp_ex = common::GetFCompute(op, "FComputeEx", vctx[i]); if (fcomp_ex != nullptr && dispatch_modes[i] == DispatchMode::kFComputeEx) { - ret[i] = std::make_shared( - inode.source->attrs, fcomp_ex, exec_type); + ret[i] = std::make_shared(inode.source->attrs, fcomp_ex, exec_type); } else if (fcompute != nullptr) { ret[i] = std::make_shared( inode.source->attrs, fcompute, exec_type, mutate_index); @@ -343,7 +353,6 @@ void CreateOpExecs(const Graph& g, OpExecVector* p_ret, OpStateVector* p_state, } } - // pass to attach operator executors Graph AttachOpExecs(Graph g) { const auto& idx = g.indexed_graph(); diff --git a/src/imperative/attach_op_resource_pass.cc b/src/imperative/attach_op_resource_pass.cc index 160ba8fb8d63..7364e674e9b6 100644 --- a/src/imperative/attach_op_resource_pass.cc +++ b/src/imperative/attach_op_resource_pass.cc @@ -17,7 +17,6 @@ * under the License. */ - /*! * Copyright (c) 2016 by Contributors * \file attach_op_resource_pass.cc @@ -30,35 +29,31 @@ namespace mxnet { namespace exec { -void AttachOpResources( - const Graph& g, - const OpExecVector& op_execs, - size_t start_nid, - size_t end_nid) { - static auto& fresource = - nnvm::Op::GetAttr("FResourceRequest"); - static auto& fresource_ex = - nnvm::Op::GetAttr("FResourceRequestEx"); - const auto& vctx = g.GetAttr("context"); - const auto& vdispatch = g.GetAttr("dispatch_mode"); - const auto& dev_masks = g.GetAttr("dev_mask"); - const auto& idx = g.indexed_graph(); +void AttachOpResources(const Graph& g, + const OpExecVector& op_execs, + size_t start_nid, + size_t end_nid) { + static auto& fresource = nnvm::Op::GetAttr("FResourceRequest"); + static auto& fresource_ex = nnvm::Op::GetAttr("FResourceRequestEx"); + const auto& vctx = g.GetAttr("context"); + const auto& vdispatch = g.GetAttr("dispatch_mode"); + const auto& dev_masks = g.GetAttr("dev_mask"); + const auto& idx = g.indexed_graph(); // Use global resource pool for each executor for now. std::map cached_temp; // Resource allocation for (uint32_t nid = start_nid; nid < end_nid; ++nid) { const auto& inode = idx[nid]; - if (inode.source->is_variable()) continue; - const Context &ctx = vctx[nid]; - auto& requested = op_execs[nid]->op_ctx.requested; + if (inode.source->is_variable()) + continue; + const Context& ctx = vctx[nid]; + auto& requested = op_execs[nid]->op_ctx.requested; requested.clear(); - const auto op = inode.source->op(); - const bool rsc_req = (fresource.count(op) != 0); + const auto op = inode.source->op(); + const bool rsc_req = (fresource.count(op) != 0); const bool rsc_ex_req = (fresource_ex.count(op) != 0); if (rsc_req || rsc_ex_req) { - auto reqs = rsc_ex_req ? fresource_ex[op](inode.source->attrs, - dev_masks[nid], - vdispatch[nid]) + auto reqs = rsc_ex_req ? fresource_ex[op](inode.source->attrs, dev_masks[nid], vdispatch[nid]) : fresource[op](inode.source->attrs); // Get the resource of temporal space. for (const ResourceRequest& req : reqs) { diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc index ab640587276c..692f9d6b55ad 100644 --- a/src/imperative/cached_op.cc +++ b/src/imperative/cached_op.cc @@ -26,7 +26,6 @@ #include "../operator/operator_common.h" #include "../operator/subgraph/common.h" - namespace mxnet { DMLC_REGISTER_PARAMETER(CachedOpConfig); @@ -40,9 +39,9 @@ nnvm::Symbol CachedOp::GetOptimizedSymbol() const { return ret.Copy(); } -CachedOp::CachedOp( - const nnvm::Symbol& sym, - const std::vector >& flags) : sym_(sym), flags_(flags) { +CachedOp::CachedOp(const nnvm::Symbol& sym, + const std::vector >& flags) + : sym_(sym), flags_(flags) { config_.Init(flags); this->dynamic_shape_checked_ = false; @@ -52,25 +51,30 @@ CachedOp::CachedOp( auto grad_graph = nnvm::Graph(); std::unordered_map fwd_input_to_grad_output; - CreateFullGraph(sym.Copy(), &fwd_graph_, &grad_graph, &full_graph_, - &ograd_entries_, &fwd_input_to_grad_output); + CreateFullGraph(sym.Copy(), + &fwd_graph_, + &grad_graph, + &full_graph_, + &ograd_entries_, + &fwd_input_to_grad_output); { - const auto& idx = fwd_graph_.indexed_graph(); + const auto& idx = fwd_graph_.indexed_graph(); bwd_output_reqs_ = std::vector(grad_graph.outputs.size(), kWriteTo); - inlining_ = !config_.static_alloc && - (idx.num_nodes() - idx.input_nodes().size()) <= config_.inline_limit; + inlining_ = !config_.static_alloc && + (idx.num_nodes() - idx.input_nodes().size()) <= config_.inline_limit; } SetInputIndices(fwd_graph_, config_.param_indices, &config_.data_indices); // Set the backward dependency vectors { - const auto& idx = full_graph_.indexed_graph(); - size_t num_forward_inputs = num_inputs(); + const auto& idx = full_graph_.indexed_graph(); + size_t num_forward_inputs = num_inputs(); size_t num_forward_outputs = num_outputs(); for (uint32_t i = 0; i < ograd_entries_.size(); ++i) { - if (!idx.exist(ograd_entries_[i].node.get())) continue; + if (!idx.exist(ograd_entries_[i].node.get())) + continue; bwd_ograd_dep_.push_back(i); } save_inputs_.resize(num_forward_inputs, false); @@ -90,16 +94,15 @@ CachedOp::CachedOp( CachedOp::~CachedOp() = default; -std::vector CachedOp::Gradient( - const nnvm::ObjectPtr& node, - const std::vector& ograds) const { +std::vector CachedOp::Gradient(const nnvm::ObjectPtr& node, + const std::vector& ograds) const { using namespace nnvm; static const auto _backward_CachedOp = Op::Get("_backward_CachedOp"); - static const auto _NoGrad = Op::Get("_NoGradient"); + static const auto _NoGrad = Op::Get("_NoGradient"); - auto p = Node::Create(); - p->attrs.op = _backward_CachedOp; - p->attrs.name = node->attrs.name + "_backward"; + auto p = Node::Create(); + p->attrs.op = _backward_CachedOp; + p->attrs.name = node->attrs.name + "_backward"; p->attrs.parsed = node->attrs.parsed; p->control_deps.push_back(node); p->inputs.reserve(bwd_ograd_dep_.size() + bwd_in_dep_.size() + bwd_out_dep_.size()); @@ -113,10 +116,10 @@ std::vector CachedOp::Gradient( ret.reserve(num_inputs()); const auto& auxs = mutable_input_nodes(); if (auxs.size()) { - auto nop = Node::Create(); - nop->attrs.op = _NoGrad; + auto nop = Node::Create(); + nop->attrs.op = _NoGrad; nop->attrs.name = "NoGradient"; - uint32_t k = 0; + uint32_t k = 0; for (const auto& i : fwd_graph_.indexed_graph().input_nodes()) { if (auxs.count(i)) { ret.emplace_back(nop); @@ -126,7 +129,7 @@ std::vector CachedOp::Gradient( } } else { for (uint32_t i = 0; i < num_inputs(); ++i) - ret.emplace_back(p, i, 0); + ret.emplace_back(p, i, 0); } return ret; } @@ -144,7 +147,7 @@ bool CachedOp::CheckDynamicShapeExists(const Context& default_ctx, CHECK_EQ(inputs.size(), num_inputs()); auto state_ptr = GetCachedOpState(default_ctx); - auto& state = state_ptr.get_state(); + auto& state = state_ptr.get_state(); nnvm::Graph& g = state.info.fwd_graph; ShapeVector shape_inputs(inputs.size()); @@ -155,9 +158,7 @@ bool CachedOp::CheckDynamicShapeExists(const Context& default_ctx, // If so, the pass will fail with `contain_dynamic_shape = true`, // This method is only called once, so the overhead is negligible. bool contain_dynamic_shape = false; - CheckAndInferShape(&g, std::move(shape_inputs), true, - {0, 0}, {0, 0}, - &contain_dynamic_shape); + CheckAndInferShape(&g, std::move(shape_inputs), true, {0, 0}, {0, 0}, &contain_dynamic_shape); if (!config_.static_shape && erase_result) { g.attrs.erase("shape"); g.attrs.erase("shape_inputs"); @@ -165,11 +166,10 @@ bool CachedOp::CheckDynamicShapeExists(const Context& default_ctx, return contain_dynamic_shape; } -bool CachedOp::SetForwardGraph( - const Context& default_ctx, - GraphInfo* info, - const bool recording, - const std::vector& inputs) { +bool CachedOp::SetForwardGraph(const Context& default_ctx, + GraphInfo* info, + const bool recording, + const std::vector& inputs) { using namespace nnvm; using namespace imperative; CHECK_EQ(inputs.size(), num_inputs()); @@ -179,19 +179,18 @@ bool CachedOp::SetForwardGraph( DTypeVector dtype_inputs(inputs.size()); StorageTypeVector storage_type_inputs(inputs.size()); for (size_t i = 0; i < inputs.size(); ++i) { - shape_inputs[i] = inputs[info->input_map[i]]->shape(); - dtype_inputs[i] = inputs[info->input_map[i]]->dtype(); + shape_inputs[i] = inputs[info->input_map[i]]->shape(); + dtype_inputs[i] = inputs[info->input_map[i]]->dtype(); storage_type_inputs[i] = inputs[info->input_map[i]]->storage_type(); } - bool match = true; + bool match = true; bool contain_dynamic_shape = false; - match &= CheckAndInferShape(&g, std::move(shape_inputs), true, - {0, 0}, {0, 0}, &contain_dynamic_shape); + match &= + CheckAndInferShape(&g, std::move(shape_inputs), true, {0, 0}, {0, 0}, &contain_dynamic_shape); match &= CheckAndInferType(&g, std::move(dtype_inputs), true); exec::DevMaskVector dev_mask(g.indexed_graph().num_nodes(), default_ctx.dev_mask()); - match &= CheckAndInferStorageType(&g, std::move(dev_mask), - std::move(storage_type_inputs), true); + match &= CheckAndInferStorageType(&g, std::move(dev_mask), std::move(storage_type_inputs), true); // When dynmaic shape exists, it is not feasible to plan memory ahead of time if (contain_dynamic_shape) { @@ -213,7 +212,8 @@ bool CachedOp::SetForwardGraph( const auto& stypes = g.GetAttr("storage_type"); CHECK_EQ(stypes.size(), storage.size()); for (size_t i = 0; i < stypes.size(); i++) { - if (stypes[i] != kDefaultStorage) storage[i] = exec::kDynamicStorageID; + if (stypes[i] != kDefaultStorage) + storage[i] = exec::kDynamicStorageID; } for (const auto i : idx.input_nodes()) { storage[idx.entry_id(i, 0)] = exec::kExternalStorageID; @@ -222,11 +222,11 @@ bool CachedOp::SetForwardGraph( storage[idx.entry_id(idx.outputs()[i])] = exec::kExternalStorageID; } - auto mem_plan = MXPlanMemory( - &g, std::move(storage), g.GetAttr >(AddPrefix(prefix, REF_COUNT)), - AddPrefix(prefix, STORAGE_PLAN)); - g.attrs[AddPrefix(prefix, MEM_PLAN)] = - std::make_shared(std::move(mem_plan)); + auto mem_plan = MXPlanMemory(&g, + std::move(storage), + g.GetAttr >(AddPrefix(prefix, REF_COUNT)), + AddPrefix(prefix, STORAGE_PLAN)); + g.attrs[AddPrefix(prefix, MEM_PLAN)] = std::make_shared(std::move(mem_plan)); return false; } @@ -237,7 +237,7 @@ void SetBackwardInputEid(const std::vector& bwd_in_dep, const std::vector& bwd_ograd_dep, const std::vector& ograd_entries, const nnvm::IndexedGraph& idx, - std::vector *bwd_input_eid) { + std::vector* bwd_input_eid) { for (const auto& i : bwd_ograd_dep) { auto ograd = ograd_entries[i]; if (idx.exist(ograd.node.get())) { @@ -256,24 +256,24 @@ void SetBackwardInputEid(const std::vector& bwd_in_dep, } } -bool CachedOp::SetBackwardGraph( - GraphInfo* info, - const std::vector& reqs, - const std::vector& inputs, - bool detect_inplace_addto) { +bool CachedOp::SetBackwardGraph(GraphInfo* info, + const std::vector& reqs, + const std::vector& inputs, + bool detect_inplace_addto) { using namespace nnvm; using namespace imperative; std::lock_guard lock(mutex_); Context default_ctx = inputs[0]->ctx(); - nnvm::Graph& g = info->full_graph; + nnvm::Graph& g = info->full_graph; if (info->bwd_output_reqs != reqs) { info->bwd_output_reqs = reqs; info->bwd_input_eid.clear(); - g = nnvm::Graph(); + g = nnvm::Graph(); g.outputs = info->fwd_graph.outputs; for (size_t i = 0; i < info->grad_graph.outputs.size(); ++i) { - if (info->bwd_output_reqs[i] == kNullOp) continue; + if (info->bwd_output_reqs[i] == kNullOp) + continue; g.outputs.emplace_back(info->grad_graph.outputs[i]); } g.attrs["context"] = std::make_shared( @@ -284,25 +284,27 @@ bool CachedOp::SetBackwardGraph( if (info->bwd_input_eid.size() != inputs.size()) { info->bwd_input_eid.clear(); - SetBackwardInputEid(bwd_in_dep_, bwd_out_dep_, bwd_ograd_dep_, - info->ograd_entries, idx, &info->bwd_input_eid); + SetBackwardInputEid( + bwd_in_dep_, bwd_out_dep_, bwd_ograd_dep_, info->ograd_entries, idx, &info->bwd_input_eid); CHECK_EQ(inputs.size(), info->bwd_input_eid.size()); } - size_t num_forward_nodes = info->fwd_graph.indexed_graph().num_nodes(); + size_t num_forward_nodes = info->fwd_graph.indexed_graph().num_nodes(); size_t num_forward_entries = info->fwd_graph.indexed_graph().num_node_entries(); if (!g.attrs.count(AddPrefix(BACKWARD, REF_COUNT))) { std::vector ref_count(idx.num_node_entries(), 0); for (size_t i = num_forward_nodes; i < idx.num_nodes(); ++i) { - for (const auto& j : idx[i].inputs) ++ref_count[idx.entry_id(j)]; + for (const auto& j : idx[i].inputs) + ++ref_count[idx.entry_id(j)]; } for (size_t i = 0; i < inputs.size(); ++i) { if (info->bwd_input_eid[i] != kEidNotExist) { ++ref_count[info->bwd_input_eid[i]]; } } - for (const auto& i : idx.outputs()) ++ref_count[idx.entry_id(i)]; + for (const auto& i : idx.outputs()) + ++ref_count[idx.entry_id(i)]; g.attrs[AddPrefix(BACKWARD, REF_COUNT)] = std::make_shared(std::move(ref_count)); } @@ -317,24 +319,22 @@ bool CachedOp::SetBackwardGraph( if (info->bwd_input_eid[i] == kEidNotExist) { continue; } - size_t oi = BwdOriginalInput(info->input_map, i); + size_t oi = BwdOriginalInput(info->input_map, i); shapes[info->bwd_input_eid[i]] = inputs[oi]->shape(); dtypes[info->bwd_input_eid[i]] = inputs[oi]->dtype(); stypes[info->bwd_input_eid[i]] = inputs[oi]->storage_type(); } std::pair node_range, entry_range; - node_range = {num_forward_nodes, idx.num_nodes()}; + node_range = {num_forward_nodes, idx.num_nodes()}; entry_range = {num_forward_entries, idx.num_node_entries()}; bool match = true; - match &= CheckAndInferShape(&g, std::move(shapes), false, - node_range, entry_range); - match &= CheckAndInferType(&g, std::move(dtypes), false, - node_range, entry_range); + match &= CheckAndInferShape(&g, std::move(shapes), false, node_range, entry_range); + match &= CheckAndInferType(&g, std::move(dtypes), false, node_range, entry_range); exec::DevMaskVector dev_mask(idx.num_nodes(), default_ctx.dev_mask()); - match &= CheckAndInferStorageType(&g, std::move(dev_mask), std::move(stypes), - false, node_range, entry_range); + match &= CheckAndInferStorageType( + &g, std::move(dev_mask), std::move(stypes), false, node_range, entry_range); if (!match) { g.attrs.erase(AddPrefix(BACKWARD, MEM_PLAN)); @@ -345,26 +345,29 @@ bool CachedOp::SetBackwardGraph( StorageVector storage(idx.num_node_entries(), exec::kBadStorageID); const auto& bwd_stypes = g.GetAttr("storage_type"); for (size_t i = 0; i < bwd_stypes.size(); i++) { - if (bwd_stypes[i] != kDefaultStorage) storage[i] = exec::kDynamicStorageID; - } - for (size_t i = 0; i < num_forward_entries; ++i) storage[i] = exec::kExternalStorageID; - for (const auto i : idx.input_nodes()) storage[idx.entry_id(i, 0)] = exec::kExternalStorageID; - for (const auto i : idx.outputs()) storage[idx.entry_id(i)] = exec::kExternalStorageID; - - auto mem_plan = MXPlanMemory( - &g, std::move(storage), - g.GetAttr >(AddPrefix(BACKWARD, REF_COUNT)), - AddPrefix(BACKWARD, STORAGE_PLAN), - {num_forward_nodes, idx.num_nodes()}, - {num_forward_entries, idx.num_node_entries()}, - detect_inplace_addto); + if (bwd_stypes[i] != kDefaultStorage) + storage[i] = exec::kDynamicStorageID; + } + for (size_t i = 0; i < num_forward_entries; ++i) + storage[i] = exec::kExternalStorageID; + for (const auto i : idx.input_nodes()) + storage[idx.entry_id(i, 0)] = exec::kExternalStorageID; + for (const auto i : idx.outputs()) + storage[idx.entry_id(i)] = exec::kExternalStorageID; + + auto mem_plan = MXPlanMemory(&g, + std::move(storage), + g.GetAttr >(AddPrefix(BACKWARD, REF_COUNT)), + AddPrefix(BACKWARD, STORAGE_PLAN), + {num_forward_nodes, idx.num_nodes()}, + {num_forward_entries, idx.num_node_entries()}, + detect_inplace_addto); g.attrs[AddPrefix(BACKWARD, MEM_PLAN)] = std::make_shared(std::move(mem_plan)); return false; } -OpStatePtr CachedOp::GetCachedOpState( - const Context& ctx) { +OpStatePtr CachedOp::GetCachedOpState(const Context& ctx) { std::lock_guard lock(mutex_); for (const auto& i : cached_op_states_[ctx]) { // only create one state per device when not using static memory @@ -372,52 +375,50 @@ OpStatePtr CachedOp::GetCachedOpState( return i; } } - auto state_ptr = OpStatePtr::Create(ctx, fwd_graph_, full_graph_, - inlining_); + auto state_ptr = OpStatePtr::Create(ctx, fwd_graph_, full_graph_, inlining_); cached_op_states_[ctx].push_back(state_ptr); return state_ptr; } -void CachedOp::StaticAllocMemory( - const OpStatePtr& state_ptr, - bool recording, - bool keep_fwd) { +void CachedOp::StaticAllocMemory(const OpStatePtr& state_ptr, bool recording, bool keep_fwd) { using namespace nnvm; using namespace imperative; - auto& state = state_ptr.get_state(); - const auto& default_ctx = state.context; - nnvm::Graph& g = keep_fwd ? state.info.full_graph : state.info.fwd_graph; - const auto& idx = g.indexed_graph(); + auto& state = state_ptr.get_state(); + const auto& default_ctx = state.context; + nnvm::Graph& g = keep_fwd ? state.info.full_graph : state.info.fwd_graph; + const auto& idx = g.indexed_graph(); const std::string& graph_type = keep_fwd ? BACKWARD : (recording ? FULL : FORWARD); const auto& storage_plan_attr = AddPrefix(graph_type, STORAGE_PLAN); - const auto& storage_plan = g.GetAttr >(storage_plan_attr); - const auto& mem_plan = g.GetAttr(AddPrefix(graph_type, MEM_PLAN)); + const auto& storage_plan = g.GetAttr >(storage_plan_attr); + const auto& mem_plan = g.GetAttr(AddPrefix(graph_type, MEM_PLAN)); std::vector addto_entry; if (g.attrs.count("addto_entry")) { addto_entry = g.GetAttr >("addto_entry"); } - size_t start_eid = - keep_fwd ? state.info.fwd_graph.indexed_graph().num_node_entries() : 0; - size_t end_eid = idx.num_node_entries(); + size_t start_eid = keep_fwd ? state.info.fwd_graph.indexed_graph().num_node_entries() : 0; + size_t end_eid = idx.num_node_entries(); - if (!keep_fwd) state.fwd_alloc = false; + if (!keep_fwd) + state.fwd_alloc = false; state.bwd_alloc = false; for (size_t i = start_eid; i < state.buff.size(); ++i) { - state.buff[i] = NDArray(); - state.arrays[i] = &state.buff[i]; - state.array_reqs[i] = kNullOp; + state.buff[i] = NDArray(); + state.arrays[i] = &state.buff[i]; + state.array_reqs[i] = kNullOp; state.dynamic_entries[i] = false; } for (auto i : idx.input_nodes()) { auto eid = idx.entry_id(i, 0); - if (eid >= start_eid) state.dynamic_entries[eid] = true; + if (eid >= start_eid) + state.dynamic_entries[eid] = true; } for (auto i : idx.outputs()) { auto eid = idx.entry_id(i); - if (eid >= start_eid) state.dynamic_entries[eid] = true; + if (eid >= start_eid) + state.dynamic_entries[eid] = true; } for (size_t i = start_eid; i < end_eid; ++i) { @@ -434,9 +435,15 @@ void CachedOp::StaticAllocMemory( } auto& reuse_pool = keep_fwd ? state.bwd_reuse_pool : state.fwd_reuse_pool; - reuse_pool = imperative::AllocateMemory( - g, idx, default_ctx, start_eid, end_eid, mem_plan, - state.arrays, &state.array_reqs, std::move(reuse_pool)); + reuse_pool = imperative::AllocateMemory(g, + idx, + default_ctx, + start_eid, + end_eid, + mem_plan, + state.arrays, + &state.array_reqs, + std::move(reuse_pool)); state.recording = recording; if (keep_fwd) { @@ -446,26 +453,23 @@ void CachedOp::StaticAllocMemory( } } -void CachedOp::StaticInitExec( - const OpStatePtr& state_ptr, - bool recording, - bool keep_fwd) { +void CachedOp::StaticInitExec(const OpStatePtr& state_ptr, bool recording, bool keep_fwd) { using namespace nnvm; using namespace imperative; - auto& state = state_ptr.get_state(); + auto& state = state_ptr.get_state(); const auto& default_ctx = state.context; - nnvm::Graph& g = keep_fwd ? state.info.full_graph : state.info.fwd_graph; - const auto& idx = g.indexed_graph(); + nnvm::Graph& g = keep_fwd ? state.info.full_graph : state.info.fwd_graph; + const auto& idx = g.indexed_graph(); std::vector skip_plus_node; if (g.attrs.count("skip_plus_node")) { skip_plus_node = g.GetAttr >("skip_plus_node"); } - size_t start_nid = - keep_fwd ? state.info.fwd_graph.indexed_graph().num_nodes() : 0; - size_t end_nid = idx.num_nodes(); + size_t start_nid = keep_fwd ? state.info.fwd_graph.indexed_graph().num_nodes() : 0; + size_t end_nid = idx.num_nodes(); - if (!keep_fwd) state.fwd_exec_init = false; + if (!keep_fwd) + state.fwd_exec_init = false; state.bwd_exec_init = false; for (size_t i = start_nid; i < state.execs.size(); ++i) { @@ -476,7 +480,7 @@ void CachedOp::StaticInitExec( if (!config_.static_shape) { for (size_t i = start_nid; i < end_nid; ++i) { state.opr_segs[i].next_nid = i + 1; - state.opr_segs[i].skip = skip_plus_node.size() && skip_plus_node[i]; + state.opr_segs[i].skip = skip_plus_node.size() && skip_plus_node[i]; } } else { for (size_t i = start_nid; i < end_nid; ++i) { @@ -492,7 +496,8 @@ void CachedOp::StaticInitExec( for (size_t j = 0; !skip && j < idx[i].source->num_outputs(); ++j) { skip = state.dynamic_entries[idx.entry_id(i, j)]; } - if (skip) continue; + if (skip) + continue; SetupOpExec(g, i, state.execs[i], state.arrays, state.array_reqs); } @@ -510,8 +515,14 @@ void CachedOp::StaticInitExec( bulk_size = 0; } - CreateEngineOpSeg(idx, default_ctx, start_nid, end_nid, bulk_size, - state.execs, skip_plus_node, &state.opr_segs); + CreateEngineOpSeg(idx, + default_ctx, + start_nid, + end_nid, + bulk_size, + state.execs, + skip_plus_node, + &state.opr_segs); } if (keep_fwd) { @@ -521,22 +532,21 @@ void CachedOp::StaticInitExec( } } -void CachedOp::StaticRunOps( - const Context& default_ctx, - const nnvm::Graph& g, - const OpStatePtr& state_ptr, - const std::vector &state_arrays, - size_t start_nid, - size_t end_nid) { - static auto& createop = nnvm::Op::GetAttr("FCreateOpState"); +void CachedOp::StaticRunOps(const Context& default_ctx, + const nnvm::Graph& g, + const OpStatePtr& state_ptr, + const std::vector& state_arrays, + size_t start_nid, + size_t end_nid) { + static auto& createop = nnvm::Op::GetAttr("FCreateOpState"); static auto& is_layer_backward = Op::GetAttr("TIsLayerOpBackward"); - bool profiling = profiler::Profiler::Get()->GetState() == profiler::Profiler::kRunning; + bool profiling = profiler::Profiler::Get()->GetState() == profiler::Profiler::kRunning; bool is_training = Imperative::Get()->is_training(); - auto& state = state_ptr.get_state(); - const auto& idx = g.indexed_graph(); + auto& state = state_ptr.get_state(); + const auto& idx = g.indexed_graph(); const auto& dispatch_modes = g.GetAttr("dispatch_mode"); - const auto& op_execs = state.execs; + const auto& op_execs = state.execs; std::vector ndinputs, ndoutputs; mxnet::ShapeVector arg_shapes; @@ -544,17 +554,20 @@ void CachedOp::StaticRunOps( std::vector req; for (size_t i = start_nid; config_.static_shape && i < end_nid; ++i) { - if (op_execs[i]) op_execs[i]->op_ctx.is_train = is_training; + if (op_execs[i]) + op_execs[i]->op_ctx.is_train = is_training; } for (size_t i = start_nid; i < end_nid; i = state.opr_segs[i].next_nid) { const auto& opr_seg = state.opr_segs[i]; - if (opr_seg.skip) continue; + if (opr_seg.skip) + continue; if (opr_seg.opr != nullptr) { Engine::Get()->Push(opr_seg.opr.get(), default_ctx, 0, profiling); } else { const nnvm::IndexedGraph::Node& node = idx[i]; - if (node.source->is_variable()) continue; + if (node.source->is_variable()) + continue; auto num_outputs = node.source->num_outputs(); ndinputs.clear(); ndinputs.reserve(node.inputs.size()); @@ -563,7 +576,7 @@ void CachedOp::StaticRunOps( CHECK(!ndinputs.back()->is_none()); } if (monitor_callback_ && monitor_all_) { - mxnet::common::ExecuteMonInputCallback(idx, state_arrays, i, monitor_callback_); + mxnet::common::ExecuteMonInputCallback(idx, state_arrays, i, monitor_callback_); } ndoutputs.clear(); ndoutputs.reserve(num_outputs); @@ -590,39 +603,50 @@ void CachedOp::StaticRunOps( state.op_states[i] = createop[node.source->op()](node.source->attrs, default_ctx, arg_shapes, arg_dtypes); } - Imperative::Get()->InvokeOp( - default_ctx, node.source->attrs, ndinputs, ndoutputs, req, - dispatch_mode, state.op_states[i]); + Imperative::Get()->InvokeOp(default_ctx, + node.source->attrs, + ndinputs, + ndoutputs, + req, + dispatch_mode, + state.op_states[i]); } else if (is_layer_backward.get(node.source->op(), false)) { nnvm::Node* fwd_node = node.source->control_deps[0].get(); - auto fwd_node_id = idx.node_id(fwd_node); - Imperative::Get()->InvokeOp( - default_ctx, node.source->attrs, ndinputs, ndoutputs, - req, dispatch_mode, state.op_states[fwd_node_id]); + auto fwd_node_id = idx.node_id(fwd_node); + Imperative::Get()->InvokeOp(default_ctx, + node.source->attrs, + ndinputs, + ndoutputs, + req, + dispatch_mode, + state.op_states[fwd_node_id]); } else { Imperative::Get()->InvokeOp( - default_ctx, node.source->attrs, ndinputs, ndoutputs, req, - dispatch_mode); + default_ctx, node.source->attrs, ndinputs, ndoutputs, req, dispatch_mode); } if (monitor_callback_) { - mxnet::common::ExecuteMonOutputCallback(idx, state_arrays, i, monitor_callback_); + mxnet::common::ExecuteMonOutputCallback(idx, state_arrays, i, monitor_callback_); } } } } -#define INIT_DETACHED(x, y) if (!y->is_none()) x->InitDetached(y) +#define INIT_DETACHED(x, y) \ + if (!y->is_none()) \ + x->InitDetached(y) -static void PrepareOutputs(const nnvm::Graph& g, const Context& default_ctx, - const std::vector &outputs, - std::vector *pArrays, bool detach) { +static void PrepareOutputs(const nnvm::Graph& g, + const Context& default_ctx, + const std::vector& outputs, + std::vector* pArrays, + bool detach) { using namespace nnvm; const auto& dtypes = g.GetAttr("dtype"); const auto& shapes = g.GetAttr("shape"); const auto& stypes = g.GetAttr("storage_type"); const auto& idx = g.indexed_graph(); - auto &arrays = *pArrays; + auto& arrays = *pArrays; for (size_t i = 0; i < outputs.size(); ++i) { const auto eid = idx.entry_id(idx.outputs()[i]); // An input and an output may share the same array. @@ -631,24 +655,22 @@ static void PrepareOutputs(const nnvm::Graph& g, const Context& default_ctx, arrays[eid] = outputs[i]; if (arrays[eid]->is_none()) - arrays[eid]->ReInit(static_cast(stypes[eid]), - shapes[eid], default_ctx, dtypes[eid]); + arrays[eid]->ReInit( + static_cast(stypes[eid]), shapes[eid], default_ctx, dtypes[eid]); const nnvm::NodeAttrs& attrs = idx[idx.outputs()[i].node_id].source->attrs; - outputs[i]->AssignStorageInfo(common::NodeAttrsGetProfilerScope(attrs), - attrs.name); + outputs[i]->AssignStorageInfo(common::NodeAttrsGetProfilerScope(attrs), attrs.name); } } -OpStatePtr CachedOp::StaticForward( - const Context& default_ctx, - const std::vector& inputs, - const std::vector& outputs) { +OpStatePtr CachedOp::StaticForward(const Context& default_ctx, + const std::vector& inputs, + const std::vector& outputs) { using namespace nnvm; using namespace imperative; bool recording = Imperative::Get()->is_recording(); auto state_ptr = GetCachedOpState(default_ctx); - auto& state = state_ptr.get_state(); + auto& state = state_ptr.get_state(); // Need to lock the mutex on the state, this allows // for multi context push of ops to dependency engine. @@ -658,11 +680,11 @@ OpStatePtr CachedOp::StaticForward( std::lock_guard lock(state.mutex); bool match = SetForwardGraph(default_ctx, &state.info, recording, inputs); - match = match && state.recording == recording; + match = match && state.recording == recording; - nnvm::Graph& g = state.info.fwd_graph; + nnvm::Graph& g = state.info.fwd_graph; const auto& idx = g.indexed_graph(); - if (!state.fwd_alloc || !match) { + if (!state.fwd_alloc || !match) { StaticAllocMemory(state_ptr, recording, false); } @@ -670,25 +692,25 @@ OpStatePtr CachedOp::StaticForward( // The input and output arrays should only be valid for this run, // so we shouldn't modify the state's array list. state.arrays_with_in_out = state.arrays; - auto& arrays = state.arrays_with_in_out; + auto& arrays = state.arrays_with_in_out; if (config_.static_shape) { for (auto i : config_.param_indices) { auto nid = idx.input_nodes()[i]; if (!arrays[idx.entry_id(nid, 0)]->IsSame(*inputs[state.info.input_map[i]])) { - match = false; + match = false; auto ptr = &state.buff[idx.entry_id(nid, 0)]; CHECK_EQ(arrays[idx.entry_id(nid, 0)], ptr); - *arrays[idx.entry_id(nid, 0)] = *inputs[state.info.input_map[i]]; + *arrays[idx.entry_id(nid, 0)] = *inputs[state.info.input_map[i]]; state.dynamic_entries[idx.entry_id(nid, 0)] = false; } } for (auto i : config_.data_indices) { - auto eid = idx.entry_id(idx.input_nodes()[i], 0); + auto eid = idx.entry_id(idx.input_nodes()[i], 0); arrays[eid] = inputs[state.info.input_map[i]]; } } else { for (size_t i = 0; i < num_inputs(); ++i) { - auto nid = idx.input_nodes()[i]; + auto nid = idx.input_nodes()[i]; arrays[idx.entry_id(nid, 0)] = inputs[state.info.input_map[i]]; } } @@ -703,31 +725,29 @@ OpStatePtr CachedOp::StaticForward( return recording ? state_ptr : OpStatePtr(); } - -OpStatePtr CachedOp::DynamicForward( - const Context& default_ctx, - const std::vector& inputs, - const std::vector& outputs, - bool use_naive_run) { +OpStatePtr CachedOp::DynamicForward(const Context& default_ctx, + const std::vector& inputs, + const std::vector& outputs, + bool use_naive_run) { using namespace nnvm; using namespace imperative; // Initialize bool recording = Imperative::Get()->is_recording(); - auto op_state = OpStatePtr::Create(); - auto& runtime = op_state.get_state(); + auto op_state = OpStatePtr::Create(); + auto& runtime = op_state.get_state(); { auto state_ptr = GetCachedOpState(default_ctx); - auto& state = state_ptr.get_state(); + auto& state = state_ptr.get_state(); std::lock_guard lock(state.mutex); SetForwardGraph(default_ctx, &state.info, recording, inputs); runtime.info.fwd_graph = state.info.fwd_graph; runtime.info.input_map = state.info.input_map; } - nnvm::Graph& g = runtime.info.fwd_graph; + nnvm::Graph& g = runtime.info.fwd_graph; const auto& idx = g.indexed_graph(); - auto& buff = runtime.buff; - auto& states = runtime.op_states; + auto& buff = runtime.buff; + auto& states = runtime.op_states; // Allocate entries buff.resize(idx.num_node_entries()); @@ -738,32 +758,54 @@ OpStatePtr CachedOp::DynamicForward( arrays.push_back(&buffered_array); } std::vector array_reqs(arrays.size(), kWriteTo); - const auto& dispatch_modes = g.GetAttr("dispatch_mode"); + const auto& dispatch_modes = g.GetAttr("dispatch_mode"); const std::string& graph_type = recording ? FULL : FORWARD; std::vector ref_count = - g.GetAttr >(AddPrefix(graph_type, REF_COUNT)); + g.GetAttr >(AddPrefix(graph_type, REF_COUNT)); for (size_t i = 0; i < idx.num_node_entries(); ++i) { - if (ref_count[i] == 0) array_reqs[i] = kNullOp; + if (ref_count[i] == 0) + array_reqs[i] = kNullOp; } CollectInputOutputNDRefs(g, inputs, runtime.info.input_map, outputs, &arrays); if (!use_naive_run) { - const auto& mem_plan = g.GetAttr(AddPrefix(graph_type, MEM_PLAN)); + const auto& mem_plan = g.GetAttr(AddPrefix(graph_type, MEM_PLAN)); CreateGraphNDs(g, default_ctx, mem_plan, &array_reqs, &arrays); // If CachedOp is running in the inline mode, it uses RunGraph to record // computation; otherwise, CachedOp records computation itself. // So if it's not the inline mode, we disable recording. - RunGraph(false, idx, arrays, 0, idx.num_nodes(), std::move(array_reqs), - std::move(ref_count), &states, dispatch_modes, - recording && inlining_, nullptr, monitor_callback_, monitor_all_); + RunGraph(false, + idx, + arrays, + 0, + idx.num_nodes(), + std::move(array_reqs), + std::move(ref_count), + &states, + dispatch_modes, + recording && inlining_, + nullptr, + monitor_callback_, + monitor_all_); } else { mxnet::ShapeVector shapes = g.GetAttr("shape"); - NaiveRunGraph(false, default_ctx, idx, arrays, 0, idx.num_nodes(), - std::move(array_reqs), std::move(ref_count), &states, - dispatch_modes, recording && inlining_, &shapes, monitor_callback_, monitor_all_); + NaiveRunGraph(false, + default_ctx, + idx, + arrays, + 0, + idx.num_nodes(), + std::move(array_reqs), + std::move(ref_count), + &states, + dispatch_modes, + recording && inlining_, + &shapes, + monitor_callback_, + monitor_all_); { - auto state_ptr = GetCachedOpState(default_ctx); - auto& state = state_ptr.get_state(); + auto state_ptr = GetCachedOpState(default_ctx); + auto& state = state_ptr.get_state(); auto copied_shape = shapes; std::lock_guard lock(state.mutex); state.info.fwd_graph.attrs["shape"] = std::make_shared(std::move(copied_shape)); @@ -773,23 +815,22 @@ OpStatePtr CachedOp::DynamicForward( return op_state; } -OpStatePtr CachedOp::Forward( - const std::shared_ptr& op_ptr, - const std::vector& inputs, - const std::vector& outputs, - const Context& default_ctx) { +OpStatePtr CachedOp::Forward(const std::shared_ptr& op_ptr, + const std::vector& inputs, + const std::vector& outputs, + const Context& default_ctx) { static const auto cached_op = nnvm::Op::Get("_CachedOp"); CHECK_EQ(inputs.size(), num_inputs()); // Assign the storage information for the input arguments. Similar to the // implementation in `graph_executor.cc`, we use `mutable_input_nodes()` to // distinguish between weight parameters and auxiliary states. - const auto& fwd_idx = fwd_graph_.indexed_graph(); + const auto& fwd_idx = fwd_graph_.indexed_graph(); const auto& mutable_input_nodes = fwd_idx.mutable_input_nodes(); for (size_t i = 0; i < fwd_idx.input_nodes().size(); ++i) { - const uint32_t nid = fwd_idx.input_nodes().at(i); - const nnvm::NodeAttrs& attrs = fwd_idx[nid].source->attrs; - const std::string& arg_name = attrs.name; + const uint32_t nid = fwd_idx.input_nodes().at(i); + const nnvm::NodeAttrs& attrs = fwd_idx[nid].source->attrs; + const std::string& arg_name = attrs.name; const std::string profiler_scope = common::NodeAttrsGetProfilerScope(attrs); if (mutable_input_nodes.count(nid)) { inputs[i]->AssignStorageInfo(profiler_scope + "aux_state:", arg_name); @@ -800,16 +841,14 @@ OpStatePtr CachedOp::Forward( { auto state_ptr = GetCachedOpState(default_ctx); - auto& state = state_ptr.get_state(); + auto& state = state_ptr.get_state(); const auto& idx = state.info.fwd_graph.indexed_graph(); for (size_t i = 0; i < inputs.size(); ++i) { CHECK_EQ(inputs[i]->ctx(), default_ctx) << "CachedOp requires all inputs to live on the same context. But " - << idx[idx.input_nodes()[0]].source->attrs.name - << " is on " << default_ctx << " while " - << idx[idx.input_nodes()[i]].source->attrs.name - << " is on " << inputs[i]->ctx(); + << idx[idx.input_nodes()[0]].source->attrs.name << " is on " << default_ctx << " while " + << idx[idx.input_nodes()[i]].source->attrs.name << " is on " << inputs[i]->ctx(); } } @@ -818,9 +857,9 @@ OpStatePtr CachedOp::Forward( OpStatePtr op_state; try { if (config_.is_dynamic || CheckDynamicShapeExists(default_ctx, inputs, true)) { - config_.is_dynamic = true; + config_.is_dynamic = true; config_.static_alloc = false; - op_state = DynamicForward(default_ctx, inputs, outputs, true); + op_state = DynamicForward(default_ctx, inputs, outputs, true); } else if (config_.static_alloc) { op_state = StaticForward(default_ctx, inputs, outputs); } else { @@ -835,45 +874,43 @@ OpStatePtr CachedOp::Forward( if (Imperative::Get()->is_recording() && !inlining_) { nnvm::NodeAttrs attrs; - attrs.op = cached_op; - attrs.name = "_cachedop"; + attrs.op = cached_op; + attrs.name = "_cachedop"; attrs.parsed = op_ptr; Imperative::Get()->RecordOp( - std::move(attrs), inputs, outputs, op_state, - &save_inputs(), &save_outputs()); + std::move(attrs), inputs, outputs, op_state, &save_inputs(), &save_outputs()); } return op_state; } -void CachedOp::DynamicBackward( - const bool retain_graph, - const OpStatePtr& op_state, - const std::vector& inputs, - const std::vector& reqs, - const std::vector& outputs) { +void CachedOp::DynamicBackward(const bool retain_graph, + const OpStatePtr& op_state, + const std::vector& inputs, + const std::vector& reqs, + const std::vector& outputs) { using namespace nnvm; using namespace imperative; // Initialize Context default_ctx = outputs[0]->ctx(); - auto& runtime = op_state.get_state(); + auto& runtime = op_state.get_state(); { auto state_ptr = GetCachedOpState(default_ctx); - auto& state = state_ptr.get_state(); + auto& state = state_ptr.get_state(); std::lock_guard lock(state.mutex); state.info.fwd_graph = runtime.info.fwd_graph; state.info.input_map = runtime.info.input_map; SetBackwardGraph(&state.info, reqs, inputs); - runtime.info.full_graph = state.info.full_graph; + runtime.info.full_graph = state.info.full_graph; runtime.info.bwd_input_eid = state.info.bwd_input_eid; } - nnvm::Graph& g = runtime.info.full_graph; + nnvm::Graph& g = runtime.info.full_graph; const auto& idx = g.indexed_graph(); - auto& buff = runtime.buff; - auto& states = runtime.op_states; + auto& buff = runtime.buff; + auto& states = runtime.op_states; size_t num_forward_outputs = runtime.info.fwd_graph.outputs.size(); - size_t num_forward_nodes = runtime.info.fwd_graph.indexed_graph().num_nodes(); + size_t num_forward_nodes = runtime.info.fwd_graph.indexed_graph().num_nodes(); size_t num_forward_entries = runtime.info.fwd_graph.indexed_graph().num_node_entries(); buff.resize(idx.num_node_entries()); std::vector arrays; @@ -888,7 +925,8 @@ void CachedOp::DynamicBackward( arrays[runtime.info.bwd_input_eid[i]] = inputs[BwdOriginalInput(runtime.info.input_map, i)]; } for (size_t i = 0, j = num_forward_outputs; i < reqs.size(); ++i) { - if (reqs[i] == kNullOp) continue; + if (reqs[i] == kNullOp) + continue; const auto eid = idx.entry_id(idx.outputs()[j++]); // An input and an output may share the same array. INIT_DETACHED(outputs[i], arrays[eid]); @@ -898,29 +936,47 @@ void CachedOp::DynamicBackward( // Allocate NDArrays auto ref_count = g.GetAttr >(AddPrefix(BACKWARD, REF_COUNT)); if (retain_graph) { - for (size_t i = 0; i < num_forward_entries; ++i) ++ref_count[i]; + for (size_t i = 0; i < num_forward_entries; ++i) + ++ref_count[i]; } std::vector array_reqs(arrays.size(), kWriteTo); // set output reqs for (size_t i = 0, j = num_forward_outputs; i < reqs.size(); ++i) { - if (reqs[i] == kNullOp) continue; + if (reqs[i] == kNullOp) + continue; array_reqs[idx.entry_id(idx.outputs()[j++])] = reqs[i]; } // set null reqs based on ref counts for (size_t i = num_forward_entries; i < idx.num_node_entries(); ++i) { - if (ref_count[i] == 0) array_reqs[i] = kNullOp; + if (ref_count[i] == 0) + array_reqs[i] = kNullOp; } - const auto& mem_plan = g.GetAttr(AddPrefix(BACKWARD, MEM_PLAN)); - AllocateMemory(g, idx, default_ctx, num_forward_entries, idx.num_node_entries(), - mem_plan, arrays, &array_reqs); + const auto& mem_plan = g.GetAttr(AddPrefix(BACKWARD, MEM_PLAN)); + AllocateMemory(g, + idx, + default_ctx, + num_forward_entries, + idx.num_node_entries(), + mem_plan, + arrays, + &array_reqs); const auto& dispatch_modes = g.GetAttr("dispatch_mode"); - RunGraph(retain_graph, idx, arrays, num_forward_nodes, idx.num_nodes(), - std::move(array_reqs), std::move(ref_count), &states, dispatch_modes, - Imperative::Get()->is_recording(), nullptr, monitor_callback_); + RunGraph(retain_graph, + idx, + arrays, + num_forward_nodes, + idx.num_nodes(), + std::move(array_reqs), + std::move(ref_count), + &states, + dispatch_modes, + Imperative::Get()->is_recording(), + nullptr, + monitor_callback_); if (retain_graph) { buff.resize(num_forward_entries); @@ -930,12 +986,11 @@ void CachedOp::DynamicBackward( } } -void CachedOp::StaticBackward( - const bool retain_graph, - const OpStatePtr& state_ptr, - const std::vector& inputs, - const std::vector& reqs, - const std::vector& outputs) { +void CachedOp::StaticBackward(const bool retain_graph, + const OpStatePtr& state_ptr, + const std::vector& inputs, + const std::vector& reqs, + const std::vector& outputs) { using namespace nnvm; using namespace imperative; @@ -946,8 +1001,8 @@ void CachedOp::StaticBackward( bool match = SetBackwardGraph(&state.info, reqs, inputs, true); - nnvm::Graph& g = state.info.full_graph; - const auto& idx = g.indexed_graph(); + nnvm::Graph& g = state.info.full_graph; + const auto& idx = g.indexed_graph(); auto num_forward_nodes = state.info.fwd_graph.indexed_graph().num_nodes(); if (!state.bwd_alloc || !match) { @@ -958,37 +1013,41 @@ void CachedOp::StaticBackward( // The input and output arrays should only be valid for this run, // so we shouldn't modify the state's array list. state.arrays_with_in_out = state.arrays; - auto& arrays = state.arrays_with_in_out; + auto& arrays = state.arrays_with_in_out; for (size_t i = 0; i < state.info.bwd_input_eid.size(); ++i) { auto eid = state.info.bwd_input_eid[i]; - if (eid == kEidNotExist || !state.dynamic_entries[eid]) continue; + if (eid == kEidNotExist || !state.dynamic_entries[eid]) + continue; arrays[eid] = inputs[BwdOriginalInput(state.info.input_map, i)]; } if (config_.static_shape) { for (auto i : config_.param_indices) { const auto iter = state.info.fwd_input_to_grad_output.find(i); - if (iter == state.info.fwd_input_to_grad_output.end()) continue; + if (iter == state.info.fwd_input_to_grad_output.end()) + continue; auto entry = state.info.grad_graph.outputs[iter->second]; - if (!idx.exist(entry.node.get())) continue; + if (!idx.exist(entry.node.get())) + continue; auto eid = idx.entry_id(entry); - if ((!arrays[eid]->IsSame(*outputs[iter->second]) && - state.array_reqs[eid] != kNullOp) || + if ((!arrays[eid]->IsSame(*outputs[iter->second]) && state.array_reqs[eid] != kNullOp) || !(state.array_reqs[eid] == reqs[iter->second])) { - match = false; + match = false; state.array_reqs[eid] = reqs[iter->second]; // An input and an output may share the same array. INIT_DETACHED(outputs[iter->second], arrays[eid]); - *arrays[eid] = *outputs[iter->second]; + *arrays[eid] = *outputs[iter->second]; state.dynamic_entries[eid] = false; } } for (auto i : config_.data_indices) { const auto iter = state.info.fwd_input_to_grad_output.find(i); - if (iter == state.info.fwd_input_to_grad_output.end()) continue; + if (iter == state.info.fwd_input_to_grad_output.end()) + continue; auto entry = state.info.grad_graph.outputs[iter->second]; - if (!idx.exist(entry.node.get())) continue; - auto eid = idx.entry_id(entry); + if (!idx.exist(entry.node.get())) + continue; + auto eid = idx.entry_id(entry); state.array_reqs[eid] = reqs[iter->second]; // An input and an output may share the same array. INIT_DETACHED(outputs[iter->second], arrays[eid]); @@ -997,8 +1056,9 @@ void CachedOp::StaticBackward( } else { for (size_t i = 0; i < state.info.grad_graph.outputs.size(); ++i) { auto entry = state.info.grad_graph.outputs[i]; - if (!idx.exist(entry.node.get())) continue; - auto eid = idx.entry_id(entry); + if (!idx.exist(entry.node.get())) + continue; + auto eid = idx.entry_id(entry); state.array_reqs[eid] = reqs[i]; // An input and an output may share the same array. INIT_DETACHED(outputs[i], arrays[eid]); @@ -1013,17 +1073,16 @@ void CachedOp::StaticBackward( StaticRunOps(default_ctx, g, state_ptr, arrays, num_forward_nodes, idx.num_nodes()); } -void CachedOp::Backward( - const bool retain_graph, - const OpStatePtr& state, - const std::vector& inputs, - const std::vector& reqs, - const std::vector& outputs) { - const auto& fwd_idx = fwd_graph_.indexed_graph(); - const auto& full_idx = full_graph_.indexed_graph(); +void CachedOp::Backward(const bool retain_graph, + const OpStatePtr& state, + const std::vector& inputs, + const std::vector& reqs, + const std::vector& outputs) { + const auto& fwd_idx = fwd_graph_.indexed_graph(); + const auto& full_idx = full_graph_.indexed_graph(); const auto& mutable_input_nodes = fwd_idx.mutable_input_nodes(); for (size_t i = 0, j = 0; i < fwd_idx.input_nodes().size(); ++i) { - const uint32_t nid = fwd_idx.input_nodes().at(i); + const uint32_t nid = fwd_idx.input_nodes().at(i); const std::string& arg_name = fwd_idx[nid].source->attrs.name; const std::string profiler_scope = common::NodeAttrsGetProfilerScope(fwd_idx[nid].source->attrs); @@ -1032,10 +1091,9 @@ void CachedOp::Backward( } outputs[j++]->AssignStorageInfo(profiler_scope + "arg_grad:", arg_name); } - for (size_t i = fwd_idx.input_nodes().size(), j = 0; - i < full_idx.input_nodes().size(); ++i) { - const nnvm::NodeAttrs& attrs = full_idx[full_idx.input_nodes().at(i)].source->attrs; - const std::string& entry_name = attrs.name; + for (size_t i = fwd_idx.input_nodes().size(), j = 0; i < full_idx.input_nodes().size(); ++i) { + const nnvm::NodeAttrs& attrs = full_idx[full_idx.input_nodes().at(i)].source->attrs; + const std::string& entry_name = attrs.name; const std::string profiler_scope = common::NodeAttrsGetProfilerScope(attrs); inputs[j++]->AssignStorageInfo(profiler_scope, entry_name); } @@ -1089,11 +1147,11 @@ void CachedOpForward(const OpStatePtr& state_ptr, const std::vector& inputs, const std::vector& req, const std::vector& outputs) { - CachedOpActualState &s = state_ptr.get_state(); - std::vector in_bufs = inputs; + CachedOpActualState& s = state_ptr.get_state(); + std::vector in_bufs = inputs; std::vector out_bufs = outputs; - std::vector in_ptrs(in_bufs.size()); - std::vector out_ptrs(out_bufs.size()); + std::vector in_ptrs(in_bufs.size()); + std::vector out_ptrs(out_bufs.size()); for (size_t i = 0; i < in_ptrs.size(); i++) in_ptrs[i] = &in_bufs[i]; for (size_t i = 0; i < out_ptrs.size(); i++) @@ -1113,7 +1171,7 @@ void CachedOpForward(const OpStatePtr& state_ptr, orig_is_train = Imperative::Get()->is_training(); CHECK(inputs.size() > 0) << "cached op forward requires at least 1 input"; Context default_ctx = inputs[0].ctx(); - s.forward_state = s.op->Forward(nullptr, in_ptrs, out_ptrs, default_ctx); + s.forward_state = s.op->Forward(nullptr, in_ptrs, out_ptrs, default_ctx); Imperative::Get()->set_is_training(orig_is_train); Imperative::Get()->set_is_recording(orig_is_record); // The arrays in out_ptrs may be changed by CachedOp. @@ -1134,29 +1192,29 @@ void CachedOpBackward(const OpStatePtr& state_ptr, const std::vector& outputs) { using namespace nnvm; using namespace imperative; - CachedOpActualState &s = state_ptr.get_state(); - std::vector in_bufs = inputs; + CachedOpActualState& s = state_ptr.get_state(); + std::vector in_bufs = inputs; std::vector out_bufs = outputs; - std::vector in_ptrs; - std::vector out_ptrs; + std::vector in_ptrs; + std::vector out_ptrs; CHECK_EQ(s.op->num_backward_inputs(), inputs.size()); in_ptrs.reserve(s.op->num_backward_inputs()); out_ptrs.reserve(s.op->num_inputs()); - const std::vector &save_inputs = s.op->save_inputs(); - const std::vector &save_outputs = s.op->save_outputs(); - size_t bwd_in_dep = s.op->num_inputs(); - size_t bwd_out_dep = s.op->num_outputs(); + const std::vector& save_inputs = s.op->save_inputs(); + const std::vector& save_outputs = s.op->save_outputs(); + size_t bwd_in_dep = s.op->num_inputs(); + size_t bwd_out_dep = s.op->num_outputs(); CHECK(s.op->num_backward_inputs() > bwd_in_dep + bwd_out_dep); size_t bwd_ograd_dep = s.op->num_backward_inputs() - bwd_in_dep - bwd_out_dep; // Find inputs, outputs and ograds auto ograds_begin = in_bufs.begin(); - auto ograds_end = in_bufs.begin() + bwd_ograd_dep; - auto in_begin = ograds_end; - auto in_end = in_begin + bwd_in_dep; - auto out_begin = in_end; - auto out_end = in_bufs.end(); + auto ograds_end = in_bufs.begin() + bwd_ograd_dep; + auto in_begin = ograds_end; + auto in_end = in_begin + bwd_in_dep; + auto out_begin = in_end; + auto out_end = in_bufs.end(); for (auto it = ograds_begin; it != ograds_end; it++) in_ptrs.push_back(&(*it)); @@ -1209,11 +1267,10 @@ void CachedOpBackward(const OpStatePtr& state_ptr, /* * Register the callback to be called when the operator is executed */ -void CachedOp::RegisterOpHook(const CachedOp::CachedOpMonCallback& callback, - bool monitor_all) { - CHECK(callback) << "invalid callback"; - monitor_callback_ = callback; - monitor_all_ = monitor_all; +void CachedOp::RegisterOpHook(const CachedOp::CachedOpMonCallback& callback, bool monitor_all) { + CHECK(callback) << "invalid callback"; + monitor_callback_ = callback; + monitor_all_ = monitor_all; } OpStatePtr CreateCachedOpState(const NodeAttrs& attrs, @@ -1227,19 +1284,19 @@ OpStatePtr CreateCachedOpState(const NodeAttrs& attrs, bool CachedOp::BackwardStorageType(const nnvm::NodeAttrs& attrs, const int dev_mask, DispatchMode* dispatch_mode, - std::vector *in_attrs, - std::vector *out_attrs) { + std::vector* in_attrs, + std::vector* out_attrs) { using namespace imperative; nnvm::Graph g(full_graph_); - const auto& idx = g.indexed_graph(); - const auto &outputs = idx.outputs(); + const auto& idx = g.indexed_graph(); + const auto& outputs = idx.outputs(); const size_t num_forward_outputs = fwd_graph_.outputs.size(); CHECK_EQ(outputs.size(), num_forward_outputs + out_attrs->size()); // Construct bwd_input_eid std::vector bwd_input_eid; - SetBackwardInputEid(bwd_in_dep_, bwd_out_dep_, bwd_ograd_dep_, - ograd_entries_, idx, &bwd_input_eid); + SetBackwardInputEid( + bwd_in_dep_, bwd_out_dep_, bwd_ograd_dep_, ograd_entries_, idx, &bwd_input_eid); CHECK_EQ(in_attrs->size(), bwd_input_eid.size()); // Prepare stypes and contexts based on inputs @@ -1305,92 +1362,98 @@ size_t CachedOp::BwdOriginalInput(const std::vector& input_map, size_t n } NNVM_REGISTER_OP(_CachedOp) -.set_num_inputs([](const NodeAttrs& attrs) { - const CachedOpPtr& op = nnvm::get(attrs.parsed); - return op->num_inputs(); - }) -.set_num_outputs([](const NodeAttrs& attrs) { - const CachedOpPtr& op = nnvm::get(attrs.parsed); - return op->num_outputs(); - }) -.set_attr_parser(CachedOpParamParser) -.set_attr("FGradient", - [](const nnvm::ObjectPtr& n, const std::vector& ograds) { - const CachedOpPtr& op = nnvm::get(n->attrs.parsed); - return op->Gradient(n, ograds); - }) -.set_attr("FListInputNames", - [](const nnvm::NodeAttrs& attrs) { - const CachedOpPtr& op = nnvm::get(attrs.parsed); - return op->ListForwardInputNames(); - }) -.set_attr("FListOutputNames", - [](const nnvm::NodeAttrs& attrs) { - const CachedOpPtr& op = nnvm::get(attrs.parsed); - return op->ListForwardOutputNames(); - }) -.set_attr("FCreateOpState", CreateCachedOpState) -.set_attr("FInferShape", - [](const nnvm::NodeAttrs& attrs, - mxnet::ShapeVector *in_shapes, - mxnet::ShapeVector *out_shapes) { - const CachedOpPtr& op = nnvm::get(attrs.parsed); - return op::DefaultSubgraphOpShapeHelper(op->GetForwardSym(), in_shapes, out_shapes); - }) -.set_attr("FInferType", - [](const nnvm::NodeAttrs& attrs, - std::vector *in_types, - std::vector *out_types) { - const CachedOpPtr& op = nnvm::get(attrs.parsed); - return op::DefaultSubgraphOpTypeHelper(op->GetForwardSym(), in_types, out_types); - }) -.set_attr("FInferStorageType", - [](const nnvm::NodeAttrs& attrs, - const int dev_mask, - DispatchMode* dispatch_mode, - std::vector* in_stypes, - std::vector* out_stypes) { - const CachedOpPtr& op = nnvm::get(attrs.parsed); - return op::DefaultSubgraphOpStorageTypeHelper(op->GetForwardSym(), - dev_mask, dispatch_mode, - in_stypes, out_stypes); - }) -.set_attr("FStatefulComputeEx", CachedOpForward) -.set_attr("FStatefulComputeEx", CachedOpForward) -.set_attr("FMutateInputs", - [](const nnvm::NodeAttrs& attrs) { - const CachedOpPtr& op = nnvm::get(attrs.parsed); - return op::DefaultSubgraphOpMutableInputsHelper(op->GetForwardSym()); - }) -.set_attr("FResourceRequest", - [](const nnvm::NodeAttrs& attrs) { - const CachedOpPtr& op = nnvm::get(attrs.parsed); - return op::DefaultSubgraphOpResourceRequestHelper(op->GetForwardSym()); - }) -.set_attr("FExecType", op::DefaultSubgraphOpExecType) -.add_argument("data", "NDArray-or-Symbol[]", "input data list"); + .set_num_inputs([](const NodeAttrs& attrs) { + const CachedOpPtr& op = nnvm::get(attrs.parsed); + return op->num_inputs(); + }) + .set_num_outputs([](const NodeAttrs& attrs) { + const CachedOpPtr& op = nnvm::get(attrs.parsed); + return op->num_outputs(); + }) + .set_attr_parser(CachedOpParamParser) + .set_attr("FGradient", + [](const nnvm::ObjectPtr& n, + const std::vector& ograds) { + const CachedOpPtr& op = nnvm::get(n->attrs.parsed); + return op->Gradient(n, ograds); + }) + .set_attr("FListInputNames", + [](const nnvm::NodeAttrs& attrs) { + const CachedOpPtr& op = nnvm::get(attrs.parsed); + return op->ListForwardInputNames(); + }) + .set_attr("FListOutputNames", + [](const nnvm::NodeAttrs& attrs) { + const CachedOpPtr& op = + nnvm::get(attrs.parsed); + return op->ListForwardOutputNames(); + }) + .set_attr("FCreateOpState", CreateCachedOpState) + .set_attr("FInferShape", + [](const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector* in_shapes, + mxnet::ShapeVector* out_shapes) { + const CachedOpPtr& op = nnvm::get(attrs.parsed); + return op::DefaultSubgraphOpShapeHelper( + op->GetForwardSym(), in_shapes, out_shapes); + }) + .set_attr( + "FInferType", + [](const nnvm::NodeAttrs& attrs, std::vector* in_types, std::vector* out_types) { + const CachedOpPtr& op = nnvm::get(attrs.parsed); + return op::DefaultSubgraphOpTypeHelper(op->GetForwardSym(), in_types, out_types); + }) + .set_attr( + "FInferStorageType", + [](const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector* in_stypes, + std::vector* out_stypes) { + const CachedOpPtr& op = nnvm::get(attrs.parsed); + return op::DefaultSubgraphOpStorageTypeHelper( + op->GetForwardSym(), dev_mask, dispatch_mode, in_stypes, out_stypes); + }) + .set_attr("FStatefulComputeEx", CachedOpForward) + .set_attr("FStatefulComputeEx", CachedOpForward) + .set_attr("FMutateInputs", + [](const nnvm::NodeAttrs& attrs) { + const CachedOpPtr& op = nnvm::get(attrs.parsed); + return op::DefaultSubgraphOpMutableInputsHelper( + op->GetForwardSym()); + }) + .set_attr("FResourceRequest", + [](const nnvm::NodeAttrs& attrs) { + const CachedOpPtr& op = nnvm::get(attrs.parsed); + return op::DefaultSubgraphOpResourceRequestHelper( + op->GetForwardSym()); + }) + .set_attr("FExecType", op::DefaultSubgraphOpExecType) + .add_argument("data", "NDArray-or-Symbol[]", "input data list"); NNVM_REGISTER_OP(_backward_CachedOp) -.set_num_inputs([](const NodeAttrs& attrs){ - const CachedOpPtr& op = nnvm::get(attrs.parsed); - return op->num_backward_inputs(); - }) -.set_num_outputs([](const NodeAttrs& attrs){ - const CachedOpPtr& op = nnvm::get(attrs.parsed); - return op->num_inputs() - op->mutable_input_nodes().size(); - }) -.set_attr("FInferStorageType", [](const nnvm::NodeAttrs& attrs, - const int dev_mask, - DispatchMode* dispatch_mode, - std::vector *in_attrs, - std::vector *out_attrs) { - const CachedOpPtr& op = nnvm::get(attrs.parsed); - return op->BackwardStorageType(attrs, dev_mask, dispatch_mode, in_attrs, out_attrs); - }) -.set_attr("FStatefulComputeEx", CachedOpBackward) -.set_attr("FStatefulComputeEx", CachedOpBackward) -.set_attr("FExecType", op::DefaultSubgraphOpExecType) -.set_attr("TIsLayerOpBackward", true) -.set_attr("TIsBackward", true); + .set_num_inputs([](const NodeAttrs& attrs) { + const CachedOpPtr& op = nnvm::get(attrs.parsed); + return op->num_backward_inputs(); + }) + .set_num_outputs([](const NodeAttrs& attrs) { + const CachedOpPtr& op = nnvm::get(attrs.parsed); + return op->num_inputs() - op->mutable_input_nodes().size(); + }) + .set_attr("FInferStorageType", + [](const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector* in_attrs, + std::vector* out_attrs) { + const CachedOpPtr& op = nnvm::get(attrs.parsed); + return op->BackwardStorageType( + attrs, dev_mask, dispatch_mode, in_attrs, out_attrs); + }) + .set_attr("FStatefulComputeEx", CachedOpBackward) + .set_attr("FStatefulComputeEx", CachedOpBackward) + .set_attr("FExecType", op::DefaultSubgraphOpExecType) + .set_attr("TIsLayerOpBackward", true) + .set_attr("TIsBackward", true); } // namespace mxnet diff --git a/src/imperative/cached_op.h b/src/imperative/cached_op.h index 4ac9a0f52497..97ac23cc3a11 100644 --- a/src/imperative/cached_op.h +++ b/src/imperative/cached_op.h @@ -36,41 +36,41 @@ namespace mxnet { namespace { - static const char FULL[] = "full"; - static const char FORWARD[] = "forward"; - static const char BACKWARD[] = "backward"; - static const char REF_COUNT[] = "ref_count"; - static const char MEM_PLAN[] = "mem_plan"; - static const char STORAGE_PLAN[] = "storage_plan"; - -std::string AddPrefix(const std::string& prefix, - const std::string& s) { +static const char FULL[] = "full"; +static const char FORWARD[] = "forward"; +static const char BACKWARD[] = "backward"; +static const char REF_COUNT[] = "ref_count"; +static const char MEM_PLAN[] = "mem_plan"; +static const char STORAGE_PLAN[] = "storage_plan"; + +std::string AddPrefix(const std::string& prefix, const std::string& s) { return prefix + "_" + s; } nnvm::NodeEntry AggregateGradient(std::vector&& v) { using nnvm::Op; - static size_t inplace_sum_cap = dmlc::GetEnv("MXNET_EXEC_INPLACE_GRAD_SUM_CAP", 8); + static size_t inplace_sum_cap = dmlc::GetEnv("MXNET_EXEC_INPLACE_GRAD_SUM_CAP", 8); static const Op* ewise_plus_op = Op::Get("_grad_add"); - static const Op* ewise_sum_op = Op::Get("ElementWiseSum"); - static const Op* identity_op = Op::Get("identity"); - static const Op* zeros_op = Op::Get("_zeros"); + static const Op* ewise_sum_op = Op::Get("ElementWiseSum"); + static const Op* identity_op = Op::Get("identity"); + static const Op* zeros_op = Op::Get("_zeros"); static const Op* zeros_like_op = Op::Get("zeros_like"); if (v.empty()) { nnvm::ObjectPtr ng = nnvm::Node::Create(); - ng->attrs.op = Op::Get("_zeros_without_dtype"); - ng->attrs.name = "zeros_without_dtype"; + ng->attrs.op = Op::Get("_zeros_without_dtype"); + ng->attrs.name = "zeros_without_dtype"; ng->attrs.op->attr_parser(&(ng->attrs)); return nnvm::NodeEntry(std::move(ng), 0, 0); } // remove zero in the sum. at least keep 1. auto begin = std::remove_if(v.begin(), v.end(), [](const nnvm::NodeEntry& nodeEntry) { - CHECK(nodeEntry.node); - return nodeEntry.node->op() == zeros_op || nodeEntry.node->op() == zeros_like_op; + CHECK(nodeEntry.node); + return nodeEntry.node->op() == zeros_op || nodeEntry.node->op() == zeros_like_op; }); - if (begin == v.begin()) ++begin; + if (begin == v.begin()) + ++begin; v.erase(begin, v.end()); CHECK(!v.empty()); @@ -78,9 +78,9 @@ nnvm::NodeEntry AggregateGradient(std::vector&& v) { return std::move(v[0]); } else { if (v.size() < inplace_sum_cap) { - nnvm::ObjectPtr sum_node = nnvm::Node::Create(); - sum_node->attrs.op = ewise_sum_op; - sum_node->attrs.name = "sum_grad"; + nnvm::ObjectPtr sum_node = nnvm::Node::Create(); + sum_node->attrs.op = ewise_sum_op; + sum_node->attrs.name = "sum_grad"; sum_node->attrs.dict["num_args"] = std::to_string(v.size()); sum_node->attrs.op->attr_parser(&(sum_node->attrs)); sum_node->inputs = std::move(v); @@ -104,30 +104,29 @@ nnvm::NodeEntry AggregateGradient(std::vector&& v) { // the node entries v passed in here are of the same node of // op _identity_with_attr_like_rhs. We should skip adding a node // to its own control_deps. - if (v[i-1].node != v[i].node) { + if (v[i - 1].node != v[i].node) { v[i].node->control_deps.push_back(ret.node); } std::ostringstream os; os << "sum_grad_" << i; nnvm::ObjectPtr x = nnvm::Node::Create(); - x->attrs.op = ewise_plus_op; - x->attrs.name = os.str(); - x->inputs = {ret, v[i]}; - ret = nnvm::NodeEntry(std::move(x), 0, 0); + x->attrs.op = ewise_plus_op; + x->attrs.name = os.str(); + x->inputs = {ret, v[i]}; + ret = nnvm::NodeEntry(std::move(x), 0, 0); } // identity node is used to avoid exposure of dummy plus node // when its output get assigned to another space. nnvm::ObjectPtr id_node = nnvm::Node::Create(); - id_node->attrs.op = identity_op; - id_node->attrs.name = "sum_grad_final"; - id_node->inputs = {ret}; + id_node->attrs.op = identity_op; + id_node->attrs.name = "sum_grad_final"; + id_node->inputs = {ret}; return nnvm::NodeEntry{id_node, 0, 0}; } } } - /* \brief collect pointers to input and output ndarrays * into a single data structure, this data structure can * be used for Memory allocation pass*/ @@ -142,7 +141,7 @@ void CollectInputOutputNDRefs(const nnvm::Graph& g, const std::vector& input_map, const std::vector& outputs, std::vector* arrays) { - const auto& idx = g.indexed_graph(); + const auto& idx = g.indexed_graph(); size_t num_inputs = idx.input_nodes().size(); for (size_t i = 0; i < num_inputs; ++i) { (*arrays)[idx.entry_id(idx.input_nodes()[i], 0)] = inputs[input_map[i]]; @@ -168,27 +167,24 @@ void CreateGraphNDs(const nnvm::Graph& g, std::vector* array_reqs, std::vector* arrays) { const auto& idx = g.indexed_graph(); - mxnet::imperative::AllocateMemory(g, idx, default_ctx, 0, - idx.num_node_entries(), mem_plan, *arrays, - array_reqs); - const auto &dtypes = g.GetAttr("dtype"); - const auto &shapes = g.GetAttr("shape"); - const auto &stypes = g.GetAttr("storage_type"); + mxnet::imperative::AllocateMemory( + g, idx, default_ctx, 0, idx.num_node_entries(), mem_plan, *arrays, array_reqs); + const auto& dtypes = g.GetAttr("dtype"); + const auto& shapes = g.GetAttr("shape"); + const auto& stypes = g.GetAttr("storage_type"); for (size_t i = 0; i < idx.outputs().size(); ++i) { auto eid = idx.entry_id(idx.outputs()[i]); if (!(*arrays)[eid]->is_none()) continue; - *((*arrays)[eid]) = NDArray(static_cast(stypes[eid]), - shapes[eid], default_ctx, true, dtypes[eid]); + *((*arrays)[eid]) = NDArray( + static_cast(stypes[eid]), shapes[eid], default_ctx, true, dtypes[eid]); const nnvm::NodeAttrs& attrs = idx[idx.outputs()[i].node_id].source->attrs; - (*arrays)[eid]->AssignStorageInfo( - common::NodeAttrsGetProfilerScope(attrs), - attrs.name); + (*arrays)[eid]->AssignStorageInfo(common::NodeAttrsGetProfilerScope(attrs), attrs.name); } } /* \brief create a forward graph from they Symbol */ -void CreateForwardGraph(const nnvm::Symbol &sym, nnvm::Graph *fwd_graph) { +void CreateForwardGraph(const nnvm::Symbol& sym, nnvm::Graph* fwd_graph) { using namespace nnvm; static const auto _copy_op = Op::Get("_copy"); NodeEntryMap dedup_out; @@ -196,12 +192,12 @@ void CreateForwardGraph(const nnvm::Symbol &sym, nnvm::Graph *fwd_graph) { // to graph outputs. Since node entry stores information about the node // as well as the input node of the graph, a graph can be recreated from a // symbol by just copying the outputs - for (const NodeEntry &nodeEntry : sym.outputs) { + for (const NodeEntry& nodeEntry : sym.outputs) { if (dedup_out.find(nodeEntry) != dedup_out.end()) { ObjectPtr copy_node = Node::Create(); copy_node->attrs.op = _copy_op; - copy_node->attrs.name = nodeEntry.node->attrs.name + "_copy" + - std::to_string(dedup_out[nodeEntry]++); + copy_node->attrs.name = + nodeEntry.node->attrs.name + "_copy" + std::to_string(dedup_out[nodeEntry]++); copy_node->inputs.emplace_back(nodeEntry); if (_copy_op->attr_parser != nullptr) { _copy_op->attr_parser(&(copy_node->attrs)); @@ -223,15 +219,15 @@ void CreateBackwardGraph(nnvm::Graph* fwd_graph, static const std::vector zero_ops{Op::Get("zeros_like"), Op::Get("_zeros")}; ograd_entries->reserve(fwd_graph->outputs.size()); for (size_t i = 0; i < fwd_graph->outputs.size(); ++i) { - nnvm::ObjectPtr np = Node::Create(); - const nnvm::NodeAttrs& attrs = fwd_graph->outputs[i].node->attrs; - np->attrs.name = attrs.name + "_head_grad"; + nnvm::ObjectPtr np = Node::Create(); + const nnvm::NodeAttrs& attrs = fwd_graph->outputs[i].node->attrs; + np->attrs.name = attrs.name + "_head_grad"; np->attrs.dict["__profiler_scope__"] = common::NodeAttrsGetProfilerScope(attrs); ograd_entries->emplace_back(np); } std::vector xs; - const IndexedGraph &indexed_graph = fwd_graph->indexed_graph(); + const IndexedGraph& indexed_graph = fwd_graph->indexed_graph(); // Create vector of inputs to be passed to the gradient pass for (size_t i = 0; i < indexed_graph.input_nodes().size(); ++i) { const uint32_t node_id = indexed_graph.input_nodes()[i]; @@ -249,11 +245,15 @@ void CreateBackwardGraph(nnvm::Graph* fwd_graph, // There are inputs in computation graph that require gradients if (!xs.empty()) { try { - *grad_graph = pass::MXGradient( - *fwd_graph, fwd_graph->outputs, xs, *ograd_entries, - mxnet::AggregateGradient, nullptr, - zero_ops, "_copy"); - } catch (const nnvm::pass::InvalidGraphError &e) { + *grad_graph = pass::MXGradient(*fwd_graph, + fwd_graph->outputs, + xs, + *ograd_entries, + mxnet::AggregateGradient, + nullptr, + zero_ops, + "_copy"); + } catch (const nnvm::pass::InvalidGraphError& e) { *grad_graph = nnvm::Graph(); } } else { @@ -276,25 +276,27 @@ void CreateFullGraph(const nnvm::Symbol& sym, *fwd_graph = exec::EliminateCommonExpr(std::move(*fwd_graph)); // construct backward graph - CreateBackwardGraph(fwd_graph, grad_graph, ograd_entries, - fwd_input_to_grad_output); + CreateBackwardGraph(fwd_graph, grad_graph, ograd_entries, fwd_input_to_grad_output); full_graph->outputs = fwd_graph->outputs; // add backward graph outputs to full graph - for (const auto &i : grad_graph->outputs) { + for (const auto& i : grad_graph->outputs) { full_graph->outputs.emplace_back(i); } } /* \brief Set Ref counts for node entries for forward graph */ -void SetForwardRefCounts(nnvm::Graph *fwd_graph) { +void SetForwardRefCounts(nnvm::Graph* fwd_graph) { const auto& idx = fwd_graph->indexed_graph(); std::vector ref_count(idx.num_node_entries(), 0); - for (const auto& i : idx.input_nodes()) ++ref_count[idx.entry_id(i, 0)]; - for (const auto& i : idx.outputs()) ++ref_count[idx.entry_id(i)]; + for (const auto& i : idx.input_nodes()) + ++ref_count[idx.entry_id(i, 0)]; + for (const auto& i : idx.outputs()) + ++ref_count[idx.entry_id(i)]; for (size_t i = 0; i < idx.num_nodes(); ++i) { - for (const auto& j : idx[i].inputs) ++ref_count[idx.entry_id(j)]; + for (const auto& j : idx[i].inputs) + ++ref_count[idx.entry_id(j)]; } fwd_graph->attrs[AddPrefix(FORWARD, REF_COUNT)] = @@ -306,7 +308,7 @@ void SetRefCounts(nnvm::Graph* fwd_graph, const nnvm::Graph& full_graph) { const auto& idx = fwd_graph->indexed_graph(); SetForwardRefCounts(fwd_graph); - size_t num_forward_nodes = idx.num_nodes(); + size_t num_forward_nodes = idx.num_nodes(); size_t num_forward_entries = idx.num_node_entries(); const auto& full_idx = full_graph.indexed_graph(); @@ -314,38 +316,39 @@ void SetRefCounts(nnvm::Graph* fwd_graph, const nnvm::Graph& full_graph) { std::vector temp_ref_count(full_idx.num_node_entries(), 0); for (size_t i = num_forward_nodes; i < full_idx.num_nodes(); ++i) { for (const auto& j : full_idx[i].inputs) { - ++temp_ref_count[full_idx.entry_id(j)]; + ++temp_ref_count[full_idx.entry_id(j)]; } } - auto full_ref_count = fwd_graph->GetAttr >(AddPrefix(FORWARD, - REF_COUNT)); - for (size_t i = 0; i < num_forward_entries; ++i) full_ref_count.at(i) += temp_ref_count[i]; + auto full_ref_count = fwd_graph->GetAttr>(AddPrefix(FORWARD, REF_COUNT)); + for (size_t i = 0; i < num_forward_entries; ++i) + full_ref_count.at(i) += temp_ref_count[i]; fwd_graph->attrs[AddPrefix(FULL, REF_COUNT)] = std::make_shared(std::move(full_ref_count)); } -void OptimizeGraph(nnvm::Graph* full_graph, nnvm::Graph* fwd_graph, nnvm::Graph* grad_graph, - std::vector* input_map, const Context& context, - size_t num_forward_outputs, const bool inlining) { +void OptimizeGraph(nnvm::Graph* full_graph, + nnvm::Graph* fwd_graph, + nnvm::Graph* grad_graph, + std::vector* input_map, + const Context& context, + size_t num_forward_outputs, + const bool inlining) { input_map->resize(full_graph->indexed_graph().input_nodes().size()); std::iota(input_map->begin(), input_map->end(), 0); #if MXNET_USE_CUDA && !defined(_WIN32) - if (context.dev_mask() == kGPU && - !inlining && - dmlc::GetEnv("MXNET_USE_FUSION", true)) { + if (context.dev_mask() == kGPU && !inlining && dmlc::GetEnv("MXNET_USE_FUSION", true)) { nnvm::Graph unoptimized_graph; common::CopyGraph(&unoptimized_graph, *full_graph, false); if (common::CheckForInputNameDuplicates(unoptimized_graph.indexed_graph())) { *full_graph = exec::FusePointwise(*full_graph, num_forward_outputs); // Fill in input_map - mapping from the new to the original input indices. - const auto &original_inputs = unoptimized_graph.indexed_graph().input_nodes(); - const auto &new_inputs = full_graph->indexed_graph().input_nodes(); + const auto& original_inputs = unoptimized_graph.indexed_graph().input_nodes(); + const auto& new_inputs = full_graph->indexed_graph().input_nodes(); if (original_inputs.size() != new_inputs.size()) { - LOG(WARNING) - << "Number of inputs after fusion does not match original number of inputs. " - << "This is most probably a bug. Disabling fusion for this run."; + LOG(WARNING) << "Number of inputs after fusion does not match original number of inputs. " + << "This is most probably a bug. Disabling fusion for this run."; *full_graph = unoptimized_graph; } else { std::unordered_map original_input_map; @@ -363,25 +366,22 @@ void OptimizeGraph(nnvm::Graph* full_graph, nnvm::Graph* fwd_graph, nnvm::Graph* } } else { LOG(WARNING) - << "Graph contains duplicate names for some of its inputs - fusion is NOT enabled!"; - } + << "Graph contains duplicate names for some of its inputs - fusion is NOT enabled!"; + } } #else // Only warn user if MXNET_USE_FUSION env var is explicitly set - if (context.dev_mask() == kGPU && !inlining && - dmlc::GetEnv("MXNET_USE_FUSION", false)) { + if (context.dev_mask() == kGPU && !inlining && dmlc::GetEnv("MXNET_USE_FUSION", false)) { exec::WarnFusionNotSupported(); } #endif // MXNET_USE_CUDA && !defined(_WIN32) - *fwd_graph = nnvm::Graph(); - fwd_graph->outputs = std::vector(full_graph->outputs.begin(), - full_graph->outputs.begin() + - num_forward_outputs); - *grad_graph = nnvm::Graph(); - grad_graph->outputs = std::vector(full_graph->outputs.begin() + - num_forward_outputs, - full_graph->outputs.end()); + *fwd_graph = nnvm::Graph(); + fwd_graph->outputs = std::vector( + full_graph->outputs.begin(), full_graph->outputs.begin() + num_forward_outputs); + *grad_graph = nnvm::Graph(); + grad_graph->outputs = std::vector( + full_graph->outputs.begin() + num_forward_outputs, full_graph->outputs.end()); SetRefCounts(fwd_graph, *full_graph); } @@ -421,35 +421,37 @@ struct CachedOpConfig : public dmlc::Parameter { std::string subgraph; DMLC_DECLARE_PARAMETER(CachedOpConfig) { DMLC_DECLARE_FIELD(static_alloc) - .set_default(false) - .describe("Statically allocate memory to improve speed. " - "Memory usage may increase."); + .set_default(false) + .describe( + "Statically allocate memory to improve speed. " + "Memory usage may increase."); DMLC_DECLARE_FIELD(static_shape) - .set_default(false) - .describe("Optimize for invariant input shapes between iterations. " - "Must also set static_alloc to True. " - "Change of input shapes is still allowed but slower."); + .set_default(false) + .describe( + "Optimize for invariant input shapes between iterations. " + "Must also set static_alloc to True. " + "Change of input shapes is still allowed but slower."); DMLC_DECLARE_FIELD(inline_limit) - .set_default(2) - .describe("Maximum number of operators that can be inlined."); + .set_default(2) + .describe("Maximum number of operators that can be inlined."); DMLC_DECLARE_FIELD(forward_bulk_size) - .set_default(Imperative::BulkExecMaxNodeTrainFwd()) - .describe("Segment size of bulk execution during forward pass."); + .set_default(Imperative::BulkExecMaxNodeTrainFwd()) + .describe("Segment size of bulk execution during forward pass."); DMLC_DECLARE_FIELD(backward_bulk_size) - .set_default(Imperative::BulkExecMaxNodeTrainBwd()) - .describe("Segment size of bulk execution during backward pass."); + .set_default(Imperative::BulkExecMaxNodeTrainBwd()) + .describe("Segment size of bulk execution during backward pass."); DMLC_DECLARE_FIELD(data_indices) - .set_default(mxnet::Tuple()) - .describe("Position of argument variables."); + .set_default(mxnet::Tuple()) + .describe("Position of argument variables."); DMLC_DECLARE_FIELD(param_indices) - .set_default(mxnet::Tuple()) - .describe("Position of parameters."); + .set_default(mxnet::Tuple()) + .describe("Position of parameters."); DMLC_DECLARE_FIELD(subgraph) - .set_default(std::string("")) - .describe("JSON string of a subgraph."); + .set_default(std::string("")) + .describe("JSON string of a subgraph."); DMLC_DECLARE_FIELD(is_dynamic) - .set_default(false) - .describe("Whether the graph contains dynamic shape operators."); + .set_default(false) + .describe("Whether the graph contains dynamic shape operators."); } }; @@ -458,13 +460,10 @@ class LazyTransformDataset; } class CachedOp { - using CachedOpMonCallback = - std::function; + using CachedOpMonCallback = std::function; public: - CachedOp( - const nnvm::Symbol& sym, - const std::vector >& flags); + CachedOp(const nnvm::Symbol& sym, const std::vector>& flags); virtual ~CachedOp(); nnvm::Symbol GetOptimizedSymbol() const; uint32_t num_inputs() const { @@ -477,7 +476,7 @@ class CachedOp { return bwd_ograd_dep_.size() + bwd_in_dep_.size() + bwd_out_dep_.size(); } uint32_t num_backward_outputs() const { - auto &idx = fwd_graph_.indexed_graph(); + auto& idx = fwd_graph_.indexed_graph(); return idx.input_nodes().size() - idx.mutable_input_nodes().size(); } std::vector& save_inputs() { @@ -489,27 +488,23 @@ class CachedOp { const std::unordered_set& mutable_input_nodes() const { return fwd_graph_.indexed_graph().mutable_input_nodes(); } - virtual std::vector Gradient( - const nnvm::ObjectPtr& node, - const std::vector& ograds) const; - virtual OpStatePtr Forward( - const std::shared_ptr& op_ptr, - const std::vector& inputs, - const std::vector& outputs, - const Context &default_context); - virtual void Backward( - const bool retain_graph, - const OpStatePtr& state, - const std::vector& inputs, - const std::vector& reqs, - const std::vector& outputs); + virtual std::vector Gradient(const nnvm::ObjectPtr& node, + const std::vector& ograds) const; + virtual OpStatePtr Forward(const std::shared_ptr& op_ptr, + const std::vector& inputs, + const std::vector& outputs, + const Context& default_context); + virtual void Backward(const bool retain_graph, + const OpStatePtr& state, + const std::vector& inputs, + const std::vector& reqs, + const std::vector& outputs); // backward storage type inference - virtual bool BackwardStorageType( - const nnvm::NodeAttrs& attrs, - const int dev_mask, - DispatchMode* dispatch_mode, - std::vector *in_attrs, - std::vector *out_attrs); + virtual bool BackwardStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector* in_attrs, + std::vector* out_attrs); std::vector ListForwardInputNames() const { nnvm::Symbol sym = GetForwardSym(); return sym.ListInputNames(nnvm::Symbol::kAll); @@ -523,8 +518,7 @@ class CachedOp { sym.outputs = fwd_graph_.outputs; return sym; } - void RegisterOpHook(const CachedOp::CachedOpMonCallback& callback, - bool monitor_all = false); + void RegisterOpHook(const CachedOp::CachedOpMonCallback& callback, bool monitor_all = false); protected: struct GraphInfo { @@ -539,23 +533,32 @@ class CachedOp { }; struct CachedOpState { - CachedOpState(const Context &context_, const nnvm::Graph &fwd_graph_, - const nnvm::Graph &full_graph_, const bool inlining_) { + CachedOpState(const Context& context_, + const nnvm::Graph& fwd_graph_, + const nnvm::Graph& full_graph_, + const bool inlining_) { context = context_; nnvm::Symbol sym; sym.outputs = fwd_graph_.outputs; - CreateFullGraph(sym.Copy(), &info.fwd_graph, &info.grad_graph, - &info.full_graph, &info.ograd_entries, + CreateFullGraph(sym.Copy(), + &info.fwd_graph, + &info.grad_graph, + &info.full_graph, + &info.ograd_entries, &info.fwd_input_to_grad_output); - OptimizeGraph(&info.full_graph, &info.fwd_graph, &info.grad_graph, &info.input_map, - context_, fwd_graph_.outputs.size(), inlining_); - - size_t max_nodes = info.full_graph.indexed_graph().num_nodes(); - size_t max_entries = info.full_graph.indexed_graph().num_node_entries(); - info.fwd_graph.attrs["context"] = - std::make_shared(std::vector( - info.fwd_graph.indexed_graph().num_nodes(), context)); + OptimizeGraph(&info.full_graph, + &info.fwd_graph, + &info.grad_graph, + &info.input_map, + context_, + fwd_graph_.outputs.size(), + inlining_); + + size_t max_nodes = info.full_graph.indexed_graph().num_nodes(); + size_t max_entries = info.full_graph.indexed_graph().num_node_entries(); + info.fwd_graph.attrs["context"] = std::make_shared( + std::vector(info.fwd_graph.indexed_graph().num_nodes(), context)); info.full_graph.attrs["context"] = std::make_shared(std::vector(max_nodes, context)); @@ -572,15 +575,15 @@ class CachedOp { Context context; GraphInfo info; - bool recording = false; - bool fwd_alloc = false; - bool bwd_alloc = false; + bool recording = false; + bool fwd_alloc = false; + bool bwd_alloc = false; bool fwd_exec_init = false; bool bwd_exec_init = false; std::vector buff; - std::vector arrays; - std::vector arrays_with_in_out; + std::vector arrays; + std::vector arrays_with_in_out; std::vector array_reqs; std::vector op_states; @@ -593,59 +596,45 @@ class CachedOp { }; OpStatePtr GetCachedOpState(const Context& ctx); - bool SetForwardGraph( - const Context& default_ctx, - GraphInfo* info, - const bool recording, - const std::vector& inputs); - bool SetBackwardGraph( - GraphInfo* info, - const std::vector& reqs, - const std::vector& inputs, - bool detect_inplace_addto = false); - bool CheckDynamicShapeExists( - const Context& default_ctx, - const std::vector& inputs, - bool erase_result); - void StaticAllocMemory( - const OpStatePtr& state_ptr, - bool recording, - bool keep_fwd); - void StaticInitExec( - const OpStatePtr& state_ptr, - bool recording, - bool keep_fwd); - void StaticRunOps( - const Context& default_ctx, - const nnvm::Graph& g, - const OpStatePtr& state_ptr, - const std::vector &state_arrays, - size_t start_nid, - size_t end_nid); - OpStatePtr StaticForward( - const Context& default_ctx, - const std::vector& inputs, - const std::vector& outputs); + bool SetForwardGraph(const Context& default_ctx, + GraphInfo* info, + const bool recording, + const std::vector& inputs); + bool SetBackwardGraph(GraphInfo* info, + const std::vector& reqs, + const std::vector& inputs, + bool detect_inplace_addto = false); + bool CheckDynamicShapeExists(const Context& default_ctx, + const std::vector& inputs, + bool erase_result); + void StaticAllocMemory(const OpStatePtr& state_ptr, bool recording, bool keep_fwd); + void StaticInitExec(const OpStatePtr& state_ptr, bool recording, bool keep_fwd); + void StaticRunOps(const Context& default_ctx, + const nnvm::Graph& g, + const OpStatePtr& state_ptr, + const std::vector& state_arrays, + size_t start_nid, + size_t end_nid); + OpStatePtr StaticForward(const Context& default_ctx, + const std::vector& inputs, + const std::vector& outputs); struct DynamicRuntime; private: - OpStatePtr DynamicForward( - const Context& default_ctx, - const std::vector& inputs, - const std::vector& outputs, - bool use_naive_run = false); - void DynamicBackward( - const bool retain_graph, - const OpStatePtr& op_state, - const std::vector& inputs, - const std::vector& reqs, - const std::vector& outputs); - void StaticBackward( - const bool retain_graph, - const OpStatePtr& state_ptr, - const std::vector& inputs, - const std::vector& reqs, - const std::vector& outputs); + OpStatePtr DynamicForward(const Context& default_ctx, + const std::vector& inputs, + const std::vector& outputs, + bool use_naive_run = false); + void DynamicBackward(const bool retain_graph, + const OpStatePtr& op_state, + const std::vector& inputs, + const std::vector& reqs, + const std::vector& outputs); + void StaticBackward(const bool retain_graph, + const OpStatePtr& state_ptr, + const std::vector& inputs, + const std::vector& reqs, + const std::vector& outputs); size_t BwdOriginalInput(const std::vector& input_map, size_t new_i); CachedOpConfig config_; @@ -662,11 +651,11 @@ class CachedOp { bool monitor_all_{false}; std::mutex mutex_; - std::unordered_map > cached_op_states_; + std::unordered_map> cached_op_states_; friend class ::mxnet::io::LazyTransformDataset; nnvm::Symbol sym_; - std::vector > flags_; + std::vector> flags_; }; struct CachedOp::DynamicRuntime { diff --git a/src/imperative/cached_op_threadsafe.cc b/src/imperative/cached_op_threadsafe.cc index 7d93eb84bd11..9ae0cb50ffe8 100644 --- a/src/imperative/cached_op_threadsafe.cc +++ b/src/imperative/cached_op_threadsafe.cc @@ -39,9 +39,7 @@ struct CachedOpThreadSafe::DynamicRuntime { std::vector op_states; }; -OpStatePtr CachedOpThreadSafe::GetCachedOpState( - const Context& ctx) { - +OpStatePtr CachedOpThreadSafe::GetCachedOpState(const Context& ctx) { for (const auto& i : cached_op_states_[ctx]) { // only create one state per device when not using static memory if (!config_.static_alloc || i.unique()) { @@ -55,26 +53,24 @@ OpStatePtr CachedOpThreadSafe::GetCachedOpState( return state_ptr; } - -CachedOpThreadSafe::CachedOpThreadSafe(const nnvm::Symbol& sym, - const std::vector >& flags) : CachedOp(sym, flags) { +CachedOpThreadSafe::CachedOpThreadSafe( + const nnvm::Symbol& sym, + const std::vector>& flags) + : CachedOp(sym, flags) { using namespace nnvm; using namespace imperative; - static const std::vector zero_ops{Op::Get("zeros_like"), - Op::Get("_zeros")}; + static const std::vector zero_ops{Op::Get("zeros_like"), Op::Get("_zeros")}; config_.Init(flags); if (config_.static_shape) { - CHECK(config_.static_alloc) << "static_alloc must be True when static_shape is True"; + CHECK(config_.static_alloc) << "static_alloc must be True when static_shape is True"; } // construct forward graph CreateForwardGraph(sym.Copy(), &fwd_graph_); SetForwardRefCounts(&fwd_graph_); - SetInputIndices(fwd_graph_, config_.param_indices, - &config_.data_indices); + SetInputIndices(fwd_graph_, config_.param_indices, &config_.data_indices); } /* @@ -88,10 +84,10 @@ OpStatePtr CachedOpThreadSafe::DynamicForward(const Context& default_ctx, using namespace imperative; auto state_ptr = GetCachedOpState(default_ctx); - auto op_state = OpStatePtr::Create(); - auto &runtime = op_state.get_state(); + auto op_state = OpStatePtr::Create(); + auto& runtime = op_state.get_state(); { - auto &state = state_ptr.get_state(); + auto& state = state_ptr.get_state(); // Need to lock the mutex on the state, this allows // for multi context push of ops to dependency engine. // SetForwardGraph runs infer passes on graphs as well @@ -104,28 +100,28 @@ OpStatePtr CachedOpThreadSafe::DynamicForward(const Context& default_ctx, SetForwardGraph(default_ctx, &state.info, false, inputs); runtime.info.fwd_graph = state.info.fwd_graph; } - nnvm::Graph &g = runtime.info.fwd_graph; - const auto &idx = g.indexed_graph(); + nnvm::Graph& g = runtime.info.fwd_graph; + const auto& idx = g.indexed_graph(); size_t max_nodes = runtime.info.fwd_graph.indexed_graph().num_nodes(); runtime.op_states.resize(max_nodes); - auto &states = runtime.op_states; + auto& states = runtime.op_states; // Allocate entries // This buff is thread local and used to store intermediate // nodes in the graph buff.resize(idx.num_node_entries()); states.resize(idx.num_nodes()); - std::vector arrays; + std::vector arrays; arrays.reserve(buff.size()); - for (auto &buffered_array : buff) { + for (auto& buffered_array : buff) { arrays.push_back(&buffered_array); } std::vector array_reqs(arrays.size(), kWriteTo); - const auto &dispatch_modes = g.GetAttr("dispatch_mode"); - std::vector ref_count = g.GetAttr>( - "forward_ref_count"); + const auto& dispatch_modes = g.GetAttr("dispatch_mode"); + std::vector ref_count = g.GetAttr>("forward_ref_count"); for (size_t i = 0; i < idx.num_node_entries(); ++i) { - if (ref_count[i] == 0) array_reqs[i] = kNullOp; + if (ref_count[i] == 0) + array_reqs[i] = kNullOp; } const MemoryPlanVector& mem_plan = g.GetAttr("forward_mem_plan"); @@ -140,8 +136,16 @@ OpStatePtr CachedOpThreadSafe::DynamicForward(const Context& default_ctx, // that. CreateGraphNDs(g, default_ctx, mem_plan, &array_reqs, &arrays); // Invokes operators in the graph in a topologically sorted manner - RunGraph(false, idx, arrays, 0, idx.num_nodes(), std::move(array_reqs), - std::move(ref_count), &states, dispatch_modes, false); + RunGraph(false, + idx, + arrays, + 0, + idx.num_nodes(), + std::move(array_reqs), + std::move(ref_count), + &states, + dispatch_modes, + false); return op_state; } @@ -165,10 +169,8 @@ OpStatePtr CachedOpThreadSafe::Forward(const std::shared_ptr& op_ptr, for (size_t i = 0; i < inputs.size(); ++i) { CHECK_EQ(inputs[i]->ctx(), default_ctx) << "CachedOp requires all inputs to live on the same context. But " - << idx[idx.input_nodes()[0]].source->attrs.name - << " is on " << default_ctx << " while " - << idx[idx.input_nodes()[i]].source->attrs.name - << " is on " << inputs[i]->ctx(); + << idx[idx.input_nodes()[0]].source->attrs.name << " is on " << default_ctx << " while " + << idx[idx.input_nodes()[i]].source->attrs.name << " is on " << inputs[i]->ctx(); } int prev_bulk_size = Engine::Get()->set_bulk_size(config_.forward_bulk_size); @@ -199,9 +201,9 @@ struct CachedOpThreadSafeActualState { } }; OpStatePtr CreateCachedOpThreadSafeState(const NodeAttrs& attrs, - Context ctx, - const mxnet::ShapeVector& in_shapes, - const std::vector& in_types) { + Context ctx, + const mxnet::ShapeVector& in_shapes, + const std::vector& in_types) { const CachedOpThreadSafePtr& op = nnvm::get(attrs.parsed); return OpStatePtr::Create(op); } @@ -211,11 +213,11 @@ void CachedOpThreadSafeForward(const OpStatePtr& state_ptr, const std::vector& inputs, const std::vector& req, const std::vector& outputs) { - CachedOpThreadSafeActualState &s = state_ptr.get_state(); - std::vector in_bufs = inputs; - std::vector out_bufs = outputs; - std::vector in_ptrs(in_bufs.size()); - std::vector out_ptrs(out_bufs.size()); + CachedOpThreadSafeActualState& s = state_ptr.get_state(); + std::vector in_bufs = inputs; + std::vector out_bufs = outputs; + std::vector in_ptrs(in_bufs.size()); + std::vector out_ptrs(out_bufs.size()); for (size_t i = 0; i < in_ptrs.size(); i++) in_ptrs[i] = &in_bufs[i]; for (size_t i = 0; i < out_ptrs.size(); i++) @@ -226,7 +228,7 @@ void CachedOpThreadSafeForward(const OpStatePtr& state_ptr, CHECK(!ctx.is_train) << "Only inference use case supported with thread safe cached op"; CHECK(inputs.size() > 0) << "thread safe cached op requires at least one input"; Context default_ctx = inputs[0].ctx(); - s.forward_state = s.op->Forward(nullptr, in_ptrs, out_ptrs, default_ctx); + s.forward_state = s.op->Forward(nullptr, in_ptrs, out_ptrs, default_ctx); // The arrays in out_ptrs may be changed by CachedOp. // If it is, we need to copy data back. for (size_t i = 0; i < out_bufs.size(); i++) @@ -253,64 +255,71 @@ void CachedOpThreadSafeParamParser(nnvm::NodeAttrs* attrs) { CachedOpThreadSafe::~CachedOpThreadSafe() = default; NNVM_REGISTER_OP(_CachedOpThreadSafe) -.set_num_inputs([](const NodeAttrs& attrs) { - const CachedOpThreadSafePtr& op = nnvm::get(attrs.parsed); - return op->num_inputs(); - }) -.set_num_outputs([](const NodeAttrs& attrs) { - const CachedOpThreadSafePtr& op = nnvm::get(attrs.parsed); - return op->num_outputs(); - }) -.set_attr_parser(CachedOpThreadSafeParamParser) -.set_attr("FListInputNames", - [](const nnvm::NodeAttrs& attrs) { - const CachedOpThreadSafePtr& op = nnvm::get(attrs.parsed); - return op->ListForwardInputNames(); - }) -.set_attr("FListOutputNames", - [](const nnvm::NodeAttrs& attrs) { - const CachedOpThreadSafePtr& op = nnvm::get(attrs.parsed); - return op->ListForwardOutputNames(); - }) -.set_attr("FCreateOpState", CreateCachedOpThreadSafeState) -.set_attr("FInferShape", - [](const nnvm::NodeAttrs& attrs, - mxnet::ShapeVector *in_shapes, - mxnet::ShapeVector *out_shapes) { - const CachedOpThreadSafePtr& op = nnvm::get(attrs.parsed); - return op::DefaultSubgraphOpShapeHelper(op->GetForwardSym(), in_shapes, out_shapes); - }) -.set_attr("FInferType", - [](const nnvm::NodeAttrs& attrs, - std::vector *in_types, - std::vector *out_types) { - const CachedOpThreadSafePtr& op = nnvm::get(attrs.parsed); - return op::DefaultSubgraphOpTypeHelper(op->GetForwardSym(), in_types, out_types); - }) -.set_attr("FInferStorageType", - [](const nnvm::NodeAttrs& attrs, - const int dev_mask, - DispatchMode* dispatch_mode, - std::vector* in_stypes, - std::vector* out_stypes) { - const CachedOpThreadSafePtr& op = nnvm::get(attrs.parsed); - return op::DefaultSubgraphOpStorageTypeHelper(op->GetForwardSym(), - dev_mask, dispatch_mode, - in_stypes, out_stypes); - }) -.set_attr("FStatefulComputeEx", CachedOpThreadSafeForward) -.set_attr("FStatefulComputeEx", CachedOpThreadSafeForward) -.set_attr("FMutateInputs", - [](const nnvm::NodeAttrs& attrs) { - const CachedOpThreadSafePtr& op = nnvm::get(attrs.parsed); - return op::DefaultSubgraphOpMutableInputsHelper(op->GetForwardSym()); - }) -.set_attr("FResourceRequest", - [](const nnvm::NodeAttrs& attrs) { - const CachedOpThreadSafePtr& op = nnvm::get(attrs.parsed); - return op::DefaultSubgraphOpResourceRequestHelper(op->GetForwardSym()); - }) -.set_attr("FExecType", op::DefaultSubgraphOpExecType) -.add_argument("data", "NDArray-or-Symbol[]", "input data list"); + .set_num_inputs([](const NodeAttrs& attrs) { + const CachedOpThreadSafePtr& op = nnvm::get(attrs.parsed); + return op->num_inputs(); + }) + .set_num_outputs([](const NodeAttrs& attrs) { + const CachedOpThreadSafePtr& op = nnvm::get(attrs.parsed); + return op->num_outputs(); + }) + .set_attr_parser(CachedOpThreadSafeParamParser) + .set_attr("FListInputNames", + [](const nnvm::NodeAttrs& attrs) { + const CachedOpThreadSafePtr& op = + nnvm::get(attrs.parsed); + return op->ListForwardInputNames(); + }) + .set_attr("FListOutputNames", + [](const nnvm::NodeAttrs& attrs) { + const CachedOpThreadSafePtr& op = + nnvm::get(attrs.parsed); + return op->ListForwardOutputNames(); + }) + .set_attr("FCreateOpState", CreateCachedOpThreadSafeState) + .set_attr("FInferShape", + [](const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector* in_shapes, + mxnet::ShapeVector* out_shapes) { + const CachedOpThreadSafePtr& op = + nnvm::get(attrs.parsed); + return op::DefaultSubgraphOpShapeHelper( + op->GetForwardSym(), in_shapes, out_shapes); + }) + .set_attr( + "FInferType", + [](const nnvm::NodeAttrs& attrs, std::vector* in_types, std::vector* out_types) { + const CachedOpThreadSafePtr& op = nnvm::get(attrs.parsed); + return op::DefaultSubgraphOpTypeHelper(op->GetForwardSym(), in_types, out_types); + }) + .set_attr( + "FInferStorageType", + [](const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector* in_stypes, + std::vector* out_stypes) { + const CachedOpThreadSafePtr& op = nnvm::get(attrs.parsed); + return op::DefaultSubgraphOpStorageTypeHelper( + op->GetForwardSym(), dev_mask, dispatch_mode, in_stypes, out_stypes); + }) + .set_attr("FStatefulComputeEx", CachedOpThreadSafeForward) + .set_attr("FStatefulComputeEx", CachedOpThreadSafeForward) + .set_attr("FMutateInputs", + [](const nnvm::NodeAttrs& attrs) { + const CachedOpThreadSafePtr& op = + nnvm::get(attrs.parsed); + return op::DefaultSubgraphOpMutableInputsHelper( + op->GetForwardSym()); + }) + .set_attr("FResourceRequest", + [](const nnvm::NodeAttrs& attrs) { + const CachedOpThreadSafePtr& op = + nnvm::get(attrs.parsed); + return op::DefaultSubgraphOpResourceRequestHelper( + op->GetForwardSym()); + }) + .set_attr("FExecType", op::DefaultSubgraphOpExecType) + .add_argument("data", "NDArray-or-Symbol[]", "input data list"); } // namespace mxnet diff --git a/src/imperative/cached_op_threadsafe.h b/src/imperative/cached_op_threadsafe.h index 63521c7219e7..590f72dc12d0 100644 --- a/src/imperative/cached_op_threadsafe.h +++ b/src/imperative/cached_op_threadsafe.h @@ -30,12 +30,9 @@ #include #include "./cached_op.h" - - namespace mxnet { /*! \brief CachedOp Parameters*/ -struct CachedOpThreadSafeConfig - : public dmlc::Parameter { +struct CachedOpThreadSafeConfig : public dmlc::Parameter { // keeping the config minimal // inlining, bulking, dynamic shapes, static allocing and shaping not // supported @@ -49,21 +46,23 @@ struct CachedOpThreadSafeConfig bool static_shape; DMLC_DECLARE_PARAMETER(CachedOpThreadSafeConfig) { DMLC_DECLARE_FIELD(static_alloc) - .set_default(false) - .describe("Statically allocate memory to improve speed. " - "Memory usage may increase."); + .set_default(false) + .describe( + "Statically allocate memory to improve speed. " + "Memory usage may increase."); DMLC_DECLARE_FIELD(static_shape) - .set_default(false) - .describe("Optimize for invariant input shapes between iterations. " - "Must also set static_alloc to True. " - "Change of input shapes is still allowed but slower."); + .set_default(false) + .describe( + "Optimize for invariant input shapes between iterations. " + "Must also set static_alloc to True. " + "Change of input shapes is still allowed but slower."); DMLC_DECLARE_FIELD(forward_bulk_size) - .set_default(Imperative::BulkExecMaxNodeTrainFwd()) - .describe("Segment size of bulk execution during dynamic forward"); + .set_default(Imperative::BulkExecMaxNodeTrainFwd()) + .describe("Segment size of bulk execution during dynamic forward"); DMLC_DECLARE_FIELD(data_indices) .set_default(mxnet::Tuple()) .describe("Position of argument variables."); - DMLC_DECLARE_FIELD(param_indices) + DMLC_DECLARE_FIELD(param_indices) .set_default(mxnet::Tuple()) .describe("Position of parameters."); } @@ -72,33 +71,29 @@ struct CachedOpThreadSafeConfig // Thread local buff to store internal states of the graph // Used in dynamic_forward #if DMLC_CXX11_THREAD_LOCAL - static thread_local std::vector buff; +static thread_local std::vector buff; #else - static MX_THREAD_LOCAL std::vector buff; +static MX_THREAD_LOCAL std::vector buff; #endif - - class CachedOpThreadSafe : public CachedOp { public: - CachedOpThreadSafe( - const nnvm::Symbol &sym, - const std::vector> &flags); + CachedOpThreadSafe(const nnvm::Symbol& sym, + const std::vector>& flags); ~CachedOpThreadSafe(); uint32_t num_inputs() const { - return fwd_graph_.indexed_graph().input_nodes().size(); + return fwd_graph_.indexed_graph().input_nodes().size(); } uint32_t num_outputs() const { - return fwd_graph_.outputs.size(); + return fwd_graph_.outputs.size(); } const std::unordered_set& mutable_input_nodes() const { return fwd_graph_.indexed_graph().mutable_input_nodes(); } - OpStatePtr Forward( - const std::shared_ptr& op_ptr, - const std::vector& inputs, - const std::vector& outputs, - const Context& default_ctx); + OpStatePtr Forward(const std::shared_ptr& op_ptr, + const std::vector& inputs, + const std::vector& outputs, + const Context& default_ctx); std::vector ListForwardInputNames() const { nnvm::Symbol sym = GetForwardSym(); return sym.ListInputNames(nnvm::Symbol::kAll); @@ -114,6 +109,7 @@ class CachedOpThreadSafe : public CachedOp { } struct GraphInfo; + private: struct DynamicRuntime; diff --git a/src/imperative/eliminate_common_expr_pass.cc b/src/imperative/eliminate_common_expr_pass.cc index bee1d159057b..0ef204cdeca6 100644 --- a/src/imperative/eliminate_common_expr_pass.cc +++ b/src/imperative/eliminate_common_expr_pass.cc @@ -37,10 +37,10 @@ namespace exec { namespace { -using nnvm::Node; -using nnvm::ObjectPtr; using nnvm::Graph; using nnvm::IndexedGraph; +using nnvm::Node; +using nnvm::ObjectPtr; // NodeInput holds the sufficient subset of NodeEntry fields for Node-input equality tests using NodeInput = std::pair; @@ -61,15 +61,19 @@ std::vector ConvertInputs(const std::vector& inputs) * \brief Determine if two Nodes have equal function such that one Node can be eliminated. */ bool NodeEqual(const Node* n, const Node* m) { - if (n->is_variable() || m->is_variable()) return false; - if (n->op() != m->op()) return false; + if (n->is_variable() || m->is_variable()) + return false; + if (n->op() != m->op()) + return false; // Nodes with different attributes are considered not identical, // though this may reject Node pairs that are in fact functionally the same. - if (n->attrs.dict != m->attrs.dict) return false; + if (n->attrs.dict != m->attrs.dict) + return false; // Ops that mutate inputs cannot be optimized out static auto& fmutate_inputs = Op::GetAttr("FMutateInputs"); - if (fmutate_inputs.get(n->op(), nullptr) != nullptr) return false; + if (fmutate_inputs.get(n->op(), nullptr) != nullptr) + return false; // Stateful ops cannot be be equal to each other static auto& fstateful = Op::GetAttr("FCreateOpState"); @@ -86,9 +90,9 @@ bool NodeEqual(const Node* n, const Node* m) { // Ops that require resource could ask for // random resource, so need to be explicitly marked // to be eligible - static auto& resource_request = Op::GetAttr("FResourceRequest"); + static auto& resource_request = Op::GetAttr("FResourceRequest"); static auto& resource_request_ex = Op::GetAttr("FResourceRequestEx"); - const auto fresource_request = resource_request.get(n->op(), nullptr); + const auto fresource_request = resource_request.get(n->op(), nullptr); if (fresource_request != nullptr) { const auto& requests = fresource_request(n->attrs); for (const auto& req : requests) { @@ -97,7 +101,8 @@ bool NodeEqual(const Node* n, const Node* m) { } } } - if (resource_request_ex.get(n->op(), nullptr) != nullptr) return false; + if (resource_request_ex.get(n->op(), nullptr) != nullptr) + return false; return true; } @@ -115,17 +120,18 @@ std::vector > GetCommonNodes(const Graph& g) { }); // Now check for identical node ops within the node groups (having identical inputs) for (const auto& pair : grouped_nodes) { - auto &node_group = pair.second; // Group of nodes that share the same vector of inputs + auto& node_group = pair.second; // Group of nodes that share the same vector of inputs if (node_group.size() > 1) { std::unordered_set visited; for (size_t i = 0; i < node_group.size(); ++i) { - if (visited.count(i)) continue; + if (visited.count(i)) + continue; for (size_t j = i + 1; j < node_group.size(); ++j) { // If the two Nodes have equal function, then one Node (called the 'replaced') can // be eliminated in favor of the other Node (the 'src'). if (NodeEqual(node_group[i]->get(), node_group[j]->get())) { visited.insert(j); - ObjectPtr src = *node_group[i]; + ObjectPtr src = *node_group[i]; ObjectPtr replaced = *node_group[j]; ret.emplace_back(src, replaced); } @@ -141,20 +147,20 @@ std::vector > GetCommonNodes(const Graph& g) { */ void EliminateCommonNodes(Graph* g, const std::vector >& common_nodes) { - for (const auto &p : common_nodes) { - std::vector nodes_to_change; - const ObjectPtr &src = p.first; - const ObjectPtr &replaced = p.second; + for (const auto& p : common_nodes) { + std::vector nodes_to_change; + const ObjectPtr& src = p.first; + const ObjectPtr& replaced = p.second; // Create a `nodes_to_change` list containing the Nodes that refer to the `replaced` Node // that is targeted for elimination. - DFSVisit(g->outputs, [replaced, &nodes_to_change](const ObjectPtr &n) { - for (const auto &dep : n->control_deps) { + DFSVisit(g->outputs, [replaced, &nodes_to_change](const ObjectPtr& n) { + for (const auto& dep : n->control_deps) { if (dep == replaced) { nodes_to_change.push_back(n); return; } } - for (const auto &inp : n->inputs) { + for (const auto& inp : n->inputs) { if (inp.node == replaced) { nodes_to_change.push_back(n); return; @@ -164,13 +170,13 @@ void EliminateCommonNodes(Graph* g, // Change references to the `replaced` Node within the `nodes_to_change` list to be // references to the equivalent `src` Node. - for (auto &n : nodes_to_change) { - for (auto &dep : n->control_deps) { + for (auto& n : nodes_to_change) { + for (auto& dep : n->control_deps) { if (dep == replaced) { dep = src; } } - for (auto &inp : n->inputs) { + for (auto& inp : n->inputs) { if (inp.node == replaced) { inp.node = src; } @@ -178,7 +184,7 @@ void EliminateCommonNodes(Graph* g, } // Add `replaced` Node control dependencies to those of the `src` Node. - for (const auto &n : replaced->control_deps) { + for (const auto& n : replaced->control_deps) { src->control_deps.push_back(n); } @@ -193,7 +199,7 @@ void EliminateCommonNodes(Graph* g, // insert Copy nodes as appropriate const Op* copy_op = Op::Get("_copy"); nnvm::NodeEntryMap unique_outputs; - for (auto & output : g->outputs) { + for (auto& output : g->outputs) { auto kv = unique_outputs.find(output); if (kv == unique_outputs.end()) { unique_outputs.emplace(output, 0); @@ -202,7 +208,7 @@ void EliminateCommonNodes(Graph* g, std::ostringstream os; os << kv->first.node->attrs.name << "_" << kv->second << "_copy"; kv->second++; - copy_node->attrs.op = copy_op; + copy_node->attrs.op = copy_op; copy_node->attrs.name = os.str(); copy_node->inputs.emplace_back(kv->first); output = nnvm::NodeEntry{copy_node, 0, 0}; diff --git a/src/imperative/exec_pass.h b/src/imperative/exec_pass.h index 341f5109ebdd..440fc6b937ca 100644 --- a/src/imperative/exec_pass.h +++ b/src/imperative/exec_pass.h @@ -41,33 +41,32 @@ namespace mxnet { namespace exec { template -using FAccessSubgraphAttr = std::function, - std::vector> - (const NodeAttrs& attrs)>; +using FAccessSubgraphAttr = + std::function, std::vector>( + const NodeAttrs& attrs)>; -using FAccessSubgraphShape = FAccessSubgraphAttr; -using FAccessSubgraphType = FAccessSubgraphAttr; +using FAccessSubgraphShape = FAccessSubgraphAttr; +using FAccessSubgraphType = FAccessSubgraphAttr; using FAccessSubgraphStorageType = FAccessSubgraphAttr; template -using FProvideSubgraphAttr = std::function &nodes, - const std::vector> &in_attrs, - const std::vector> &out_attrs)>; -using FProvideSubgraphShape = FProvideSubgraphAttr; -using FProvideSubgraphType = FProvideSubgraphAttr; +using FProvideSubgraphAttr = std::function& nodes, + const std::vector>& in_attrs, + const std::vector>& out_attrs)>; +using FProvideSubgraphShape = FProvideSubgraphAttr; +using FProvideSubgraphType = FProvideSubgraphAttr; using FProvideSubgraphStorageType = FProvideSubgraphAttr; -using TIsFusion = bool; +using TIsFusion = bool; using TIsFusionHelper = bool; /*! \brief reuse graph definition */ using nnvm::Graph; -const int kBadStorageID = -1; +const int kBadStorageID = -1; const int kExternalStorageID = -2; -const int kDynamicStorageID = -3; +const int kDynamicStorageID = -3; const int kNonDefaultStorage = -2; @@ -120,7 +119,7 @@ class OpExecutor { * \brief per node vector of operator executors. * \note stored under attribute "op_exec" */ -using OpExecVector = std::vector >; +using OpExecVector = std::vector>; /*! * \brief per node vector of operator states. @@ -201,7 +200,7 @@ Graph DetectInplaceAddTo(Graph g); * * \return graph with common expressions eliminated */ -Graph EliminateCommonExpr(Graph && g); +Graph EliminateCommonExpr(Graph&& g); /*! * \brief Fuse pointwise operations in the graph. @@ -241,7 +240,7 @@ Graph InferShape(Graph&& graph, * The index of ShapeVector is given by graph.indexed_graph().entry_id. */ Graph InferType(Graph&& graph, - nnvm::DTypeVector&& dtype_inputs = nnvm::DTypeVector(), + nnvm::DTypeVector&& dtype_inputs = nnvm::DTypeVector(), const std::string& dtype_attr_key = ""); /*! @@ -254,7 +253,7 @@ Graph InferType(Graph&& graph, * The index of StorageTypeVector is given by graph.indexed_graph().entry_id. */ Graph InferStorageType(Graph&& graph, - StorageTypeVector&& storage_type_inputs = StorageTypeVector(), + StorageTypeVector&& storage_type_inputs = StorageTypeVector(), const std::string& storage_type_attr_key = ""); } // namespace exec @@ -284,16 +283,16 @@ inline Graph MXGradient( std::vector xs, std::vector ys_out_grad, std::function&& inputs)> aggregate_fun = nullptr, - std::function mirror_fun = nullptr, - std::vector zero_ops = std::vector(), - std::string copy_op_str = std::string(), + std::function mirror_fun = nullptr, + std::vector zero_ops = std::vector(), + std::string copy_op_str = std::string(), mxnet::ShapeVector in_arg_shapes = mxnet::ShapeVector(), - DTypeVector in_arg_dtypes = DTypeVector()) { - graph.attrs["grad_ys"] = std::make_shared(std::move(ys)); - graph.attrs["grad_xs"] = std::make_shared(std::move(xs)); + DTypeVector in_arg_dtypes = DTypeVector()) { + graph.attrs["grad_ys"] = std::make_shared(std::move(ys)); + graph.attrs["grad_xs"] = std::make_shared(std::move(xs)); graph.attrs["grad_ys_out_grad"] = std::make_shared(std::move(ys_out_grad)); - graph.attrs["in_arg_shapes"] = std::make_shared(std::move(in_arg_shapes)); - graph.attrs["in_arg_dtypes"] = std::make_shared(std::move(in_arg_dtypes)); + graph.attrs["in_arg_shapes"] = std::make_shared(std::move(in_arg_shapes)); + graph.attrs["in_arg_dtypes"] = std::make_shared(std::move(in_arg_dtypes)); if (aggregate_fun != nullptr) { graph.attrs["grad_aggregate_fun"] = std::make_shared(aggregate_fun); diff --git a/src/imperative/imperative.cc b/src/imperative/imperative.cc index a42a60b919fe..0ec5ae579dce 100644 --- a/src/imperative/imperative.cc +++ b/src/imperative/imperative.cc @@ -25,19 +25,19 @@ #include "./cached_op.h" namespace nnvm { -ObjectPtr CreateVariableNode(const std::string &name); +ObjectPtr CreateVariableNode(const std::string& name); } namespace mxnet { #if DMLC_CXX11_THREAD_LOCAL -thread_local bool Imperative::is_train_ = false; -thread_local bool Imperative::is_recording_ = false; -thread_local bool Imperative::is_deferred_compute_ = false; +thread_local bool Imperative::is_train_ = false; +thread_local bool Imperative::is_recording_ = false; +thread_local bool Imperative::is_deferred_compute_ = false; thread_local bool Imperative::is_np_shape_thread_local_ = false; #else -MX_THREAD_LOCAL bool Imperative::is_train_ = false; -MX_THREAD_LOCAL bool Imperative::is_recording_ = false; -MX_THREAD_LOCAL bool Imperative::is_deferred_compute_ = false; +MX_THREAD_LOCAL bool Imperative::is_train_ = false; +MX_THREAD_LOCAL bool Imperative::is_recording_ = false; +MX_THREAD_LOCAL bool Imperative::is_deferred_compute_ = false; MX_THREAD_LOCAL bool Imperative::is_np_shape_thread_local_ = false; #endif @@ -46,60 +46,66 @@ Imperative* Imperative::Get() { return &inst; } -OpStatePtr Imperative::InvokeOp( - const Context& ctx, - const nnvm::NodeAttrs& attrs, - const std::vector& inputs, - const std::vector& outputs, - const std::vector& req, - const DispatchMode dispatch_mode, - OpStatePtr state) { +OpStatePtr Imperative::InvokeOp(const Context& ctx, + const nnvm::NodeAttrs& attrs, + const std::vector& inputs, + const std::vector& outputs, + const std::vector& req, + const DispatchMode dispatch_mode, + OpStatePtr state) { using namespace imperative; - static auto& createop = nnvm::Op::GetAttr("FCreateOpState"); + static auto& createop = nnvm::Op::GetAttr("FCreateOpState"); static auto& is_layer_backward = Op::GetAttr("TIsLayerOpBackward"); - MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get(); + MXAPIThreadLocalEntry<>* ret = MXAPIThreadLocalStore<>::Get(); - const nnvm::Op *op = attrs.op; + const nnvm::Op* op = attrs.op; std::vector read_vars, write_vars; std::vector requested; std::vector mutate_idx; - SetDependency(attrs, ctx, inputs, outputs, - &read_vars, &write_vars, &requested, &mutate_idx, dispatch_mode); + SetDependency( + attrs, ctx, inputs, outputs, &read_vars, &write_vars, &requested, &mutate_idx, dispatch_mode); - FCompute fn = common::GetFCompute(op, "FCompute", ctx); + FCompute fn = common::GetFCompute(op, "FCompute", ctx); FComputeEx fn_ex = common::GetFCompute(op, "FComputeEx", ctx); // FComputeEx is dispatched only when dispatch_mode is DispatchMode::kFComputeEx CHECK(dispatch_mode != DispatchMode::kUndefined); bool dispatch_fcompex = dispatch_mode == DispatchMode::kFComputeEx; if (fn_ex && dispatch_fcompex) { - PushFComputeEx(fn_ex, op, attrs, ctx, read_vars, write_vars, - requested, inputs, outputs, req); + PushFComputeEx(fn_ex, op, attrs, ctx, read_vars, write_vars, requested, inputs, outputs, req); } else if (fn) { - PushFCompute(fn, op, attrs, ctx, read_vars, write_vars, - requested, inputs, outputs, mutate_idx, req); + PushFCompute( + fn, op, attrs, ctx, read_vars, write_vars, requested, inputs, outputs, mutate_idx, req); } else if (createop.count(op) || is_layer_backward.get(op, false)) { if (!state) { state = createop[op](attrs, ctx, ret->arg_shapes, ret->arg_types); } write_vars.push_back(state.get_var()); - PushOperator(state, op, attrs, ctx, read_vars, write_vars, - requested, inputs, outputs, mutate_idx, req, dispatch_mode); + PushOperator(state, + op, + attrs, + ctx, + read_vars, + write_vars, + requested, + inputs, + outputs, + mutate_idx, + req, + dispatch_mode); } else { - LOG(FATAL) - << "Operator " << op->name << " is not implemented for " - << (ctx.dev_mask() == gpu::kDevMask ? "GPU." : "CPU."); + LOG(FATAL) << "Operator " << op->name << " is not implemented for " + << (ctx.dev_mask() == gpu::kDevMask ? "GPU." : "CPU."); } return state; } -OpStatePtr Imperative::Invoke( - const Context& default_ctx, - const nnvm::NodeAttrs& attrs, - const std::vector& inputs, - const std::vector& outputs) { +OpStatePtr Imperative::Invoke(const Context& default_ctx, + const nnvm::NodeAttrs& attrs, + const std::vector& inputs, + const std::vector& outputs) { using namespace imperative; static auto& ndfunc = nnvm::Op::GetAttr("FNDArrayFunction"); @@ -107,13 +113,14 @@ OpStatePtr Imperative::Invoke( std::vector p_inputs, p_outputs; DerefInputOutput(inputs, outputs, &p_inputs, &p_outputs); ndfunc[attrs.op](attrs, p_inputs, &p_outputs); - for (size_t i = 0; i < outputs.size(); ++i) *outputs[i] = std::move(p_outputs[i]); + for (size_t i = 0; i < outputs.size(); ++i) + *outputs[i] = std::move(p_outputs[i]); return OpStatePtr(); } // TODO(piiswrong): infer ctx DispatchMode dispatch_mode = DispatchMode::kUndefined; - Context ctx = GetContext(attrs, inputs, outputs, default_ctx); + Context ctx = GetContext(attrs, inputs, outputs, default_ctx); SetShapeType(ctx, attrs, inputs, outputs, &dispatch_mode); std::vector req; SetWriteInplaceReq(inputs, outputs, &req); @@ -131,37 +138,35 @@ OpStatePtr Imperative::Invoke( // Create nnvm::NodeEntry for variables' and gradients' autograd_entry_ // attribute and associate AGInfo with it's info attribute -void Imperative::MarkVariables( - const std::vector& variables, - const std::vector& grad_reqs, - const std::vector& gradients) { +void Imperative::MarkVariables(const std::vector& variables, + const std::vector& grad_reqs, + const std::vector& gradients) { for (uint32_t i = 0; i < variables.size(); ++i) { std::string str_c(std::to_string(variable_count_++)); - variables[i]->autograd_entry_ = nnvm::NodeEntry{ - nnvm::Symbol::CreateVariable("var" + str_c).outputs[0].node, 0, 0}; + variables[i]->autograd_entry_ = + nnvm::NodeEntry{nnvm::Symbol::CreateVariable("var" + str_c).outputs[0].node, 0, 0}; AGInfo& info = AGInfo::Create(variables[i]->autograd_entry_.node); info.outputs.emplace_back(variables[i]->Detach()); info.out_grads.emplace_back(gradients[i]->Detach()); info.grad_req = static_cast(grad_reqs[i]); - info.ctx = variables[i]->ctx(); + info.ctx = variables[i]->ctx(); - gradients[i]->autograd_entry_ = nnvm::NodeEntry{ - nnvm::Symbol::CreateVariable("grad" + str_c).outputs[0].node, 0, 0}; + gradients[i]->autograd_entry_ = + nnvm::NodeEntry{nnvm::Symbol::CreateVariable("grad" + str_c).outputs[0].node, 0, 0}; AGInfo& grad_info = AGInfo::Create(gradients[i]->autograd_entry_.node); grad_info.outputs.emplace_back(gradients[i]->Detach()); grad_info.ctx = gradients[i]->ctx(); } } - -void Imperative::GetBackwardDependency( - const nnvm::ObjectPtr& node, - uint32_t num_inputs, uint32_t num_outputs, - std::vector *p_save_inputs, - std::vector *p_save_outputs) { - static auto& fgradient = nnvm::Op::GetAttr("FGradient"); - std::vector& save_inputs = *p_save_inputs; +void Imperative::GetBackwardDependency(const nnvm::ObjectPtr& node, + uint32_t num_inputs, + uint32_t num_outputs, + std::vector* p_save_inputs, + std::vector* p_save_outputs) { + static auto& fgradient = nnvm::Op::GetAttr("FGradient"); + std::vector& save_inputs = *p_save_inputs; std::vector& save_outputs = *p_save_outputs; save_inputs.resize(num_inputs); save_outputs.resize(num_outputs); @@ -189,50 +194,52 @@ void Imperative::GetBackwardDependency( } } DFSVisit(igrad_entries, [&](const nnvm::ObjectPtr& gnode) { - if (!gnode || gnode == node) return; - for (const auto& i : gnode->inputs) { - if (i.node == nullptr && i.version == 0) { - save_inputs[i.index] = true; - } else if (i.node == node) { - save_outputs[i.index] = true; - } + if (!gnode || gnode == node) + return; + for (const auto& i : gnode->inputs) { + if (i.node == nullptr && i.version == 0) { + save_inputs[i.index] = true; + } else if (i.node == node) { + save_outputs[i.index] = true; } - }); + } + }); } } -void Imperative::RecordOp( - nnvm::NodeAttrs&& attrs, - const std::vector& inputs, - const std::vector& outputs, - const OpStatePtr& state, - std::vector* p_save_inputs, - std::vector* p_save_outputs) { - MXAPIThreadLocalEntry<> *local_buff = MXAPIThreadLocalStore<>::Get(); +void Imperative::RecordOp(nnvm::NodeAttrs&& attrs, + const std::vector& inputs, + const std::vector& outputs, + const OpStatePtr& state, + std::vector* p_save_inputs, + std::vector* p_save_outputs) { + MXAPIThreadLocalEntry<>* local_buff = MXAPIThreadLocalStore<>::Get(); CHECK(!is_deferred_compute()) << "Autograd recording is not supported during deferred compute mode."; for (auto output : outputs) { CHECK(AGInfo::IsNone(*output)) - << "Assigning to NDArrays that are already in a computational graph " - << "will cause undefined behavior when evaluating gradients. " - << "Please call backward first to clear the graph or do this out side of " - << "a record section. Also note that you cannot use inplace operations " - << "like +=, *=, relu(x, out=x), y[idx]=x, etc inside a record section. " - << "Issue occurred while recording op: " << attrs.name; + << "Assigning to NDArrays that are already in a computational graph " + << "will cause undefined behavior when evaluating gradients. " + << "Please call backward first to clear the graph or do this out side of " + << "a record section. Also note that you cannot use inplace operations " + << "like +=, *=, relu(x, out=x), y[idx]=x, etc inside a record section. " + << "Issue occurred while recording op: " << attrs.name; } bool need_grad = false; for (const auto& i : inputs) { - if (AGInfo::IsNone(*i)) continue; + if (AGInfo::IsNone(*i)) + continue; need_grad = true; break; } - if (!need_grad) return; + if (!need_grad) + return; nnvm::ObjectPtr node = nnvm::Node::Create(); - node->attrs = std::move(attrs); + node->attrs = std::move(attrs); // if node name is empty or node name is equal to op name - name it with unique name if (node->attrs.name == "" || node->attrs.op->name == node->attrs.name) { node->attrs.name = "node_" + std::to_string(node_count_++); @@ -240,39 +247,40 @@ void Imperative::RecordOp( node_count_++; } AGInfo& info = AGInfo::Create(node); - info.state = state; - info.ctx = outputs[0]->ctx(); + info.state = state; + info.ctx = outputs[0]->ctx(); if (p_save_inputs == nullptr) { - p_save_inputs = &(local_buff->save_inputs); + p_save_inputs = &(local_buff->save_inputs); p_save_outputs = &(local_buff->save_outputs); - GetBackwardDependency( - node, inputs.size(), outputs.size(), p_save_inputs, p_save_outputs); + GetBackwardDependency(node, inputs.size(), outputs.size(), p_save_inputs, p_save_outputs); } else { node->inputs.resize(inputs.size()); } - std::vector& save_inputs = *p_save_inputs; + std::vector& save_inputs = *p_save_inputs; std::vector& save_outputs = *p_save_outputs; for (size_t i = 0; i < inputs.size(); ++i) { if (AGInfo::IsNone(*(inputs[i]))) { - nnvm::NodeEntry entry{nnvm::Symbol::CreateVariable( - "null" + std::to_string(variable_count_++)).outputs[0].node, 0, 0}; + nnvm::NodeEntry entry{ + nnvm::Symbol::CreateVariable("null" + std::to_string(variable_count_++)).outputs[0].node, + 0, + 0}; AGInfo& input_info = AGInfo::Create(entry.node); - input_info.ctx = inputs[i]->ctx(); + input_info.ctx = inputs[i]->ctx(); if (save_inputs[i]) { input_info.outputs.emplace_back(*inputs[i]); } else { // Put a dummy array here since it will not be used. input_info.outputs.emplace_back(); - input_info.outputs.back().shape_ = inputs[i]->shape(); - input_info.outputs.back().dtype_ = inputs[i]->dtype(); + input_info.outputs.back().shape_ = inputs[i]->shape(); + input_info.outputs.back().dtype_ = inputs[i]->dtype(); input_info.outputs.back().storage_type_ = inputs[i]->storage_type(); } inputs[i]->autograd_entry_ = std::move(entry); // assign last to prevent cyclic reference } else if (save_inputs[i]) { - nnvm::NodeEntry& entry = inputs[i]->autograd_entry_; + nnvm::NodeEntry& entry = inputs[i]->autograd_entry_; AGInfo::Get(entry.node).outputs[entry.index] = inputs[i]->Detach(); } node->inputs[i] = inputs[i]->autograd_entry_; @@ -290,38 +298,38 @@ void Imperative::RecordOp( } else { // Put a dummy array here since it will not be used. info.outputs.emplace_back(); - info.outputs.back().shape_ = outputs[i]->shape(); - info.outputs.back().dtype_ = outputs[i]->dtype(); + info.outputs.back().shape_ = outputs[i]->shape(); + info.outputs.back().dtype_ = outputs[i]->dtype(); info.outputs.back().storage_type_ = outputs[i]->storage_type(); } outputs[i]->autograd_entry_ = nnvm::NodeEntry{node, i, 0}; } } -void Imperative::RecordDeferredCompute(nnvm::NodeAttrs &&attrs, - const std::vector &inputs, - const std::vector &outputs) { +void Imperative::RecordDeferredCompute(nnvm::NodeAttrs&& attrs, + const std::vector& inputs, + const std::vector& outputs) { CHECK(!is_recording()) << "MXNetError: Autograd recording is not supported during deferred compute mode."; - for (const NDArray *input : inputs) { + for (const NDArray* input : inputs) { CHECK(!DCInfo::IsNone(*input)) << "ValueError: All inputs to deferred compute recording must be associated " << "with a symbolic variable or be the output of a deferred compute operator."; } - for (const NDArray *output : outputs) { + for (const NDArray* output : outputs) { CHECK(DCInfo::IsNone(*output)) << "NotImplementedError: Inplace operations (+=, -=, x[:]=, etc) " << "are not supported when recording in deferred compute mode."; } DispatchMode dispatch_mode = DispatchMode::kUndefined; - Context ctx = imperative::GetContext(attrs, inputs, outputs, Context::CPU()); + Context ctx = imperative::GetContext(attrs, inputs, outputs, Context::CPU()); imperative::SetShapeType(ctx, attrs, inputs, outputs, &dispatch_mode); nnvm::ObjectPtr node = nnvm::Node::Create(); node->inputs.reserve(inputs.size()); // Get NodeEntries for inputs - for (const NDArray *array : inputs) { + for (const NDArray* array : inputs) { CHECK(array->deferredcompute_entry_.node); // Must not be nullptr node->inputs.emplace_back(array->deferredcompute_entry_); } @@ -341,10 +349,10 @@ void Imperative::RecordDeferredCompute(nnvm::NodeAttrs &&attrs, DCInfo::Create(node, inputs, outputs); } -nnvm::Symbol Imperative::GetDeferredComputeSymbol(const std::vector &outputs) { +nnvm::Symbol Imperative::GetDeferredComputeSymbol(const std::vector& outputs) { nnvm::Symbol s; s.outputs.reserve(outputs.size()); - for (NDArray * ndoutput : outputs) { + for (NDArray* ndoutput : outputs) { CHECK(!Imperative::DCInfo::IsNone(*ndoutput)) << "ValueError: output_arrays for GetDeferredComputeSymbol " << "must have a deferred compute history associated with them."; @@ -353,16 +361,16 @@ nnvm::Symbol Imperative::GetDeferredComputeSymbol(const std::vector & return s.Copy(); } -void Imperative::SetDeferredComputeVariable(NDArrayHandle *arrays, - SymbolHandle *variables, const int num) { +void Imperative::SetDeferredComputeVariable(NDArrayHandle* arrays, + SymbolHandle* variables, + const int num) { // Sanity check all inputs for (int i = 0; i < num; i++) { - nnvm::Symbol *s = reinterpret_cast(variables[i]); - NDArray *nd = reinterpret_cast(arrays[i]); + nnvm::Symbol* s = reinterpret_cast(variables[i]); + NDArray* nd = reinterpret_cast(arrays[i]); CHECK_EQ(s->outputs.size(), 1) << "MXNDArraySetDeferredComputeVariable expects variables as input. " - << "Instead got a Symbol with " << s->outputs.size() - << " outputs as input " << i; + << "Instead got a Symbol with " << s->outputs.size() << " outputs as input " << i; CHECK(s->outputs[0].node->is_variable()) << "MXNDArraySetDeferredComputeVariable expects variables as input. " << "Instead got a Symbol associated with an operator as input " << i; @@ -373,22 +381,22 @@ void Imperative::SetDeferredComputeVariable(NDArrayHandle *arrays, // Store variables in DCInfo of arrays for (int i = 0; i < num; i++) { - nnvm::Symbol *s = reinterpret_cast(variables[i]); - NDArray *nd = reinterpret_cast(arrays[i]); + nnvm::Symbol* s = reinterpret_cast(variables[i]); + NDArray* nd = reinterpret_cast(arrays[i]); nd->deferredcompute_entry_ = nnvm::NodeEntry{s->outputs[0].node, 0, 0}; - std::vector inputs; - std::vector outputs; // No need to specify outputs, as we will set is_computed_ + std::vector inputs; + std::vector outputs; // No need to specify outputs, as we will set is_computed_ Imperative::DCInfo& info = Imperative::DCInfo::Create(s->outputs[0].node, inputs, outputs); - info.is_computed_ = true; + info.is_computed_ = true; } } -void Imperative::DeferredComputeClear(NDArrayHandle *arrays, const int num) { +void Imperative::DeferredComputeClear(NDArrayHandle* arrays, const int num) { std::vector outputs; outputs.reserve(num); for (int i = 0; i < num; i++) { - NDArray *nd = reinterpret_cast(arrays[i]); + NDArray* nd = reinterpret_cast(arrays[i]); outputs.emplace_back(nd->deferredcompute_entry_); } nnvm::DFSVisit(outputs, [&](const nnvm::ObjectPtr& n) { @@ -402,12 +410,12 @@ void Imperative::DeferredComputeClear(NDArrayHandle *arrays, const int num) { }); } -std::vector Imperative::Backward( - const std::vector& outputs, - const std::vector& ograds, - const std::vector& variables, - bool is_train, bool retain_graph, - bool create_graph) { +std::vector Imperative::Backward(const std::vector& outputs, + const std::vector& ograds, + const std::vector& variables, + bool is_train, + bool retain_graph, + bool create_graph) { using namespace nnvm; using namespace imperative; static const std::vector zero_ops{Op::Get("zeros_like"), Op::Get("_zeros")}; @@ -418,10 +426,10 @@ std::vector Imperative::Backward( graph.outputs.reserve(outputs.size()); for (const auto& i : outputs) { CHECK(!AGInfo::IsNone(*i)) - << "Cannot differentiate node because it is not in a computational graph. " - << "You need to set is_recording to true or use autograd.record() to save " - << "computational graphs for backward. If you want to differentiate the same " - << "graph twice, you need to pass retain_graph=True to backward."; + << "Cannot differentiate node because it is not in a computational graph. " + << "You need to set is_recording to true or use autograd.record() to save " + << "computational graphs for backward. If you want to differentiate the same " + << "graph twice, you need to pass retain_graph=True to backward."; graph.outputs.emplace_back(i->autograd_entry_); } size_t num_forward_outputs = graph.outputs.size(); @@ -431,15 +439,14 @@ std::vector Imperative::Backward( ograd_entries.reserve(ograds.size()); for (size_t i = 0; i < outputs.size(); ++i) { nnvm::ObjectPtr np = Node::Create(); - np->attrs.name = "_head_grad_" + std::to_string(i); + np->attrs.name = "_head_grad_" + std::to_string(i); ograd_entries.emplace_back(NodeEntry{np, 0, 0}); AGInfo& info = AGInfo::Create(ograd_entries.back().node); - info.ctx = outputs[i]->ctx(); + info.ctx = outputs[i]->ctx(); if (ograds[i] != nullptr) { info.outputs.emplace_back(*ograds[i]); } else { - info.outputs.emplace_back(outputs[i]->shape(), outputs[i]->ctx(), - true, outputs[i]->dtype()); + info.outputs.emplace_back(outputs[i]->shape(), outputs[i]->ctx(), true, outputs[i]->dtype()); if (info.outputs.back().shape().Size() != 0) { info.outputs.back() = static_cast(1.0); } @@ -459,7 +466,7 @@ std::vector Imperative::Backward( for (size_t i = 0; i < variables.size(); ++i) { CHECK(!AGInfo::IsNone(*variables[i]) && AGInfo::IsVariable(variables[i]->autograd_entry_.node)) - << "Cannot differentiate with respect to the " << i+1 << "-th variable" + << "Cannot differentiate with respect to the " << i + 1 << "-th variable" << " because it does not require gradient."; xs.emplace_back(variables[i]->autograd_entry_); x_grads.push_back(new NDArray()); @@ -472,24 +479,28 @@ std::vector Imperative::Backward( x_reqs.reserve(args.size()); for (const auto& i : args) { AGInfo& info = AGInfo::Get(i); - if (info.grad_req == kNullOp) continue; + if (info.grad_req == kNullOp) + continue; xs.emplace_back(NodeEntry{i, 0, 0}); x_grads.push_back(&info.out_grads[0]); x_reqs.push_back(info.grad_req); info.fresh_out_grad = true; } - CHECK_GT(xs.size(), 0) - << "There are no inputs in computation graph that require gradients."; + CHECK_GT(xs.size(), 0) << "There are no inputs in computation graph that require gradients."; } - Graph g_graph = pass::MXGradient( - graph, graph.outputs, xs, ograd_entries, - mxnet::AggregateGradient, nullptr, - zero_ops, "_copy"); + Graph g_graph = pass::MXGradient(graph, + graph.outputs, + xs, + ograd_entries, + mxnet::AggregateGradient, + nullptr, + zero_ops, + "_copy"); CHECK_EQ(g_graph.outputs.size(), xs.size()); for (const auto& e : g_graph.outputs) { if (e.node->op() == nullptr) { - auto node = Node::Create(); + auto node = Node::Create(); node->attrs.op = copy_op; node->inputs.push_back(e); graph.outputs.emplace_back(std::move(node)); @@ -499,13 +510,13 @@ std::vector Imperative::Backward( } const auto& idx = graph.indexed_graph(); // get number of nodes used in forward pass - size_t num_forward_nodes = 0; + size_t num_forward_nodes = 0; size_t num_forward_entries = 0; for (size_t i = 0; i < num_forward_outputs; ++i) { - num_forward_nodes = std::max( - num_forward_nodes, static_cast(idx.outputs()[i].node_id + 1)); - num_forward_entries = std::max( - num_forward_entries, static_cast(idx.entry_id(idx.outputs()[i])) + 1); + num_forward_nodes = + std::max(num_forward_nodes, static_cast(idx.outputs()[i].node_id + 1)); + num_forward_entries = + std::max(num_forward_entries, static_cast(idx.entry_id(idx.outputs()[i])) + 1); } // Allocate buffer @@ -520,22 +531,23 @@ std::vector Imperative::Backward( if (create_graph) { states.resize(num_forward_nodes); nnvm::DFSVisit(sym.outputs, [&](const nnvm::ObjectPtr& n) { - AGInfo& info = AGInfo::Get(n); + AGInfo& info = AGInfo::Get(n); states[idx.node_id(n.get())] = info.state; for (uint32_t i = 0; i < info.outputs.size(); ++i) { CHECK(idx.exist(n.get())); - size_t nid = idx.node_id(n.get()); - size_t eid = idx.entry_id(nid, i); - buff[eid] = info.outputs[i]; + size_t nid = idx.node_id(n.get()); + size_t eid = idx.entry_id(nid, i); + buff[eid] = info.outputs[i]; buff[eid].autograd_entry_ = NodeEntry{n, i, 0}; - ref_count[eid] = 1; + ref_count[eid] = 1; } }); for (auto& ograd_entry : ograd_entries) { AGInfo& info = AGInfo::Get(ograd_entry.node); - if (!idx.exist(ograd_entry.node.get())) continue; - size_t eid = idx.entry_id(ograd_entry); - buff[eid] = info.outputs[0]; + if (!idx.exist(ograd_entry.node.get())) + continue; + size_t eid = idx.entry_id(ograd_entry); + buff[eid] = info.outputs[0]; buff[eid].autograd_entry_ = ograd_entry; } } else { @@ -544,21 +556,23 @@ std::vector Imperative::Backward( const AGInfo& info = dmlc::get(idx[i].source->info); states.emplace_back(info.state); for (size_t j = 0; j < info.outputs.size(); ++j) { - size_t eid = idx.entry_id(i, j); + size_t eid = idx.entry_id(i, j); arrays[eid] = const_cast(&(info.outputs[j])); - if (retain_graph || info.grad_req != kNullOp) ref_count[eid] = 1; + if (retain_graph || info.grad_req != kNullOp) + ref_count[eid] = 1; } } for (auto& ograd_entry : ograd_entries) { - if (!idx.exist(ograd_entry.node.get())) continue; - AGInfo& info = AGInfo::Get(ograd_entry.node); + if (!idx.exist(ograd_entry.node.get())) + continue; + AGInfo& info = AGInfo::Get(ograd_entry.node); arrays[idx.entry_id(ograd_entry)] = &info.outputs[0]; } } for (size_t i = num_forward_outputs; i < graph.outputs.size(); ++i) { - size_t eid = idx.entry_id(graph.outputs[i]); - arrays[eid] = x_grads[i - num_forward_outputs]; + size_t eid = idx.entry_id(graph.outputs[i]); + arrays[eid] = x_grads[i - num_forward_outputs]; ref_count[eid] = 1; } @@ -568,52 +582,55 @@ std::vector Imperative::Backward( // Infer shape type { std::pair node_range, entry_range; - node_range = {num_forward_nodes, idx.num_nodes()}; + node_range = {num_forward_nodes, idx.num_nodes()}; entry_range = {num_forward_entries, idx.num_node_entries()}; ShapeVector shapes; shapes.reserve(idx.num_node_entries()); bool contain_unknown = false; - for (const auto& i : arrays) shapes.emplace_back(i->shape()); - CheckAndInferShape(&graph, std::move(shapes), false, - node_range, entry_range, &contain_unknown); + for (const auto& i : arrays) + shapes.emplace_back(i->shape()); + CheckAndInferShape(&graph, std::move(shapes), false, node_range, entry_range, &contain_unknown); DTypeVector dtypes; dtypes.reserve(idx.num_node_entries()); - for (const auto& i : arrays) dtypes.emplace_back(i->dtype()); - CheckAndInferType(&graph, std::move(dtypes), false, - node_range, entry_range); + for (const auto& i : arrays) + dtypes.emplace_back(i->dtype()); + CheckAndInferType(&graph, std::move(dtypes), false, node_range, entry_range); StorageTypeVector stypes; stypes.reserve(idx.num_node_entries()); - for (const auto& i : arrays) stypes.emplace_back(i->storage_type()); + for (const auto& i : arrays) + stypes.emplace_back(i->storage_type()); exec::DevMaskVector dev_mask; dev_mask.reserve(idx.num_nodes()); - for (const auto& i : vctx) dev_mask.emplace_back(i.dev_mask()); - CheckAndInferStorageType(&graph, std::move(dev_mask), std::move(stypes), false, - node_range, entry_range); + for (const auto& i : vctx) + dev_mask.emplace_back(i.dev_mask()); + CheckAndInferStorageType( + &graph, std::move(dev_mask), std::move(stypes), false, node_range, entry_range); } // Calculate ref count for (size_t i = num_forward_nodes; i < idx.num_nodes(); ++i) { for (const auto& j : idx[i].inputs) { - ++ref_count[idx.entry_id(j)]; + ++ref_count[idx.entry_id(j)]; } } // Assign reqs std::vector array_reqs(arrays.size(), kWriteTo); for (size_t i = num_forward_entries; i < idx.num_node_entries(); ++i) { - if (ref_count[i] == 0) array_reqs[i] = kNullOp; + if (ref_count[i] == 0) + array_reqs[i] = kNullOp; } for (size_t i = num_forward_outputs; i < idx.outputs().size(); ++i) { - size_t eid = idx.entry_id(idx.outputs()[i]); + size_t eid = idx.entry_id(idx.outputs()[i]); array_reqs[eid] = x_reqs[i - num_forward_outputs]; } - const auto& shapes = graph.GetAttr("shape"); - const auto& dtypes = graph.GetAttr("dtype"); - const auto& stypes = graph.GetAttr("storage_type"); + const auto& shapes = graph.GetAttr("shape"); + const auto& dtypes = graph.GetAttr("dtype"); + const auto& stypes = graph.GetAttr("storage_type"); const auto& dispatch_modes = graph.GetAttr("dispatch_mode"); for (size_t i = num_forward_nodes; i < idx.num_nodes(); ++i) { @@ -621,18 +638,16 @@ std::vector Imperative::Backward( for (size_t j = 0; j < num_outputs; ++j) { auto eid = idx.entry_id(i, j); if (arrays[eid]->is_none()) - arrays[eid]->ReInit(static_cast(stypes[eid]), - shapes[eid], vctx[i], dtypes[eid]); + arrays[eid]->ReInit( + static_cast(stypes[eid]), shapes[eid], vctx[i], dtypes[eid]); } } - for (size_t nid = num_forward_nodes; - nid < idx.num_nodes(); ++nid) { + for (size_t nid = num_forward_nodes; nid < idx.num_nodes(); ++nid) { const nnvm::NodeAttrs& attrs = idx[nid].source->attrs; for (size_t oid = 0; oid < idx[nid].source->num_outputs(); ++oid) { size_t eid = idx.entry_id(nid, oid); - arrays[eid]->AssignStorageInfo(common::NodeAttrsGetProfilerScope(attrs), - attrs.name); + arrays[eid]->AssignStorageInfo(common::NodeAttrsGetProfilerScope(attrs), attrs.name); } } // for (nid ∈ [num_forward_nodes, idx.num_nodes())) @@ -643,13 +658,20 @@ std::vector Imperative::Backward( // Execution bool prev_recording = set_is_recording(create_graph); - bool prev_training = set_is_training(is_train); - int prev_bulk_size = Engine::Get()->set_bulk_size(backward_bulk_size_); + bool prev_training = set_is_training(is_train); + int prev_bulk_size = Engine::Get()->set_bulk_size(backward_bulk_size_); try { - RunGraph(retain_graph, idx, arrays, num_forward_nodes, idx.num_nodes(), - std::move(array_reqs), std::move(ref_count), &states, dispatch_modes, - is_recording()); + RunGraph(retain_graph, + idx, + arrays, + num_forward_nodes, + idx.num_nodes(), + std::move(array_reqs), + std::move(ref_count), + &states, + dispatch_modes, + is_recording()); } catch (const dmlc::Error& e) { Engine::Get()->set_bulk_size(prev_bulk_size); set_is_recording(prev_recording); @@ -675,32 +697,31 @@ std::vector Imperative::Backward( return {}; } -Imperative::DCInfo::DCInfo(const std::vector &inputs, - const std::vector &outputs) { +Imperative::DCInfo::DCInfo(const std::vector& inputs, + const std::vector& outputs) { this->inputs_.reserve(inputs.size()); this->input_handles_.reserve(inputs.size()); - for (const NDArray *arr : inputs) { + for (const NDArray* arr : inputs) { CHECK(!arr->is_none()); this->inputs_.push_back(*arr); this->input_handles_.push_back(arr); } this->outputs_.reserve(outputs.size()); - for (const NDArray *arr : outputs) { + for (const NDArray* arr : outputs) { CHECK(!arr->is_none()); this->outputs_.push_back(*arr); } } -Imperative::DCInfo & -Imperative::DCInfo::Create(const nnvm::ObjectPtr &node, - const std::vector &inputs, - const std::vector &outputs) { +Imperative::DCInfo& Imperative::DCInfo::Create(const nnvm::ObjectPtr& node, + const std::vector& inputs, + const std::vector& outputs) { node->info.construct(inputs, outputs); return Imperative::DCInfo::Get(node); } -void Imperative::DCInfo::Compute(const NDArray &arr) { +void Imperative::DCInfo::Compute(const NDArray& arr) { if (Imperative::DCInfo::IsComputed(arr)) { if (!shape_is_known(arr.shape())) { // We can't call arr.WaitToRead(); here, as WaitToRead calls Compute @@ -715,30 +736,29 @@ void Imperative::DCInfo::Compute(const NDArray &arr) { return; } - DCInfo &info = Imperative::DCInfo::Get(arr.deferredcompute_entry_.node); + DCInfo& info = Imperative::DCInfo::Get(arr.deferredcompute_entry_.node); info.is_computed_ = true; // We will Invoke at the end of this function. // Recursively compute input arrays - for (const NDArray &input : info.inputs_) { + for (const NDArray& input : info.inputs_) { Compute(input); } // Prepare pointers - std::vector ndinputs, ndoutputs; + std::vector ndinputs, ndoutputs; ndinputs.reserve(info.inputs_.size()); ndoutputs.reserve(info.outputs_.size()); - for (NDArray &input : info.inputs_) + for (NDArray& input : info.inputs_) ndinputs.push_back(&input); - for (NDArray &output : info.outputs_) + for (NDArray& output : info.outputs_) ndoutputs.push_back(&output); // Compute this array - Imperative::Get()->Invoke(Context::CPU(), - arr.deferredcompute_entry_.node->attrs, ndinputs, - ndoutputs); + Imperative::Get()->Invoke( + Context::CPU(), arr.deferredcompute_entry_.node->attrs, ndinputs, ndoutputs); if (!shape_is_known(arr.shape())) { - arr.WaitToRead(); - arr.SetShapeFromChunk(); + arr.WaitToRead(); + arr.SetShapeFromChunk(); } // Deallocate copies diff --git a/src/imperative/imperative_utils.cc b/src/imperative/imperative_utils.cc index 9d15084003de..e3a58804d8ac 100644 --- a/src/imperative/imperative_utils.cc +++ b/src/imperative/imperative_utils.cc @@ -27,7 +27,7 @@ std::vector NodeInputs(const nnvm::IndexedGraph& idx, const int node_idx, const std::vector& arrays) { const nnvm::IndexedGraph::Node& node = idx[node_idx]; - const size_t num_inputs = node.inputs.size(); + const size_t num_inputs = node.inputs.size(); std::vector ndinputs; ndinputs.reserve(num_inputs); for (const auto& j : node.inputs) { @@ -41,7 +41,7 @@ std::vector NodeOutputs(const nnvm::IndexedGraph& idx, const int node_idx, const std::vector& arrays) { const nnvm::IndexedGraph::Node& node = idx[node_idx]; - const size_t num_outputs = node.source->num_outputs(); + const size_t num_outputs = node.source->num_outputs(); std::vector ndoutputs; ndoutputs.reserve(num_outputs); for (size_t j = 0; j < num_outputs; ++j) { @@ -55,7 +55,7 @@ std::vector NodeReq(const nnvm::IndexedGraph& idx, const int node_idx, const std::vector& array_reqs) { const nnvm::IndexedGraph::Node& node = idx[node_idx]; - const size_t num_outputs = node.source->num_outputs(); + const size_t num_outputs = node.source->num_outputs(); std::vector req; req.reserve(num_outputs); for (size_t j = 0; j < num_outputs; ++j) { @@ -73,21 +73,21 @@ void InvokeOperator(const nnvm::IndexedGraph& idx, std::vector* p_states, const std::vector& ndinputs, const std::vector& ndoutputs, - std::vector *p_req, - std::vector *p_ref_count, - std::function invoke) { - static const auto bwd_cached_op = Op::Get("_backward_CachedOp"); - static auto& createop = nnvm::Op::GetAttr("FCreateOpState"); - static auto& is_layer_backward = Op::GetAttr("TIsLayerOpBackward"); - std::vector& states = *p_states; - std::vector &req = *p_req; - std::vector &ref_count = *p_ref_count; + std::vector* p_req, + std::vector* p_ref_count, + std::function invoke) { + static const auto bwd_cached_op = Op::Get("_backward_CachedOp"); + static auto& createop = nnvm::Op::GetAttr("FCreateOpState"); + static auto& is_layer_backward = Op::GetAttr("TIsLayerOpBackward"); + std::vector& states = *p_states; + std::vector& req = *p_req; + std::vector& ref_count = *p_ref_count; const nnvm::IndexedGraph::Node& node = idx[node_idx]; if (node.source->op() == bwd_cached_op && node.source->attrs.name == "_cachedop_backward") { const auto& cached_op = dmlc::get(node.source->attrs.parsed); - nnvm::Node* fwd_node = node.source->control_deps[0].get(); - auto fwd_node_id = idx.node_id(fwd_node); + nnvm::Node* fwd_node = node.source->control_deps[0].get(); + auto fwd_node_id = idx.node_id(fwd_node); cached_op->Backward(retain_graph, states[fwd_node_id], ndinputs, req, ndoutputs); } else if (createop.count(node.source->op())) { mxnet::ShapeVector arg_shapes; @@ -102,7 +102,7 @@ void InvokeOperator(const nnvm::IndexedGraph& idx, invoke(states[node_idx]); } else if (is_layer_backward.get(node.source->op(), false)) { nnvm::Node* fwd_node = node.source->control_deps[0].get(); - auto fwd_node_id = idx.node_id(fwd_node); + auto fwd_node_id = idx.node_id(fwd_node); invoke(states[fwd_node_id]); } else { invoke(OpStatePtr()); @@ -126,108 +126,107 @@ void InvokeOperator(const nnvm::IndexedGraph& idx, namespace mxnet { namespace imperative { -void RunGraph( - const bool retain_graph, - const nnvm::IndexedGraph& idx, - const std::vector& arrays, - size_t node_start, size_t node_end, - std::vector&& array_reqs, - std::vector&& ref_count, - std::vector *p_states, - const DispatchModeVector &dispatch_modes, - bool recording, - mxnet::ShapeVector *shapes, - const imperative::CachedOpMonCallback& callback, - const bool monitor_all) { +void RunGraph(const bool retain_graph, + const nnvm::IndexedGraph& idx, + const std::vector& arrays, + size_t node_start, + size_t node_end, + std::vector&& array_reqs, + std::vector&& ref_count, + std::vector* p_states, + const DispatchModeVector& dispatch_modes, + bool recording, + mxnet::ShapeVector* shapes, + const imperative::CachedOpMonCallback& callback, + const bool monitor_all) { CHECK(shapes == nullptr); for (size_t i = node_start; i < node_end; ++i) { const nnvm::IndexedGraph::Node& node = idx[i]; if (node.source->op() == nullptr) { continue; } - std::vector ndinputs = NodeInputs(idx, i, arrays); + std::vector ndinputs = NodeInputs(idx, i, arrays); std::vector ndoutputs = NodeOutputs(idx, i, arrays); - std::vector req = NodeReq(idx, i, array_reqs); - Context ctx = ndoutputs[0]->ctx(); + std::vector req = NodeReq(idx, i, array_reqs); + Context ctx = ndoutputs[0]->ctx(); if (callback && monitor_all) { - mxnet::common::ExecuteMonInputCallback(idx, arrays, i, callback); + mxnet::common::ExecuteMonInputCallback(idx, arrays, i, callback); } - auto invoke = [&](const OpStatePtr &state) { + auto invoke = [&](const OpStatePtr& state) { const nnvm::IndexedGraph::Node& node = idx[i]; - DispatchMode dispatch_mode = dispatch_modes[i]; - Imperative::Get()->InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs, - req, dispatch_mode, state); + DispatchMode dispatch_mode = dispatch_modes[i]; + Imperative::Get()->InvokeOp( + ctx, node.source->attrs, ndinputs, ndoutputs, req, dispatch_mode, state); if (recording) { Imperative::Get()->RecordOp(NodeAttrs(node.source->attrs), ndinputs, ndoutputs, state); } }; - InvokeOperator(idx, i, retain_graph, arrays, ctx, p_states, ndinputs, ndoutputs, - &req, &ref_count, invoke); + InvokeOperator( + idx, i, retain_graph, arrays, ctx, p_states, ndinputs, ndoutputs, &req, &ref_count, invoke); if (callback) { - mxnet::common::ExecuteMonOutputCallback(idx, arrays, i, callback); + mxnet::common::ExecuteMonOutputCallback(idx, arrays, i, callback); } } } -void NaiveRunGraph( - const bool retain_graph, - const Context& default_ctx, - const nnvm::IndexedGraph& idx, - const std::vector& arrays, - size_t node_start, size_t node_end, - std::vector&& array_reqs, - std::vector&& ref_count, - std::vector *p_states, - const DispatchModeVector &dispatch_modes, - bool recording, - mxnet::ShapeVector *shapes, - const imperative::CachedOpMonCallback& callback, - const bool monitor_all, - const bool skip_engine) { +void NaiveRunGraph(const bool retain_graph, + const Context& default_ctx, + const nnvm::IndexedGraph& idx, + const std::vector& arrays, + size_t node_start, + size_t node_end, + std::vector&& array_reqs, + std::vector&& ref_count, + std::vector* p_states, + const DispatchModeVector& dispatch_modes, + bool recording, + mxnet::ShapeVector* shapes, + const imperative::CachedOpMonCallback& callback, + const bool monitor_all, + const bool skip_engine) { for (size_t i = node_start; i < node_end; ++i) { const nnvm::IndexedGraph::Node& node = idx[i]; if (node.source->op() == nullptr) { continue; } - std::vector ndinputs = NodeInputs(idx, i, arrays); + std::vector ndinputs = NodeInputs(idx, i, arrays); std::vector ndoutputs = NodeOutputs(idx, i, arrays); std::vector req; Context ctx = GetContext(node.source->attrs, ndinputs, ndoutputs, default_ctx); if (callback && monitor_all) { - mxnet::common::ExecuteMonInputCallback(idx, arrays, i, callback); + mxnet::common::ExecuteMonInputCallback(idx, arrays, i, callback); } - auto invoke = [&](const OpStatePtr &state) { + auto invoke = [&](const OpStatePtr& state) { const nnvm::IndexedGraph::Node& node = idx[i]; - DispatchMode dispatch_mode = DispatchMode::kUndefined; + DispatchMode dispatch_mode = DispatchMode::kUndefined; SetShapeType(ctx, node.source->attrs, ndinputs, ndoutputs, &dispatch_mode); SetWriteInplaceReq(ndinputs, ndoutputs, &req); if (skip_engine) { auto new_attr = node.source->attrs; CHECK(new_attr.dict.find(SKIP_ENGINE) == new_attr.dict.end()); new_attr.dict[SKIP_ENGINE] = SKIP_ENGINE_SET; - Imperative::Get()->InvokeOp(ctx, new_attr, ndinputs, ndoutputs, - req, dispatch_mode, state); + Imperative::Get()->InvokeOp(ctx, new_attr, ndinputs, ndoutputs, req, dispatch_mode, state); } else { - Imperative::Get()->InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs, - req, dispatch_mode, state); + Imperative::Get()->InvokeOp( + ctx, node.source->attrs, ndinputs, ndoutputs, req, dispatch_mode, state); } for (size_t j = 0; j < ndoutputs.size(); ++j) { if (mxnet::op::shape_is_none(ndoutputs[j]->shape())) { ndoutputs[j]->WaitToRead(); ndoutputs[j]->SetShapeFromChunk(); } - size_t eid = idx.entry_id(i, j); - auto shape = ndoutputs[j]->shape(); + size_t eid = idx.entry_id(i, j); + auto shape = ndoutputs[j]->shape(); (*shapes)[eid] = shape; } if (recording) { Imperative::Get()->RecordOp(NodeAttrs(node.source->attrs), ndinputs, ndoutputs, state); } }; - InvokeOperator(idx, i, retain_graph, arrays, ctx, p_states, ndinputs, ndoutputs, - &req, &ref_count, invoke); + InvokeOperator( + idx, i, retain_graph, arrays, ctx, p_states, ndinputs, ndoutputs, &req, &ref_count, invoke); if (callback) { - mxnet::common::ExecuteMonOutputCallback(idx, arrays, i, callback); + mxnet::common::ExecuteMonOutputCallback(idx, arrays, i, callback); } } } diff --git a/src/imperative/imperative_utils.h b/src/imperative/imperative_utils.h index f53b0db91b2b..31c7333e08ba 100644 --- a/src/imperative/imperative_utils.h +++ b/src/imperative/imperative_utils.h @@ -37,12 +37,17 @@ namespace mxnet { #if MXNET_USE_ONEDNN == 1 -templateT *pntr(T &obj) { return &obj; } // NOLINT -templateT *pntr(T *obj) { return obj; } +template +T* pntr(T& obj) { // NOLINT + return &obj; +} +template +T* pntr(T* obj) { + return obj; +} -template -void InvalidateOutputs(const std::vector *pArrs, - const std::vector &reqs) { +template +void InvalidateOutputs(const std::vector* pArrs, const std::vector& reqs) { auto arrs = *pArrs; for (size_t i = 0; i < arrs.size(); i++) { if (reqs[i] == kWriteTo || reqs[i] == kNullOp) @@ -51,8 +56,8 @@ void InvalidateOutputs(const std::vector *pArrs, } // TODO(alexzai): (MXNET-856) Remove helper function after subgraph feature added -static inline void CreateDefaultInputs(const std::vector &arrs, - std::vector *out_arrs) { +static inline void CreateDefaultInputs(const std::vector& arrs, + std::vector* out_arrs) { out_arrs->clear(); for (size_t i = 0; i < arrs.size(); ++i) { if (arrs[i].IsMKLDNNData()) @@ -63,8 +68,8 @@ static inline void CreateDefaultInputs(const std::vector &arrs, } // TODO(alexzai): (MXNET-856) Remove helper function after subgraph feature added -static inline void CreateDefaultInputs(std::vector *pArrs) { - auto &&arrs = *pArrs; +static inline void CreateDefaultInputs(std::vector* pArrs) { + auto&& arrs = *pArrs; for (size_t i = 0; i < arrs.size(); ++i) arrs[i].SelfReorder2Default(); } @@ -75,14 +80,18 @@ static inline void CreateDefaultInputs(std::vector *pArrs) { // So for the case that A is holding mkldnn memory, and then copy A to B, and then copy B // back to A, we shouldn't invalidate outputs for copying B back to A, because at this time, // copying A to B may not happen, and will corrupt A's memory. -#define INVALIDATE_OUTPUTS_COND(cond, outputs, req) if (cond) INVALIDATE_OUTPUTS(outputs, req) +#define INVALIDATE_OUTPUTS_COND(cond, outputs, req) \ + if (cond) { \ + INVALIDATE_OUTPUTS(outputs, req); \ + } // add for mkldnn OP + no mkldnn OP -#define CREATE_DEFAULT_INPUTS(cond, attrs, func_call) \ - if (cond) { \ - const auto is_mkldnn = Op::GetAttr("TIsMKLDNN"); \ - if (!is_mkldnn.get(attrs.op, false)) func_call; \ - } +#define CREATE_DEFAULT_INPUTS(cond, attrs, func_call) \ + if (cond) { \ + const auto is_mkldnn = Op::GetAttr("TIsMKLDNN"); \ + if (!is_mkldnn.get(attrs.op, false)) \ + func_call; \ + } #else #define INVALIDATE_OUTPUTS(outputs, ...) // empty macros @@ -93,15 +102,16 @@ static inline void CreateDefaultInputs(std::vector *pArrs) { namespace imperative { namespace { - static const char SKIP_ENGINE[] = "__skip_engine__"; - static const char SKIP_ENGINE_SET[] = "__true__"; - - inline bool CheckIfSkipEngine(const nnvm::NodeAttrs& attrs) { - const auto& skip_engine_attr = attrs.dict.find(SKIP_ENGINE); - if (skip_engine_attr == attrs.dict.end()) return false; - return (*skip_engine_attr).second == SKIP_ENGINE_SET; - } +static const char SKIP_ENGINE[] = "__skip_engine__"; +static const char SKIP_ENGINE_SET[] = "__true__"; + +inline bool CheckIfSkipEngine(const nnvm::NodeAttrs& attrs) { + const auto& skip_engine_attr = attrs.dict.find(SKIP_ENGINE); + if (skip_engine_attr == attrs.dict.end()) + return false; + return (*skip_engine_attr).second == SKIP_ENGINE_SET; } +} // namespace struct MemoryPlanInfo { int storage_id; @@ -122,23 +132,21 @@ struct EngineOprSeg { std::unique_ptr opr; }; -using MemoryPlanVector = std::vector; +using MemoryPlanVector = std::vector; using CachedOpMonCallback = std::function; inline Context GetContext(const nnvm::NodeAttrs& attrs, - const std::vector& inputs, - const std::vector& outputs, - const Context& default_ctx) { + const std::vector& inputs, + const std::vector& outputs, + const Context& default_ctx) { Context ctx; if (inputs.size()) { ctx = inputs[0]->ctx(); for (size_t i = 1; i < inputs.size(); ++i) { CHECK_EQ(inputs[i]->ctx().dev_mask(), ctx.dev_mask()) - << "Operator " << attrs.op->name - << " require all inputs live on the same context. " - << "But the first argument is on " - << ctx << " while the " << i+1 << "-th argument is on " - << inputs[i]->ctx(); + << "Operator " << attrs.op->name << " require all inputs live on the same context. " + << "But the first argument is on " << ctx << " while the " << i + 1 + << "-th argument is on " << inputs[i]->ctx(); } } else if (outputs.size() && !outputs[0]->is_none()) { ctx = outputs[0]->ctx(); @@ -171,12 +179,12 @@ inline void SetShapeType(const Context& ctx, const std::vector& inputs, const std::vector& outputs, DispatchMode* dispatch_mode) { - static auto& infershape = nnvm::Op::GetAttr("FInferShape"); - static auto& infertype = nnvm::Op::GetAttr("FInferType"); - static auto& inferstorage = nnvm::Op::GetAttr("FInferStorageType"); - MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get(); + static auto& infershape = nnvm::Op::GetAttr("FInferShape"); + static auto& infertype = nnvm::Op::GetAttr("FInferType"); + static auto& inferstorage = nnvm::Op::GetAttr("FInferStorageType"); + MXAPIThreadLocalEntry<>* ret = MXAPIThreadLocalStore<>::Get(); // infer shape - mxnet::ShapeVector& in_shapes = ret->arg_shapes; + mxnet::ShapeVector& in_shapes = ret->arg_shapes; in_shapes.clear(); in_shapes.reserve(inputs.size()); for (auto& i : inputs) { @@ -192,7 +200,7 @@ inline void SetShapeType(const Context& ctx, if (!is_dynamic_shape_existing) { // If any of the inputs is a deferred computed array with unknown shape, we // can't infer shapes. - for (const NDArray *i : inputs) { + for (const NDArray* i : inputs) { if (!shape_is_known(i->shape()) && !Imperative::DCInfo::IsNone(*i)) { is_dynamic_shape_existing = true; break; @@ -262,15 +270,16 @@ inline void SetShapeType(const Context& ctx, } bool infer_stype_success = false; if (inferstorage.count(attrs.op)) { - infer_stype_success = inferstorage[attrs.op](attrs, ctx.dev_mask(), dispatch_mode, - &in_storage_types, &out_storage_types); + infer_stype_success = inferstorage[attrs.op]( + attrs, ctx.dev_mask(), dispatch_mode, &in_storage_types, &out_storage_types); } else { // if infer storage attr is not present, apply the default infer storage function - infer_stype_success = common::DefaultStorageType(attrs, ctx.dev_mask(), dispatch_mode, - &in_storage_types, &out_storage_types); + infer_stype_success = common::DefaultStorageType( + attrs, ctx.dev_mask(), dispatch_mode, &in_storage_types, &out_storage_types); } CHECK(infer_stype_success) << "Operator not implemented: " - << common::operator_stype_string(attrs, ctx.dev_mask(), in_storage_types, out_storage_types); + << common::operator_stype_string( + attrs, ctx.dev_mask(), in_storage_types, out_storage_types); if (*dispatch_mode == DispatchMode::kFComputeFallback) { common::LogStorageFallback(attrs, ctx.dev_mask(), &in_storage_types, &out_storage_types); } @@ -279,12 +288,12 @@ inline void SetShapeType(const Context& ctx, CHECK(*dispatch_mode != DispatchMode::kUndefined); for (size_t i = 0; i < outputs.size(); ++i) { if (outputs[i]->is_none() || (mxnet::op::shape_is_none(outputs[i]->shape()) && - Imperative::DCInfo::IsNone(*outputs[i]))) { + Imperative::DCInfo::IsNone(*outputs[i]))) { if (!is_dynamic_shape_existing) { const auto storage_type = static_cast(out_storage_types[i]); outputs[i]->ReInit(storage_type, out_shapes[i], ctx, out_types[i]); } else { - *outputs[i] = NDArray(ctx, out_types[i]); + *outputs[i] = NDArray(ctx, out_types[i]); } outputs[i]->AssignStorageInfo(common::NodeAttrsGetProfilerScope(attrs), attrs.name); } else if (mxnet::op::shape_is_none(outputs[i]->shape())) { @@ -295,18 +304,18 @@ inline void SetShapeType(const Context& ctx, outputs[i]->Init(out_shapes[i]); } CHECK_EQ(outputs[i]->dtype(), out_types[i]) - << i << "-th output has invalid dtype. " - << "Expecting " << out_types[i] << " got " << outputs[i]->dtype() - << " in operator " << attrs.op->name; + << i << "-th output has invalid dtype. " + << "Expecting " << out_types[i] << " got " << outputs[i]->dtype() << " in operator " + << attrs.op->name; } else { CHECK_EQ(outputs[i]->shape(), out_shapes[i]) - << i << "-th output has invalid shape. " - << "Expecting " << out_shapes[i] << " got " - << outputs[i]->shape() << " in operator " << attrs.op->name; + << i << "-th output has invalid shape. " + << "Expecting " << out_shapes[i] << " got " << outputs[i]->shape() << " in operator " + << attrs.op->name; CHECK_EQ(outputs[i]->dtype(), out_types[i]) - << i << "-th output has invalid dtype. " - << "Expecting " << out_types[i] << " got " - << outputs[i]->dtype() << " in operator " << attrs.op->name; + << i << "-th output has invalid dtype. " + << "Expecting " << out_types[i] << " got " << outputs[i]->dtype() << " in operator " + << attrs.op->name; } } } @@ -316,53 +325,53 @@ inline void SetShapeType(const Context& ctx, * For inputs and outputs arguments only NDArray::var() is accessed. */ inline void SetDependency(const nnvm::NodeAttrs& attrs, - const Context& ctx, - const std::vector& inputs, - const std::vector& outputs, - std::vector *p_read_vars, - std::vector *p_write_vars, - std::vector *p_requested, - std::vector *p_mutate_idx, - const DispatchMode dispatch_mode) { - static auto& fmutate = nnvm::Op::GetAttr("FMutateInputs"); - static auto& ftmp_resource = nnvm::Op::GetAttr("FResourceRequest"); + const Context& ctx, + const std::vector& inputs, + const std::vector& outputs, + std::vector* p_read_vars, + std::vector* p_write_vars, + std::vector* p_requested, + std::vector* p_mutate_idx, + const DispatchMode dispatch_mode) { + static auto& fmutate = nnvm::Op::GetAttr("FMutateInputs"); + static auto& ftmp_resource = nnvm::Op::GetAttr("FResourceRequest"); static auto& ftmp_resource_ex = nnvm::Op::GetAttr("FResourceRequestEx"); std::vector& read_vars = *p_read_vars; std::vector& write_vars = *p_write_vars; - std::vector& requested = *p_requested; - std::vector& mutate_idx = *p_mutate_idx; + std::vector& requested = *p_requested; + std::vector& mutate_idx = *p_mutate_idx; if (fmutate.count(attrs.op)) { mutate_idx = fmutate[attrs.op](attrs); } - const bool rsc_req = (ftmp_resource.count(attrs.op) != 0); + const bool rsc_req = (ftmp_resource.count(attrs.op) != 0); const bool rsc_ex_req = (ftmp_resource_ex.count(attrs.op) != 0); if (rsc_req || rsc_ex_req) { - int ntmp = 0; - auto resource_reqs = rsc_ex_req ? ftmp_resource_ex[attrs.op](attrs, - static_cast(ctx.dev_mask()), dispatch_mode) + int ntmp = 0; + auto resource_reqs = rsc_ex_req ? ftmp_resource_ex[attrs.op]( + attrs, static_cast(ctx.dev_mask()), dispatch_mode) : ftmp_resource[attrs.op](attrs); for (const auto& req : resource_reqs) { switch (req.type) { - case ResourceRequest::kTempSpace: - ++ntmp; - case ResourceRequest::kRandom: - requested.push_back(ResourceManager::Get()->Request(ctx, req)); - write_vars.push_back(requested.back().var); - break; - case ResourceRequest::kParallelRandom: - requested.push_back(ResourceManager::Get()->Request(ctx, req)); - write_vars.push_back(requested.back().var); - break; + case ResourceRequest::kTempSpace: + ++ntmp; + case ResourceRequest::kRandom: + requested.push_back(ResourceManager::Get()->Request(ctx, req)); + write_vars.push_back(requested.back().var); + break; + case ResourceRequest::kParallelRandom: + requested.push_back(ResourceManager::Get()->Request(ctx, req)); + write_vars.push_back(requested.back().var); + break; #if MXNET_USE_CUDNN == 1 - case ResourceRequest::kCuDNNDropoutDesc: - requested.push_back(ResourceManager::Get()->Request(ctx, req)); - write_vars.push_back(requested.back().var); - break; + case ResourceRequest::kCuDNNDropoutDesc: + requested.push_back(ResourceManager::Get()->Request(ctx, req)); + write_vars.push_back(requested.back().var); + break; #endif // MXNET_USE_CUDNN == 1 - default: - LOG(FATAL) << "resource type not yet supported"; + default: + LOG(FATAL) << "resource type not yet supported"; } } CHECK_LE(ntmp, 1) << "Only support 1 temp space request"; @@ -382,7 +391,7 @@ inline void SetDependency(const nnvm::NodeAttrs& attrs, for (auto& i : outputs) { write_vars.push_back(i->var()); } - for (auto & i : mutate_idx) { + for (auto& i : mutate_idx) { write_vars.push_back(inputs[i]->var()); } Engine::Get()->DeduplicateVarHandle(&read_vars, &write_vars); @@ -394,11 +403,11 @@ inline void SetDependency(const nnvm::NodeAttrs& attrs, * NDArray. Set to kWriteTo otherwise. */ inline void SetWriteInplaceReq(const std::vector& inputs, - const std::vector& outputs, - std::vector *req) { + const std::vector& outputs, + std::vector* req) { std::unordered_set in_vars; in_vars.reserve(inputs.size()); - for (auto &i : inputs) { + for (auto& i : inputs) { in_vars.insert(i->var()); } req->clear(); @@ -420,16 +429,16 @@ inline void SetWriteInplaceReq(const std::vector& inputs, * \param param_vals Array of string pointers representing the associated values * \return nnvm::NodeAttrs structure representing the parsed attributes */ -inline nnvm::NodeAttrs ParseAttrs(const nnvm::Op *op, +inline nnvm::NodeAttrs ParseAttrs(const nnvm::Op* op, const int num_inputs, const int num_params, - const char **param_keys, - const char **param_vals) { + const char** param_keys, + const char** param_vals) { static auto& num_args = nnvm::Op::GetAttr("key_var_num_args"); nnvm::NodeAttrs attrs; attrs.op = op; - attrs.dict.reserve(num_params+1); + attrs.dict.reserve(num_params + 1); for (int i = 0; i < num_params; ++i) { attrs.dict.emplace(param_keys[i], param_vals[i]); } @@ -451,7 +460,7 @@ inline nnvm::NodeAttrs ParseAttrs(const nnvm::Op *op, * \param infered_num_outputs The inferred number of outputs * \param num_visible_outputs The actual number of visible outputs */ -inline void SetNumOutputs(const nnvm::Op *op, +inline void SetNumOutputs(const nnvm::Op* op, const nnvm::NodeAttrs& attrs, const int& num_inputs, int* infered_num_outputs, @@ -464,8 +473,8 @@ inline void SetNumOutputs(const nnvm::Op *op, infered_num_inputs = op->num_inputs; } CHECK_EQ(num_inputs, infered_num_inputs) - << "Operator " << op->name << " expects " << infered_num_inputs - << " inputs, but got " << num_inputs << " instead."; + << "Operator " << op->name << " expects " << infered_num_inputs << " inputs, but got " + << num_inputs << " instead."; if (op->get_num_outputs != nullptr) { *infered_num_outputs = op->get_num_outputs(attrs); } else { @@ -481,33 +490,38 @@ inline void SetNumOutputs(const nnvm::Op *op, /*! * \brief Copy-construct NDArrays referenced by inputs and outputs to p_inputs and p_outputs */ -inline void DerefInputOutput(const std::vector& inputs, - const std::vector& outputs, +inline void DerefInputOutput(const std::vector& inputs, + const std::vector& outputs, std::vector* p_inputs, std::vector* p_outputs) { p_inputs->reserve(inputs.size()); p_outputs->reserve(outputs.size()); - for (const auto i : inputs) p_inputs->emplace_back(*i); - for (const auto i : outputs) p_outputs->emplace_back(*i); + for (const auto i : inputs) + p_inputs->emplace_back(*i); + for (const auto i : outputs) + p_outputs->emplace_back(*i); } inline void DerefInputOutput(const std::vector& inputs, const std::vector& outputs, - std::vector* p_inputs, - std::vector* p_outputs) { + std::vector* p_inputs, + std::vector* p_outputs) { p_inputs->reserve(inputs.size()); p_outputs->reserve(outputs.size()); - for (const auto i : inputs) p_inputs->emplace_back(new NDArray(*i)); - for (const auto i : outputs) p_outputs->emplace_back(new NDArray(*i)); + for (const auto i : inputs) + p_inputs->emplace_back(new NDArray(*i)); + for (const auto i : outputs) + p_outputs->emplace_back(new NDArray(*i)); } -inline void DerefInputOutputRelease(const std::vector& inputs, - const std::vector& outputs) { - for (auto i : inputs) delete i; - for (auto i : outputs) delete i; +inline void DerefInputOutputRelease(const std::vector& inputs, + const std::vector& outputs) { + for (auto i : inputs) + delete i; + for (auto i : outputs) + delete i; } - /* * \brief setup default-storage tblobs from source NDArrays. If any source NDArray has non-default * storage, it creates a temp NDArray with default storage and uses the temp tblob. The @@ -521,19 +535,19 @@ inline void DerefInputOutputRelease(const std::vector& inputs, indices are not recorded * \return true if any source NDArray need to cast storage */ -inline bool SetupDefaultBlobsIn(const std::vector& src, - const std::vector *bufs, - std::vector *blobs, - std::vector *temp_src, - std::vector *temp_dst, - std::unordered_map *idx_map) { +inline bool SetupDefaultBlobsIn(const std::vector& src, + const std::vector* bufs, + std::vector* blobs, + std::vector* temp_src, + std::vector* temp_dst, + std::unordered_map* idx_map) { bool require_cast = false; for (size_t i = 0; i < src.size(); i++) { const auto& nd = *src[i]; if (!DEFAULT_DATA(nd)) { (*idx_map)[i] = temp_dst->size(); - NDArray temp = bufs != nullptr ? bufs->at(i) : NDArray(nd.shape(), nd.ctx(), - true, nd.dtype()); + NDArray temp = + bufs != nullptr ? bufs->at(i) : NDArray(nd.shape(), nd.ctx(), true, nd.dtype()); #if MXNET_USE_ONEDNN == 1 CHECK(temp.IsDefaultData()); #endif @@ -548,12 +562,12 @@ inline bool SetupDefaultBlobsIn(const std::vector& src, return require_cast; } -inline bool SetupDefaultBlobsOut(const std::vector& src, - const std::vector *bufs, - std::vector *req, - std::vector *blobs, - std::vector *temp_src, - std::vector *temp_dst) { +inline bool SetupDefaultBlobsOut(const std::vector& src, + const std::vector* bufs, + std::vector* req, + std::vector* blobs, + std::vector* temp_src, + std::vector* temp_dst) { bool require_cast = false; for (size_t i = 0; i < src.size(); i++) { const auto& nd = *src[i]; @@ -572,14 +586,14 @@ inline bool SetupDefaultBlobsOut(const std::vector& src, if (bufs != nullptr) { temp = bufs->at(i); } else if (kAddTo == req->at(i)) { - temp = nd.IsMKLDNNData()? nd.Reorder2Default() : nd; + temp = nd.IsMKLDNNData() ? nd.Reorder2Default() : nd; } else { temp = NDArray(nd.shape(), nd.ctx(), true, nd.dtype()); } CHECK(temp.IsDefaultData()); #else - NDArray temp = bufs != nullptr ? bufs->at(i) : NDArray(nd.shape(), nd.ctx(), - true, nd.dtype()); + NDArray temp = + bufs != nullptr ? bufs->at(i) : NDArray(nd.shape(), nd.ctx(), true, nd.dtype()); #endif temp_src->emplace_back(nd); temp_dst->emplace_back(temp); @@ -599,25 +613,23 @@ inline bool SetupDefaultBlobsOut(const std::vector& src, * function also records the indices of non-default source NDArrays and the indices of * their corresponding temporary NDArrays in the temp array. */ -inline void SetupDefaultBlobsInOut(const std::vector &ndinputs, - const std::vector &ndoutputs, - const std::vector *in_bufs, - const std::vector *out_bufs, - std::vector *req, - std::vector *input_blobs, - std::vector *output_blobs, - std::vector *pre_temp_src, - std::vector *pre_temp_dst, - std::vector *post_temp_src, - std::vector *post_temp_dst, - std::unordered_map *in_temp_idx_map, - const std::vector &mutate_idx) { +inline void SetupDefaultBlobsInOut(const std::vector& ndinputs, + const std::vector& ndoutputs, + const std::vector* in_bufs, + const std::vector* out_bufs, + std::vector* req, + std::vector* input_blobs, + std::vector* output_blobs, + std::vector* pre_temp_src, + std::vector* pre_temp_dst, + std::vector* post_temp_src, + std::vector* post_temp_dst, + std::unordered_map* in_temp_idx_map, + const std::vector& mutate_idx) { // populate input blobs - SetupDefaultBlobsIn(ndinputs, in_bufs, input_blobs, pre_temp_src, pre_temp_dst, - in_temp_idx_map); + SetupDefaultBlobsIn(ndinputs, in_bufs, input_blobs, pre_temp_src, pre_temp_dst, in_temp_idx_map); // populate output blobs - SetupDefaultBlobsOut(ndoutputs, out_bufs, req, output_blobs, post_temp_dst, - post_temp_src); + SetupDefaultBlobsOut(ndoutputs, out_bufs, req, output_blobs, post_temp_dst, post_temp_src); // add mutable inputs to post temp list for (const auto idx : mutate_idx) { auto map_iter = in_temp_idx_map->find(idx); @@ -628,30 +640,30 @@ inline void SetupDefaultBlobsInOut(const std::vector &ndinputs, } } -#define REDEFINE_INPUTS_OUTPUTS(in, out, newIn, newOut) \ - std::vector newIn, newOut; \ - DerefInputOutput(in, out, &newIn, &newOut); \ - DerefInputOutputRelease(in, out) +#define REDEFINE_INPUTS_OUTPUTS(in, out, newIn, newOut) \ + std::vector newIn, newOut; \ + DerefInputOutput(in, out, &newIn, &newOut); \ + DerefInputOutputRelease(in, out) inline void PushFCompute(const FCompute& fn, - const nnvm::Op* op, - const nnvm::NodeAttrs& attrs, - const Context& ctx, - const std::vector& read_vars, - const std::vector& write_vars, - const std::vector& requested, - const std::vector& p_inputs, - const std::vector& p_outputs, - const std::vector& mutate_idx, - const std::vector& req) { + const nnvm::Op* op, + const nnvm::NodeAttrs& attrs, + const Context& ctx, + const std::vector& read_vars, + const std::vector& write_vars, + const std::vector& requested, + const std::vector& p_inputs, + const std::vector& p_outputs, + const std::vector& mutate_idx, + const std::vector& req) { using namespace common; static auto& fexec_type = nnvm::Op::GetAttr("FExecType"); - bool is_train = Imperative::Get()->is_training(); - bool need_grad = Imperative::Get()->is_recording(); + bool is_train = Imperative::Get()->is_training(); + bool need_grad = Imperative::Get()->is_recording(); ExecType exec_type = fexec_type.count(op) ? fexec_type[op](attrs) : ExecType::kSync; CHECK(exec_type == ExecType::kSync); - std::vector inputs, outputs; + std::vector inputs, outputs; DerefInputOutput(p_inputs, p_outputs, &inputs, &outputs); const auto& run = [=](RunContext rctx) { std::vector input_blobs, output_blobs; @@ -662,9 +674,19 @@ inline void PushFCompute(const FCompute& fn, INVALIDATE_OUTPUTS_COND(exec_type != ExecType::kCrossDeviceCopy, outputs, req); std::vector tmp_req = req; // setup blobs - SetupDefaultBlobsInOut(inputs, outputs, nullptr, nullptr, &tmp_req, - &input_blobs, &output_blobs, &pre_temp_src, &pre_temp_dst, - &post_temp_src, &post_temp_dst, &in_temp_idx_map, mutate_idx); + SetupDefaultBlobsInOut(inputs, + outputs, + nullptr, + nullptr, + &tmp_req, + &input_blobs, + &output_blobs, + &pre_temp_src, + &pre_temp_dst, + &post_temp_src, + &post_temp_dst, + &in_temp_idx_map, + mutate_idx); // setup context OpContext opctx{need_grad, is_train, rctx, engine::CallbackOnComplete(), requested}; bool is_gpu = ctx.dev_mask() == gpu::kDevMask; @@ -683,82 +705,81 @@ inline void PushFCompute(const FCompute& fn, run(RunContext{ctx, nullptr, nullptr, false}); } else { Engine::Get()->PushSync( - run, ctx, read_vars, write_vars, FnProperty::kNormal, - 0, op->name.c_str()); + run, ctx, read_vars, write_vars, FnProperty::kNormal, 0, op->name.c_str()); } } inline void PushFComputeEx(const FComputeEx& fn, - const nnvm::Op* op, - const nnvm::NodeAttrs& attrs, - const Context& ctx, - const std::vector& read_vars, - const std::vector& write_vars, - const std::vector& requested, - const std::vector& p_inputs, - const std::vector& p_outputs, - const std::vector& req) { + const nnvm::Op* op, + const nnvm::NodeAttrs& attrs, + const Context& ctx, + const std::vector& read_vars, + const std::vector& write_vars, + const std::vector& requested, + const std::vector& p_inputs, + const std::vector& p_outputs, + const std::vector& req) { static auto& fexec_type = nnvm::Op::GetAttr("FExecType"); - const bool is_train = Imperative::Get()->is_training(); - const bool need_grad = Imperative::Get()->is_recording(); - const auto exec_type = fexec_type.count(op) ? fexec_type[op](attrs) : ExecType::kSync; + const bool is_train = Imperative::Get()->is_training(); + const bool need_grad = Imperative::Get()->is_recording(); + const auto exec_type = fexec_type.count(op) ? fexec_type[op](attrs) : ExecType::kSync; const auto cross_device_copy = exec_type == ExecType::kCrossDeviceCopy; - std::vector inputs, outputs; + std::vector inputs, outputs; DerefInputOutput(p_inputs, p_outputs, &inputs, &outputs); const auto& run = [=](RunContext rctx) { - OpContext opctx{need_grad, is_train, rctx, engine::CallbackOnComplete(), requested}; - REDEFINE_INPUTS_OUTPUTS(inputs, outputs, inputsA, outputsA); - INVALIDATE_OUTPUTS_COND(!cross_device_copy, outputsA, req); - CREATE_DEFAULT_INPUTS(!cross_device_copy, attrs, CreateDefaultInputs(&inputsA)); - fn(attrs, opctx, inputsA, req, outputsA); - if (ctx.dev_mask() == gpu::kDevMask && exec_type == ExecType::kSync && !rctx.is_bulk) { - rctx.get_stream()->Wait(); - } - }; + OpContext opctx{need_grad, is_train, rctx, engine::CallbackOnComplete(), requested}; + REDEFINE_INPUTS_OUTPUTS(inputs, outputs, inputsA, outputsA); + INVALIDATE_OUTPUTS_COND(!cross_device_copy, outputsA, req); + CREATE_DEFAULT_INPUTS(!cross_device_copy, attrs, CreateDefaultInputs(&inputsA)); + fn(attrs, opctx, inputsA, req, outputsA); + if (ctx.dev_mask() == gpu::kDevMask && exec_type == ExecType::kSync && !rctx.is_bulk) { + rctx.get_stream()->Wait(); + } + }; if (cross_device_copy || CheckIfSkipEngine(attrs)) { run(RunContext{ctx, nullptr, nullptr, false}); } else { CHECK(exec_type == ExecType::kSync); - Engine::Get()->PushSync(run, ctx, read_vars, write_vars, FnProperty::kNormal, - 0, op->name.c_str()); + Engine::Get()->PushSync( + run, ctx, read_vars, write_vars, FnProperty::kNormal, 0, op->name.c_str()); } } inline void PushOperator(const OpStatePtr& state, - const nnvm::Op* op, - const nnvm::NodeAttrs& attrs, - const Context& ctx, - const std::vector& read_vars, - const std::vector& write_vars, - const std::vector& requested, - const std::vector& p_inputs, - const std::vector& p_outputs, - const std::vector& mutate_idx, - const std::vector& req, - const DispatchMode dispatch_mode) { + const nnvm::Op* op, + const nnvm::NodeAttrs& attrs, + const Context& ctx, + const std::vector& read_vars, + const std::vector& write_vars, + const std::vector& requested, + const std::vector& p_inputs, + const std::vector& p_outputs, + const std::vector& mutate_idx, + const std::vector& req, + const DispatchMode dispatch_mode) { using namespace common; static auto& fexec_type = nnvm::Op::GetAttr("FExecType"); - bool is_train = Imperative::Get()->is_training(); - bool need_grad = Imperative::Get()->is_recording(); + bool is_train = Imperative::Get()->is_training(); + bool need_grad = Imperative::Get()->is_recording(); ExecType exec_type = fexec_type.count(op) ? fexec_type[op](attrs) : ExecType::kSync; - std::vector inputs, outputs; + std::vector inputs, outputs; DerefInputOutput(p_inputs, p_outputs, &inputs, &outputs); auto fcompute_ex = common::GetFCompute(op, "FStatefulComputeEx", ctx); if (fcompute_ex != nullptr && dispatch_mode == DispatchMode::kFComputeEx) { - const auto& run = [=](RunContext rctx, - engine::CallbackOnComplete on_complete) { + const auto& run = [=](RunContext rctx, engine::CallbackOnComplete on_complete) { OpContext opctx{need_grad, is_train, rctx, on_complete, requested}; REDEFINE_INPUTS_OUTPUTS(inputs, outputs, inputsA, outputsA); - INVALIDATE_OUTPUTS_COND(exec_type != ExecType::kCrossDeviceCopy && op->name != "_CachedOp", - outputsA, req); + INVALIDATE_OUTPUTS_COND( + exec_type != ExecType::kCrossDeviceCopy && op->name != "_CachedOp", outputsA, req); CREATE_DEFAULT_INPUTS(exec_type != ExecType::kCrossDeviceCopy && op->name != "_CachedOp", - attrs, CreateDefaultInputs(&inputsA)); + attrs, + CreateDefaultInputs(&inputsA)); fcompute_ex(state, opctx, inputsA, req, outputsA); - if (ctx.dev_mask() == gpu::kDevMask && exec_type == ExecType::kSync - && rctx.get_stream() && !rctx.is_bulk) { + if (ctx.dev_mask() == gpu::kDevMask && exec_type == ExecType::kSync && + rctx.get_stream() && !rctx.is_bulk) { rctx.get_stream()->Wait(); } }; @@ -769,15 +790,17 @@ inline void PushOperator(const OpStatePtr& state, RunContext rctx{ctx, nullptr, nullptr, false}; run(rctx, engine::CallbackOnComplete()); } else if (exec_type == ExecType::kSync) { - Engine::Get()->PushSync( - [=](RunContext rctx) { run(rctx, engine::CallbackOnComplete()); }, - ctx, read_vars, write_vars, FnProperty::kNormal, 0, - op->name.c_str()); + Engine::Get()->PushSync([=](RunContext rctx) { run(rctx, engine::CallbackOnComplete()); }, + ctx, + read_vars, + write_vars, + FnProperty::kNormal, + 0, + op->name.c_str()); } else { CHECK(exec_type == ExecType::kAsync); - Engine::Get()->PushAsync(run, ctx, read_vars, write_vars, - FnProperty::kAsync, 0, - op->name.c_str()); + Engine::Get()->PushAsync( + run, ctx, read_vars, write_vars, FnProperty::kAsync, 0, op->name.c_str()); } } else { auto fcompute = common::GetFCompute(op, "FStatefulCompute", ctx); @@ -786,57 +809,68 @@ inline void PushOperator(const OpStatePtr& state, << "for stateful operator " << op->name; const auto& run = [=](RunContext rctx, engine::CallbackOnComplete on_complete) { - OpContext opctx{need_grad, is_train, rctx, on_complete, requested}; - - std::vector input_blobs, output_blobs; - // pre-fcompute and post-fcompute storage fallback src NDArrays and dst NDArrays - std::vector pre_temp_src, pre_temp_dst, post_temp_dst, post_temp_src; - // mapping from index in input_blobs to index in pre_temp_dst - std::unordered_map in_temp_idx_map; - INVALIDATE_OUTPUTS_COND(exec_type != ExecType::kCrossDeviceCopy, outputs, req); - - std::vector tmp_req = req; - // populate input blobs and output blobs - SetupDefaultBlobsInOut(inputs, outputs, nullptr, nullptr, &tmp_req, - &input_blobs, &output_blobs, &pre_temp_src, &pre_temp_dst, - &post_temp_src, &post_temp_dst, &in_temp_idx_map, mutate_idx); - // setup contexts - const bool is_gpu = rctx.get_ctx().dev_mask() == gpu::kDevMask; - // pre-fcompute fallback - CastNonDefaultStorage(pre_temp_src, pre_temp_dst, opctx, is_gpu); - fcompute(state, opctx, input_blobs, tmp_req, output_blobs); - // post-fcompute fallback, cast to original storage type, if necessary - CastNonDefaultStorage(post_temp_src, post_temp_dst, opctx, is_gpu); - if (is_gpu && exec_type == ExecType::kSync - && rctx.get_stream() && !rctx.is_bulk) { - rctx.get_stream()->Wait(); - } - DerefInputOutputRelease(inputs, outputs); - }; + OpContext opctx{need_grad, is_train, rctx, on_complete, requested}; + + std::vector input_blobs, output_blobs; + // pre-fcompute and post-fcompute storage fallback src NDArrays and dst NDArrays + std::vector pre_temp_src, pre_temp_dst, post_temp_dst, post_temp_src; + // mapping from index in input_blobs to index in pre_temp_dst + std::unordered_map in_temp_idx_map; + INVALIDATE_OUTPUTS_COND(exec_type != ExecType::kCrossDeviceCopy, outputs, req); + + std::vector tmp_req = req; + // populate input blobs and output blobs + SetupDefaultBlobsInOut(inputs, + outputs, + nullptr, + nullptr, + &tmp_req, + &input_blobs, + &output_blobs, + &pre_temp_src, + &pre_temp_dst, + &post_temp_src, + &post_temp_dst, + &in_temp_idx_map, + mutate_idx); + // setup contexts + const bool is_gpu = rctx.get_ctx().dev_mask() == gpu::kDevMask; + // pre-fcompute fallback + CastNonDefaultStorage(pre_temp_src, pre_temp_dst, opctx, is_gpu); + fcompute(state, opctx, input_blobs, tmp_req, output_blobs); + // post-fcompute fallback, cast to original storage type, if necessary + CastNonDefaultStorage(post_temp_src, post_temp_dst, opctx, is_gpu); + if (is_gpu && exec_type == ExecType::kSync && rctx.get_stream() && !rctx.is_bulk) { + rctx.get_stream()->Wait(); + } + DerefInputOutputRelease(inputs, outputs); + }; if (exec_type == ExecType::kSubgraphExec || CheckIfSkipEngine(attrs)) { RunContext rctx{ctx, nullptr, nullptr, false}; run(rctx, engine::CallbackOnComplete()); } else if (exec_type == ExecType::kSync) { - Engine::Get()->PushSync( - [=](RunContext rctx) { - run(rctx, engine::CallbackOnComplete()); - }, ctx, read_vars, write_vars, FnProperty::kNormal, - 0, op->name.c_str()); + Engine::Get()->PushSync([=](RunContext rctx) { run(rctx, engine::CallbackOnComplete()); }, + ctx, + read_vars, + write_vars, + FnProperty::kNormal, + 0, + op->name.c_str()); } else { CHECK(exec_type == ExecType::kAsync); Engine::Get()->PushAsync( - run, ctx, read_vars, write_vars, FnProperty::kAsync, - 0, op->name.c_str()); + run, ctx, read_vars, write_vars, FnProperty::kAsync, 0, op->name.c_str()); } } } -inline bool CheckAndInferShape(nnvm::Graph* p_g, mxnet::ShapeVector&& shapes, +inline bool CheckAndInferShape(nnvm::Graph* p_g, + mxnet::ShapeVector&& shapes, bool use_inputs, - std::pair node_range = {0, 0}, + std::pair node_range = {0, 0}, std::pair entry_range = {0, 0}, - bool *contain_unknown = nullptr) { + bool* contain_unknown = nullptr) { using namespace nnvm; if (contain_unknown != nullptr) { *contain_unknown = false; @@ -852,13 +886,16 @@ inline bool CheckAndInferShape(nnvm::Graph* p_g, mxnet::ShapeVector&& shapes, for (size_t i = 0; i < shapes.size(); ++i) { if (i == entry_range.first) { i = entry_range.second; - if (i >= shapes.size()) break; + if (i >= shapes.size()) + break; } - if (shapes[i] == prev_shapes[i]) continue; + if (shapes[i] == prev_shapes[i]) + continue; match = false; break; } - if (match) return true; + if (match) + return true; } } g.attrs.erase("shape"); @@ -870,7 +907,7 @@ inline bool CheckAndInferShape(nnvm::Graph* p_g, mxnet::ShapeVector&& shapes, g = exec::InferShape(std::move(g), std::move(shapes)); } else { g.attrs["shape"] = std::make_shared(std::move(shapes)); - g = exec::InferShape(std::move(g)); + g = exec::InferShape(std::move(g)); } if (contain_unknown == nullptr) { CHECK_EQ(g.GetAttr("shape_num_unknown_nodes"), 0U); @@ -880,16 +917,16 @@ inline bool CheckAndInferShape(nnvm::Graph* p_g, mxnet::ShapeVector&& shapes, return false; } - -inline bool CheckAndInferType(nnvm::Graph* p_g, nnvm::DTypeVector&& dtypes, +inline bool CheckAndInferType(nnvm::Graph* p_g, + nnvm::DTypeVector&& dtypes, bool use_inputs, - std::pair node_range = {0, 0}, + std::pair node_range = {0, 0}, std::pair entry_range = {0, 0}) { using namespace nnvm; nnvm::Graph& g = *p_g; if (use_inputs) { - if (g.attrs.count("dtype_inputs") && - g.GetAttr("dtype_inputs") == dtypes) return true; + if (g.attrs.count("dtype_inputs") && g.GetAttr("dtype_inputs") == dtypes) + return true; } else if (g.attrs.count("dtype")) { const auto& prev_dtypes = g.GetAttr("dtype"); CHECK_EQ(prev_dtypes.size(), dtypes.size()); @@ -897,13 +934,16 @@ inline bool CheckAndInferType(nnvm::Graph* p_g, nnvm::DTypeVector&& dtypes, for (size_t i = 0; i < dtypes.size(); ++i) { if (i == entry_range.first) { i = entry_range.second; - if (i >= dtypes.size()) break; + if (i >= dtypes.size()) + break; } - if (dtypes[i] == prev_dtypes[i]) continue; + if (dtypes[i] == prev_dtypes[i]) + continue; match = false; break; } - if (match) return true; + if (match) + return true; } g.attrs.erase("dtype"); g.attrs.erase("dtype_inputs"); @@ -917,28 +957,31 @@ inline bool CheckAndInferType(nnvm::Graph* p_g, nnvm::DTypeVector&& dtypes, g = exec::InferType(std::move(g), std::move(dtypes)); } else { g.attrs["dtype"] = std::make_shared(std::move(dtypes)); - g = exec::InferType(std::move(g)); + g = exec::InferType(std::move(g)); } CHECK_EQ(g.GetAttr("dtype_num_unknown_nodes"), 0U); return false; } -inline bool CheckAndInferStorageType(nnvm::Graph* p_g, exec::DevMaskVector&& dev_mask, - StorageTypeVector&& storage_types, bool use_inputs, - std::pair node_range = {0, 0}, +inline bool CheckAndInferStorageType(nnvm::Graph* p_g, + exec::DevMaskVector&& dev_mask, + StorageTypeVector&& storage_types, + bool use_inputs, + std::pair node_range = {0, 0}, std::pair entry_range = {0, 0}) { using namespace nnvm; nnvm::Graph& g = *p_g; - bool dev_match = g.attrs.count("dev_mask") && - g.GetAttr("dev_mask") == dev_mask; + bool dev_match = + g.attrs.count("dev_mask") && g.GetAttr("dev_mask") == dev_mask; if (!dev_match) { g.attrs["dev_mask"] = std::make_shared(std::move(dev_mask)); } if (dev_match && use_inputs) { if (g.attrs.count("storage_type_inputs") && - g.GetAttr("storage_type_inputs") == storage_types) return true; + g.GetAttr("storage_type_inputs") == storage_types) + return true; } else if (dev_match && g.attrs.count("storage_type")) { const auto& prev_storage_types = g.GetAttr("storage_type"); CHECK_EQ(prev_storage_types.size(), storage_types.size()); @@ -946,13 +989,16 @@ inline bool CheckAndInferStorageType(nnvm::Graph* p_g, exec::DevMaskVector&& dev for (size_t i = 0; i < storage_types.size(); ++i) { if (i == entry_range.first) { i = entry_range.second; - if (i >= storage_types.size()) break; + if (i >= storage_types.size()) + break; } - if (storage_types[i] == prev_storage_types[i]) continue; + if (storage_types[i] == prev_storage_types[i]) + continue; match = false; break; } - if (match) return true; + if (match) + return true; } g.attrs.erase("dispatch_mode"); g.attrs.erase("storage_type"); @@ -964,18 +1010,17 @@ inline bool CheckAndInferStorageType(nnvm::Graph* p_g, exec::DevMaskVector&& dev g = exec::InferStorageType(std::move(g), std::move(storage_types)); } else { g.attrs["storage_type"] = std::make_shared(std::move(storage_types)); - g = exec::InferStorageType(std::move(g)); + g = exec::InferStorageType(std::move(g)); } CHECK_EQ(g.GetAttr("storage_type_num_unknown_nodes"), 0U); return false; } - inline std::vector PlaceDevice(const nnvm::IndexedGraph& idx) { static const auto& _copyto = Op::Get("_copyto"); - std::vector vctx( - idx.num_nodes(), Context::Create(static_cast(-1), 0)); + std::vector vctx(idx.num_nodes(), + Context::Create(static_cast(-1), 0)); // forward pass for (size_t i = 0; i < idx.num_nodes(); ++i) { if (!idx[i].source->info.empty()) { @@ -990,7 +1035,8 @@ inline std::vector PlaceDevice(const nnvm::IndexedGraph& idx) { vctx[i] = vctx[idx[i].control_deps[0]]; } else { for (const auto& in : idx[i].inputs) { - if (vctx[in.node_id].dev_type == static_cast(-1)) continue; + if (vctx[in.node_id].dev_type == static_cast(-1)) + continue; vctx[i] = vctx[in.node_id]; break; } @@ -998,10 +1044,12 @@ inline std::vector PlaceDevice(const nnvm::IndexedGraph& idx) { } // backward pass for (int i = idx.num_nodes() - 1; i >= 0; --i) { - if (vctx[i].dev_type == static_cast(-1)) continue; + if (vctx[i].dev_type == static_cast(-1)) + continue; if (idx[i].source->op() == _copyto) { auto in_nid = idx[i].inputs[0].node_id; - if (vctx[in_nid].dev_type != static_cast(-1)) continue; + if (vctx[in_nid].dev_type != static_cast(-1)) + continue; CHECK_GT(idx[i].source->control_deps.size(), 0); auto fwd_nid = idx.node_id(idx[i].source->control_deps[0].get()); CHECK_EQ(idx[fwd_nid].source->op(), _copyto); @@ -1009,7 +1057,8 @@ inline std::vector PlaceDevice(const nnvm::IndexedGraph& idx) { continue; } for (const auto& j : idx[i].inputs) { - if (vctx[j.node_id].dev_type != static_cast(-1)) continue; + if (vctx[j.node_id].dev_type != static_cast(-1)) + continue; vctx[j.node_id] = vctx[i]; } } @@ -1024,32 +1073,31 @@ inline std::vector PlaceDevice(const nnvm::IndexedGraph& idx) { return vctx; } - -inline MemoryPlanVector MXPlanMemory( - nnvm::Graph* p_g, - nnvm::StorageVector&& storage, - const std::vector& ref_count, - const std::string& storage_plan, - const std::pair& node_range = {0, 0}, - const std::pair& entry_range = {0, 0}, - bool detect_inplace_addto = false) { +inline MemoryPlanVector MXPlanMemory(nnvm::Graph* p_g, + nnvm::StorageVector&& storage, + const std::vector& ref_count, + const std::string& storage_plan, + const std::pair& node_range = {0, 0}, + const std::pair& entry_range = {0, 0}, + bool detect_inplace_addto = false) { using namespace nnvm; - nnvm::Graph& g = *p_g; + nnvm::Graph& g = *p_g; const auto& idx = g.indexed_graph(); if (node_range.second > node_range.first) { g.attrs["node_range"] = std::make_shared(node_range); } g.attrs["ref_count"] = std::make_shared(ref_count); - g.attrs["storage"] = std::make_shared(std::move(storage)); - g = nnvm::ApplyPass(g, "MXPlanMemory"); - if (detect_inplace_addto) g = exec::DetectInplaceAddTo(g); + g.attrs["storage"] = std::make_shared(std::move(storage)); + g = nnvm::ApplyPass(g, "MXPlanMemory"); + if (detect_inplace_addto) + g = exec::DetectInplaceAddTo(g); - const auto& dtypes = g.GetAttr("dtype"); - const auto& shapes = g.GetAttr("shape"); + const auto& dtypes = g.GetAttr("dtype"); + const auto& shapes = g.GetAttr("shape"); const auto& storage_inplace = g.GetAttr >("storage_inplace_index"); - g.attrs[storage_plan] = std::make_shared(storage_inplace); - const auto& storage_ids = g.GetAttr("storage_id"); - uint32_t entry_start = entry_range.first; + g.attrs[storage_plan] = std::make_shared(storage_inplace); + const auto& storage_ids = g.GetAttr("storage_id"); + uint32_t entry_start = entry_range.first; uint32_t entry_end = entry_range.second > entry_start ? entry_range.second : idx.num_node_entries(); MemoryPlanVector mem_plan(idx.num_node_entries()); @@ -1061,29 +1109,28 @@ inline MemoryPlanVector MXPlanMemory( } else if (!sid_to_root.count(storage_ids[i])) { CHECK_LT(storage_inplace[i], 0); sid_to_root[storage_ids[i]] = i; - mem_plan[i] = {storage_ids[i], i, - mshadow::mshadow_sizeof(dtypes[i]) * shapes[i].Size(), - false}; + mem_plan[i] = { + storage_ids[i], i, mshadow::mshadow_sizeof(dtypes[i]) * shapes[i].Size(), false}; } else { uint32_t root = sid_to_root[storage_ids[i]]; - mem_plan[i] = {storage_ids[i], root, 0, storage_inplace[i] >= 0}; - mem_plan[root].size = std::max(mem_plan[root].size, - mshadow::mshadow_sizeof(dtypes[i]) * shapes[i].Size()); + mem_plan[i] = {storage_ids[i], root, 0, storage_inplace[i] >= 0}; + mem_plan[root].size = + std::max(mem_plan[root].size, mshadow::mshadow_sizeof(dtypes[i]) * shapes[i].Size()); } } return mem_plan; } - inline std::multimap AllocateMemory( const nnvm::Graph& g, const nnvm::IndexedGraph& idx, const Context& default_ctx, - const uint32_t entry_start, const uint32_t entry_end, + const uint32_t entry_start, + const uint32_t entry_end, const MemoryPlanVector& mem_plan, const std::vector& arrays, - std::vector *array_reqs, + std::vector* array_reqs, std::multimap&& pool = std::multimap()) { using namespace nnvm; const auto& dtypes = g.GetAttr("dtype"); @@ -1102,18 +1149,19 @@ inline std::multimap AllocateMemory( continue; } data_entry_profiler_scopes[eid - entry_start] = profiler_scope; - data_entry_names[eid - entry_start] = idx[nid].source->attrs.name; + data_entry_names[eid - entry_start] = idx[nid].source->attrs.name; } } - const NDArray *pntr; + const NDArray* pntr; for (uint32_t i = entry_start; i < entry_end; ++i) { - const auto &plan = mem_plan[i]; - if (plan.storage_id == exec::kExternalStorageID) continue; + const auto& plan = mem_plan[i]; + if (plan.storage_id == exec::kExternalStorageID) + continue; CHECK(arrays[i]->is_none()); if (plan.storage_id == exec::kDynamicStorageID) { - *arrays[i] = NDArray(static_cast(stypes[i]), - shapes[i], default_ctx, true, dtypes[i]); + *arrays[i] = NDArray( + static_cast(stypes[i]), shapes[i], default_ctx, true, dtypes[i]); arrays[i]->AssignStorageInfo(data_entry_profiler_scopes[i - entry_start], data_entry_names[i - entry_start]); continue; @@ -1126,7 +1174,9 @@ inline std::multimap AllocateMemory( pool.erase(iter); } else { NDArray buff(mxnet::TShape({static_cast(plan.size)}), - default_ctx, true, mshadow::kUint8); + default_ctx, + true, + mshadow::kUint8); buff.AssignStorageInfo(data_entry_profiler_scopes[i - entry_start], data_entry_names[i - entry_start]); pntr = &new_pool.insert({plan.size, buff})->second; @@ -1143,13 +1193,12 @@ inline std::multimap AllocateMemory( return new_pool; } -inline void SetupOpExec( - const nnvm::Graph& g, - size_t nid, - const std::shared_ptr& exec, - const std::vector arrays, - const std::vector array_reqs) { - const auto& idx = g.indexed_graph(); +inline void SetupOpExec(const nnvm::Graph& g, + size_t nid, + const std::shared_ptr& exec, + const std::vector arrays, + const std::vector array_reqs) { + const auto& idx = g.indexed_graph(); const auto& inode = idx[nid]; CHECK_EQ(exec->in_array.size(), 0U); CHECK_EQ(exec->out_array.size(), 0U); @@ -1195,24 +1244,25 @@ inline Engine::OprHandle CreateEngineOp( // dedup vars Engine::Get()->DeduplicateVarHandle(&use_vars, &mutate_vars); - bool is_gpu = default_ctx.dev_mask() == gpu::kDevMask; + bool is_gpu = default_ctx.dev_mask() == gpu::kDevMask; bool is_async = execs.size() > 1 ? false : execs[0]->exec_type() == ExecType::kAsync; - auto exec_fun = [execs, is_async, is_gpu] ( - RunContext ctx, Engine::CallbackOnComplete on_complete) { + auto exec_fun = [execs, is_async, is_gpu](RunContext ctx, + Engine::CallbackOnComplete on_complete) { if (is_async) { execs[0]->op_ctx.async_on_complete = on_complete; } - for (const auto& exec : execs) exec->Run(ctx, is_gpu); + for (const auto& exec : execs) + exec->Run(ctx, is_gpu); // call on complete only if it is async op if (!is_async) { if (is_gpu) { - #if MXNET_USE_CUDA +#if MXNET_USE_CUDA // Wait GPU kernel to finish. ctx.get_stream()->Wait(); - #else +#else LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; - #endif +#endif } on_complete(); } @@ -1222,26 +1272,27 @@ inline Engine::OprHandle CreateEngineOp( exec_fun, use_vars, mutate_vars, FnProperty::kNormal, opr_names); } -inline void CreateEngineOpSeg( - const nnvm::IndexedGraph& idx, - const Context default_ctx, - const size_t start_nid, - const size_t end_nid, - const size_t bulk_size, - const std::vector >& execs, - const std::vector skip_plus_node, - std::vector *opr_segs) { +inline void CreateEngineOpSeg(const nnvm::IndexedGraph& idx, + const Context default_ctx, + const size_t start_nid, + const size_t end_nid, + const size_t bulk_size, + const std::vector >& execs, + const std::vector skip_plus_node, + std::vector* opr_segs) { size_t seg_start = start_nid; std::vector > seg_execs; std::string opr_names; for (size_t nid = start_nid; nid < end_nid; ++nid) { const auto& node = idx[nid]; - if (node.source->is_variable()) continue; - if (skip_plus_node.size() && skip_plus_node[nid]) continue; - auto& exec = execs[nid]; - const auto &op_name = node.source->op()->name; - bool is_async = exec->exec_type() != ExecType::kSync; - bool valid = exec->out_array.size() > 0; + if (node.source->is_variable()) + continue; + if (skip_plus_node.size() && skip_plus_node[nid]) + continue; + auto& exec = execs[nid]; + const auto& op_name = node.source->op()->name; + bool is_async = exec->exec_type() != ExecType::kSync; + bool valid = exec->out_array.size() > 0; // Stop at async nodes and invalid node (due to input/output is not allocated) bool stop = is_async || !valid || seg_execs.size() >= bulk_size; @@ -1261,7 +1312,8 @@ inline void CreateEngineOpSeg( } seg_execs.push_back(exec); - if (opr_names.size()) opr_names += ","; + if (opr_names.size()) + opr_names += ","; opr_names += op_name; auto& seg = (*opr_segs)[nid]; @@ -1290,34 +1342,35 @@ inline void CreateEngineOpSeg( } } - void RunGraph(const bool retain_graph, const nnvm::IndexedGraph& idx, const std::vector& arrays, - size_t node_start, size_t node_end, + size_t node_start, + size_t node_end, std::vector&& array_reqs, std::vector&& ref_count, - std::vector *p_states, - const DispatchModeVector &dispatch_modes, + std::vector* p_states, + const DispatchModeVector& dispatch_modes, bool recording, - mxnet::ShapeVector *shapes = nullptr, + mxnet::ShapeVector* shapes = nullptr, const CachedOpMonCallback& callback = nullptr, - const bool monitor_all_ = false); + const bool monitor_all_ = false); void NaiveRunGraph(const bool retain_graph, const Context& default_ctx, const nnvm::IndexedGraph& idx, const std::vector& arrays, - size_t node_start, size_t node_end, + size_t node_start, + size_t node_end, std::vector&& array_reqs, std::vector&& ref_count, - std::vector *p_states, - const DispatchModeVector &dispatch_modes, + std::vector* p_states, + const DispatchModeVector& dispatch_modes, bool recording, - mxnet::ShapeVector *shapes, + mxnet::ShapeVector* shapes, const CachedOpMonCallback& callback = nullptr, - const bool monitor_all_ = false, - const bool skip_engine = false); + const bool monitor_all_ = false, + const bool skip_engine = false); } // namespace imperative } // namespace mxnet diff --git a/src/imperative/infer_graph_attr_pass.cc b/src/imperative/infer_graph_attr_pass.cc index d5d969618f87..13d0e07c174c 100644 --- a/src/imperative/infer_graph_attr_pass.cc +++ b/src/imperative/infer_graph_attr_pass.cc @@ -32,7 +32,7 @@ namespace mxnet { namespace exec { -template +template bool ApplyOpInferAttr(const nnvm::Graph& g, const FInfer& finfer, const NodeAttrs& attrs, @@ -43,7 +43,7 @@ bool ApplyOpInferAttr(const nnvm::Graph& g, return finfer(attrs, in_attrs, out_attrs); } -template<> +template <> bool ApplyOpInferAttr(const nnvm::Graph& g, const FInferStorageType& finfer, const NodeAttrs& attrs, @@ -63,18 +63,17 @@ bool ApplyOpInferAttr(const nnvm::Graph& g, return true; } -template +template inline void GetAttrFromForwardNode(const uint32_t nid, - const nnvm::IndexedGraph &idx, + const nnvm::IndexedGraph& idx, std::vector* rshape_ptr, std::vector* inference_finished, IsNone fis_none) { - std::vector& rshape = *rshape_ptr; + std::vector& rshape = *rshape_ptr; const nnvm::IndexedGraph::Node& inode = idx[nid]; // gradient function, used to get node correspondence. - static auto& fgrad = - Op::GetAttr("FGradient"); - nnvm::ObjectPtr fwd_ptr = inode.source->control_deps[0]; + static auto& fgrad = Op::GetAttr("FGradient"); + nnvm::ObjectPtr fwd_ptr = inode.source->control_deps[0]; const nnvm::IndexedGraph::Node& fnode = idx[inode.control_deps[0]]; // use gradient function to find out the correspondence. std::vector ograd(fwd_ptr->num_outputs()); @@ -83,8 +82,8 @@ inline void GetAttrFromForwardNode(const uint32_t nid, } // input gradient list const std::vector& igrad = fgrad[fwd_ptr->op()](fwd_ptr, ograd); - const nnvm::Node* igrad_node = nullptr; - bool all_attrs_known = true; + const nnvm::Node* igrad_node = nullptr; + bool all_attrs_known = true; // Input gradient assignement for (size_t i = 0; i < igrad.size(); ++i) { if (igrad[i].node->op() == inode.source->op()) { @@ -110,8 +109,8 @@ inline void GetAttrFromForwardNode(const uint32_t nid, } } // out grad entries - CHECK(igrad_node != nullptr) - << "Cannot find matching backward op for " << inode.source->attrs.name; + CHECK(igrad_node != nullptr) << "Cannot find matching backward op for " + << inode.source->attrs.name; for (size_t i = 0; i < igrad_node->inputs.size(); ++i) { const nnvm::NodeEntry& e = igrad_node->inputs[i]; if (e.node == nullptr) { @@ -128,7 +127,7 @@ inline void GetAttrFromForwardNode(const uint32_t nid, (*inference_finished)[nid] = all_attrs_known; } -template +template void GetAttrFromFusedNode(uint32_t nid, const nnvm::IndexedGraph& idx, std::vector* rshape_ptr, @@ -136,20 +135,18 @@ void GetAttrFromFusedNode(uint32_t nid, IsNone fis_none, const std::string& infer_fusion_name) { std::vector& rshape = *rshape_ptr; - const auto& inode = idx[nid]; + const auto& inode = idx[nid]; // gradient function, used to get node correspondence. - static auto& fgrad = - Op::GetAttr("FGradient"); - nnvm::ObjectPtr fused_fwd_ptr = inode.source->control_deps[0]; - static auto& finfer_fused_shape = - Op::GetAttr(infer_fusion_name); - auto finfer = finfer_fused_shape.get(fused_fwd_ptr->op(), nullptr); - CHECK(finfer != nullptr) << "Operator " << fused_fwd_ptr->attrs.name << - " is marked as Fusion but does not allow accessing attributes"; + static auto& fgrad = Op::GetAttr("FGradient"); + nnvm::ObjectPtr fused_fwd_ptr = inode.source->control_deps[0]; + static auto& finfer_fused_shape = Op::GetAttr(infer_fusion_name); + auto finfer = finfer_fused_shape.get(fused_fwd_ptr->op(), nullptr); + CHECK(finfer != nullptr) << "Operator " << fused_fwd_ptr->attrs.name + << " is marked as Fusion but does not allow accessing attributes"; const auto& inferred_attrs = finfer(fused_fwd_ptr->attrs); - const auto& fwd_ptr = std::get<0>(inferred_attrs); - const auto& input_attrs = std::get<1>(inferred_attrs); - const auto& output_attrs = std::get<2>(inferred_attrs); + const auto& fwd_ptr = std::get<0>(inferred_attrs); + const auto& input_attrs = std::get<1>(inferred_attrs); + const auto& output_attrs = std::get<2>(inferred_attrs); // use gradient function to find out the correspondence. std::vector ograd(fwd_ptr->num_outputs()); @@ -158,8 +155,8 @@ void GetAttrFromFusedNode(uint32_t nid, } // input gradient list const std::vector& igrad = fgrad[fwd_ptr->op()](fwd_ptr, ograd); - const nnvm::Node* igrad_node = nullptr; - bool all_attrs_known = true; + const nnvm::Node* igrad_node = nullptr; + bool all_attrs_known = true; // Set the attributes of output gradients // using attributes of forward node inputs for (size_t i = 0; i < igrad.size(); ++i) { @@ -188,8 +185,8 @@ void GetAttrFromFusedNode(uint32_t nid, // Set the attributes of input gradients // using attributes of forward node outputs - CHECK(igrad_node != nullptr) - << "Cannot find matching backward op for " << inode.source->attrs.name; + CHECK(igrad_node != nullptr) << "Cannot find matching backward op for " + << inode.source->attrs.name; for (size_t i = 0; i < igrad_node->inputs.size(); ++i) { const nnvm::NodeEntry& e = igrad_node->inputs[i]; if (e.node == nullptr) { @@ -217,9 +214,9 @@ void ProvideAttrToFusion(const uint32_t nid, for (const auto& dep_node : inode.source->control_deps) { in_attrs.push_back({}); out_attrs.push_back({}); - auto ¤t_in_attrs = in_attrs.back(); - auto ¤t_out_attrs = out_attrs.back(); - uint32_t dep_node_id = idx.node_id(dep_node.get()); + auto& current_in_attrs = in_attrs.back(); + auto& current_out_attrs = out_attrs.back(); + uint32_t dep_node_id = idx.node_id(dep_node.get()); for (const auto& e : idx[dep_node_id].inputs) { current_in_attrs.push_back(rshape[idx.entry_id(e)]); } @@ -228,10 +225,10 @@ void ProvideAttrToFusion(const uint32_t nid, } } auto provide = - Op::GetAttr(provide_fusion_name).get(inode.source->op(), nullptr); - CHECK(provide != nullptr) << - "Encountered Fusion operator that does not implement providing subgraph attr " << - provide_fusion_name << "."; + Op::GetAttr(provide_fusion_name).get(inode.source->op(), nullptr); + CHECK(provide != nullptr) + << "Encountered Fusion operator that does not implement providing subgraph attr " + << provide_fusion_name << "."; provide(inode.source->attrs, inode.source->control_deps, in_attrs, out_attrs); } @@ -263,9 +260,13 @@ void ProvideAttrToFusion(const uint32_t nid, * \param default_mode_val default value of the dispatch mode attribute on the node. Used * for storage type inference */ -template -nnvm::Graph InferAttr(nnvm::Graph &&ret, +template +nnvm::Graph InferAttr(nnvm::Graph&& ret, const AttrType empty_val, const char* infer_name, const char* infer_fusion_name, @@ -281,15 +282,13 @@ nnvm::Graph InferAttr(nnvm::Graph &&ret, const DispatchMode default_mode_val = DispatchMode::kUndefined) { using nnvm::IndexedGraph; using nnvm::Op; - using AttrVector = std::vector; + using AttrVector = std::vector; using NodeAttrVector = std::vector; using dmlc::any; - const IndexedGraph& idx = ret.indexed_graph(); - static auto& finfer_shape = - Op::GetAttr(infer_name); - static auto& is_backward = - Op::GetAttr("TIsBackward"); + const IndexedGraph& idx = ret.indexed_graph(); + static auto& finfer_shape = Op::GetAttr(infer_name); + static auto& is_backward = Op::GetAttr("TIsBackward"); // reshape shape vector AttrVector rshape; // vector holding information which operators @@ -316,7 +315,7 @@ nnvm::Graph InferAttr(nnvm::Graph &&ret, std::string shape_hints_key = std::string(attr_name) + "_hints"; if (ret.attrs.count(shape_hints_key)) { nnvm::NodeEntryMap shape_hints = - ret.GetAttr>(shape_hints_key); + ret.GetAttr>(shape_hints_key); for (const auto& kv : shape_hints) { nnvm::NodeEntry e = kv.first; if (idx.exist(e.node.get())) { @@ -335,18 +334,18 @@ nnvm::Graph InferAttr(nnvm::Graph &&ret, // limit inference to part of the graph uint32_t node_start = 0, node_end = idx.num_nodes(); if (ret.attrs.count("node_range")) { - const auto& range = ret.GetAttr >("node_range"); - node_start = range.first; - node_end = range.second; + const auto& range = ret.GetAttr>("node_range"); + node_start = range.first; + node_end = range.second; CHECK_GE(node_start, 0); CHECK_LE(node_end, idx.num_nodes()); ret.attrs.erase("node_range"); } uint32_t entry_start = 0, entry_end = idx.num_node_entries(); if (ret.attrs.count("entry_range")) { - const auto& range = ret.GetAttr >("entry_range"); - entry_start = range.first; - entry_end = range.second; + const auto& range = ret.GetAttr>("entry_range"); + entry_start = range.first; + entry_end = range.second; CHECK_GE(entry_start, 0); CHECK_LE(entry_end, idx.num_node_entries()); ret.attrs.erase("entry_range"); @@ -365,9 +364,10 @@ nnvm::Graph InferAttr(nnvm::Graph &&ret, // inference step function for nid auto infer_step = [&](uint32_t nid, bool last_iter) { - if (inference_finished[nid]) return; - const auto& inode = idx[nid]; - const uint32_t num_inputs = inode.inputs.size(); + if (inference_finished[nid]) + return; + const auto& inode = idx[nid]; + const uint32_t num_inputs = inode.inputs.size(); const uint32_t num_outputs = inode.source->num_outputs(); if (inode.source->is_variable()) { // Variable node. No operator. Only one output entry. @@ -388,12 +388,12 @@ nnvm::Graph InferAttr(nnvm::Graph &&ret, if (dispatch_mode_name != nullptr) { op::dispatch_mode_assign(&dispatch_modes[nid], default_mode_val); } - } else if (is_backward.get(inode.source->op(), false) && - inode.source->control_deps.size() && bwd_identity_assign) { + } else if (is_backward.get(inode.source->op(), false) && inode.source->control_deps.size() && + bwd_identity_assign) { CHECK(dispatch_mode_name == nullptr) - << "Backward inference for node attributes is not available"; + << "Backward inference for node attributes is not available"; CHECK_GE(inode.source->control_deps.size(), 1U) - << "BackwardOp need to have control_deps to its forward op"; + << "BackwardOp need to have control_deps to its forward op"; nnvm::ObjectPtr fwd_ptr = inode.source->control_deps[0]; CHECK(fwd_ptr->op() != nullptr) << "Forward op cannot be a variable"; @@ -401,8 +401,8 @@ nnvm::Graph InferAttr(nnvm::Graph &&ret, if (!is_fusion_helper.get(fwd_ptr->op(), false)) { GetAttrFromForwardNode(nid, idx, &rshape, &inference_finished, fis_none); } else { - GetAttrFromFusedNode(nid, idx, &rshape, &inference_finished, - fis_none, infer_fusion_name); + GetAttrFromFusedNode( + nid, idx, &rshape, &inference_finished, fis_none, infer_fusion_name); } } else { DispatchMode* dispatch_mode = nullptr; @@ -426,14 +426,15 @@ nnvm::Graph InferAttr(nnvm::Graph &&ret, if (is_fusion.get(inode.source->op(), false)) { ProvideAttrToFusion(nid, idx, rshape, provide_fusion_name); } - ApplyOpInferAttr(ret, finfer, inode.source->attrs, - nid, &ishape, &oshape, dispatch_mode); + ApplyOpInferAttr(ret, finfer, inode.source->attrs, nid, &ishape, &oshape, dispatch_mode); bool finished = true; for (const auto& attr : ishape) { - if (fis_none(attr)) finished = false; + if (fis_none(attr)) + finished = false; } for (const auto& attr : oshape) { - if (fis_none(attr)) finished = false; + if (fis_none(attr)) + finished = false; } inference_finished[nid] = finished; } catch (const std::exception& e) { @@ -455,10 +456,9 @@ nnvm::Graph InferAttr(nnvm::Graph &&ret, } inference_finished[nid] = all_attrs_known; if (!all_attrs_known) { - CHECK(!last_iter) - << "Attribute " << infer_name - << " is not registered by op " << inode.source->op()->name - << ". We are not able to complete the inference because of this"; + CHECK(!last_iter) << "Attribute " << infer_name << " is not registered by op " + << inode.source->op()->name + << ". We are not able to complete the inference because of this"; } } // Save to the result map. @@ -473,11 +473,11 @@ nnvm::Graph InferAttr(nnvm::Graph &&ret, size_t last_num_unknown; size_t num_unknown_dispatch_mode = dispatch_mode_name ? node_end - node_start : 0; - size_t num_unknown_entry_attr = entry_end - entry_start; - size_t num_unknown = num_unknown_entry_attr + num_unknown_dispatch_mode; - bool last_iter = false; - bool do_next_iteration = true; - int i = 0; + size_t num_unknown_entry_attr = entry_end - entry_start; + size_t num_unknown = num_unknown_entry_attr + num_unknown_dispatch_mode; + bool last_iter = false; + bool do_next_iteration = true; + int i = 0; do { if (i % 2 == 0) { for (uint32_t nid = node_start; nid < node_end; ++nid) { @@ -490,7 +490,7 @@ nnvm::Graph InferAttr(nnvm::Graph &&ret, } } last_num_unknown = num_unknown; - num_unknown = 0; + num_unknown = 0; for (size_t j = entry_start; j < entry_end; ++j) { if (fis_none(rshape[j])) { ++num_unknown; @@ -498,7 +498,8 @@ nnvm::Graph InferAttr(nnvm::Graph &&ret, } if (dispatch_mode_name) { for (size_t i = node_start; i < node_end; i++) { - if (dispatch_modes[i] == DispatchMode::kUndefined) ++num_unknown; + if (dispatch_modes[i] == DispatchMode::kUndefined) + ++num_unknown; } } do_next_iteration = num_unknown > 0 && last_num_unknown > num_unknown; @@ -549,8 +550,8 @@ nnvm::Graph InferAttr(nnvm::Graph &&ret, * \param default_mode_val default value of the dispatch mode attribute on the node. Used * for storage type inference */ -template -nnvm::Graph InferShapeAttr(nnvm::Graph &&ret, +template +nnvm::Graph InferShapeAttr(nnvm::Graph&& ret, const mxnet::TShape empty_val, const char* infer_name, const char* input_name, @@ -565,16 +566,14 @@ nnvm::Graph InferShapeAttr(nnvm::Graph &&ret, const DispatchMode default_mode_val = DispatchMode::kUndefined) { using nnvm::IndexedGraph; using nnvm::Op; - using AttrType = mxnet::TShape; - using FInferType = mxnet::FInferShape; - using AttrVector = std::vector; + using AttrType = mxnet::TShape; + using FInferType = mxnet::FInferShape; + using AttrVector = std::vector; using NodeAttrVector = std::vector; using dmlc::any; - const IndexedGraph& idx = ret.indexed_graph(); - static auto& finfer_shape = - Op::GetAttr(infer_name); - static auto& is_backward = - Op::GetAttr("TIsBackward"); + const IndexedGraph& idx = ret.indexed_graph(); + static auto& finfer_shape = Op::GetAttr(infer_name); + static auto& is_backward = Op::GetAttr("TIsBackward"); // reshape shape vector AttrVector rshape; // vector holding information which operators @@ -601,7 +600,7 @@ nnvm::Graph InferShapeAttr(nnvm::Graph &&ret, std::string shape_hints_key = std::string(attr_name) + "_hints"; if (ret.attrs.count(shape_hints_key)) { nnvm::NodeEntryMap shape_hints = - ret.GetAttr>(shape_hints_key); + ret.GetAttr>(shape_hints_key); for (const auto& kv : shape_hints) { nnvm::NodeEntry e = kv.first; if (idx.exist(e.node.get())) { @@ -620,18 +619,18 @@ nnvm::Graph InferShapeAttr(nnvm::Graph &&ret, // limit inference to part of the graph uint32_t node_start = 0, node_end = idx.num_nodes(); if (ret.attrs.count("node_range")) { - const auto& range = ret.GetAttr >("node_range"); - node_start = range.first; - node_end = range.second; + const auto& range = ret.GetAttr>("node_range"); + node_start = range.first; + node_end = range.second; CHECK_GE(node_start, 0); CHECK_LE(node_end, idx.num_nodes()); ret.attrs.erase("node_range"); } uint32_t entry_start = 0, entry_end = idx.num_node_entries(); if (ret.attrs.count("entry_range")) { - const auto& range = ret.GetAttr >("entry_range"); - entry_start = range.first; - entry_end = range.second; + const auto& range = ret.GetAttr>("entry_range"); + entry_start = range.first; + entry_end = range.second; CHECK_GE(entry_start, 0); CHECK_LE(entry_end, idx.num_node_entries()); ret.attrs.erase("entry_range"); @@ -657,10 +656,11 @@ nnvm::Graph InferShapeAttr(nnvm::Graph &&ret, // inference step function for nid auto infer_step = [&](uint32_t nid, bool last_iter) { - if (inference_finished[nid]) return; - const auto& inode = idx[nid]; - const std::string name = inode.source->attrs.name; - const uint32_t num_inputs = inode.inputs.size(); + if (inference_finished[nid]) + return; + const auto& inode = idx[nid]; + const std::string name = inode.source->attrs.name; + const uint32_t num_inputs = inode.inputs.size(); const uint32_t num_outputs = inode.source->num_outputs(); if (inode.source->is_variable()) { @@ -685,12 +685,12 @@ nnvm::Graph InferShapeAttr(nnvm::Graph &&ret, if (dispatch_mode_name != nullptr) { op::dispatch_mode_assign(&dispatch_modes[nid], default_mode_val); } - } else if (is_backward.get(inode.source->op(), false) && - inode.source->control_deps.size() && bwd_identity_assign) { + } else if (is_backward.get(inode.source->op(), false) && inode.source->control_deps.size() && + bwd_identity_assign) { CHECK(dispatch_mode_name == nullptr) - << "Backward inference for node attributes is not available"; + << "Backward inference for node attributes is not available"; CHECK_GE(inode.source->control_deps.size(), 1U) - << "BackwardOp need to have control_deps to its forward op"; + << "BackwardOp need to have control_deps to its forward op"; nnvm::ObjectPtr fwd_ptr = inode.source->control_deps[0]; CHECK(fwd_ptr->op() != nullptr) << "Forward op cannot be a variable"; @@ -698,10 +698,8 @@ nnvm::Graph InferShapeAttr(nnvm::Graph &&ret, if (!is_fusion_helper.get(fwd_ptr->op(), false)) { GetAttrFromForwardNode(nid, idx, &rshape, &inference_finished, fis_none); } else { - GetAttrFromFusedNode(nid, idx, &rshape, - &inference_finished, - fis_none, - "FAccessSubgraphShape"); + GetAttrFromFusedNode( + nid, idx, &rshape, &inference_finished, fis_none, "FAccessSubgraphShape"); } } else { DispatchMode* dispatch_mode = nullptr; @@ -734,17 +732,18 @@ nnvm::Graph InferShapeAttr(nnvm::Graph &&ret, try { static auto& is_fusion = Op::GetAttr("TIsFusion"); if (is_fusion.get(inode.source->op(), false)) { - ProvideAttrToFusion(nid, idx, rshape, - "FProvideSubgraphShape"); + ProvideAttrToFusion( + nid, idx, rshape, "FProvideSubgraphShape"); } - ApplyOpInferAttr(ret, finfer, inode.source->attrs, - nid, &ishape, &oshape, dispatch_mode); + ApplyOpInferAttr(ret, finfer, inode.source->attrs, nid, &ishape, &oshape, dispatch_mode); bool finished = true; for (const auto& attr : ishape) { - if (fis_none(attr)) finished = false; + if (fis_none(attr)) + finished = false; } for (const auto& attr : oshape) { - if (fis_none(attr)) finished = false; + if (fis_none(attr)) + finished = false; } inference_finished[nid] = finished; } catch (const std::exception& e) { @@ -762,8 +761,8 @@ nnvm::Graph InferShapeAttr(nnvm::Graph &&ret, }; size_t last_num_unknown; - size_t num_unknown = static_cast(-1); // Infinity - bool last_iter = false; + size_t num_unknown = static_cast(-1); // Infinity + bool last_iter = false; bool do_next_iteration = true; int i = 0; @@ -780,7 +779,7 @@ nnvm::Graph InferShapeAttr(nnvm::Graph &&ret, } } last_num_unknown = num_unknown; - num_unknown = 0; + num_unknown = 0; for (size_t j = entry_start; j < entry_end; ++j) { if (fis_none(rshape[j])) { num_unknown += fnum_unknown(rshape[j]); @@ -827,9 +826,13 @@ nnvm::Graph InferShape(nnvm::Graph&& graph, graph.attrs["shape_attr_key"] = std::make_shared(shape_attr_key); } return InferShapeAttr( - std::move(graph), mxnet::TShape(), - "FInferShape", "shape_inputs", "shape_attr_key", - "shape", "shape_num_unknown_nodes", + std::move(graph), + mxnet::TShape(), + "FInferShape", + "shape_inputs", + "shape_attr_key", + "shape", + "shape_num_unknown_nodes", [](const mxnet::TShape& s) { return !mxnet::shape_is_known(s); }, [](const mxnet::TShape& s) { if (!mxnet::ndim_is_known(s)) { @@ -843,7 +846,9 @@ nnvm::Graph InferShape(nnvm::Graph&& graph, } return ret; }, - nullptr, true, nullptr); + nullptr, + true, + nullptr); } nnvm::Graph InferType(nnvm::Graph&& graph, @@ -856,13 +861,20 @@ nnvm::Graph InferType(nnvm::Graph&& graph, if (dtype_attr_key.length() != 0) { graph.attrs["dtype_attr_key"] = std::make_shared(dtype_attr_key); } - return InferAttr( - std::move(graph), -1, - "FInferType", "FAccessSubgraphType", "FProvideSubgraphType", - "dtype_inputs", "dtype_attr_key", "dtype", "dtype_num_unknown_nodes", + return InferAttr( + std::move(graph), + -1, + "FInferType", + "FAccessSubgraphType", + "FProvideSubgraphType", + "dtype_inputs", + "dtype_attr_key", + "dtype", + "dtype_num_unknown_nodes", [](const int t) { return t == -1; }, - common::SameType, true, nullptr); + common::SameType, + true, + nullptr); } nnvm::Graph InferStorageType(nnvm::Graph&& graph, @@ -885,19 +897,30 @@ nnvm::Graph InferStorageType(nnvm::Graph&& graph, CHECK_GT(graph.attrs.count("context"), 0); DevMaskVector dev_masks(graph.indexed_graph().num_nodes()); const ContextVector& vctx = graph.GetAttr("context"); - for (size_t i = 0; i < vctx.size(); i++) dev_masks[i] = vctx[i].dev_mask(); + for (size_t i = 0; i < vctx.size(); i++) + dev_masks[i] = vctx[i].dev_mask(); graph.attrs["dev_mask"] = std::make_shared(std::move(dev_masks)); } // for storage type, the backward attr is not necessarily the same as it's correspondence - nnvm::Graph ret = InferAttr( - std::move(graph), -1, - "FInferStorageType", "FAccessSubgraphStorageType", "FProvideSubgraphStorageType", - "storage_type_inputs", "storage_type_attr_key", "storage_type", + std::move(graph), + -1, + "FInferStorageType", + "FAccessSubgraphStorageType", + "FProvideSubgraphStorageType", + "storage_type_inputs", + "storage_type_attr_key", + "storage_type", "storage_type_num_unknown_nodes", [](const int t) { return t == -1; }, - common::DefaultStorageType, false, "dispatch_mode", DispatchMode::kVariable); + common::DefaultStorageType, + false, + "dispatch_mode", + DispatchMode::kVariable); // log the storage types and dispatch modes of the graph static bool log_verbose = dmlc::GetEnv("MXNET_INFER_STORAGE_TYPE_VERBOSE_LOGGING", false); diff --git a/src/imperative/inplace_addto_detect_pass.cc b/src/imperative/inplace_addto_detect_pass.cc index 4af2dcd66306..a087b210849c 100644 --- a/src/imperative/inplace_addto_detect_pass.cc +++ b/src/imperative/inplace_addto_detect_pass.cc @@ -33,12 +33,11 @@ namespace mxnet { namespace exec { Graph DetectInplaceAddTo(Graph g) { - nnvm::StorageVector storage_id = - g.MoveCopyAttr("storage_id"); + nnvm::StorageVector storage_id = g.MoveCopyAttr("storage_id"); std::vector storage_inplace_index = g.MoveCopyAttr >("storage_inplace_index"); static const Op* ewise_plus_op = Op::Get("_grad_add"); - auto& idx = g.indexed_graph(); + auto& idx = g.indexed_graph(); // reference cont. std::vector ref_count(idx.num_node_entries(), 0); std::vector addto_entry(idx.num_node_entries(), 0); @@ -48,35 +47,41 @@ Graph DetectInplaceAddTo(Graph g) { ++ref_count[idx.entry_id(e)]; } for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { - for (auto &e : idx[nid].inputs) { + for (auto& e : idx[nid].inputs) { ++ref_count[idx.entry_id(e)]; } } for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { const auto& inode = idx[nid]; - if (inode.source->op() != ewise_plus_op) continue; + if (inode.source->op() != ewise_plus_op) + continue; int sid = storage_id[idx.entry_id(inode.inputs[0])]; - if (sid != storage_id[idx.entry_id(nid, 0)]) continue; - if (idx[inode.inputs[0].node_id].source->is_variable()) continue; - if (idx[inode.inputs[1].node_id].source->is_variable()) continue; - uint32_t eid_rhs = idx.entry_id(inode.inputs[1]); - if (ref_count[eid_rhs] != 1) continue; - if (inode.inputs[0].node_id >= inode.inputs[1].node_id) continue; + if (sid != storage_id[idx.entry_id(nid, 0)]) + continue; + if (idx[inode.inputs[0].node_id].source->is_variable()) + continue; + if (idx[inode.inputs[1].node_id].source->is_variable()) + continue; + uint32_t eid_rhs = idx.entry_id(inode.inputs[1]); + if (ref_count[eid_rhs] != 1) + continue; + if (inode.inputs[0].node_id >= inode.inputs[1].node_id) + continue; // TODO(haibin) support inplace addto for Dynamic Storage - if (storage_id[eid_rhs] == kDynamicStorageID) continue; + if (storage_id[eid_rhs] == kDynamicStorageID) + continue; CHECK_NE(storage_id[eid_rhs], sid); - storage_id[eid_rhs] = sid; - addto_entry[eid_rhs] = 1; + storage_id[eid_rhs] = sid; + addto_entry[eid_rhs] = 1; storage_inplace_index[eid_rhs] = -1; - skip_plus_node[nid] = 1; + skip_plus_node[nid] = 1; } - g.attrs["storage_id"] = std::make_shared(std::move(storage_id)); - g.attrs["storage_inplace_index"] = std::make_shared( - std::move(storage_inplace_index)); - g.attrs["addto_entry"] = std::make_shared(std::move(addto_entry)); - g.attrs["skip_plus_node"] = std::make_shared(std::move(skip_plus_node)); + g.attrs["storage_id"] = std::make_shared(std::move(storage_id)); + g.attrs["storage_inplace_index"] = std::make_shared(std::move(storage_inplace_index)); + g.attrs["addto_entry"] = std::make_shared(std::move(addto_entry)); + g.attrs["skip_plus_node"] = std::make_shared(std::move(skip_plus_node)); return g; } diff --git a/src/imperative/naive_cached_op.cc b/src/imperative/naive_cached_op.cc index 8ca0b09cd109..851501d427f5 100644 --- a/src/imperative/naive_cached_op.cc +++ b/src/imperative/naive_cached_op.cc @@ -25,28 +25,23 @@ #include "../operator/operator_common.h" #include "../operator/subgraph/common.h" - namespace mxnet { -OpStatePtr NaiveCachedOp::Forward( - const std::shared_ptr& op_ptr, - const std::vector& inputs, - const std::vector& outputs, - const Context& default_ctx) { - +OpStatePtr NaiveCachedOp::Forward(const std::shared_ptr& op_ptr, + const std::vector& inputs, + const std::vector& outputs, + const Context& default_ctx) { CHECK_EQ(inputs.size(), num_inputs()); { auto state_ptr = GetCachedOpState(default_ctx); - auto& state = state_ptr.get_state(); + auto& state = state_ptr.get_state(); const auto& idx = state.info.fwd_graph.indexed_graph(); for (size_t i = 0; i < inputs.size(); ++i) { CHECK_EQ(inputs[i]->ctx(), default_ctx) << "CachedOp requires all inputs to live on the same context. But " - << idx[idx.input_nodes()[0]].source->attrs.name - << " is on " << default_ctx << " while " - << idx[idx.input_nodes()[i]].source->attrs.name - << " is on " << inputs[i]->ctx(); + << idx[idx.input_nodes()[0]].source->attrs.name << " is on " << default_ctx << " while " + << idx[idx.input_nodes()[i]].source->attrs.name << " is on " << inputs[i]->ctx(); } } @@ -54,20 +49,20 @@ OpStatePtr NaiveCachedOp::Forward( try { // Initialize bool recording = false; - op_state = OpStatePtr::Create(); - auto& runtime = op_state.get_state(); + op_state = OpStatePtr::Create(); + auto& runtime = op_state.get_state(); { auto state_ptr = GetCachedOpState(default_ctx); - auto& state = state_ptr.get_state(); + auto& state = state_ptr.get_state(); std::lock_guard lock(state.mutex); SetForwardGraph(default_ctx, &state.info, recording, inputs); runtime.info.fwd_graph = state.info.fwd_graph; runtime.info.input_map = state.info.input_map; } - nnvm::Graph& g = runtime.info.fwd_graph; + nnvm::Graph& g = runtime.info.fwd_graph; const auto& idx = g.indexed_graph(); - auto& buff = runtime.buff; - auto& states = runtime.op_states; + auto& buff = runtime.buff; + auto& states = runtime.op_states; // Allocate entries buff.resize(idx.num_node_entries()); @@ -78,22 +73,35 @@ OpStatePtr NaiveCachedOp::Forward( arrays.push_back(&buffered_array); } std::vector array_reqs(arrays.size(), kWriteTo); - const auto& dispatch_modes = g.GetAttr("dispatch_mode"); + const auto& dispatch_modes = g.GetAttr("dispatch_mode"); const std::string& graph_type = recording ? FULL : FORWARD; std::vector ref_count = - g.GetAttr >(AddPrefix(graph_type, REF_COUNT)); + g.GetAttr >(AddPrefix(graph_type, REF_COUNT)); for (size_t i = 0; i < idx.num_node_entries(); ++i) { - if (ref_count[i] == 0) array_reqs[i] = kNullOp; + if (ref_count[i] == 0) + array_reqs[i] = kNullOp; } CollectInputOutputNDRefs(g, inputs, runtime.info.input_map, outputs, &arrays); mxnet::ShapeVector shapes = g.GetAttr("shape"); - imperative::NaiveRunGraph(false, default_ctx, idx, arrays, 0, idx.num_nodes(), - std::move(array_reqs), std::move(ref_count), &states, - dispatch_modes, false, &shapes, nullptr, false, true); + imperative::NaiveRunGraph(false, + default_ctx, + idx, + arrays, + 0, + idx.num_nodes(), + std::move(array_reqs), + std::move(ref_count), + &states, + dispatch_modes, + false, + &shapes, + nullptr, + false, + true); { - auto state_ptr = GetCachedOpState(default_ctx); - auto& state = state_ptr.get_state(); + auto state_ptr = GetCachedOpState(default_ctx); + auto& state = state_ptr.get_state(); auto copied_shape = shapes; std::lock_guard lock(state.mutex); state.info.fwd_graph.attrs["shape"] = std::make_shared(std::move(copied_shape)); @@ -105,5 +113,4 @@ OpStatePtr NaiveCachedOp::Forward( return op_state; } - } // namespace mxnet diff --git a/src/imperative/naive_cached_op.h b/src/imperative/naive_cached_op.h index f762f0bcc92e..cd1365508f97 100644 --- a/src/imperative/naive_cached_op.h +++ b/src/imperative/naive_cached_op.h @@ -30,41 +30,36 @@ #include #include "./cached_op.h" - - namespace mxnet { /*! \brief NaiveCachedOp which does not involve engine which is useful when executed in parallel. It does not support advanced features of CachedOp, including backward/recording, etc... */ class NaiveCachedOp : public CachedOp { public: - NaiveCachedOp( - const nnvm::Symbol &sym, - const std::vector> &flags) : CachedOp(sym, flags) {} + NaiveCachedOp(const nnvm::Symbol& sym, + const std::vector>& flags) + : CachedOp(sym, flags) {} virtual ~NaiveCachedOp() {} - OpStatePtr Forward( - const std::shared_ptr& op_ptr, - const std::vector& inputs, - const std::vector& outputs, - const Context& default_ctx) override; - void Backward( - const bool retain_graph, - const OpStatePtr& state, - const std::vector& inputs, - const std::vector& reqs, - const std::vector& outputs) override { - LOG(FATAL) << "Backward is not supported in NaiveCachedOp."; - } + OpStatePtr Forward(const std::shared_ptr& op_ptr, + const std::vector& inputs, + const std::vector& outputs, + const Context& default_ctx) override; + void Backward(const bool retain_graph, + const OpStatePtr& state, + const std::vector& inputs, + const std::vector& reqs, + const std::vector& outputs) override { + LOG(FATAL) << "Backward is not supported in NaiveCachedOp."; + } // backward storage type inference - bool BackwardStorageType( - const nnvm::NodeAttrs& attrs, - const int dev_mask, - DispatchMode* dispatch_mode, - std::vector *in_attrs, - std::vector *out_attrs) override { - LOG(FATAL) << "Backward is not supported in NaiveCachedOp."; - return false; - } + bool BackwardStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector* in_attrs, + std::vector* out_attrs) override { + LOG(FATAL) << "Backward is not supported in NaiveCachedOp."; + return false; + } }; // NaiveCachedOp using NaiveCachedOpPtr = std::shared_ptr; diff --git a/src/imperative/pointwise_fusion_pass.cc b/src/imperative/pointwise_fusion_pass.cc index 860b77f2617c..1b47e9270c71 100644 --- a/src/imperative/pointwise_fusion_pass.cc +++ b/src/imperative/pointwise_fusion_pass.cc @@ -49,8 +49,8 @@ void WarnFusionNotSupported() { << "Unset env var MXNET_USE_FUSION=1 to quiet this message."; #else LOG(WARNING) << "Omitting dynamic fused op creation- needs MXNet lib built with " - << "USE_CUDA=1. Unset env var MXNET_USE_FUSION=1 " - << "to quiet this message."; + << "USE_CUDA=1. Unset env var MXNET_USE_FUSION=1 " + << "to quiet this message."; #endif // defined(_WIN32) } } @@ -68,24 +68,21 @@ bool IsFusionCompatible(const nnvm::Node* n) { return true; if (slice_ops.count(op_name)) return false; - if (std::find(variable_io_ops.begin(), - variable_io_ops.end(), - op_name) != - variable_io_ops.end()) + if (std::find(variable_io_ops.begin(), variable_io_ops.end(), op_name) != variable_io_ops.end()) return true; if (op_name == "LeakyReLU") { - std::string act_type = n->attrs.dict.at("act_type"); - if (LeakyReLU_ops.count(act_type)) - return true; - else - return false; + std::string act_type = n->attrs.dict.at("act_type"); + if (LeakyReLU_ops.count(act_type)) + return true; + else + return false; } if (op_name == "_backward_LeakyReLU") { - std::string act_type = n->attrs.dict.at("act_type"); - if (LeakyReLU_bwd_ops.count(act_type)) - return true; - else - return false; + std::string act_type = n->attrs.dict.at("act_type"); + if (LeakyReLU_bwd_ops.count(act_type)) + return true; + else + return false; } return false; } @@ -100,8 +97,7 @@ bool IsInputsOnlyCompatible(const nnvm::Node* n) { // slice with non-default step attribute is not supported // currently if (n->attrs.dict.count("step") && - !(n->attrs.dict.at("step") == "()" || - n->attrs.dict.at("step") == "[]")) { + !(n->attrs.dict.at("step") == "()" || n->attrs.dict.at("step") == "[]")) { return false; } } @@ -116,9 +112,9 @@ void CreateSubgraphNode(const nnvm::Graph& subgraph, static const Op* fused_op_ptr = Op::Get("_FusedOp"); subgraph_node->attrs.subgraphs.emplace_back(std::make_shared()); subgraph_node->attrs.subgraphs.back()->outputs = subgraph.outputs; - subgraph_node->attrs.dict["num_inputs"] = std::to_string(inputs_size); - subgraph_node->attrs.dict["num_outputs"] = std::to_string(subgraph.outputs.size()); - subgraph_node->attrs.op = fused_op_ptr; + subgraph_node->attrs.dict["num_inputs"] = std::to_string(inputs_size); + subgraph_node->attrs.dict["num_outputs"] = std::to_string(subgraph.outputs.size()); + subgraph_node->attrs.op = fused_op_ptr; subgraph_node->op()->attr_parser(&(subgraph_node->attrs)); } @@ -127,8 +123,7 @@ struct EntryInfo { int index; }; -inline int SetInsert(const EntryInfo& new_elem, - std::vector* elements) { +inline int SetInsert(const EntryInfo& new_elem, std::vector* elements) { for (size_t i = 0; i < elements->size(); ++i) { if ((new_elem.source_node == elements->at(i).source_node) && (new_elem.index == elements->at(i).index)) { @@ -152,7 +147,7 @@ inline int SetInsert(const EntryInfo& new_elem, * \param num_subgraphs number of subgraphs. * \param create_subgraph_node function used to prepare the subgraph node. */ -template +template Graph CopyAndReplaceSubgraphs(const Graph& g, const std::vector& subgraph_assignment, const int num_subgraphs, @@ -165,8 +160,8 @@ Graph CopyAndReplaceSubgraphs(const Graph& g, const auto& idx = g.indexed_graph(); - CHECK_EQ(idx.num_nodes(), subgraph_assignment.size()) << - "Every node in the graph needs to be included in subgraph assignment."; + CHECK_EQ(idx.num_nodes(), subgraph_assignment.size()) + << "Every node in the graph needs to be included in subgraph assignment."; std::vector new_nodes; new_nodes.reserve(idx.num_nodes()); @@ -190,49 +185,43 @@ Graph CopyAndReplaceSubgraphs(const Graph& g, // subgraph. Variables are not copied. if (idx[i].source->op() != nullptr) { new_nodes.emplace_back(nnvm::Node::Create()); - auto& node_copy = new_nodes.back(); + auto& node_copy = new_nodes.back(); node_copy->attrs = idx[i].source->attrs; - node_copy->info = idx[i].source->info; + node_copy->info = idx[i].source->info; } else { new_nodes.emplace_back(idx[i].weak_ref.lock()); continue; } - auto& node_copy = new_nodes.back(); + auto& node_copy = new_nodes.back(); const int subgraph_id = subgraph_assignment[i]; if (subgraph_id != -1) { auto& info = subgraphs[subgraph_id]; for (const auto& input : idx[i].inputs) { const int their_subgraph = subgraph_assignment[input.node_id]; if (their_subgraph == subgraph_id) { - node_copy->inputs.emplace_back(new_nodes[input.node_id], - input.index, - input.version); + node_copy->inputs.emplace_back(new_nodes[input.node_id], input.index, input.version); } else { int input_num; int output_num; if (their_subgraph == -1) { - input_num = SetInsert({static_cast(input.node_id), - static_cast(input.index)}, &(info.inputs)); + input_num = SetInsert({static_cast(input.node_id), static_cast(input.index)}, + &(info.inputs)); } else { auto& their_subgraph_info = subgraphs[their_subgraph]; - output_num = SetInsert({static_cast(input.node_id), - static_cast(input.index)}, + output_num = SetInsert({static_cast(input.node_id), static_cast(input.index)}, &(their_subgraph_info.outputs)); - input_num = SetInsert({static_cast(idx.num_nodes() + their_subgraph), - output_num}, + input_num = SetInsert({static_cast(idx.num_nodes() + their_subgraph), output_num}, &(info.inputs)); } if (static_cast(input_num) == info.input_nodes.size()) { info.input_nodes.emplace_back(nnvm::Node::Create()); info.input_nodes.back()->attrs.name = "input_" + std::to_string(input_num); if (their_subgraph == -1) { - info.subgraph_node->inputs.emplace_back(new_nodes[input.node_id], - input.index, - input.version); + info.subgraph_node->inputs.emplace_back( + new_nodes[input.node_id], input.index, input.version); } else { - info.subgraph_node->inputs.emplace_back(subgraphs[their_subgraph].subgraph_node, - output_num, - input.version); + info.subgraph_node->inputs.emplace_back( + subgraphs[their_subgraph].subgraph_node, output_num, input.version); } } node_copy->inputs.emplace_back(info.input_nodes[input_num], 0, 0); @@ -242,17 +231,12 @@ Graph CopyAndReplaceSubgraphs(const Graph& g, for (const auto& input : idx[i].inputs) { const int subgraph_id = subgraph_assignment[input.node_id]; if (subgraph_id == -1) { - node_copy->inputs.emplace_back(new_nodes[input.node_id], - input.index, - input.version); + node_copy->inputs.emplace_back(new_nodes[input.node_id], input.index, input.version); } else { - auto& info = subgraphs[subgraph_id]; - const int output_num = SetInsert({static_cast(input.node_id), - static_cast(input.index)}, - &(info.outputs)); - node_copy->inputs.emplace_back(info.subgraph_node, - output_num, - input.version); + auto& info = subgraphs[subgraph_id]; + const int output_num = SetInsert( + {static_cast(input.node_id), static_cast(input.index)}, &(info.outputs)); + node_copy->inputs.emplace_back(info.subgraph_node, output_num, input.version); } } } @@ -269,25 +253,19 @@ Graph CopyAndReplaceSubgraphs(const Graph& g, for (const auto& output : idx.outputs()) { const int subgraph_id = subgraph_assignment[output.node_id]; if (subgraph_id == -1) { - ret.outputs.emplace_back(new_nodes[output.node_id], - output.index, - output.version); + ret.outputs.emplace_back(new_nodes[output.node_id], output.index, output.version); } else { - const int output_num = SetInsert({static_cast(output.node_id), - static_cast(output.index)}, - &(subgraphs[subgraph_id].outputs)); - ret.outputs.emplace_back(subgraphs[subgraph_id].subgraph_node, - output_num, - output.version); + const int output_num = + SetInsert({static_cast(output.node_id), static_cast(output.index)}, + &(subgraphs[subgraph_id].outputs)); + ret.outputs.emplace_back(subgraphs[subgraph_id].subgraph_node, output_num, output.version); } } for (auto& info : subgraphs) { info.graph.outputs.reserve(info.outputs.size()); for (const auto& entry_info : info.outputs) { - info.graph.outputs.emplace_back(new_nodes[entry_info.source_node], - entry_info.index, - 0); + info.graph.outputs.emplace_back(new_nodes[entry_info.source_node], entry_info.index, 0); } create_subgraph_node(info.graph, info.inputs.size(), info.subgraph_node.get()); } @@ -296,46 +274,41 @@ Graph CopyAndReplaceSubgraphs(const Graph& g, // Add _FusedOpHelper nodes const int subgraph_id = subgraph_assignment[i]; for (size_t dep_num = 0; dep_num < idx[i].control_deps.size(); ++dep_num) { - const auto& dep = idx[i].control_deps[dep_num]; + const auto& dep = idx[i].control_deps[dep_num]; const int their_subgraph_id = subgraph_assignment[dep]; if (subgraph_id != -1 && their_subgraph_id == -1) { // Not in any subgraph, use FusedOpOutHelper - auto& info = subgraphs[subgraph_id]; + auto& info = subgraphs[subgraph_id]; size_t node_id = info.subgraph_node->control_deps.size(); info.subgraph_node->control_deps.emplace_back(new_nodes[dep]); - auto helper_node = op::MakeNode("_FusedOpOutHelper", + auto helper_node = op::MakeNode("_FusedOpOutHelper", "FusedOp_" + new_nodes[i]->attrs.name + "_outhelper", nullptr, nullptr, nullptr); - helper_node->attrs.parsed = - FusedOpHelperParamPtr(new FusedOpHelperParam( - nnvm::get(info.subgraph_node->attrs.parsed), - node_id)); + helper_node->attrs.parsed = FusedOpHelperParamPtr(new FusedOpHelperParam( + nnvm::get(info.subgraph_node->attrs.parsed), node_id)); new_nodes[i]->control_deps.insert(new_nodes[i]->control_deps.begin() + dep_num, std::move(helper_node)); - } else if (their_subgraph_id != subgraph_id && - their_subgraph_id != -1) { - auto& info = subgraphs[their_subgraph_id]; + } else if (their_subgraph_id != subgraph_id && their_subgraph_id != -1) { + auto& info = subgraphs[their_subgraph_id]; const auto& subgraph_idx = info.graph.indexed_graph(); - uint32_t node_id = subgraph_idx.node_id(new_nodes[dep].get()); - auto helper_node = op::MakeNode("_FusedOpHelper", - info.subgraph_node->attrs.name + "_" - + idx[i].source->attrs.name + "_helper", - nullptr, - nullptr, - nullptr); - helper_node->attrs.parsed = - FusedOpHelperParamPtr(new FusedOpHelperParam( - nnvm::get(info.subgraph_node->attrs.parsed), - node_id)); + uint32_t node_id = subgraph_idx.node_id(new_nodes[dep].get()); + auto helper_node = op::MakeNode( + "_FusedOpHelper", + info.subgraph_node->attrs.name + "_" + idx[i].source->attrs.name + "_helper", + nullptr, + nullptr, + nullptr); + helper_node->attrs.parsed = FusedOpHelperParamPtr(new FusedOpHelperParam( + nnvm::get(info.subgraph_node->attrs.parsed), node_id)); new_nodes[i]->control_deps.insert(new_nodes[i]->control_deps.begin() + dep_num, std::move(helper_node)); } } } for (auto& info : subgraphs) { - const auto& idx = info.graph.indexed_graph(); + const auto& idx = info.graph.indexed_graph(); const auto& input_nodes = idx.input_nodes(); std::vector subgraph_inputs; subgraph_inputs.reserve(info.subgraph_node->inputs.size()); @@ -359,19 +332,18 @@ Graph CopyAndReplaceSubgraphs(const Graph& g, return ret; } -Graph FusePointwise(const Graph &g, const size_t num_forward_outputs) { - auto start = std::chrono::steady_clock::now(); - auto [subset_assignment, num_subsets] = GetCompatibleSubsets(g, num_forward_outputs, // NOLINT(*) +Graph FusePointwise(const Graph& g, const size_t num_forward_outputs) { + auto start = std::chrono::steady_clock::now(); + auto [subset_assignment, num_subsets] = GetCompatibleSubsets(g, // NOLINT(*) + num_forward_outputs, // NOLINT(*) IsFusionCompatible, IsInputsOnlyCompatible); - Graph ret = CopyAndReplaceSubgraphs(g, subset_assignment, num_subsets, - CreateSubgraphNode); - auto end = std::chrono::steady_clock::now(); + Graph ret = CopyAndReplaceSubgraphs(g, subset_assignment, num_subsets, CreateSubgraphNode); + auto end = std::chrono::steady_clock::now(); if (dmlc::GetEnv("MXNET_RTC_VERBOSE", false)) { auto diff = end - start; LOG(INFO) << "Pointwise fusion graph pass took: " - << std::chrono::duration(diff).count() - << "ms."; + << std::chrono::duration(diff).count() << "ms."; } return ret; } @@ -379,4 +351,3 @@ Graph FusePointwise(const Graph &g, const size_t num_forward_outputs) { } // namespace exec } // namespace mxnet - diff --git a/src/imperative/simple_partition_pass.cc b/src/imperative/simple_partition_pass.cc index 941959d4bb45..fc5e3f002d2f 100644 --- a/src/imperative/simple_partition_pass.cc +++ b/src/imperative/simple_partition_pass.cc @@ -36,10 +36,9 @@ namespace detail { const IntervalVec* LargerSet(const IntervalVec* const first, const IntervalVec* const second) noexcept { const IntervalVec* ret = nullptr; - auto first_iter = first->begin(); - auto second_iter = second->begin(); - while (first_iter != first->end() && - second_iter != second->end()) { + auto first_iter = first->begin(); + auto second_iter = second->begin(); + while (first_iter != first->end() && second_iter != second->end()) { if (*first_iter == *second_iter) { ++first_iter; ++second_iter; @@ -65,8 +64,7 @@ const IntervalVec* LargerSet(const IntervalVec* const first, continue; } // Entry in first set fully encloses the entry in the second set - if (first_iter->first <= second_iter->first && - first_iter->second >= second_iter->second) { + if (first_iter->first <= second_iter->first && first_iter->second >= second_iter->second) { if (ret == first || ret == nullptr) { ret = first; ++second_iter; @@ -76,8 +74,7 @@ const IntervalVec* LargerSet(const IntervalVec* const first, continue; } // Entry in second set fully encloses the entry in the first set - if (second_iter->first <= first_iter->first && - second_iter->second >= first_iter->second) { + if (second_iter->first <= first_iter->first && second_iter->second >= first_iter->second) { if (ret == second || ret == nullptr) { ret = second; ++first_iter; @@ -117,13 +114,12 @@ void MergeSets(const IntervalVec** const my_set, *my_set = larger_set; return; } - auto my_iter = (*my_set)->cbegin(); + auto my_iter = (*my_set)->cbegin(); auto other_iter = other_set->cbegin(); - auto new_set = IntervalVec(); - int last_end = -10; // less than -1 - while (my_iter != (*my_set)->cend() && - other_iter != other_set->cend()) { - const auto& mine = *my_iter; + auto new_set = IntervalVec(); + int last_end = -10; // less than -1 + while (my_iter != (*my_set)->cend() && other_iter != other_set->cend()) { + const auto& mine = *my_iter; const auto& other = *other_iter; if (other.second < mine.first - 1) { // other interval is before ours @@ -145,8 +141,7 @@ void MergeSets(const IntervalVec** const my_set, ++my_iter; } else { // Intervals can be merged together - Interval n(std::min(mine.first, other.first), - std::max(mine.second, other.second)); + Interval n(std::min(mine.first, other.first), std::max(mine.second, other.second)); if (last_end >= n.first - 1) { new_set.back().second = n.second; } else { @@ -162,10 +157,10 @@ void MergeSets(const IntervalVec** const my_set, } } auto remaining_iter = my_iter == (*my_set)->cend() ? other_iter : my_iter; - auto remaining_end = my_iter == (*my_set)->cend() ? other_set->cend() : (*my_set)->cend(); + auto remaining_end = my_iter == (*my_set)->cend() ? other_set->cend() : (*my_set)->cend(); // Add the rest of entries for (; remaining_iter != remaining_end; ++remaining_iter) { - auto& mine = new_set.back(); + auto& mine = new_set.back(); const auto& other = *remaining_iter; if (other.second < mine.first - 1) { // other interval is before ours, should never happen @@ -175,7 +170,7 @@ void MergeSets(const IntervalVec** const my_set, new_set.emplace_back(other); } else { // Intervals can be merged together - mine.first = std::min(mine.first, other.first); + mine.first = std::min(mine.first, other.first); mine.second = std::max(mine.second, other.second); } } @@ -183,12 +178,10 @@ void MergeSets(const IntervalVec** const my_set, *my_set = storage->back().get(); } -bool Intersect(const IntervalVec& checked_sets, - const IntervalVec& excluded_sets) noexcept { +bool Intersect(const IntervalVec& checked_sets, const IntervalVec& excluded_sets) noexcept { size_t current_interval = 0, current_other_interval = 0; - while (current_interval < checked_sets.size() && - current_other_interval < excluded_sets.size()) { - const auto& mine = checked_sets[current_interval]; + while (current_interval < checked_sets.size() && current_other_interval < excluded_sets.size()) { + const auto& mine = checked_sets[current_interval]; const auto& other = excluded_sets[current_other_interval]; if (other.second < mine.first) { // other interval is before ours @@ -204,23 +197,23 @@ bool Intersect(const IntervalVec& checked_sets, return false; } -void AddSet(const IntervalVec** const sets, const int set_to_add, +void AddSet(const IntervalVec** const sets, + const int set_to_add, std::vector>* const storage) noexcept { if (*sets != nullptr && (*sets)->size() != 0) { for (auto& interval : (**sets)) { - if (set_to_add >= interval.first && - set_to_add <= interval.second) { + if (set_to_add >= interval.first && set_to_add <= interval.second) { return; } } } - storage->emplace_back( - std::make_unique(1, std::make_pair(set_to_add, set_to_add))); + storage->emplace_back(std::make_unique(1, std::make_pair(set_to_add, set_to_add))); MergeSets(sets, storage->back().get(), storage); } int GetSetMapping(const int set, std::vector* const set_mapping) noexcept { - if (set == -1) return -1; + if (set == -1) + return -1; int temp = set; while ((*set_mapping)[temp] != temp) { temp = (*set_mapping)[temp]; @@ -229,17 +222,17 @@ int GetSetMapping(const int set, std::vector* const set_mapping) noexcept { return temp; } -void CheckAndUpdateCombinedExcludedSets(const IntervalVec** const combined_excluded_sets_ptr, - const IntervalVec* const new_excluded_sets, - std::vector* const excluded_sets_ptr, - const int set_id, - const int first_node_in_set, - const size_t new_node_id, - const std::vector& set_assignment, - std::vector* const set_mapping_ptr, - const IntervalVec& inverse_set_mapping, - std::vector>* const - storage) noexcept { +void CheckAndUpdateCombinedExcludedSets( + const IntervalVec** const combined_excluded_sets_ptr, + const IntervalVec* const new_excluded_sets, + std::vector* const excluded_sets_ptr, + const int set_id, + const int first_node_in_set, + const size_t new_node_id, + const std::vector& set_assignment, + std::vector* const set_mapping_ptr, + const IntervalVec& inverse_set_mapping, + std::vector>* const storage) noexcept { const auto* previous_excluded_sets = *combined_excluded_sets_ptr; MergeSets(combined_excluded_sets_ptr, new_excluded_sets, storage); if (new_excluded_sets != nullptr) { @@ -250,8 +243,7 @@ void CheckAndUpdateCombinedExcludedSets(const IntervalVec** const combined_exclu auto& excluded_sets = *excluded_sets_ptr; for (size_t j = first_node_in_set; j < new_node_id; ++j) { if (GetSetMapping(set_assignment[j], set_mapping_ptr) == set_id || - (excluded_sets[j] != nullptr && - Intersect(inverse_set_mapping, *excluded_sets[j]))) { + (excluded_sets[j] != nullptr && Intersect(inverse_set_mapping, *excluded_sets[j]))) { MergeSets(&excluded_sets[j], *combined_excluded_sets_ptr, storage); } } diff --git a/src/imperative/simple_partition_pass.h b/src/imperative/simple_partition_pass.h index 5b28c1796094..1d3825f3b630 100644 --- a/src/imperative/simple_partition_pass.h +++ b/src/imperative/simple_partition_pass.h @@ -43,14 +43,14 @@ namespace exec { namespace detail { -using Interval = std::pair; +using Interval = std::pair; using IntervalVec = std::vector; /* \brief Return the set that fully contains the other set, or nullptr * if neither set is a subset of another. */ -const IntervalVec* LargerSet(const IntervalVec* const first, - const IntervalVec* const second) noexcept; +const IntervalVec* LargerSet(const IntervalVec* const first, + const IntervalVec* const second) noexcept; /* \brief Compute the sum of the 2 sets and store it in my_set. */ @@ -61,12 +61,12 @@ void MergeSets(const IntervalVec** const my_set, /* \brief Returns true if there is non-empty intersection * between the 2 sets. */ -bool Intersect(const IntervalVec& checked_sets, - const IntervalVec& excluded_sets) noexcept; +bool Intersect(const IntervalVec& checked_sets, const IntervalVec& excluded_sets) noexcept; /* \brief Add a single entry to the sets. */ -void AddSet(const IntervalVec** const sets, const int set_to_add, +void AddSet(const IntervalVec** const sets, + const int set_to_add, std::vector>* const storage) noexcept; /* \brief Get the true mapping of the set (which could change @@ -78,8 +78,7 @@ int GetSetMapping(const int set, std::vector* const set_mapping) noexcept; * (so either both on the FWD side or the BWD side). */ inline bool IsSamePass(const int my_id, const int their_id, const int cutoff) noexcept { - return (my_id > cutoff && their_id > cutoff) || - (my_id <= cutoff && their_id <= cutoff); + return (my_id > cutoff && their_id > cutoff) || (my_id <= cutoff && their_id <= cutoff); } /* \brief Check if adding a new node to the set changes the excluded set of the future @@ -97,21 +96,20 @@ inline bool IsSamePass(const int my_id, const int their_id, const int cutoff) no * \param inverse_set_mapping inverse mapping of the set * \param storage memory storage */ -void CheckAndUpdateCombinedExcludedSets(const IntervalVec** const combined_excluded_sets_ptr, - const IntervalVec* const new_excluded_sets, - std::vector* const excluded_sets_ptr, - const int set_id, - const int first_node_in_set, - const size_t new_node_id, - const std::vector& set_assignment, - std::vector* const set_mapping_ptr, - const IntervalVec& inverse_set_mapping, - std::vector>* const - storage) noexcept; +void CheckAndUpdateCombinedExcludedSets( + const IntervalVec** const combined_excluded_sets_ptr, + const IntervalVec* const new_excluded_sets, + std::vector* const excluded_sets_ptr, + const int set_id, + const int first_node_in_set, + const size_t new_node_id, + const std::vector& set_assignment, + std::vector* const set_mapping_ptr, + const IntervalVec& inverse_set_mapping, + std::vector>* const storage) noexcept; } // namespace detail - /* \brief Get all subsets of nodes, where: * - graph constructed from nodes in each subset is a connected graph * - every node fulfills a predicate is_compatible @@ -131,13 +129,12 @@ void CheckAndUpdateCombinedExcludedSets(const IntervalVec** const combined_exclu * need to be excluded). * \return tuple (subset assignment, number of found subsets) */ -template +template std::tuple, int> GetCompatibleSubsets( const Graph& g, const size_t num_forward_outputs, FCompatible is_compatible, FInputOnlyCompatible is_input_only_compatible) { - using namespace detail; const auto& idx = g.indexed_graph(); std::vector set_assignment(idx.num_nodes(), -1); @@ -158,7 +155,7 @@ std::tuple, int> GetCompatibleSubsets( int num_sets = 0; for (size_t i = 0; i < idx.num_nodes(); ++i) { - const auto& node = idx[i]; + const auto& node = idx[i]; auto& my_excluded_sets = excluded_sets[i]; for (const auto& input : node.inputs) { MergeSets(&my_excluded_sets, excluded_sets[input.node_id], &storage); @@ -167,11 +164,10 @@ std::tuple, int> GetCompatibleSubsets( int my_set = -1; for (const auto& input : node.inputs) { int their_set = GetSetMapping(set_assignment[input.node_id], &set_mapping); - if (their_set != -1 && - their_set != my_set && + if (their_set != -1 && their_set != my_set && IsSamePass(i, input.node_id, last_forward_node) && (my_excluded_sets == nullptr || - !Intersect(*inverse_set_mapping[their_set], *my_excluded_sets))) { + !Intersect(*inverse_set_mapping[their_set], *my_excluded_sets))) { if (my_set == -1) { my_set = their_set; CheckAndUpdateCombinedExcludedSets(&(combined_excluded_sets[their_set]), @@ -185,12 +181,10 @@ std::tuple, int> GetCompatibleSubsets( *(inverse_set_mapping[their_set]), &storage); } else { - MergeSets(&inverse_set_mapping[my_set], - inverse_set_mapping[their_set], - &storage); + MergeSets(&inverse_set_mapping[my_set], inverse_set_mapping[their_set], &storage); set_mapping[their_set] = my_set; - first_node_in_set[my_set] = std::min(first_node_in_set[my_set], - first_node_in_set[their_set]); + first_node_in_set[my_set] = + std::min(first_node_in_set[my_set], first_node_in_set[their_set]); CheckAndUpdateCombinedExcludedSets(&(combined_excluded_sets[their_set]), combined_excluded_sets[my_set], &excluded_sets, @@ -208,9 +202,8 @@ std::tuple, int> GetCompatibleSubsets( set_mapping.emplace_back(num_sets); combined_excluded_sets.emplace_back(my_excluded_sets); first_node_in_set.emplace_back(i); - storage.emplace_back(std::make_unique>( - 1, std::make_pair(num_sets, - num_sets))); + storage.emplace_back( + std::make_unique>(1, std::make_pair(num_sets, num_sets))); inverse_set_mapping.emplace_back(storage.back().get()); my_set = num_sets++; } @@ -222,14 +215,12 @@ std::tuple, int> GetCompatibleSubsets( AddSet(&my_excluded_sets, their_set, &storage); } } - if ((is_input_only_compatible != nullptr) && - is_input_only_compatible(node.source)) { + if ((is_input_only_compatible != nullptr) && is_input_only_compatible(node.source)) { set_mapping.emplace_back(num_sets); combined_excluded_sets.emplace_back(my_excluded_sets); first_node_in_set.emplace_back(i); - storage.emplace_back(std::make_unique>( - 1, std::make_pair(num_sets, - num_sets))); + storage.emplace_back( + std::make_unique>(1, std::make_pair(num_sets, num_sets))); inverse_set_mapping.emplace_back(storage.back().get()); set_assignment[i] = num_sets++; } diff --git a/src/initialize.cc b/src/initialize.cc index 6be13e61ae9e..b1a472bf9d3b 100644 --- a/src/initialize.cc +++ b/src/initialize.cc @@ -1,4 +1,4 @@ - /* +/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information @@ -32,17 +32,16 @@ * \brief Retrieve the system error message for the last-error code * \param err string that gets the error message */ -void win_err(char **err) { +void win_err(char** err) { uint32_t dw = GetLastError(); FormatMessage( - FORMAT_MESSAGE_ALLOCATE_BUFFER | - FORMAT_MESSAGE_FROM_SYSTEM | - FORMAT_MESSAGE_IGNORE_INSERTS, - nullptr, - dw, - MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), - reinterpret_cast(err), - 0, nullptr); + FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, + nullptr, + dw, + MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), + reinterpret_cast(err), + 0, + nullptr); } #else #include @@ -64,8 +63,6 @@ void win_err(char **err) { #include "common/utils.h" #include "engine/openmp.h" - - namespace mxnet { // pthread_atfork handlers, delegated to LibraryInitializer members. @@ -88,12 +85,12 @@ void pthread_atfork_child() { // LibraryInitializer member functions LibraryInitializer::LibraryInitializer() - : original_pid_(common::current_process_id()), - mp_worker_nthreads_(dmlc::GetEnv("MXNET_MP_WORKER_NTHREADS", 1)), - cpu_worker_nthreads_(dmlc::GetEnv("MXNET_CPU_WORKER_NTHREADS", 1)), - mp_cv_num_threads_(dmlc::GetEnv("MXNET_MP_OPENCV_NUM_THREADS", 0)) { + : original_pid_(common::current_process_id()), + mp_worker_nthreads_(dmlc::GetEnv("MXNET_MP_WORKER_NTHREADS", 1)), + cpu_worker_nthreads_(dmlc::GetEnv("MXNET_CPU_WORKER_NTHREADS", 1)), + mp_cv_num_threads_(dmlc::GetEnv("MXNET_MP_OPENCV_NUM_THREADS", 0)) { dmlc::InitLogging("mxnet"); - engine::OpenMP::Get(); // force OpenMP initialization + engine::OpenMP::Get(); // force OpenMP initialization install_pthread_atfork_handlers(); } @@ -109,14 +106,14 @@ bool LibraryInitializer::lib_is_loaded(const std::string& path) const { * \return handle a pointer for the loaded library, throws dmlc::error if library can't be loaded */ void* LibraryInitializer::lib_load(const char* path) { - void *handle = nullptr; + void* handle = nullptr; // check if library was already loaded if (!lib_is_loaded(path)) { // if not, load it #if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) handle = LoadLibrary(path); if (!handle) { - char *err_msg = nullptr; + char* err_msg = nullptr; win_err(&err_msg); LOG(FATAL) << "Error loading library: '" << path << "'\n" << err_msg; LocalFree(err_msg); @@ -161,7 +158,7 @@ void LibraryInitializer::lib_close(void* handle) { #else if (dlclose(handle)) { LOG(WARNING) << "LibraryInitializer::lib_close: couldn't close library at address: " << handle - << " loaded from: '" << libpath << "': " << dlerror(); + << " loaded from: '" << libpath << "': " << dlerror(); } #endif // _WIN32 or _WIN64 or __WINDOWS__ loaded_libs.erase(libpath); @@ -177,7 +174,7 @@ void LibraryInitializer::get_sym(void* handle, void** func, const char* name) { #if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) *func = GetProcAddress((HMODULE)handle, name); if (!(*func)) { - char *err_msg = nullptr; + char* err_msg = nullptr; win_err(&err_msg); LOG(FATAL) << "Error getting function '" << name << "' from library\n" << err_msg; LocalFree(err_msg); @@ -220,7 +217,6 @@ void LibraryInitializer::atfork_child() { CustomOperator::Get()->Start(); } - void LibraryInitializer::install_pthread_atfork_handlers() { #ifndef _WIN32 engine::OpenMP::Get()->initialize_process(); // force omp to set its atfork handler first @@ -228,20 +224,15 @@ void LibraryInitializer::install_pthread_atfork_handlers() { #endif } - - - #if MXNET_USE_SIGNAL_HANDLER && DMLC_LOG_STACK_TRACE -static inline void printStackTrace(FILE *out = stderr, - const unsigned int max_frames = 63) { - +static inline void printStackTrace(FILE* out = stderr, const unsigned int max_frames = 63) { #if !defined(_WIN32) && !defined(_WIN64) && !defined(__WINDOWS__) // storage array for stack trace address data - void* addrlist[max_frames+1]; + void* addrlist[max_frames + 1]; // retrieve current stack addresses - size_t addrlen = backtrace(addrlist, sizeof(addrlist)/sizeof(void*)); + size_t addrlen = backtrace(addrlist, sizeof(addrlist) / sizeof(void*)); if (addrlen < 5) { return; @@ -250,7 +241,6 @@ static inline void printStackTrace(FILE *out = stderr, } fprintf(out, "Stack trace:\n"); - // resolve addresses into strings containing "filename(function+address)", // Actually it will be ## program address function + offset // this array must be free()-ed @@ -261,7 +251,7 @@ static inline void printStackTrace(FILE *out = stderr, // iterate over the returned symbol lines. skip the first, it is the // address of this function. - for (unsigned int i = 4; i < addrlen ; i++) { + for (unsigned int i = 4; i < addrlen; i++) { char* begin_name = nullptr; char* begin_offset = nullptr; char* end_offset = nullptr; @@ -269,40 +259,37 @@ static inline void printStackTrace(FILE *out = stderr, // find parentheses and +address offset surrounding the mangled name #ifdef DARWIN // OSX style stack trace - for (char *p = symbollist[i]; *p; ++p) { - if (*p == '_' && *(p-1) == ' ') { - begin_name = p-1; + for (char* p = symbollist[i]; *p; ++p) { + if (*p == '_' && *(p - 1) == ' ') { + begin_name = p - 1; } else if (*p == '+') { - begin_offset = p-1; + begin_offset = p - 1; } } if (begin_name && begin_offset && begin_name < begin_offset) { - *begin_name++ = '\0'; + *begin_name++ = '\0'; *begin_offset++ = '\0'; // mangled name is now in [begin_name, begin_offset) and caller // offset in [begin_offset, end_offset). now apply // __cxa_demangle(): int status; - char* ret = abi::__cxa_demangle(begin_name, &funcname[0], - &funcnamesize, &status); + char* ret = abi::__cxa_demangle(begin_name, &funcname[0], &funcnamesize, &status); if (status == 0) { funcname = ret; // use possibly realloc()-ed string - fprintf(out, " %-30s %-40s %s\n", - symbollist[i], funcname, begin_offset); + fprintf(out, " %-30s %-40s %s\n", symbollist[i], funcname, begin_offset); } else { // demangling failed. Output function name as a C function with // no arguments. - fprintf(out, " %-30s %-38s() %s\n", - symbollist[i], begin_name, begin_offset); + fprintf(out, " %-30s %-38s() %s\n", symbollist[i], begin_name, begin_offset); } } else { - // couldn't parse the line? print the whole line. - fprintf(out, " %-40s\n", symbollist[i]); + // couldn't parse the line? print the whole line. + fprintf(out, " %-40s\n", symbollist[i]); } #else - for (char *p = symbollist[i]; *p; ++p) { + for (char* p = symbollist[i]; *p; ++p) { if (*p == '(') { begin_name = p; } else if (*p == '+') { @@ -323,24 +310,22 @@ static inline void printStackTrace(FILE *out = stderr, // offset in [begin_offset, end_offset). now apply // __cxa_demangle(): - int status = 0; - char* ret = abi::__cxa_demangle(begin_name, funcname, - &funcnamesize, &status); + int status = 0; + char* ret = abi::__cxa_demangle(begin_name, funcname, &funcnamesize, &status); char* fname = begin_name; if (status == 0) { fname = ret; } if (begin_offset) { - fprintf(out, " %-30s ( %-40s + %-6s) %s\n", - symbollist[i], fname, begin_offset, end_offset); + fprintf( + out, " %-30s ( %-40s + %-6s) %s\n", symbollist[i], fname, begin_offset, end_offset); } else { - fprintf(out, " %-30s ( %-40s %-6s) %s\n", - symbollist[i], fname, "", end_offset); + fprintf(out, " %-30s ( %-40s %-6s) %s\n", symbollist[i], fname, "", end_offset); } } else { - // couldn't parse the line? print the whole line. - fprintf(out, " %-40s\n", symbollist[i]); + // couldn't parse the line? print the whole line. + fprintf(out, " %-40s\n", symbollist[i]); } #endif // !DARWIN - but is posix } @@ -348,32 +333,33 @@ static inline void printStackTrace(FILE *out = stderr, #endif } -#define SIGNAL_HANDLER(SIGNAL, HANDLER_NAME, IS_FATAL) \ -std::shared_ptr HANDLER_NAME( \ - signal(SIGNAL, [](int signum) { \ - if (IS_FATAL) { \ - printf("\nFatal Error: %s\n", strsignal(SIGNAL)); \ - printStackTrace(); \ - signal(signum, SIG_DFL); \ - raise(signum); \ - } else { \ - switch (signum) { \ - case SIGSEGV: \ - LOG(FATAL) << "InternalError: " << strsignal(SIGNAL); \ - break; \ - case SIGFPE: \ - LOG(FATAL) << "FloatingPointError: " << strsignal(SIGNAL); \ - break; \ - case SIGBUS: \ - LOG(FATAL) << "IOError: " << strsignal(SIGNAL); \ - break; \ - default: \ - LOG(FATAL) << "RuntimeError: " << strsignal(SIGNAL); \ - break; \ - } \ - } \ - }), \ - [](auto f) { signal(SIGNAL, f); }); +#define SIGNAL_HANDLER(SIGNAL, HANDLER_NAME, IS_FATAL) \ + std::shared_ptr HANDLER_NAME( \ + signal(SIGNAL, \ + [](int signum) { \ + if (IS_FATAL) { \ + printf("\nFatal Error: %s\n", strsignal(SIGNAL)); \ + printStackTrace(); \ + signal(signum, SIG_DFL); \ + raise(signum); \ + } else { \ + switch (signum) { \ + case SIGSEGV: \ + LOG(FATAL) << "InternalError: " << strsignal(SIGNAL); \ + break; \ + case SIGFPE: \ + LOG(FATAL) << "FloatingPointError: " << strsignal(SIGNAL); \ + break; \ + case SIGBUS: \ + LOG(FATAL) << "IOError: " << strsignal(SIGNAL); \ + break; \ + default: \ + LOG(FATAL) << "RuntimeError: " << strsignal(SIGNAL); \ + break; \ + } \ + } \ + }), \ + [](auto f) { signal(SIGNAL, f); }); SIGNAL_HANDLER(SIGSEGV, SIGSEGVHandler, true); SIGNAL_HANDLER(SIGFPE, SIGFPEHandler, false); diff --git a/src/initialize.h b/src/initialize.h index d792613aefb2..5ee650c81654 100644 --- a/src/initialize.h +++ b/src/initialize.h @@ -1,4 +1,4 @@ - /* +/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information @@ -28,14 +28,11 @@ #include #include "dmlc/io.h" - #ifndef MXNET_INITIALIZE_H_ #define MXNET_INITIALIZE_H_ namespace mxnet { - - void pthread_atfork_prepare(); void pthread_atfork_parent(); void pthread_atfork_child(); @@ -64,7 +61,6 @@ class LibraryInitializer { */ bool was_forked() const; - // Library loading bool lib_is_loaded(const std::string& path) const; void* lib_load(const char* path); @@ -113,8 +109,8 @@ class LibraryInitializer { * \param func_name function name to search for in the library * \return func a function pointer */ -template -T get_func(void *lib, const char *func_name) { +template +T get_func(void* lib, const char* func_name) { T func; LibraryInitializer::Get()->get_sym(lib, reinterpret_cast(&func), func_name); if (!func) diff --git a/src/io/batchify.cc b/src/io/batchify.cc index 01d93f5cad8f..acdcf4c74f57 100644 --- a/src/io/batchify.cc +++ b/src/io/batchify.cc @@ -38,58 +38,58 @@ namespace mxnet { namespace io { - #define tostr(s) #s #ifdef _MSC_VER - #if _MSC_VER < 1925 - #define omp_parallel(t) __pragma(omp parallel for num_threads(t)) - #else - #define omp_parallel(t) _Pragma(tostr(omp parallel for num_threads( ## t ## ))) - #endif +#if _MSC_VER < 1925 +#define omp_parallel(t) __pragma(omp parallel for num_threads(t)) +#else +#define omp_parallel(t) _Pragma(tostr(omp parallel for num_threads( ## t ## ))) +#endif #else - #define omp_parallel(t) _Pragma(tostr(omp parallel for num_threads(t))) +#define omp_parallel(t) _Pragma(tostr(omp parallel for num_threads(t))) #endif struct GroupBatchifyParam : public dmlc::Parameter { mxnet::Tuple functions; // declare parameters DMLC_DECLARE_PARAMETER(GroupBatchifyParam) { - DMLC_DECLARE_FIELD(functions) - .describe("Internal sequentially applied batchify functions. " - "The number of functions must match output of dataset items."); + DMLC_DECLARE_FIELD(functions).describe( + "Internal sequentially applied batchify functions. " + "The number of functions must match output of dataset items."); } }; // struct GroupBatchifyParam DMLC_REGISTER_PARAMETER(GroupBatchifyParam); class GroupBatchify : public BatchifyFunction { public: - explicit GroupBatchify(const std::vector >& kwargs) { + explicit GroupBatchify(const std::vector>& kwargs) { param_.InitAllowUnknown(kwargs); fs_.reserve(param_.functions.ndim()); for (int i = 0; i < param_.functions.ndim(); ++i) { - fs_.emplace_back(*static_cast( - reinterpret_cast(param_.functions[i]))); + fs_.emplace_back( + *static_cast(reinterpret_cast(param_.functions[i]))); } } - bool Batchify(const std::vector >& inputs, - std::vector* outputs) override { + bool Batchify(const std::vector>& inputs, + std::vector* outputs) override { auto bs = inputs.size(); CHECK_GT(bs, 0) << "BatchifyFunction should handle at lease 1 sample"; auto out_size = inputs[0].size(); - CHECK_EQ(out_size, fs_.size()) << "In GroupBatchifyFunction, Elem size " - << out_size << " and batchify function size " << fs_.size() << " must match"; + CHECK_EQ(out_size, fs_.size()) << "In GroupBatchifyFunction, Elem size " << out_size + << " and batchify function size " << fs_.size() << " must match"; outputs->resize(out_size); for (size_t i = 0; i < out_size; ++i) { - std::vector > inp; + std::vector> inp; inp.reserve(inputs.size()); - for (const auto & input : inputs) { - std::vector curr({input[i]}); - inp.emplace_back(curr); + for (const auto& input : inputs) { + std::vector curr({input[i]}); + inp.emplace_back(curr); } std::vector tmp; - if (!fs_[i]->Batchify(inp, &tmp)) return false; + if (!fs_[i]->Batchify(inp, &tmp)) + return false; (*outputs)[i] = tmp[0]; } return true; @@ -103,20 +103,19 @@ class GroupBatchify : public BatchifyFunction { }; // class GroupBatchify MXNET_REGISTER_IO_BATCHIFY_FUNCTION(GroupBatchify) - .describe(R"code(Returns the GroupBatchify function. + .describe(R"code(Returns the GroupBatchify function. )code" ADD_FILELINE) - .add_arguments(GroupBatchifyParam::__FIELDS__()) - .set_body([](const std::vector >& kwargs) { - return new GroupBatchify(kwargs); -}); + .add_arguments(GroupBatchifyParam::__FIELDS__()) + .set_body([](const std::vector>& kwargs) { + return new GroupBatchify(kwargs); + }); struct StackBatchifyParam : public dmlc::Parameter { /*! \brief Length of the sequence. */ int use_shared_mem; // declare parameters DMLC_DECLARE_PARAMETER(StackBatchifyParam) { - DMLC_DECLARE_FIELD(use_shared_mem).set_default(0) - .describe("If 1, use shared memory."); + DMLC_DECLARE_FIELD(use_shared_mem).set_default(0).describe("If 1, use shared memory."); } }; // struct StackBatchifyParam @@ -124,90 +123,87 @@ DMLC_REGISTER_PARAMETER(StackBatchifyParam); class StackBatchify : public BatchifyFunction { public: - explicit StackBatchify(const std::vector >& kwargs) { + explicit StackBatchify(const std::vector>& kwargs) { param_.InitAllowUnknown(kwargs); } - bool Batchify(const std::vector >& inputs, - std::vector* outputs) override { + bool Batchify(const std::vector>& inputs, + std::vector* outputs) override { auto out_size = SanityCheck(inputs); - auto bs = inputs.size(); + auto bs = inputs.size(); outputs->resize(out_size); for (size_t i = 0; i < out_size; ++i) { - // Process i-th output - mxnet::TShape ashape = inputs[0][i].shape(); - CHECK_GE(ashape.ndim(), 0) << "Data dim must be larger than 0"; - // check if all shapes are same - for (size_t j = 1; j < bs; ++j) { - CHECK_EQ(ashape, inputs[j][i].shape()) - << "StackBatchify requires all data along batch dim to be the same, " - << "mismatch " << ashape << " vs. " << inputs[j][i].shape(); - } + // Process i-th output + mxnet::TShape ashape = inputs[0][i].shape(); + CHECK_GE(ashape.ndim(), 0) << "Data dim must be larger than 0"; + // check if all shapes are same + for (size_t j = 1; j < bs; ++j) { + CHECK_EQ(ashape, inputs[j][i].shape()) + << "StackBatchify requires all data along batch dim to be the same, " + << "mismatch " << ashape << " vs. " << inputs[j][i].shape(); + } - // calculate output ndarray size - TShape sshape(ashape.ndim() + 1, 0); - sshape[0] = bs; - for (int k = 0; k < ashape.ndim(); ++k) { - sshape[k + 1] = ashape[k]; - } + // calculate output ndarray size + TShape sshape(ashape.ndim() + 1, 0); + sshape[0] = bs; + for (int k = 0; k < ashape.ndim(); ++k) { + sshape[k + 1] = ashape[k]; + } - int dtype = inputs[0][i].dtype(); - if (!(*outputs)[i].is_none() && (*outputs)[i].ctx() == mxnet::Context::CPU(0) && - (*outputs)[i].dtype() == dtype && - (*outputs)[i].storage_type() == kDefaultStorage) { - if ((*outputs)[i].shape() != sshape) { - // realloc - (*outputs)[i].ReshapeAndAlloc(sshape); - } - } else { - (*outputs)[i] = NDArray(sshape, mxnet::Context::CPU(0), false, inputs[0][i].dtype()); + int dtype = inputs[0][i].dtype(); + if (!(*outputs)[i].is_none() && (*outputs)[i].ctx() == mxnet::Context::CPU(0) && + (*outputs)[i].dtype() == dtype && (*outputs)[i].storage_type() == kDefaultStorage) { + if ((*outputs)[i].shape() != sshape) { + // realloc + (*outputs)[i].ReshapeAndAlloc(sshape); } - int sbs = static_cast(bs); - MSHADOW_TYPE_SWITCH_WITH_BOOL(dtype, DType, { - omp_parallel(bs) - for (int j = 0; j < sbs; ++j) { - omp_exc_.Run([&] { - // inputs[j][i].WaitToRead(); - DType *ptr = (*outputs)[i].data().dptr(); - auto asize = ashape.Size(); - RunContext rctx{(*outputs)[i].ctx(), nullptr, nullptr, false}; - auto dst = TBlob( - ptr + asize * j, inputs[j][i].data().shape_, cpu::kDevMask, dtype, 0); - mxnet::ndarray::Copy( + } else { + (*outputs)[i] = NDArray(sshape, mxnet::Context::CPU(0), false, inputs[0][i].dtype()); + } + int sbs = static_cast(bs); + MSHADOW_TYPE_SWITCH_WITH_BOOL(dtype, DType, { + omp_parallel(bs) for (int j = 0; j < sbs; ++j) { + omp_exc_.Run([&] { + // inputs[j][i].WaitToRead(); + DType* ptr = (*outputs)[i].data().dptr(); + auto asize = ashape.Size(); + RunContext rctx{(*outputs)[i].ctx(), nullptr, nullptr, false}; + auto dst = TBlob(ptr + asize * j, inputs[j][i].data().shape_, cpu::kDevMask, dtype, 0); + mxnet::ndarray::Copy( inputs[j][i].data(), &dst, Context::CPU(), Context::CPU(), rctx); - }); - } - omp_exc_.Rethrow(); - }) + }); + } + omp_exc_.Rethrow(); + }) } return true; } + private: /*! \brief parameters */ StackBatchifyParam param_; /*! \brief OMPException obj to store and rethrow exceptions from omp blocks*/ dmlc::OMPException omp_exc_; - std::size_t SanityCheck(const std::vector >& inputs) { + std::size_t SanityCheck(const std::vector>& inputs) { auto bs = inputs.size(); CHECK_GT(bs, 0) << "BatchifyFunction should handle at lease 1 sample"; auto out_size = inputs[0].size(); // sanity check: each input has same size for (size_t i = 1; i < bs; ++i) { - CHECK_EQ(inputs[i].size(), out_size) - << i << "-th input size does not match " << out_size; + CHECK_EQ(inputs[i].size(), out_size) << i << "-th input size does not match " << out_size; } return out_size; } }; // class StackBatchify MXNET_REGISTER_IO_BATCHIFY_FUNCTION(StackBatchify) - .describe(R"code(Returns the StackBatchify function. + .describe(R"code(Returns the StackBatchify function. )code" ADD_FILELINE) - .add_arguments(StackBatchifyParam::__FIELDS__()) - .set_body([](const std::vector >& kwargs) { - return new StackBatchify(kwargs); -}); + .add_arguments(StackBatchifyParam::__FIELDS__()) + .set_body([](const std::vector>& kwargs) { + return new StackBatchify(kwargs); + }); struct PadBatchifyParam : public dmlc::Parameter { int use_shared_mem; @@ -216,14 +212,12 @@ struct PadBatchifyParam : public dmlc::Parameter { int round_to; // declare parameters DMLC_DECLARE_PARAMETER(PadBatchifyParam) { - DMLC_DECLARE_FIELD(use_shared_mem).set_default(0) - .describe("If 1, use shared memory."); - DMLC_DECLARE_FIELD(pad_val).set_default(0) - .describe("The filled values, default to 0."); - DMLC_DECLARE_FIELD(dtype).set_default(-1) - .describe("If not -1, force to use dtype as output type, otherwise use input type."); - DMLC_DECLARE_FIELD(round_to).set_default(-1) - .describe("If > 0, the padded dimension will be rounded to be multiple of this value."); + DMLC_DECLARE_FIELD(use_shared_mem).set_default(0).describe("If 1, use shared memory."); + DMLC_DECLARE_FIELD(pad_val).set_default(0).describe("The filled values, default to 0."); + DMLC_DECLARE_FIELD(dtype).set_default(-1).describe( + "If not -1, force to use dtype as output type, otherwise use input type."); + DMLC_DECLARE_FIELD(round_to).set_default(-1).describe( + "If > 0, the padded dimension will be rounded to be multiple of this value."); } }; // struct PadBatchifyParam @@ -231,128 +225,136 @@ DMLC_REGISTER_PARAMETER(PadBatchifyParam); class PadBatchify : public BatchifyFunction { public: - explicit PadBatchify(const std::vector >& kwargs) { + explicit PadBatchify(const std::vector>& kwargs) { param_.InitAllowUnknown(kwargs); } - bool Batchify(const std::vector >& inputs, - std::vector* outputs) override { + bool Batchify(const std::vector>& inputs, + std::vector* outputs) override { auto bs = inputs.size(); CHECK_GT(bs, 0) << "BatchifyFunction should handle at lease 1 sample"; auto out_size = inputs[0].size(); outputs->resize(out_size); for (size_t i = 0; i < out_size; ++i) { - // Process i-th output - mxnet::TShape ashape = inputs[0][i].shape(); - CHECK_GE(ashape.ndim(), 0) << "Data dim must be larger than 0"; - // find the maximum size in each dim - for (size_t j = 1; j < bs; ++j) { - mxnet::TShape other_shape = inputs[j][i].shape(); - CHECK_EQ(ashape.ndim(), other_shape.ndim()) - << "PadBatchify expects all inputs to have same dimensionality: given " - << ashape.ndim() << " vs. " << other_shape.ndim(); - for (dim_t k = 0; k < ashape.ndim(); ++k) { - ashape[k] = std::max(ashape[k], other_shape[k]); - } - } + // Process i-th output + mxnet::TShape ashape = inputs[0][i].shape(); + CHECK_GE(ashape.ndim(), 0) << "Data dim must be larger than 0"; + // find the maximum size in each dim + for (size_t j = 1; j < bs; ++j) { + mxnet::TShape other_shape = inputs[j][i].shape(); + CHECK_EQ(ashape.ndim(), other_shape.ndim()) + << "PadBatchify expects all inputs to have same dimensionality: given " << ashape.ndim() + << " vs. " << other_shape.ndim(); for (dim_t k = 0; k < ashape.ndim(); ++k) { - // pad to multiple of round_to - if (param_.round_to > 0) { - ashape[k] = param_.round_to * static_cast( - std::ceil(static_cast(ashape[k] / param_.round_to))); - } + ashape[k] = std::max(ashape[k], other_shape[k]); } - - // calculate output ndarray size - TShape sshape(ashape.ndim() + 1, 0); - sshape[0] = bs; - for (int k = 0; k < ashape.ndim(); ++k) { - sshape[k + 1] = ashape[k]; + } + for (dim_t k = 0; k < ashape.ndim(); ++k) { + // pad to multiple of round_to + if (param_.round_to > 0) { + ashape[k] = param_.round_to * + static_cast(std::ceil(static_cast(ashape[k] / param_.round_to))); } + } - int dtype = param_.dtype > -1 ? param_.dtype : inputs[0][i].dtype(); - if (!(*outputs)[i].is_none() && - (*outputs)[i].ctx() == mxnet::Context::CPU(0) && - (*outputs)[i].dtype() == dtype && - (*outputs)[i].storage_type() == kDefaultStorage) { - if ((*outputs)[i].shape() != sshape) { - // realloc - (*outputs)[i].ReshapeAndAlloc(sshape); - } - } else { - (*outputs)[i] = NDArray(sshape, mxnet::Context::CPU(0), false, inputs[0][i].dtype()); + // calculate output ndarray size + TShape sshape(ashape.ndim() + 1, 0); + sshape[0] = bs; + for (int k = 0; k < ashape.ndim(); ++k) { + sshape[k + 1] = ashape[k]; + } + + int dtype = param_.dtype > -1 ? param_.dtype : inputs[0][i].dtype(); + if (!(*outputs)[i].is_none() && (*outputs)[i].ctx() == mxnet::Context::CPU(0) && + (*outputs)[i].dtype() == dtype && (*outputs)[i].storage_type() == kDefaultStorage) { + if ((*outputs)[i].shape() != sshape) { + // realloc + (*outputs)[i].ReshapeAndAlloc(sshape); } - MSHADOW_TYPE_SWITCH_WITH_BOOL(dtype, DType, { - // fill pad value first - std::fill((*outputs)[i].data().dptr(), - (*outputs)[i].data().dptr() + sshape.Size(), - static_cast(param_.pad_val)); - DType *ptr = (*outputs)[i].data().dptr(); - auto asize = ashape.Size(); - int sbs = static_cast(bs); - omp_parallel(bs) - for (int j = 0; j < sbs; ++j) { - using namespace mshadow::expr; - auto compact_shapes = CompactShapes(ashape, inputs[j][i].shape()); - // inputs[j][i].WaitToRead(); - auto& fshape = compact_shapes.first; - auto& cshape = compact_shapes.second; - switch (fshape.size()) { - case 1U: { - mshadow::Tensor dst = TBlob( - ptr + asize * j, ashape, cpu::kDevMask, dtype, 0).get_with_shape( - mshadow::Shape1(fshape[0])); - mshadow::Tensor src = inputs[j][i].data().get_with_shape< - cpu, 1, DType>(mshadow::Shape1(cshape[0])); - slice<0>(dst, 0, cshape[0]) = src; - break; - } - case 2U: { - mshadow::Tensor dst = TBlob( - ptr + asize * j, ashape, cpu::kDevMask, dtype, 0).get_with_shape( - mshadow::Shape2(fshape[0], fshape[1])); - mshadow::Tensor src = inputs[j][i].data().get_with_shape< - cpu, 2, DType>(mshadow::Shape2(cshape[0], cshape[1])); - slice<1>(slice<0>(dst, 0, cshape[0]), 0, cshape[1]) = src; - break; - } - case 3U: { - mshadow::Tensor dst = TBlob( - ptr + asize * j, ashape, cpu::kDevMask, dtype, 0).get_with_shape( - mshadow::Shape3(fshape[0], fshape[1], fshape[2])); - mshadow::Tensor src = inputs[j][i].data().get_with_shape< - cpu, 3, DType>(mshadow::Shape3(cshape[0], cshape[1], cshape[2])); - slice<2>(slice<1>(slice<0>(dst, 0, cshape[0]), 0, cshape[1]), 0, cshape[2]) = src; - break; - } - case 4U: { - mshadow::Tensor dst = TBlob( - ptr + asize * j, ashape, cpu::kDevMask, dtype, 0).get_with_shape( - mshadow::Shape4(fshape[0], fshape[1], fshape[2], fshape[3])); - mshadow::Tensor src = inputs[j][i].data().get_with_shape< - cpu, 4, DType>(mshadow::Shape4(cshape[0], cshape[1], cshape[2], cshape[3])); - slice<3>(slice<2>(slice<1>(slice<0>(dst, 0, cshape[0]), 0, cshape[1]), - 0, cshape[2]), 0, cshape[3]) = src; - break; - } - case 5U: { - mshadow::Tensor dst = TBlob( - ptr + asize * j, ashape, cpu::kDevMask, dtype, 0).get_with_shape( - mshadow::Shape5(fshape[0], fshape[1], fshape[2], fshape[3], fshape[4])); - mshadow::Tensor src = inputs[j][i].data().get_with_shape< - cpu, 5, DType>(mshadow::Shape5( - cshape[0], cshape[1], cshape[2], cshape[3], cshape[4])); - slice<4>(slice<3>(slice<2>(slice<1>(slice<0>( - dst, 0, cshape[0]), 0, cshape[1]), 0, cshape[2]), - 0, cshape[3]), 0, cshape[4]) = src; - break; - } - default: { - LOG(FATAL) << "# dim to pad: " << cshape.size() << " exceeds limit of 5."; - } + } else { + (*outputs)[i] = NDArray(sshape, mxnet::Context::CPU(0), false, inputs[0][i].dtype()); + } + MSHADOW_TYPE_SWITCH_WITH_BOOL(dtype, DType, { + // fill pad value first + std::fill((*outputs)[i].data().dptr(), + (*outputs)[i].data().dptr() + sshape.Size(), + static_cast(param_.pad_val)); + DType* ptr = (*outputs)[i].data().dptr(); + auto asize = ashape.Size(); + int sbs = static_cast(bs); + omp_parallel(bs) for (int j = 0; j < sbs; ++j) { + using namespace mshadow::expr; + auto compact_shapes = CompactShapes(ashape, inputs[j][i].shape()); + // inputs[j][i].WaitToRead(); + auto& fshape = compact_shapes.first; + auto& cshape = compact_shapes.second; + switch (fshape.size()) { + case 1U: { + mshadow::Tensor dst = + TBlob(ptr + asize * j, ashape, cpu::kDevMask, dtype, 0) + .get_with_shape(mshadow::Shape1(fshape[0])); + mshadow::Tensor src = + inputs[j][i].data().get_with_shape(mshadow::Shape1(cshape[0])); + slice<0>(dst, 0, cshape[0]) = src; + break; + } + case 2U: { + mshadow::Tensor dst = + TBlob(ptr + asize * j, ashape, cpu::kDevMask, dtype, 0) + .get_with_shape(mshadow::Shape2(fshape[0], fshape[1])); + mshadow::Tensor src = + inputs[j][i].data().get_with_shape( + mshadow::Shape2(cshape[0], cshape[1])); + slice<1>(slice<0>(dst, 0, cshape[0]), 0, cshape[1]) = src; + break; + } + case 3U: { + mshadow::Tensor dst = + TBlob(ptr + asize * j, ashape, cpu::kDevMask, dtype, 0) + .get_with_shape( + mshadow::Shape3(fshape[0], fshape[1], fshape[2])); + mshadow::Tensor src = + inputs[j][i].data().get_with_shape( + mshadow::Shape3(cshape[0], cshape[1], cshape[2])); + slice<2>(slice<1>(slice<0>(dst, 0, cshape[0]), 0, cshape[1]), 0, cshape[2]) = src; + break; + } + case 4U: { + mshadow::Tensor dst = + TBlob(ptr + asize * j, ashape, cpu::kDevMask, dtype, 0) + .get_with_shape( + mshadow::Shape4(fshape[0], fshape[1], fshape[2], fshape[3])); + mshadow::Tensor src = + inputs[j][i].data().get_with_shape( + mshadow::Shape4(cshape[0], cshape[1], cshape[2], cshape[3])); + slice<3>(slice<2>(slice<1>(slice<0>(dst, 0, cshape[0]), 0, cshape[1]), 0, cshape[2]), + 0, + cshape[3]) = src; + break; + } + case 5U: { + mshadow::Tensor dst = + TBlob(ptr + asize * j, ashape, cpu::kDevMask, dtype, 0) + .get_with_shape( + mshadow::Shape5(fshape[0], fshape[1], fshape[2], fshape[3], fshape[4])); + mshadow::Tensor src = + inputs[j][i].data().get_with_shape( + mshadow::Shape5(cshape[0], cshape[1], cshape[2], cshape[3], cshape[4])); + slice<4>( + slice<3>( + slice<2>(slice<1>(slice<0>(dst, 0, cshape[0]), 0, cshape[1]), 0, cshape[2]), + 0, + cshape[3]), + 0, + cshape[4]) = src; + break; + } + default: { + LOG(FATAL) << "# dim to pad: " << cshape.size() << " exceeds limit of 5."; } } - }) + } + }) } return true; } @@ -403,11 +405,11 @@ class PadBatchify : public BatchifyFunction { }; // class PadBatchify MXNET_REGISTER_IO_BATCHIFY_FUNCTION(PadBatchify) - .describe(R"code(Returns the StackBatchify function. + .describe(R"code(Returns the StackBatchify function. )code" ADD_FILELINE) - .add_arguments(PadBatchifyParam::__FIELDS__()) - .set_body([](const std::vector >& kwargs) { - return new PadBatchify(kwargs); -}); + .add_arguments(PadBatchifyParam::__FIELDS__()) + .set_body([](const std::vector>& kwargs) { + return new PadBatchify(kwargs); + }); } // namespace io } // namespace mxnet diff --git a/src/io/dataloader.cc b/src/io/dataloader.cc index 47754470453c..a3808e2128bc 100644 --- a/src/io/dataloader.cc +++ b/src/io/dataloader.cc @@ -45,22 +45,19 @@ struct ThreadedDataLoaderParam : public dmlc::Parameter int pin_device_id; // declare parameters DMLC_DECLARE_PARAMETER(ThreadedDataLoaderParam) { - DMLC_DECLARE_FIELD(num_workers).set_default(0) - .describe("Number of thread workers."); - DMLC_DECLARE_FIELD(dataset) - .describe("Pointer to shared Dataset."); - DMLC_DECLARE_FIELD(sampler) - .describe("Pointer to Sampler."); - DMLC_DECLARE_FIELD(batchify_fn) - .describe("Pointer to Batchify function."); - DMLC_DECLARE_FIELD(pin_device_id).set_default(-1) - .describe("If not negative, will move data to pinned memory."); + DMLC_DECLARE_FIELD(num_workers).set_default(0).describe("Number of thread workers."); + DMLC_DECLARE_FIELD(dataset).describe("Pointer to shared Dataset."); + DMLC_DECLARE_FIELD(sampler).describe("Pointer to Sampler."); + DMLC_DECLARE_FIELD(batchify_fn).describe("Pointer to Batchify function."); + DMLC_DECLARE_FIELD(pin_device_id) + .set_default(-1) + .describe("If not negative, will move data to pinned memory."); } }; // struct ThreadedDataLoaderParam DMLC_REGISTER_PARAMETER(ThreadedDataLoaderParam); -template +template class ThreadedDataLoader : public IIterator { public: ThreadedDataLoader() = default; @@ -70,20 +67,18 @@ class ThreadedDataLoader : public IIterator { void Init(const std::vector >& kwargs) override { param_.InitAllowUnknown(kwargs); int maxthread, threadget; - #pragma omp parallel +#pragma omp parallel { // be conservative, set number of real cores maxthread = std::max(omp_get_num_procs(), 1); } param_.num_workers = std::min(maxthread, param_.num_workers); - #pragma omp parallel num_threads(param_.num_workers) - { - threadget = omp_get_num_threads(); - } +#pragma omp parallel num_threads(param_.num_workers) + { threadget = omp_get_num_threads(); } param_.num_workers = std::max(1, threadget); - dataset_ = *static_cast*>(reinterpret_cast(param_.dataset)); + dataset_ = *static_cast*>(reinterpret_cast(param_.dataset)); dataset_len_ = dataset_->GetLen(); - sampler_ = static_cast* >(reinterpret_cast(param_.sampler)); + sampler_ = static_cast*>(reinterpret_cast(param_.sampler)); batchify_fn_ = *static_cast(reinterpret_cast(param_.batchify_fn)); this->BeforeFirst(); } @@ -98,12 +93,12 @@ class ThreadedDataLoader : public IIterator { bool Next() override { bool has_next = sampler_->Next(); - if (!has_next) return false; - auto samples = sampler_->Value(); - auto batch_size = samples.data[0].shape().Size(); - int real_batch_size = batch_size - samples.num_batch_padd; - const int64_t *idx_ptr = static_cast( - samples.data[0].data().dptr_); + if (!has_next) + return false; + auto samples = sampler_->Value(); + auto batch_size = samples.data[0].shape().Size(); + int real_batch_size = batch_size - samples.num_batch_padd; + const int64_t* idx_ptr = static_cast(samples.data[0].data().dptr_); std::vector idx_ptrs; idx_ptrs.assign(idx_ptr, idx_ptr + real_batch_size); @@ -114,12 +109,11 @@ class ThreadedDataLoader : public IIterator { if (profiling) { profiler::CustomOpProfiler::Get()->OnCustomBegin("MXThreadedDataLoaderGetItems"); } - #pragma omp parallel for num_threads(param_.num_workers) +#pragma omp parallel for num_threads(param_.num_workers) for (int i = 0; i < real_batch_size; ++i) { omp_exc_.Run([&] { auto idx = idx_ptrs[i]; - CHECK(dataset_->GetItem(idx, &inputs[i])) - << "Error getting data # " << idx; + CHECK(dataset_->GetItem(idx, &inputs[i])) << "Error getting data # " << idx; }); } if (profiling) { @@ -137,7 +131,7 @@ class ThreadedDataLoader : public IIterator { profiler::CustomOpProfiler::Get()->OnCustomBegin("MXThreadedDataLoaderBatchify"); } CHECK(batchify_fn_->Batchify(inputs, &batched_buffer_)) - << "Error call batchify inside dataloader"; + << "Error call batchify inside dataloader"; if (profiling) { profiler::CustomOpProfiler::Get()->OnCustomEnd(); } @@ -150,7 +144,7 @@ class ThreadedDataLoader : public IIterator { return true; } - const TBlobBatch &Value() const override { + const TBlobBatch& Value() const override { return out_; } @@ -166,7 +160,7 @@ class ThreadedDataLoader : public IIterator { /*! \brief dataset length */ int64_t dataset_len_; /*! \brief pointer to sampler iterator */ - IIterator *sampler_; + IIterator* sampler_; /*! \brief pointer to batchify function */ BatchifyFunctionPtr batchify_fn_; /*! \brief OMPException obj to store and rethrow exceptions from omp blocks*/ @@ -174,13 +168,10 @@ class ThreadedDataLoader : public IIterator { }; // class ThreadedDataLoader MXNET_REGISTER_IO_ITER(ThreadedDataLoader) -.describe(R"code(Returns a threaded data loader iterator. + .describe(R"code(Returns a threaded data loader iterator. )code" ADD_FILELINE) -.add_arguments(ThreadedDataLoaderParam::__FIELDS__()) -.add_arguments(PrefetcherParam::__FIELDS__()) -.set_body([]() { - return new PrefetcherIter( - new ThreadedDataLoader()); - }); + .add_arguments(ThreadedDataLoaderParam::__FIELDS__()) + .add_arguments(PrefetcherParam::__FIELDS__()) + .set_body([]() { return new PrefetcherIter(new ThreadedDataLoader()); }); } // namespace io } // namespace mxnet diff --git a/src/io/dataset.cc b/src/io/dataset.cc index 31bffed88460..4690d340f06c 100644 --- a/src/io/dataset.cc +++ b/src/io/dataset.cc @@ -52,10 +52,8 @@ struct RecordFileDatasetParam : public dmlc::Parameter { std::string idx_file; // declare parameters DMLC_DECLARE_PARAMETER(RecordFileDatasetParam) { - DMLC_DECLARE_FIELD(rec_file) - .describe("The absolute path of record file."); - DMLC_DECLARE_FIELD(idx_file) - .describe("The path of the idx file."); + DMLC_DECLARE_FIELD(rec_file).describe("The absolute path of record file."); + DMLC_DECLARE_FIELD(idx_file).describe("The path of the idx file."); } }; // struct RecordFileDatasetParam @@ -63,11 +61,11 @@ DMLC_REGISTER_PARAMETER(RecordFileDatasetParam); class RecordFileDataset final : public Dataset { public: - explicit RecordFileDataset(const std::vector >& kwargs) { - std::vector > kwargs_left; + explicit RecordFileDataset(const std::vector>& kwargs) { + std::vector> kwargs_left; param_.InitAllowUnknown(kwargs); // read and process idx file - dmlc::Stream *idx_stream = dmlc::Stream::Create(param_.idx_file.c_str(), "r"); + dmlc::Stream* idx_stream = dmlc::Stream::Create(param_.idx_file.c_str(), "r"); dmlc::istream is(idx_stream); size_t key, idx; while (is >> key >> idx) { @@ -94,15 +92,20 @@ class RecordFileDataset final : public Dataset { reader->Seek(pos); static thread_local std::string read_buff; if (reader->NextRecord(&read_buff)) { - const char *buf = read_buff.c_str(); + const char* buf = read_buff.c_str(); const size_t size = read_buff.size(); out = NDArray(TShape({static_cast(size)}), Context::CPU(), false, mshadow::kInt8); TBlob dst = out.data(); RunContext rctx{Context::CPU(), nullptr, nullptr, false}; - mxnet::ndarray::Copy( - TBlob(const_cast(reinterpret_cast(buf)), - out.shape(), cpu::kDevMask, out.dtype(), 0), - &dst, Context::CPU(), Context::CPU(), rctx); + mxnet::ndarray::Copy(TBlob(const_cast(reinterpret_cast(buf)), + out.shape(), + cpu::kDevMask, + out.dtype(), + 0), + &dst, + Context::CPU(), + Context::CPU(), + rctx); } return true; } @@ -115,11 +118,11 @@ class RecordFileDataset final : public Dataset { }; MXNET_REGISTER_IO_DATASET(RecordFileDataset) - .describe("MXNet Record File Dataset") - .add_arguments(RecordFileDatasetParam::__FIELDS__()) - .set_body([](const std::vector >& kwargs) { - return new RecordFileDataset(kwargs); -}); + .describe("MXNet Record File Dataset") + .add_arguments(RecordFileDatasetParam::__FIELDS__()) + .set_body([](const std::vector>& kwargs) { + return new RecordFileDataset(kwargs); + }); struct ImageRecordFileDatasetParam : public dmlc::Parameter { std::string rec_file; @@ -127,21 +130,19 @@ struct ImageRecordFileDatasetParam : public dmlc::Parameter -void SwapImageChannels(const cv::Mat &img, NDArray* arr) { - int swap_indices[n_channels]; // NOLINT(*) +template +void SwapImageChannels(const cv::Mat& img, NDArray* arr) { + int swap_indices[n_channels]; // NOLINT(*) if (n_channels == 1) { swap_indices[0] = 0; } else if (n_channels == 3) { @@ -165,7 +166,7 @@ void SwapImageChannels(const cv::Mat &img, NDArray* arr) { // swap channels while copying elements into buffer for (int i = 0; i < img.rows; ++i) { const uint8_t* im_data = img.ptr(i); - uint8_t* buffer_data = ptr + i * img.cols * n_channels; + uint8_t* buffer_data = ptr + i * img.cols * n_channels; for (int j = 0; j < img.cols; ++j) { for (int k = 0; k < n_channels; ++k) { buffer_data[k] = im_data[swap_indices[k]]; @@ -188,8 +189,8 @@ struct IRHeader { class ImageRecordFileDataset : public Dataset { public: - explicit ImageRecordFileDataset(const std::vector >& kwargs) { - std::vector > kwargs_left; + explicit ImageRecordFileDataset(const std::vector>& kwargs) { + std::vector> kwargs_left; param_.InitAllowUnknown(kwargs); base_ = std::make_shared(kwargs); } @@ -201,9 +202,10 @@ class ImageRecordFileDataset : public Dataset { bool GetItem(uint64_t idx, std::vector* ret) override { CHECK_LT(idx, GetLen()); std::vector raw; - if (!base_->GetItem(idx, &raw)) return false; + if (!base_->GetItem(idx, &raw)) + return false; CHECK_EQ(raw.size(), 1U) << "RecordFileDataset should return size 1 NDArray vector"; - uint8_t *s = reinterpret_cast(raw[0].data().dptr_); + uint8_t* s = reinterpret_cast(raw[0].data().dptr_); size_t size = raw[0].shape().Size(); CHECK_GT(size, sizeof(IRHeader)) << "Invalid size of bytes from Record File"; IRHeader header; @@ -217,14 +219,17 @@ class ImageRecordFileDataset : public Dataset { label.ReshapeAndAlloc(label_shape); TBlob dst = label.data(); mxnet::ndarray::Copy( - TBlob(reinterpret_cast(s), label.shape(), cpu::kDevMask, label.dtype(), 0), - &dst, Context::CPU(), Context::CPU(), rctx); + TBlob(reinterpret_cast(s), label.shape(), cpu::kDevMask, label.dtype(), 0), + &dst, + Context::CPU(), + Context::CPU(), + rctx); s += sizeof(float) * header.flag; size -= sizeof(float) * header.flag; } else { // label is a scalar with ndim() == 0 label.ReshapeAndAlloc(TShape(0, 1)); - TBlob dst = label.data(); + TBlob dst = label.data(); *(dst.dptr()) = header.label; } ret->resize(2); @@ -243,9 +248,9 @@ class ImageRecordFileDataset : public Dataset { } return true; #else - LOG(FATAL) << "Opencv is needed for image decoding."; + LOG(FATAL) << "Opencv is needed for image decoding."; #endif - return false; // should not reach here + return false; // should not reach here } private: @@ -256,11 +261,11 @@ class ImageRecordFileDataset : public Dataset { }; MXNET_REGISTER_IO_DATASET(ImageRecordFileDataset) - .describe("MXNet Image Record File Dataset") - .add_arguments(ImageRecordFileDatasetParam::__FIELDS__()) - .set_body([](const std::vector >& kwargs) { - return new ImageRecordFileDataset(kwargs); -}); + .describe("MXNet Image Record File Dataset") + .add_arguments(ImageRecordFileDatasetParam::__FIELDS__()) + .set_body([](const std::vector>& kwargs) { + return new ImageRecordFileDataset(kwargs); + }); struct ImageSequenceDatasetParam : public dmlc::Parameter { /*! \brief the list of absolute image paths, separated by \0 characters */ @@ -268,18 +273,17 @@ struct ImageSequenceDatasetParam : public dmlc::Parameter >& kwargs) { - std::vector > kwargs_left; + explicit ImageSequenceDataset(const std::vector>& kwargs) { + std::vector> kwargs_left; param_.InitAllowUnknown(kwargs); img_list_ = dmlc::Split(param_.img_list, param_.path_sep); } @@ -300,7 +304,7 @@ class ImageSequenceDataset final : public Dataset { bool GetItem(uint64_t idx, std::vector* ret) override { #if MXNET_USE_OPENCV CHECK_LT(idx, img_list_.size()) - << "GetItem index: " << idx << " out of bound: " << img_list_.size(); + << "GetItem index: " << idx << " out of bound: " << img_list_.size(); cv::Mat res = cv::imread(img_list_[idx], param_.flag); CHECK(!res.empty()) << "Decoding failed. Invalid image file."; const int n_channels = res.channels(); @@ -314,9 +318,9 @@ class ImageSequenceDataset final : public Dataset { } return true; #else - LOG(FATAL) << "Opencv is needed for image decoding."; + LOG(FATAL) << "Opencv is needed for image decoding."; #endif - return false; + return false; } private: @@ -327,19 +331,18 @@ class ImageSequenceDataset final : public Dataset { }; MXNET_REGISTER_IO_DATASET(ImageSequenceDataset) - .describe("Image Sequence Dataset") - .add_arguments(ImageSequenceDatasetParam::__FIELDS__()) - .set_body([](const std::vector >& kwargs) { - return new ImageSequenceDataset(kwargs); -}); + .describe("Image Sequence Dataset") + .add_arguments(ImageSequenceDatasetParam::__FIELDS__()) + .set_body([](const std::vector>& kwargs) { + return new ImageSequenceDataset(kwargs); + }); struct NDArrayDatasetParam : public dmlc::Parameter { /*! \brief the source ndarray */ std::intptr_t arr; // declare parameters DMLC_DECLARE_PARAMETER(NDArrayDatasetParam) { - DMLC_DECLARE_FIELD(arr) - .describe("Pointer to NDArray."); + DMLC_DECLARE_FIELD(arr).describe("Pointer to NDArray."); } }; // struct NDArrayDatasetParam @@ -347,7 +350,7 @@ DMLC_REGISTER_PARAMETER(NDArrayDatasetParam); class NDArrayDataset final : public Dataset { public: - explicit NDArrayDataset(const std::vector >& kwargs) { + explicit NDArrayDataset(const std::vector>& kwargs) { param_.InitAllowUnknown(kwargs); data_ = *(static_cast(reinterpret_cast(param_.arr))); if (data_.shape().ndim() < 1) { @@ -361,11 +364,10 @@ class NDArrayDataset final : public Dataset { } bool GetItem(uint64_t idx, std::vector* rets) override { - CHECK_LT(idx, size_) - << "GetItem index: " << idx << " out of bound: " << size_; + CHECK_LT(idx, size_) << "GetItem index: " << idx << " out of bound: " << size_; rets->resize(1); auto& ret = (*rets)[0]; - ret = data_.Slice(idx, idx + 1); + ret = data_.Slice(idx, idx + 1); if (ret.shape().ndim() > 1) { // remove first dim to be consistent with numpy TShape new_shape; @@ -391,19 +393,18 @@ class NDArrayDataset final : public Dataset { }; // class NDArrayDataset MXNET_REGISTER_IO_DATASET(NDArrayDataset) - .describe("Single NDArray Dataset") - .add_arguments(NDArrayDatasetParam::__FIELDS__()) - .set_body([](const std::vector >& kwargs) { - return new NDArrayDataset(kwargs); -}); + .describe("Single NDArray Dataset") + .add_arguments(NDArrayDatasetParam::__FIELDS__()) + .set_body([](const std::vector>& kwargs) { + return new NDArrayDataset(kwargs); + }); struct GroupDatasetParam : public dmlc::Parameter { /*! \brief the source ndarray */ Tuple datasets; // declare parameters DMLC_DECLARE_PARAMETER(GroupDatasetParam) { - DMLC_DECLARE_FIELD(datasets) - .describe("A small set of pointers to other c++ datasets."); + DMLC_DECLARE_FIELD(datasets).describe("A small set of pointers to other c++ datasets."); } }; // struct GroupDatasetParam @@ -411,8 +412,8 @@ DMLC_REGISTER_PARAMETER(GroupDatasetParam); class GroupDataset final : public Dataset { public: - explicit GroupDataset(const std::vector >& kwargs) { - std::vector > kwargs_left; + explicit GroupDataset(const std::vector>& kwargs) { + std::vector> kwargs_left; param_.InitAllowUnknown(kwargs); auto childs = param_.datasets; childs_.reserve(childs.ndim()); @@ -422,9 +423,8 @@ class GroupDataset final : public Dataset { if (child_cnt == 0) { size_ = d->GetLen(); } else { - CHECK_EQ(size_, d->GetLen()) - << "All child dataset of GroupDataset must be identical " - << "Given mismatch: " << size_ << " vs " << d->GetLen(); + CHECK_EQ(size_, d->GetLen()) << "All child dataset of GroupDataset must be identical " + << "Given mismatch: " << size_ << " vs " << d->GetLen(); } childs_.emplace_back(d); child_cnt++; @@ -436,12 +436,12 @@ class GroupDataset final : public Dataset { } bool GetItem(uint64_t idx, std::vector* ret) override { - CHECK_LT(idx, size_) - << "GetItem index: " << idx << " out of bound: " << size_; + CHECK_LT(idx, size_) << "GetItem index: " << idx << " out of bound: " << size_; ret->clear(); for (const auto& child : childs_) { std::vector temp_ret; - if (!child->GetItem(idx, &temp_ret)) return false; + if (!child->GetItem(idx, &temp_ret)) + return false; ret->insert(ret->end(), temp_ret.begin(), temp_ret.end()); } return true; @@ -454,14 +454,14 @@ class GroupDataset final : public Dataset { std::vector> childs_; /*! \brief overall dataset size, equals to all child datasets */ uint64_t size_; -}; // class GroupDataset +}; // class GroupDataset MXNET_REGISTER_IO_DATASET(GroupDataset) - .describe("Grouped Dataset that combine a bunch of datasets") - .add_arguments(GroupDatasetParam::__FIELDS__()) - .set_body([](const std::vector >& kwargs) { - return new GroupDataset(kwargs); -}); + .describe("Grouped Dataset that combine a bunch of datasets") + .add_arguments(GroupDatasetParam::__FIELDS__()) + .set_body([](const std::vector>& kwargs) { + return new GroupDataset(kwargs); + }); struct IndexedDatasetParam : public dmlc::Parameter { /*! \brief the base dataset */ @@ -470,10 +470,10 @@ struct IndexedDatasetParam : public dmlc::Parameter { Tuple indices; // declare parameters DMLC_DECLARE_PARAMETER(IndexedDatasetParam) { - DMLC_DECLARE_FIELD(base) - .describe("Pointer to the internal c++ dataset that is going to be indexed."); - DMLC_DECLARE_FIELD(indices) - .describe("The indices for the internal dataset. Output[i] will be base[indices[i]]."); + DMLC_DECLARE_FIELD(base).describe( + "Pointer to the internal c++ dataset that is going to be indexed."); + DMLC_DECLARE_FIELD(indices).describe( + "The indices for the internal dataset. Output[i] will be base[indices[i]]."); } }; // struct IndexedDatasetParam @@ -481,7 +481,7 @@ DMLC_REGISTER_PARAMETER(IndexedDatasetParam); class IndexedDataset final : public Dataset { public: - explicit IndexedDataset(const std::vector >& kwargs) { + explicit IndexedDataset(const std::vector>& kwargs) { param_.InitAllowUnknown(kwargs); base_data_ = *static_cast*>(reinterpret_cast(param_.base)); } @@ -491,11 +491,12 @@ class IndexedDataset final : public Dataset { } bool GetItem(uint64_t idx, std::vector* ret) override { - CHECK_GT(param_.indices.ndim(), idx) << "IndexError: " << idx - << " from total: " << param_.indices.ndim(); + CHECK_GT(param_.indices.ndim(), idx) + << "IndexError: " << idx << " from total: " << param_.indices.ndim(); auto new_idx = param_.indices[idx]; - CHECK_GT(base_data_->GetLen(), new_idx) << "IndexError: " << new_idx - << " from original dataset with size: " << base_data_->GetLen(); + CHECK_GT(base_data_->GetLen(), new_idx) + << "IndexError: " << new_idx + << " from original dataset with size: " << base_data_->GetLen(); return base_data_->GetItem(new_idx, ret); } @@ -504,14 +505,14 @@ class IndexedDataset final : public Dataset { IndexedDatasetParam param_; /*! \brief stored child dataset */ std::shared_ptr base_data_; -}; // class IndexedDataset +}; // class IndexedDataset MXNET_REGISTER_IO_DATASET(IndexedDataset) - .describe("Grouped Dataset that combine a bunch of datasets") - .add_arguments(IndexedDatasetParam::__FIELDS__()) - .set_body([](const std::vector >& kwargs) { - return new IndexedDataset(kwargs); -}); + .describe("Grouped Dataset that combine a bunch of datasets") + .add_arguments(IndexedDatasetParam::__FIELDS__()) + .set_body([](const std::vector>& kwargs) { + return new IndexedDataset(kwargs); + }); struct LazyTransformDatasetParam : public dmlc::Parameter { /*! \brief the source ndarray */ @@ -524,16 +525,16 @@ struct LazyTransformDatasetParam : public dmlc::Parameter scalar_outputs; // declare parameters DMLC_DECLARE_PARAMETER(LazyTransformDatasetParam) { - DMLC_DECLARE_FIELD(cached_op) - .describe("Pointer to cached transform function."); - DMLC_DECLARE_FIELD(dataset) - .describe("Pointer to internal dataset."); - DMLC_DECLARE_FIELD(transform_indices).set_default(Tuple({})) - .describe("The indices for dataset items that need to be transformed/processed. " - "If `transform_indices` is empty(default), " - "then all items will be processed."); - DMLC_DECLARE_FIELD(scalar_outputs) - .describe("Indicate whether outputs are scalars, the size must match the output size."); + DMLC_DECLARE_FIELD(cached_op).describe("Pointer to cached transform function."); + DMLC_DECLARE_FIELD(dataset).describe("Pointer to internal dataset."); + DMLC_DECLARE_FIELD(transform_indices) + .set_default(Tuple({})) + .describe( + "The indices for dataset items that need to be transformed/processed. " + "If `transform_indices` is empty(default), " + "then all items will be processed."); + DMLC_DECLARE_FIELD(scalar_outputs) + .describe("Indicate whether outputs are scalars, the size must match the output size."); } }; // struct LazyTransformDatasetParam @@ -542,30 +543,29 @@ DMLC_REGISTER_PARAMETER(LazyTransformDatasetParam); class LazyTransformDataset final : public Dataset { public: LazyTransformDataset(const LazyTransformDataset& other) { - this->param_ = other.param_; + this->param_ = other.param_; this->pass_through_indices_ = other.pass_through_indices_; - this->use_input_indices_ = other.use_input_indices_; - this->num_outputs_ = other.num_outputs_; - this->cached_op_ = std::make_shared( - other.cached_op_->sym_, other.cached_op_->flags_); + this->use_input_indices_ = other.use_input_indices_; + this->num_outputs_ = other.num_outputs_; + this->cached_op_ = + std::make_shared(other.cached_op_->sym_, other.cached_op_->flags_); this->base_data_ = other.base_data_; } - explicit LazyTransformDataset(const std::vector >& kwargs) { + explicit LazyTransformDataset(const std::vector>& kwargs) { param_.InitAllowUnknown(kwargs); - auto op = *static_cast(reinterpret_cast(param_.cached_op)); + auto op = *static_cast(reinterpret_cast(param_.cached_op)); cached_op_ = std::make_shared(op->sym_, op->flags_); base_data_ = *static_cast*>(reinterpret_cast(param_.dataset)); // use first item to calculate size info - CHECK_GT(GetLen(), 0) - << "LazyTransformDataset expect the base dataset to have at least 1 item"; + CHECK_GT(GetLen(), 0) << "LazyTransformDataset expect the base dataset to have at least 1 item"; std::vector inputs; CHECK(base_data_->GetItem(0, &inputs)); // check output size CHECK_EQ(param_.scalar_outputs.ndim(), cached_op_->num_outputs()) - << "Output scalar info size: " << param_.scalar_outputs.ndim() << " vs. output size: " - << cached_op_->num_outputs() << " mismatch!"; + << "Output scalar info size: " << param_.scalar_outputs.ndim() + << " vs. output size: " << cached_op_->num_outputs() << " mismatch!"; // check input size if (param_.transform_indices.ndim() == 0) { std::vector default_indices; @@ -575,22 +575,20 @@ class LazyTransformDataset final : public Dataset { } use_input_indices_ = default_indices; } else { - use_input_indices_ = std::vector(param_.transform_indices.begin(), - param_.transform_indices.end()); + use_input_indices_ = + std::vector(param_.transform_indices.begin(), param_.transform_indices.end()); } CHECK_EQ(use_input_indices_.size(), cached_op_->num_inputs()) - << "Mismatched transform indices and transform inputs: " << use_input_indices_.size() - << " vs. " << cached_op_->num_inputs(); + << "Mismatched transform indices and transform inputs: " << use_input_indices_.size() + << " vs. " << cached_op_->num_inputs(); auto num_inputs = use_input_indices_.size(); - CHECK_GE(inputs.size(), num_inputs) - << "LazyTransformDataset input size " << inputs.size() - << " smaller than transform input size: " - << num_inputs; + CHECK_GE(inputs.size(), num_inputs) << "LazyTransformDataset input size " << inputs.size() + << " smaller than transform input size: " << num_inputs; pass_through_indices_.clear(); for (size_t i = 0; i < inputs.size(); ++i) { // filling output ndarray from unaltered inputs, transformed outputs are already inserted - if (std::find(use_input_indices_.begin(), - use_input_indices_.end(), i) == use_input_indices_.end()) { + if (std::find(use_input_indices_.begin(), use_input_indices_.end(), i) == + use_input_indices_.end()) { pass_through_indices_.emplace_back(i); } } @@ -605,7 +603,8 @@ class LazyTransformDataset final : public Dataset { bool GetItem(uint64_t idx, std::vector* outputs) override { std::vector inputs; - if (!base_data_->GetItem(idx, &inputs)) return false; + if (!base_data_->GetItem(idx, &inputs)) + return false; outputs->reserve(num_outputs_); outputs->resize(cached_op_->num_outputs()); for (auto i : pass_through_indices_) { @@ -625,7 +624,7 @@ class LazyTransformDataset final : public Dataset { ndoutputs.emplace_back(&(outputs->at(i))); } - for (auto & input : inputs) { + for (auto& input : inputs) { input.WaitToRead(); } CHECK(inputs.size() > 0) << "dataset getitem requires at least one input"; @@ -644,13 +643,13 @@ class LazyTransformDataset final : public Dataset { std::vector use_input_indices_; std::vector pass_through_indices_; size_t num_outputs_; -}; // class LazyTransformDataset +}; // class LazyTransformDataset MXNET_REGISTER_IO_DATASET(LazyTransformDataset) - .describe("Dataset that apply lazy transformation to internal dataset") - .add_arguments(LazyTransformDatasetParam::__FIELDS__()) - .set_body([](const std::vector >& kwargs) { - return new LazyTransformDataset(kwargs); -}); + .describe("Dataset that apply lazy transformation to internal dataset") + .add_arguments(LazyTransformDatasetParam::__FIELDS__()) + .set_body([](const std::vector>& kwargs) { + return new LazyTransformDataset(kwargs); + }); } // namespace io } // namespace mxnet diff --git a/src/io/image_aug_default.cc b/src/io/image_aug_default.cc index c39777ba3054..0c29e44a2754 100644 --- a/src/io/image_aug_default.cc +++ b/src/io/image_aug_default.cc @@ -101,92 +101,123 @@ struct DefaultImageAugmentParam : public dmlc::Parameter()) - .describe("Change the aspect (namely width/height) to a random value " - "in ``[min_aspect_ratio, max_aspect_ratio]``"); - DMLC_DECLARE_FIELD(max_shear_ratio).set_default(0.0f) - .describe("Apply a shear transformation (namely ``(x,y)->(x+my,y)``) " - "with ``m`` randomly chose from " - "``[-max_shear_ratio, max_shear_ratio]``"); - DMLC_DECLARE_FIELD(max_crop_size).set_default(-1) - .describe("Crop both width and height into a random size in " - "``[min_crop_size, max_crop_size].``" - "Ignored if ``random_resized_crop`` is True."); - DMLC_DECLARE_FIELD(min_crop_size).set_default(-1) - .describe("Crop both width and height into a random size in " - "``[min_crop_size, max_crop_size].``" - "Ignored if ``random_resized_crop`` is True."); - DMLC_DECLARE_FIELD(max_random_scale).set_default(1.0f) - .describe("Resize into ``[width*s, height*s]`` with ``s`` randomly" - " chosen from ``[min_random_scale, max_random_scale]``. " - "Ignored if ``random_resized_crop`` is True."); - DMLC_DECLARE_FIELD(min_random_scale).set_default(1.0f) - .describe("Resize into ``[width*s, height*s]`` with ``s`` randomly" - " chosen from ``[min_random_scale, max_random_scale]``" - "Ignored if ``random_resized_crop`` is True."); - DMLC_DECLARE_FIELD(max_random_area).set_default(1.0f) - .describe("Change the area (namely width * height) to a random value " - "in ``[min_random_area, max_random_area]``. " - "Ignored if ``random_resized_crop`` is False."); - DMLC_DECLARE_FIELD(min_random_area).set_default(1.0f) - .describe("Change the area (namely width * height) to a random value " - "in ``[min_random_area, max_random_area]``. " - "Ignored if ``random_resized_crop`` is False."); - DMLC_DECLARE_FIELD(max_img_size).set_default(1e10f) - .describe("Set the maximal width and height after all resize and" - " rotate argumentation are applied"); - DMLC_DECLARE_FIELD(min_img_size).set_default(0.0f) - .describe("Set the minimal width and height after all resize and" - " rotate argumentation are applied"); - DMLC_DECLARE_FIELD(brightness).set_default(0.0f) - .describe("Add a random value in ``[-brightness, brightness]`` to " - "the brightness of image."); - DMLC_DECLARE_FIELD(contrast).set_default(0.0f) - .describe("Add a random value in ``[-contrast, contrast]`` to " - "the contrast of image."); - DMLC_DECLARE_FIELD(saturation).set_default(0.0f) - .describe("Add a random value in ``[-saturation, saturation]`` to " - "the saturation of image."); - DMLC_DECLARE_FIELD(pca_noise).set_default(0.0f) - .describe("Add PCA based noise to the image."); - DMLC_DECLARE_FIELD(random_h).set_default(0) - .describe("Add a random value in ``[-random_h, random_h]`` to " - "the H channel in HSL color space."); - DMLC_DECLARE_FIELD(random_s).set_default(0) - .describe("Add a random value in ``[-random_s, random_s]`` to " - "the S channel in HSL color space."); - DMLC_DECLARE_FIELD(random_l).set_default(0) - .describe("Add a random value in ``[-random_l, random_l]`` to " - "the L channel in HSL color space."); - DMLC_DECLARE_FIELD(rotate).set_default(-1.0f) - .describe("Rotate by an angle. If set, it overwrites the ``max_rotate_angle`` option."); - DMLC_DECLARE_FIELD(fill_value).set_default(255) + DMLC_DECLARE_FIELD(max_aspect_ratio) + .set_default(0.0f) + .describe( + "Change the aspect (namely width/height) to a random value. " + "If min_aspect_ratio is None then the aspect ratio ins sampled from " + "[1 - max_aspect_ratio, 1 + max_aspect_ratio], " + "else it is in ``[min_aspect_ratio, max_aspect_ratio]``"); + DMLC_DECLARE_FIELD(min_aspect_ratio) + .set_default(dmlc::optional()) + .describe( + "Change the aspect (namely width/height) to a random value " + "in ``[min_aspect_ratio, max_aspect_ratio]``"); + DMLC_DECLARE_FIELD(max_shear_ratio) + .set_default(0.0f) + .describe( + "Apply a shear transformation (namely ``(x,y)->(x+my,y)``) " + "with ``m`` randomly chose from " + "``[-max_shear_ratio, max_shear_ratio]``"); + DMLC_DECLARE_FIELD(max_crop_size) + .set_default(-1) + .describe( + "Crop both width and height into a random size in " + "``[min_crop_size, max_crop_size].``" + "Ignored if ``random_resized_crop`` is True."); + DMLC_DECLARE_FIELD(min_crop_size) + .set_default(-1) + .describe( + "Crop both width and height into a random size in " + "``[min_crop_size, max_crop_size].``" + "Ignored if ``random_resized_crop`` is True."); + DMLC_DECLARE_FIELD(max_random_scale) + .set_default(1.0f) + .describe( + "Resize into ``[width*s, height*s]`` with ``s`` randomly" + " chosen from ``[min_random_scale, max_random_scale]``. " + "Ignored if ``random_resized_crop`` is True."); + DMLC_DECLARE_FIELD(min_random_scale) + .set_default(1.0f) + .describe( + "Resize into ``[width*s, height*s]`` with ``s`` randomly" + " chosen from ``[min_random_scale, max_random_scale]``" + "Ignored if ``random_resized_crop`` is True."); + DMLC_DECLARE_FIELD(max_random_area) + .set_default(1.0f) + .describe( + "Change the area (namely width * height) to a random value " + "in ``[min_random_area, max_random_area]``. " + "Ignored if ``random_resized_crop`` is False."); + DMLC_DECLARE_FIELD(min_random_area) + .set_default(1.0f) + .describe( + "Change the area (namely width * height) to a random value " + "in ``[min_random_area, max_random_area]``. " + "Ignored if ``random_resized_crop`` is False."); + DMLC_DECLARE_FIELD(max_img_size) + .set_default(1e10f) + .describe( + "Set the maximal width and height after all resize and" + " rotate argumentation are applied"); + DMLC_DECLARE_FIELD(min_img_size) + .set_default(0.0f) + .describe( + "Set the minimal width and height after all resize and" + " rotate argumentation are applied"); + DMLC_DECLARE_FIELD(brightness) + .set_default(0.0f) + .describe( + "Add a random value in ``[-brightness, brightness]`` to " + "the brightness of image."); + DMLC_DECLARE_FIELD(contrast).set_default(0.0f).describe( + "Add a random value in ``[-contrast, contrast]`` to " + "the contrast of image."); + DMLC_DECLARE_FIELD(saturation) + .set_default(0.0f) + .describe( + "Add a random value in ``[-saturation, saturation]`` to " + "the saturation of image."); + DMLC_DECLARE_FIELD(pca_noise).set_default(0.0f).describe("Add PCA based noise to the image."); + DMLC_DECLARE_FIELD(random_h).set_default(0).describe( + "Add a random value in ``[-random_h, random_h]`` to " + "the H channel in HSL color space."); + DMLC_DECLARE_FIELD(random_s).set_default(0).describe( + "Add a random value in ``[-random_s, random_s]`` to " + "the S channel in HSL color space."); + DMLC_DECLARE_FIELD(random_l).set_default(0).describe( + "Add a random value in ``[-random_l, random_l]`` to " + "the L channel in HSL color space."); + DMLC_DECLARE_FIELD(rotate).set_default(-1.0f).describe( + "Rotate by an angle. If set, it overwrites the ``max_rotate_angle`` option."); + DMLC_DECLARE_FIELD(fill_value) + .set_default(255) .describe("Set the padding pixels value to ``fill_value``."); DMLC_DECLARE_FIELD(data_shape) - .set_expect_ndim(3).enforce_nonzero() + .set_expect_ndim(3) + .enforce_nonzero() .describe("The shape of a output image."); - DMLC_DECLARE_FIELD(inter_method).set_default(1) - .describe("The interpolation method: 0-NN 1-bilinear 2-cubic 3-area " - "4-lanczos4 9-auto 10-rand."); - DMLC_DECLARE_FIELD(pad).set_default(0) - .describe("Change size from ``[width, height]`` into " - "``[pad + width + pad, pad + height + pad]`` by padding pixes"); + DMLC_DECLARE_FIELD(inter_method) + .set_default(1) + .describe( + "The interpolation method: 0-NN 1-bilinear 2-cubic 3-area " + "4-lanczos4 9-auto 10-rand."); + DMLC_DECLARE_FIELD(pad).set_default(0).describe( + "Change size from ``[width, height]`` into " + "``[pad + width + pad, pad + height + pad]`` by padding pixes"); } }; @@ -210,28 +241,33 @@ class DefaultImageAugmenter : public ImageAugmenter { std::vector > kwargs_left; kwargs_left = param_.InitAllowUnknown(kwargs); for (auto& kwarg : kwargs_left) { - if (!strcmp(kwarg.first.c_str(), "rotate_list")) { - const char* val = kwarg.second.c_str(); - const char *end = val + strlen(val); - char buf[128]; - while (val < end) { - sscanf(val, "%[^,]", buf); - val += strlen(buf) + 1; - rotate_list_.push_back(atoi(buf)); - } + if (!strcmp(kwarg.first.c_str(), "rotate_list")) { + const char* val = kwarg.second.c_str(); + const char* end = val + strlen(val); + char buf[128]; + while (val < end) { + sscanf(val, "%[^,]", buf); + val += strlen(buf) + 1; + rotate_list_.push_back(atoi(buf)); } + } } } /*! - * \brief get interpolation method with given inter_method, 0-CV_INTER_NN 1-CV_INTER_LINEAR 2-CV_INTER_CUBIC - * \ 3-CV_INTER_AREA 4-CV_INTER_LANCZOS4 9-AUTO(cubic for enlarge, area for shrink, bilinear for others) 10-RAND + * \brief get interpolation method with given inter_method, 0-CV_INTER_NN 1-CV_INTER_LINEAR + * 2-CV_INTER_CUBIC \ 3-CV_INTER_AREA 4-CV_INTER_LANCZOS4 9-AUTO(cubic for enlarge, area for + * shrink, bilinear for others) 10-RAND */ - int GetInterMethod(int inter_method, int old_width, int old_height, int new_width, - int new_height, common::RANDOM_ENGINE *prnd) { + int GetInterMethod(int inter_method, + int old_width, + int old_height, + int new_width, + int new_height, + common::RANDOM_ENGINE* prnd) { if (inter_method == 9) { if (new_width > old_width && new_height > old_height) { return 2; // CV_INTER_CUBIC for enlarge - } else if (new_width *label, - common::RANDOM_ENGINE *prnd) override { + cv::Mat Process(const cv::Mat& src, + std::vector* label, + common::RANDOM_ENGINE* prnd) override { using mshadow::index_t; bool is_cropped = false; @@ -262,36 +299,36 @@ class DefaultImageAugmenter : public ImageAugmenter { if (param_.resize != -1) { int new_height, new_width; if (src.rows > src.cols) { - new_height = param_.resize*src.rows/src.cols; - new_width = param_.resize; + new_height = param_.resize * src.rows / src.cols; + new_width = param_.resize; } else { new_height = param_.resize; - new_width = param_.resize*src.cols/src.rows; + new_width = param_.resize * src.cols / src.rows; } CHECK((param_.inter_method >= 0 && param_.inter_method <= 4) || - (param_.inter_method >= 9 && param_.inter_method <= 10)) - << "invalid inter_method: valid value 0,1,2,3,4,9,10"; - int interpolation_method = GetInterMethod(param_.inter_method, - src.cols, src.rows, new_width, new_height, prnd); - cv::resize(src, res, cv::Size(new_width, new_height), - 0, 0, interpolation_method); + (param_.inter_method >= 9 && param_.inter_method <= 10)) + << "invalid inter_method: valid value 0,1,2,3,4,9,10"; + int interpolation_method = + GetInterMethod(param_.inter_method, src.cols, src.rows, new_width, new_height, prnd); + cv::resize(src, res, cv::Size(new_width, new_height), 0, 0, interpolation_method); } else { res = src; } // normal augmentation by affine transformation. - if (param_.max_rotate_angle > 0 || param_.max_shear_ratio > 0.0f - || param_.rotate > 0 || rotate_list_.size() > 0 - || param_.max_random_scale != 1.0f || param_.min_random_scale != 1.0 - || (!param_.random_resized_crop && (min_aspect_ratio != 1.0f || max_aspect_ratio != 1.0f)) - || param_.max_img_size != 1e10f || param_.min_img_size != 0.0f) { + if (param_.max_rotate_angle > 0 || param_.max_shear_ratio > 0.0f || param_.rotate > 0 || + rotate_list_.size() > 0 || param_.max_random_scale != 1.0f || + param_.min_random_scale != 1.0 || + (!param_.random_resized_crop && (min_aspect_ratio != 1.0f || max_aspect_ratio != 1.0f)) || + param_.max_img_size != 1e10f || param_.min_img_size != 0.0f) { std::uniform_real_distribution rand_uniform(0, 1); // shear float s = rand_uniform(*prnd) * param_.max_shear_ratio * 2 - param_.max_shear_ratio; // rotate - int angle = std::uniform_int_distribution( - -param_.max_rotate_angle, param_.max_rotate_angle)(*prnd); - if (param_.rotate > 0) angle = param_.rotate; + int angle = std::uniform_int_distribution(-param_.max_rotate_angle, + param_.max_rotate_angle)(*prnd); + if (param_.rotate > 0) + angle = param_.rotate; if (rotate_list_.size() > 0) { angle = rotate_list_[std::uniform_int_distribution(0, rotate_list_.size() - 1)(*prnd)]; } @@ -300,37 +337,39 @@ class DefaultImageAugmenter : public ImageAugmenter { // scale float scale = 1.0f; if (!param_.random_resized_crop) { - scale = rand_uniform(*prnd) * - (param_.max_random_scale - param_.min_random_scale) + param_.min_random_scale; + scale = rand_uniform(*prnd) * (param_.max_random_scale - param_.min_random_scale) + + param_.min_random_scale; } // aspect ratio float ratio = 1.0f; if (!param_.random_resized_crop) { - ratio = rand_uniform(*prnd) * - (max_aspect_ratio - min_aspect_ratio) + min_aspect_ratio; + ratio = rand_uniform(*prnd) * (max_aspect_ratio - min_aspect_ratio) + min_aspect_ratio; } float hs = 2 * scale / (1 + ratio); float ws = ratio * hs; // new width and height - float new_width = std::max(param_.min_img_size, - std::min(param_.max_img_size, scale * res.cols)); - float new_height = std::max(param_.min_img_size, - std::min(param_.max_img_size, scale * res.rows)); + float new_width = + std::max(param_.min_img_size, std::min(param_.max_img_size, scale * res.cols)); + float new_height = + std::max(param_.min_img_size, std::min(param_.max_img_size, scale * res.rows)); cv::Mat M(2, 3, CV_32F); - M.at(0, 0) = hs * a - s * b * ws; - M.at(1, 0) = -b * ws; - M.at(0, 1) = hs * b + s * a * ws; - M.at(1, 1) = a * ws; - float ori_center_width = M.at(0, 0) * res.cols + M.at(0, 1) * res.rows; + M.at(0, 0) = hs * a - s * b * ws; + M.at(1, 0) = -b * ws; + M.at(0, 1) = hs * b + s * a * ws; + M.at(1, 1) = a * ws; + float ori_center_width = M.at(0, 0) * res.cols + M.at(0, 1) * res.rows; float ori_center_height = M.at(1, 0) * res.cols + M.at(1, 1) * res.rows; - M.at(0, 2) = (new_width - ori_center_width) / 2; - M.at(1, 2) = (new_height - ori_center_height) / 2; + M.at(0, 2) = (new_width - ori_center_width) / 2; + M.at(1, 2) = (new_height - ori_center_height) / 2; CHECK((param_.inter_method >= 0 && param_.inter_method <= 4) || - (param_.inter_method >= 9 && param_.inter_method <= 10)) - << "invalid inter_method: valid value 0,1,2,3,4,9,10"; - int interpolation_method = GetInterMethod(param_.inter_method, - res.cols, res.rows, new_width, new_height, prnd); - cv::warpAffine(res, temp_, M, cv::Size(new_width, new_height), + (param_.inter_method >= 9 && param_.inter_method <= 10)) + << "invalid inter_method: valid value 0,1,2,3,4,9,10"; + int interpolation_method = + GetInterMethod(param_.inter_method, res.cols, res.rows, new_width, new_height, prnd); + cv::warpAffine(res, + temp_, + M, + cv::Size(new_width, new_height), interpolation_method, cv::BORDER_CONSTANT, cv::Scalar(param_.fill_value, param_.fill_value, param_.fill_value)); @@ -339,66 +378,75 @@ class DefaultImageAugmenter : public ImageAugmenter { // pad logic if (param_.pad > 0) { - cv::copyMakeBorder(res, res, param_.pad, param_.pad, param_.pad, param_.pad, + cv::copyMakeBorder(res, + res, + param_.pad, + param_.pad, + param_.pad, + param_.pad, cv::BORDER_CONSTANT, cv::Scalar(param_.fill_value, param_.fill_value, param_.fill_value)); } if (param_.random_resized_crop) { // random resize crop - CHECK(param_.min_random_scale == 1.0f && - param_.max_random_scale == 1.0f && - param_.min_crop_size == -1 && - param_.max_crop_size == -1 && - !param_.rand_crop) << - "\nSetting random_resized_crop to true conflicts with " - "min_random_scale, max_random_scale, " - "min_crop_size, max_crop_size, " - "and rand_crop."; + CHECK(param_.min_random_scale == 1.0f && param_.max_random_scale == 1.0f && + param_.min_crop_size == -1 && param_.max_crop_size == -1 && !param_.rand_crop) + << "\nSetting random_resized_crop to true conflicts with " + "min_random_scale, max_random_scale, " + "min_crop_size, max_crop_size, " + "and rand_crop."; - if (param_.max_random_area != 1.0f || param_.min_random_area != 1.0f - || max_aspect_ratio != 1.0f || min_aspect_ratio != 1.0f) { - CHECK(min_aspect_ratio > 0.0f); - CHECK(param_.min_random_area <= param_.max_random_area); - CHECK(min_aspect_ratio <= max_aspect_ratio); - std::uniform_real_distribution rand_uniform_area(param_.min_random_area, - param_.max_random_area); - std::uniform_real_distribution rand_uniform_ratio(min_aspect_ratio, - max_aspect_ratio); - std::uniform_real_distribution rand_uniform(0, 1); - float area = res.rows * res.cols; - for (int i = 0; i < 10; ++i) { - float rand_area = rand_uniform_area(*prnd); - float ratio = rand_uniform_ratio(*prnd); - float target_area = area * rand_area; - int y_area = std::round(std::sqrt(target_area / ratio)); - int x_area = std::round(std::sqrt(target_area * ratio)); - if (rand_uniform(*prnd) > 0.5) { - float temp_y_area = y_area; - y_area = x_area; - x_area = temp_y_area; - } - if (y_area <= res.rows && x_area <= res.cols) { - index_t rand_y_area = - std::uniform_int_distribution(0, res.rows - y_area)(*prnd); - index_t rand_x_area = - std::uniform_int_distribution(0, res.cols - x_area)(*prnd); - cv::Rect roi(rand_x_area, rand_y_area, x_area, y_area); - int interpolation_method = GetInterMethod(param_.inter_method, x_area, y_area, - param_.data_shape[2], - param_.data_shape[1], prnd); - cv::resize(res(roi), res, cv::Size(param_.data_shape[2], param_.data_shape[1]), - 0, 0, interpolation_method); - is_cropped = true; - break; - } - } + if (param_.max_random_area != 1.0f || param_.min_random_area != 1.0f || + max_aspect_ratio != 1.0f || min_aspect_ratio != 1.0f) { + CHECK(min_aspect_ratio > 0.0f); + CHECK(param_.min_random_area <= param_.max_random_area); + CHECK(min_aspect_ratio <= max_aspect_ratio); + std::uniform_real_distribution rand_uniform_area(param_.min_random_area, + param_.max_random_area); + std::uniform_real_distribution rand_uniform_ratio(min_aspect_ratio, + max_aspect_ratio); + std::uniform_real_distribution rand_uniform(0, 1); + float area = res.rows * res.cols; + for (int i = 0; i < 10; ++i) { + float rand_area = rand_uniform_area(*prnd); + float ratio = rand_uniform_ratio(*prnd); + float target_area = area * rand_area; + int y_area = std::round(std::sqrt(target_area / ratio)); + int x_area = std::round(std::sqrt(target_area * ratio)); + if (rand_uniform(*prnd) > 0.5) { + float temp_y_area = y_area; + y_area = x_area; + x_area = temp_y_area; + } + if (y_area <= res.rows && x_area <= res.cols) { + index_t rand_y_area = + std::uniform_int_distribution(0, res.rows - y_area)(*prnd); + index_t rand_x_area = + std::uniform_int_distribution(0, res.cols - x_area)(*prnd); + cv::Rect roi(rand_x_area, rand_y_area, x_area, y_area); + int interpolation_method = GetInterMethod(param_.inter_method, + x_area, + y_area, + param_.data_shape[2], + param_.data_shape[1], + prnd); + cv::resize(res(roi), + res, + cv::Size(param_.data_shape[2], param_.data_shape[1]), + 0, + 0, + interpolation_method); + is_cropped = true; + break; + } + } } } else if (!param_.random_resized_crop && - (param_.max_crop_size != -1 || param_.min_crop_size != -1)) { + (param_.max_crop_size != -1 || param_.min_crop_size != -1)) { // random_crop - CHECK(res.cols >= param_.max_crop_size && res.rows >= \ - param_.max_crop_size && param_.max_crop_size >= param_.min_crop_size) + CHECK(res.cols >= param_.max_crop_size && res.rows >= param_.max_crop_size && + param_.max_crop_size >= param_.min_crop_size) << "input image size smaller than max_crop_size"; index_t rand_crop_size = std::uniform_int_distribution(param_.min_crop_size, param_.max_crop_size)(*prnd); @@ -408,37 +456,47 @@ class DefaultImageAugmenter : public ImageAugmenter { y = std::uniform_int_distribution(0, y)(*prnd); x = std::uniform_int_distribution(0, x)(*prnd); } else { - y /= 2; x /= 2; + y /= 2; + x /= 2; } cv::Rect roi(x, y, rand_crop_size, rand_crop_size); - int interpolation_method = GetInterMethod(param_.inter_method, rand_crop_size, rand_crop_size, - param_.data_shape[2], param_.data_shape[1], prnd); - cv::resize(res(roi), res, cv::Size(param_.data_shape[2], param_.data_shape[1]) - , 0, 0, interpolation_method); + int interpolation_method = GetInterMethod(param_.inter_method, + rand_crop_size, + rand_crop_size, + param_.data_shape[2], + param_.data_shape[1], + prnd); + cv::resize(res(roi), + res, + cv::Size(param_.data_shape[2], param_.data_shape[1]), + 0, + 0, + interpolation_method); is_cropped = true; } if (!is_cropped) { // center crop - int interpolation_method = GetInterMethod(param_.inter_method, res.cols, res.rows, + int interpolation_method = GetInterMethod(param_.inter_method, + res.cols, + res.rows, param_.data_shape[2], - param_.data_shape[1], prnd); + param_.data_shape[1], + prnd); if (res.rows < param_.data_shape[1]) { - index_t new_cols = static_cast(static_cast(param_.data_shape[1]) / - static_cast(res.rows) * - static_cast(res.cols)); - cv::resize(res, res, cv::Size(new_cols, param_.data_shape[1]), - 0, 0, interpolation_method); + index_t new_cols = + static_cast(static_cast(param_.data_shape[1]) / + static_cast(res.rows) * static_cast(res.cols)); + cv::resize(res, res, cv::Size(new_cols, param_.data_shape[1]), 0, 0, interpolation_method); } if (res.cols < param_.data_shape[2]) { - index_t new_rows = static_cast(static_cast(param_.data_shape[2]) / - static_cast(res.cols) * - static_cast(res.rows)); - cv::resize(res, res, cv::Size(param_.data_shape[2], new_rows), - 0, 0, interpolation_method); + index_t new_rows = + static_cast(static_cast(param_.data_shape[2]) / + static_cast(res.cols) * static_cast(res.rows)); + cv::resize(res, res, cv::Size(param_.data_shape[2], new_rows), 0, 0, interpolation_method); } - CHECK(static_cast(res.rows) >= param_.data_shape[1] - && static_cast(res.cols) >= param_.data_shape[2]) + CHECK(static_cast(res.rows) >= param_.data_shape[1] && + static_cast(res.cols) >= param_.data_shape[2]) << "input image size smaller than input shape"; index_t y = res.rows - param_.data_shape[1]; index_t x = res.cols - param_.data_shape[2]; @@ -446,7 +504,8 @@ class DefaultImageAugmenter : public ImageAugmenter { y = std::uniform_int_distribution(0, y)(*prnd); x = std::uniform_int_distribution(0, x)(*prnd); } else { - y /= 2; x /= 2; + y /= 2; + x /= 2; } cv::Rect roi(x, y, param_.data_shape[2], param_.data_shape[1]); res = res(roi); @@ -455,12 +514,12 @@ class DefaultImageAugmenter : public ImageAugmenter { // color jitter if (param_.brightness > 0.0f || param_.contrast > 0.0f || param_.saturation > 0.0f) { std::uniform_real_distribution rand_uniform(0, 1); - float alpha_b = 1.0 + std::uniform_real_distribution(-param_.brightness, - param_.brightness)(*prnd); - float alpha_c = 1.0 + std::uniform_real_distribution(-param_.contrast, - param_.contrast)(*prnd); - float alpha_s = 1.0 + std::uniform_real_distribution(-param_.saturation, - param_.saturation)(*prnd); + float alpha_b = + 1.0 + std::uniform_real_distribution(-param_.brightness, param_.brightness)(*prnd); + float alpha_c = + 1.0 + std::uniform_real_distribution(-param_.contrast, param_.contrast)(*prnd); + float alpha_s = + 1.0 + std::uniform_real_distribution(-param_.saturation, param_.saturation)(*prnd); int rand_order[3] = {0, 1, 2}; std::shuffle(std::begin(rand_order), std::end(rand_order), *prnd); for (int i : rand_order) { @@ -488,20 +547,26 @@ class DefaultImageAugmenter : public ImageAugmenter { std::uniform_real_distribution rand_uniform(0, 1); cvtColor(res, res, CV_BGR2HLS); // use an approximation of gaussian distribution to reduce extreme value - float rh = rand_uniform(*prnd); rh += 4 * rand_uniform(*prnd); rh = rh / 5; - float rs = rand_uniform(*prnd); rs += 4 * rand_uniform(*prnd); rs = rs / 5; - float rl = rand_uniform(*prnd); rl += 4 * rand_uniform(*prnd); rl = rl / 5; - int h = rh * param_.random_h * 2 - param_.random_h; - int s = rs * param_.random_s * 2 - param_.random_s; - int l = rl * param_.random_l * 2 - param_.random_l; - int temp[3] = {h, l, s}; + float rh = rand_uniform(*prnd); + rh += 4 * rand_uniform(*prnd); + rh = rh / 5; + float rs = rand_uniform(*prnd); + rs += 4 * rand_uniform(*prnd); + rs = rs / 5; + float rl = rand_uniform(*prnd); + rl += 4 * rand_uniform(*prnd); + rl = rl / 5; + int h = rh * param_.random_h * 2 - param_.random_h; + int s = rs * param_.random_s * 2 - param_.random_s; + int l = rl * param_.random_l * 2 - param_.random_l; + int temp[3] = {h, l, s}; int limit[3] = {180, 255, 255}; for (int i = 0; i < res.rows; ++i) { for (int j = 0; j < res.cols; ++j) { for (int k = 0; k < 3; ++k) { int v = res.at(i, j)[k]; v += temp[k]; - v = std::max(0, std::min(limit[k], v)); + v = std::max(0, std::min(limit[k], v)); res.at(i, j)[k] = v; } } @@ -515,19 +580,19 @@ class DefaultImageAugmenter : public ImageAugmenter { float pca_alpha_r = rand_normal(*prnd); float pca_alpha_g = rand_normal(*prnd); float pca_alpha_b = rand_normal(*prnd); - float pca_r = eigvec[0][0] * pca_alpha_r + eigvec[0][1] * pca_alpha_g + - eigvec[0][2] * pca_alpha_b; - float pca_g = eigvec[1][0] * pca_alpha_r + eigvec[1][1] * pca_alpha_g + - eigvec[1][2] * pca_alpha_b; - float pca_b = eigvec[2][0] * pca_alpha_r + eigvec[2][1] * pca_alpha_g + - eigvec[2][2] * pca_alpha_b; - float pca[3] = { pca_b, pca_g, pca_r }; + float pca_r = + eigvec[0][0] * pca_alpha_r + eigvec[0][1] * pca_alpha_g + eigvec[0][2] * pca_alpha_b; + float pca_g = + eigvec[1][0] * pca_alpha_r + eigvec[1][1] * pca_alpha_g + eigvec[1][2] * pca_alpha_b; + float pca_b = + eigvec[2][0] * pca_alpha_r + eigvec[2][1] * pca_alpha_g + eigvec[2][2] * pca_alpha_b; + float pca[3] = {pca_b, pca_g, pca_r}; for (int i = 0; i < res.rows; ++i) { for (int j = 0; j < res.cols; ++j) { for (int k = 0; k < 3; ++k) { int vp = res.at(i, j)[k]; vp += pca[k]; - vp = std::max(0, std::min(255, vp)); + vp = std::max(0, std::min(255, vp)); res.at(i, j)[k] = vp; } } @@ -536,15 +601,14 @@ class DefaultImageAugmenter : public ImageAugmenter { return res; } - private: // temporal space cv::Mat temp_; // eigval and eigvec for adding pca noise // store eigval * eigvec as eigvec - float eigvec[3][3] = { { 55.46f * -0.5675f, 4.794f * 0.7192f, 1.148f * 0.4009f }, - { 55.46f * -0.5808f, 4.794f * -0.0045f, 1.148f * -0.8140f }, - { 55.46f * -0.5836f, 4.794f * -0.6948f, 1.148f * 0.4203f } }; + float eigvec[3][3] = {{55.46f * -0.5675f, 4.794f * 0.7192f, 1.148f * 0.4009f}, + {55.46f * -0.5808f, 4.794f * -0.0045f, 1.148f * -0.8140f}, + {55.46f * -0.5836f, 4.794f * -0.6948f, 1.148f * 0.4203f}}; // parameters DefaultImageAugmentParam param_; /*! \brief list of possible rotate angle */ @@ -555,11 +619,9 @@ ImageAugmenter* ImageAugmenter::Create(const std::string& name) { return dmlc::Registry::Find(name)->body(); } -MXNET_REGISTER_IMAGE_AUGMENTER(aug_default) -.describe("default augmenter") -.set_body([]() { - return new DefaultImageAugmenter(); - }); +MXNET_REGISTER_IMAGE_AUGMENTER(aug_default).describe("default augmenter").set_body([]() { + return new DefaultImageAugmenter(); +}); #endif // MXNET_USE_OPENCV } // namespace io } // namespace mxnet diff --git a/src/io/image_augmenter.h b/src/io/image_augmenter.h index e8a56ba2e5b7..44c80abcbaee 100644 --- a/src/io/image_augmenter.h +++ b/src/io/image_augmenter.h @@ -29,9 +29,9 @@ #if MXNET_USE_OPENCV #include -#include // NOLINT(*) -#include // NOLINT(*) -#include // NOLINT(*) +#include // NOLINT(*) +#include // NOLINT(*) +#include // NOLINT(*) #include "../common/utils.h" @@ -57,8 +57,9 @@ class ImageAugmenter { * \param prnd pointer to random number generator. * \return The processed image. */ - virtual cv::Mat Process(const cv::Mat &src, std::vector *label, - common::RANDOM_ENGINE *prnd) = 0; + virtual cv::Mat Process(const cv::Mat& src, + std::vector* label, + common::RANDOM_ENGINE* prnd) = 0; // virtual destructor virtual ~ImageAugmenter() {} /*! @@ -70,14 +71,12 @@ class ImageAugmenter { }; /*! \brief typedef the factory function of data iterator */ -typedef std::function ImageAugmenterFactory; +typedef std::function ImageAugmenterFactory; /*! * \brief Registry entry for DataIterator factory functions. */ struct ImageAugmenterReg - : public dmlc::FunctionRegEntryBase { -}; + : public dmlc::FunctionRegEntryBase {}; //-------------------------------------------------------------- // The following part are API Registration of Iterators //-------------------------------------------------------------- @@ -93,7 +92,7 @@ struct ImageAugmenterReg * }); * \endcode */ -#define MXNET_REGISTER_IMAGE_AUGMENTER(name) \ +#define MXNET_REGISTER_IMAGE_AUGMENTER(name) \ DMLC_REGISTRY_REGISTER(::mxnet::io::ImageAugmenterReg, ImageAugmenterReg, name) } // namespace io } // namespace mxnet diff --git a/src/io/image_det_aug_default.cc b/src/io/image_det_aug_default.cc index 6b3109fbce19..7f9a5e570803 100644 --- a/src/io/image_det_aug_default.cc +++ b/src/io/image_det_aug_default.cc @@ -37,9 +37,9 @@ namespace io { using mxnet::Tuple; namespace image_det_aug_default_enum { -enum ImageDetAugDefaultCropEmitMode {kCenter, kOverlap}; -enum ImageDetAugDefaultResizeMode {kForce, kShrink, kFit}; -} +enum ImageDetAugDefaultCropEmitMode { kCenter, kOverlap }; +enum ImageDetAugDefaultResizeMode { kForce, kShrink, kFit }; +} // namespace image_det_aug_default_enum /*! \brief image detection augmentation parameters*/ struct DefaultImageDetAugmentParam : public dmlc::Parameter { @@ -112,87 +112,122 @@ struct DefaultImageDetAugmentParam : public dmlc::Parameter({0.0f})) + DMLC_DECLARE_FIELD(min_crop_scales) + .set_default(Tuple({0.0f})) .describe("Augmentation Param: Min crop scales."); - DMLC_DECLARE_FIELD(max_crop_scales).set_default(Tuple({1.0f})) + DMLC_DECLARE_FIELD(max_crop_scales) + .set_default(Tuple({1.0f})) .describe("Augmentation Param: Max crop scales."); - DMLC_DECLARE_FIELD(min_crop_aspect_ratios).set_default(Tuple({1.0f})) + DMLC_DECLARE_FIELD(min_crop_aspect_ratios) + .set_default(Tuple({1.0f})) .describe("Augmentation Param: Min crop aspect ratios."); - DMLC_DECLARE_FIELD(max_crop_aspect_ratios).set_default(Tuple({1.0f})) + DMLC_DECLARE_FIELD(max_crop_aspect_ratios) + .set_default(Tuple({1.0f})) .describe("Augmentation Param: Max crop aspect ratios."); - DMLC_DECLARE_FIELD(min_crop_overlaps).set_default(Tuple({0.0f})) + DMLC_DECLARE_FIELD(min_crop_overlaps) + .set_default(Tuple({0.0f})) .describe("Augmentation Param: Minimum crop IOU between crop_box and ground-truths."); - DMLC_DECLARE_FIELD(max_crop_overlaps).set_default(Tuple({1.0f})) + DMLC_DECLARE_FIELD(max_crop_overlaps) + .set_default(Tuple({1.0f})) .describe("Augmentation Param: Maximum crop IOU between crop_box and ground-truth."); - DMLC_DECLARE_FIELD(min_crop_sample_coverages).set_default(Tuple({0.0f})) - .describe("Augmentation Param: Minimum ratio of intersect/crop_area " - "between crop box and ground-truths."); - DMLC_DECLARE_FIELD(max_crop_sample_coverages).set_default(Tuple({1.0f})) - .describe("Augmentation Param: Maximum ratio of intersect/crop_area " - "between crop box and ground-truths."); - DMLC_DECLARE_FIELD(min_crop_object_coverages).set_default(Tuple({0.0f})) - .describe("Augmentation Param: Minimum ratio of intersect/gt_area " - "between crop box and ground-truths."); - DMLC_DECLARE_FIELD(max_crop_object_coverages).set_default(Tuple({1.0f})) - .describe("Augmentation Param: Maximum ratio of intersect/gt_area " - "between crop box and ground-truths."); - DMLC_DECLARE_FIELD(num_crop_sampler).set_default(1) + DMLC_DECLARE_FIELD(min_crop_sample_coverages) + .set_default(Tuple({0.0f})) + .describe( + "Augmentation Param: Minimum ratio of intersect/crop_area " + "between crop box and ground-truths."); + DMLC_DECLARE_FIELD(max_crop_sample_coverages) + .set_default(Tuple({1.0f})) + .describe( + "Augmentation Param: Maximum ratio of intersect/crop_area " + "between crop box and ground-truths."); + DMLC_DECLARE_FIELD(min_crop_object_coverages) + .set_default(Tuple({0.0f})) + .describe( + "Augmentation Param: Minimum ratio of intersect/gt_area " + "between crop box and ground-truths."); + DMLC_DECLARE_FIELD(max_crop_object_coverages) + .set_default(Tuple({1.0f})) + .describe( + "Augmentation Param: Maximum ratio of intersect/gt_area " + "between crop box and ground-truths."); + DMLC_DECLARE_FIELD(num_crop_sampler) + .set_default(1) .describe("Augmentation Param: Number of crop samplers."); DMLC_DECLARE_FIELD(crop_emit_mode) .add_enum("center", image_det_aug_default_enum::kCenter) .add_enum("overlap", image_det_aug_default_enum::kOverlap) .set_default(image_det_aug_default_enum::kCenter) - .describe("Augmentation Param: Emition mode for invalid ground-truths after crop. " - "center: emit if centroid of object is out of crop region; " - "overlap: emit if overlap is less than emit_overlap_thresh. "); - DMLC_DECLARE_FIELD(emit_overlap_thresh).set_default(0.3f) + .describe( + "Augmentation Param: Emition mode for invalid ground-truths after crop. " + "center: emit if centroid of object is out of crop region; " + "overlap: emit if overlap is less than emit_overlap_thresh. "); + DMLC_DECLARE_FIELD(emit_overlap_thresh) + .set_default(0.3f) .describe("Augmentation Param: Emit overlap thresh for emit mode overlap only."); - DMLC_DECLARE_FIELD(max_crop_trials).set_default(Tuple({25})) - .describe("Augmentation Param: Skip cropping if fail crop trail count " - "exceeds this number."); - DMLC_DECLARE_FIELD(rand_pad_prob).set_default(0.0f) + DMLC_DECLARE_FIELD(max_crop_trials) + .set_default(Tuple({25})) + .describe( + "Augmentation Param: Skip cropping if fail crop trail count " + "exceeds this number."); + DMLC_DECLARE_FIELD(rand_pad_prob) + .set_default(0.0f) .describe("Augmentation Param: Probability for random padding."); - DMLC_DECLARE_FIELD(max_pad_scale).set_default(1.0f) + DMLC_DECLARE_FIELD(max_pad_scale) + .set_default(1.0f) .describe("Augmentation Param: Maximum padding scale."); - DMLC_DECLARE_FIELD(max_random_hue).set_default(0) + DMLC_DECLARE_FIELD(max_random_hue) + .set_default(0) .describe("Augmentation Param: Maximum random value of H channel in HSL color space."); - DMLC_DECLARE_FIELD(random_hue_prob).set_default(0.0f) + DMLC_DECLARE_FIELD(random_hue_prob) + .set_default(0.0f) .describe("Augmentation Param: Probability to apply random hue."); - DMLC_DECLARE_FIELD(max_random_saturation).set_default(0) + DMLC_DECLARE_FIELD(max_random_saturation) + .set_default(0) .describe("Augmentation Param: Maximum random value of S channel in HSL color space."); - DMLC_DECLARE_FIELD(random_saturation_prob).set_default(0.0f) + DMLC_DECLARE_FIELD(random_saturation_prob) + .set_default(0.0f) .describe("Augmentation Param: Probability to apply random saturation."); - DMLC_DECLARE_FIELD(max_random_illumination).set_default(0) + DMLC_DECLARE_FIELD(max_random_illumination) + .set_default(0) .describe("Augmentation Param: Maximum random value of L channel in HSL color space."); - DMLC_DECLARE_FIELD(random_illumination_prob).set_default(0.0f) + DMLC_DECLARE_FIELD(random_illumination_prob) + .set_default(0.0f) .describe("Augmentation Param: Probability to apply random illumination."); - DMLC_DECLARE_FIELD(max_random_contrast).set_default(0) + DMLC_DECLARE_FIELD(max_random_contrast) + .set_default(0) .describe("Augmentation Param: Maximum random value of delta contrast."); - DMLC_DECLARE_FIELD(random_contrast_prob).set_default(0.0f) + DMLC_DECLARE_FIELD(random_contrast_prob) + .set_default(0.0f) .describe("Augmentation Param: Probability to apply random contrast."); - DMLC_DECLARE_FIELD(rand_mirror_prob).set_default(0.0f) + DMLC_DECLARE_FIELD(rand_mirror_prob) + .set_default(0.0f) .describe("Augmentation Param: Probability to apply horizontal flip aka. mirror."); - DMLC_DECLARE_FIELD(fill_value).set_default(127) + DMLC_DECLARE_FIELD(fill_value) + .set_default(127) .describe("Augmentation Param: Filled color value while padding."); - DMLC_DECLARE_FIELD(inter_method).set_default(1) + DMLC_DECLARE_FIELD(inter_method) + .set_default(1) .describe("Augmentation Param: 0-NN 1-bilinear 2-cubic 3-area 4-lanczos4 9-auto 10-rand."); DMLC_DECLARE_FIELD(data_shape) - .set_expect_ndim(3).enforce_nonzero() + .set_expect_ndim(3) + .enforce_nonzero() .describe("Dataset Param: Shape of each instance generated by the DataIter."); DMLC_DECLARE_FIELD(resize_mode) - .add_enum("force", image_det_aug_default_enum::kForce) - .add_enum("shrink", image_det_aug_default_enum::kShrink) - .add_enum("fit", image_det_aug_default_enum::kFit) - .set_default(image_det_aug_default_enum::kForce) - .describe("Augmentation Param: How image data fit in data_shape. " - "force: force reshape to data_shape regardless of aspect ratio; " - "shrink: ensure each side fit in data_shape, preserve aspect ratio; " - "fit: fit image to data_shape, preserve ratio, will upscale if applicable."); + .add_enum("force", image_det_aug_default_enum::kForce) + .add_enum("shrink", image_det_aug_default_enum::kShrink) + .add_enum("fit", image_det_aug_default_enum::kFit) + .set_default(image_det_aug_default_enum::kForce) + .describe( + "Augmentation Param: How image data fit in data_shape. " + "force: force reshape to data_shape regardless of aspect ratio; " + "shrink: ensure each side fit in data_shape, preserve aspect ratio; " + "fit: fit image to data_shape, preserve ratio, will upscale if applicable."); } }; @@ -227,27 +262,27 @@ class ImageDetLabel { return Rect(left, top, right - left, bottom - top); } - /*! \brief Return projected coordinates according to new region */ - ImageDetObject Project(Rect box) const { - ImageDetObject ret = *this; - ret.left = std::max(0.f, (ret.left - box.x) / box.width); - ret.top = std::max(0.f, (ret.top - box.y) / box.height); - ret.right = std::min(1.f, (ret.right - box.x) / box.width); - ret.bottom = std::min(1.f, (ret.bottom - box.y) / box.height); - return ret; - } - - /*! \brief Return Horizontally fliped coordinates */ - ImageDetObject HorizontalFlip() const { - ImageDetObject ret = *this; - ret.left = 1.f - this->right; - ret.right = 1.f - this->left; - return ret; - } + /*! \brief Return projected coordinates according to new region */ + ImageDetObject Project(Rect box) const { + ImageDetObject ret = *this; + ret.left = std::max(0.f, (ret.left - box.x) / box.width); + ret.top = std::max(0.f, (ret.top - box.y) / box.height); + ret.right = std::min(1.f, (ret.right - box.x) / box.width); + ret.bottom = std::min(1.f, (ret.bottom - box.y) / box.height); + return ret; + } + + /*! \brief Return Horizontally fliped coordinates */ + ImageDetObject HorizontalFlip() const { + ImageDetObject ret = *this; + ret.left = 1.f - this->right; + ret.right = 1.f - this->left; + return ret; + } }; // struct ImageDetObject /*! \brief constructor from raw array of detection labels */ - explicit ImageDetLabel(const std::vector &raw_label) { + explicit ImageDetLabel(const std::vector& raw_label) { FromArray(raw_label); } @@ -255,7 +290,7 @@ class ImageDetLabel { * header_width, object_width, (extra_headers...), * [id, xmin, ymin, xmax, ymax, (extra_object_info)] x N */ - void FromArray(const std::vector &raw_label) { + void FromArray(const std::vector& raw_label) { int label_width = static_cast(raw_label.size()); CHECK_GE(label_width, 7); // at least 2(header) + 5(1 object) int header_width = static_cast(raw_label[0]); @@ -268,11 +303,11 @@ class ImageDetLabel { objects_.reserve(num); for (int i = header_width; i < label_width; i += object_width_) { ImageDetObject obj; - auto it = raw_label.cbegin() + i; - obj.id = *(it++); - obj.left = *(it++); - obj.top = *(it++); - obj.right = *(it++); + auto it = raw_label.cbegin() + i; + obj.id = *(it++); + obj.left = *(it++); + obj.top = *(it++); + obj.right = *(it++); obj.bottom = *(it++); obj.extra.assign(it, it - 5 + object_width_); if (obj.right > obj.left && obj.bottom > obj.top) { @@ -299,7 +334,8 @@ class ImageDetLabel { /*! \brief Intersection over Union between two rects */ static float RectIOU(Rect a, Rect b) { float intersect = (a & b).area(); - if (intersect <= 0.f) return 0.f; + if (intersect <= 0.f) + return 0.f; return intersect / (a.area() + b.area() - intersect); } @@ -308,18 +344,22 @@ class ImageDetLabel { * convert all objects if success */ bool TryCrop(const Rect crop_box, - const float min_crop_overlap, const float max_crop_overlap, - const float min_crop_sample_coverage, const float max_crop_sample_coverage, - const float min_crop_object_coverage, const float max_crop_object_coverage, - const int crop_emit_mode, const float emit_overlap_thresh) { + const float min_crop_overlap, + const float max_crop_overlap, + const float min_crop_sample_coverage, + const float max_crop_sample_coverage, + const float min_crop_object_coverage, + const float max_crop_object_coverage, + const int crop_emit_mode, + const float emit_overlap_thresh) { if (objects_.size() < 1) { return true; // no object, raise error or just skip? } // check if crop_box valid bool valid = false; - if (min_crop_overlap > 0.f || max_crop_overlap < 1.f || - min_crop_sample_coverage > 0.f || max_crop_sample_coverage < 1.f || - min_crop_object_coverage > 0.f || max_crop_object_coverage < 1.f) { + if (min_crop_overlap > 0.f || max_crop_overlap < 1.f || min_crop_sample_coverage > 0.f || + max_crop_sample_coverage < 1.f || min_crop_object_coverage > 0.f || + max_crop_object_coverage < 1.f) { for (auto& obj : objects_) { Rect gt_box = obj.ToRect(); if (min_crop_overlap > 0.f || max_crop_overlap < 1.f) { @@ -347,7 +387,8 @@ class ImageDetLabel { valid = true; } - if (!valid) return false; + if (!valid) + return false; // transform ground-truth labels std::vector new_objects; for (auto& object : objects_) { @@ -359,14 +400,15 @@ class ImageDetLabel { } new_objects.push_back(object.Project(crop_box)); } else if (image_det_aug_default_enum::kOverlap == crop_emit_mode) { - Rect gt_box = object.ToRect(); + Rect gt_box = object.ToRect(); float overlap = (crop_box & gt_box).area() / gt_box.area(); if (overlap > emit_overlap_thresh) { new_objects.push_back(object.Project(crop_box)); } } } - if (new_objects.size() < 1) return false; + if (new_objects.size() < 1) + return false; objects_ = new_objects; // replace the old objects return true; } @@ -411,8 +453,8 @@ class DefaultImageDetAugmenter : public ImageAugmenter { kwargs_left = param_.InitAllowUnknown(kwargs); CHECK((param_.inter_method >= 0 && param_.inter_method <= 4) || - (param_.inter_method >= 9 && param_.inter_method <= 10)) - << "invalid inter_method: valid value 0,1,2,3,9,10"; + (param_.inter_method >= 9 && param_.inter_method <= 10)) + << "invalid inter_method: valid value 0,1,2,3,9,10"; // validate crop parameters ValidateCropParameters(¶m_.min_crop_scales, param_.num_crop_sampler); @@ -439,20 +481,25 @@ class DefaultImageDetAugmenter : public ImageAugmenter { CHECK_GE(param_.emit_overlap_thresh, 0.0f); } /*! - * \brief get interpolation method with given inter_method, 0-CV_INTER_NN 1-CV_INTER_LINEAR 2-CV_INTER_CUBIC - * \ 3-CV_INTER_AREA 4-CV_INTER_LANCZOS4 9-AUTO(cubic for enlarge, area for shrink, bilinear for others) 10-RAND + * \brief get interpolation method with given inter_method, 0-CV_INTER_NN 1-CV_INTER_LINEAR + * 2-CV_INTER_CUBIC \ 3-CV_INTER_AREA 4-CV_INTER_LANCZOS4 9-AUTO(cubic for enlarge, area for + * shrink, bilinear for others) 10-RAND */ - int GetInterMethod(int inter_method, int old_width, int old_height, int new_width, - int new_height, common::RANDOM_ENGINE *prnd) { + int GetInterMethod(int inter_method, + int old_width, + int old_height, + int new_width, + int new_height, + common::RANDOM_ENGINE* prnd) { if (inter_method == 9) { if (new_width > old_width && new_height > old_height) { return 2; // CV_INTER_CUBIC for enlarge - } else if (new_width rand_uniform_int(0, 4); return rand_uniform_int(*prnd); } else { @@ -461,8 +508,8 @@ class DefaultImageDetAugmenter : public ImageAugmenter { } /*! \brief Check number of crop samplers and given parameters */ - template - void ValidateCropParameters(mxnet::Tuple *param, const int num_sampler) { + template + void ValidateCropParameters(mxnet::Tuple* param, const int num_sampler) { if (num_sampler == 1) { CHECK_EQ(param->ndim(), 1); } else if (num_sampler > 1) { @@ -477,53 +524,55 @@ class DefaultImageDetAugmenter : public ImageAugmenter { /*! \brief Generate crop box region given cropping parameters */ Rect GenerateCropBox(const float min_crop_scale, - const float max_crop_scale, const float min_crop_aspect_ratio, - const float max_crop_aspect_ratio, common::RANDOM_ENGINE *prnd, - const float img_aspect_ratio) { - float new_scale = std::uniform_real_distribution( - min_crop_scale, max_crop_scale)(*prnd) + 1e-12f; - float min_ratio = std::max(min_crop_aspect_ratio / img_aspect_ratio, - new_scale * new_scale); - float max_ratio = std::min(max_crop_aspect_ratio / img_aspect_ratio, - 1. / (new_scale * new_scale)); - float new_ratio = std::sqrt(std::uniform_real_distribution( - min_ratio, max_ratio)(*prnd)); + const float max_crop_scale, + const float min_crop_aspect_ratio, + const float max_crop_aspect_ratio, + common::RANDOM_ENGINE* prnd, + const float img_aspect_ratio) { + float new_scale = + std::uniform_real_distribution(min_crop_scale, max_crop_scale)(*prnd) + 1e-12f; + float min_ratio = + std::max(min_crop_aspect_ratio / img_aspect_ratio, new_scale * new_scale); + float max_ratio = + std::min(max_crop_aspect_ratio / img_aspect_ratio, 1. / (new_scale * new_scale)); + float new_ratio = std::sqrt(std::uniform_real_distribution(min_ratio, max_ratio)(*prnd)); float new_width = std::min(1.f, new_scale * new_ratio); float new_height = std::min(1.f, new_scale / new_ratio); - float x0 = std::uniform_real_distribution(0.f, 1 - new_width)(*prnd); - float y0 = std::uniform_real_distribution(0.f, 1 - new_height)(*prnd); + float x0 = std::uniform_real_distribution(0.f, 1 - new_width)(*prnd); + float y0 = std::uniform_real_distribution(0.f, 1 - new_height)(*prnd); return Rect(x0, y0, new_width, new_height); } /*! \brief Generate padding box region given padding parameters */ Rect GeneratePadBox(const float max_pad_scale, - common::RANDOM_ENGINE *prnd, const float threshold = 1.05f) { - float new_scale = std::uniform_real_distribution( - 1.f, max_pad_scale)(*prnd); - if (new_scale < threshold) return Rect(0, 0, 0, 0); - auto rand_uniform = std::uniform_real_distribution(0.f, new_scale - 1); - float x0 = rand_uniform(*prnd); - float y0 = rand_uniform(*prnd); - return Rect(-x0, -y0, new_scale, new_scale); - } + common::RANDOM_ENGINE* prnd, + const float threshold = 1.05f) { + float new_scale = std::uniform_real_distribution(1.f, max_pad_scale)(*prnd); + if (new_scale < threshold) + return Rect(0, 0, 0, 0); + auto rand_uniform = std::uniform_real_distribution(0.f, new_scale - 1); + float x0 = rand_uniform(*prnd); + float y0 = rand_uniform(*prnd); + return Rect(-x0, -y0, new_scale, new_scale); + } - cv::Mat Process(const cv::Mat &src, std::vector *label, - common::RANDOM_ENGINE *prnd) override { + cv::Mat Process(const cv::Mat& src, + std::vector* label, + common::RANDOM_ENGINE* prnd) override { using mshadow::index_t; cv::Mat res; if (param_.resize != -1) { int new_height, new_width; if (src.rows > src.cols) { - new_height = param_.resize*src.rows/src.cols; - new_width = param_.resize; + new_height = param_.resize * src.rows / src.cols; + new_width = param_.resize; } else { new_height = param_.resize; - new_width = param_.resize*src.cols/src.rows; + new_width = param_.resize * src.cols / src.rows; } - int interpolation_method = GetInterMethod(param_.inter_method, - src.cols, src.rows, new_width, new_height, prnd); - cv::resize(src, res, cv::Size(new_width, new_height), - 0, 0, interpolation_method); + int interpolation_method = + GetInterMethod(param_.inter_method, src.cols, src.rows, new_width, new_height, prnd); + cv::resize(src, res, cv::Size(new_width, new_height), 0, 0, interpolation_method); } else { res = src; } @@ -537,16 +586,16 @@ class DefaultImageDetAugmenter : public ImageAugmenter { if (param_.random_hue_prob > 0.f || param_.random_saturation_prob > 0.f || param_.random_illumination_prob > 0.f || param_.random_contrast_prob > 0.f) { std::uniform_real_distribution uniform_range(-1.f, 1.f); - int h = uniform_range(*prnd) * param_.max_random_hue; - int s = uniform_range(*prnd) * param_.max_random_saturation; - int l = uniform_range(*prnd) * param_.max_random_illumination; + int h = uniform_range(*prnd) * param_.max_random_hue; + int s = uniform_range(*prnd) * param_.max_random_saturation; + int l = uniform_range(*prnd) * param_.max_random_illumination; float c = uniform_range(*prnd) * param_.max_random_contrast; - h = rand_uniform(*prnd) < param_.random_hue_prob ? h : 0; - s = rand_uniform(*prnd) < param_.random_saturation_prob ? s : 0; - l = rand_uniform(*prnd) < param_.random_illumination_prob ? l : 0; - c = rand_uniform(*prnd) < param_.random_contrast_prob ? c : 0; + h = rand_uniform(*prnd) < param_.random_hue_prob ? h : 0; + s = rand_uniform(*prnd) < param_.random_saturation_prob ? s : 0; + l = rand_uniform(*prnd) < param_.random_illumination_prob ? l : 0; + c = rand_uniform(*prnd) < param_.random_contrast_prob ? c : 0; if (h != 0 || s != 0 || l != 0) { - int temp[3] = {h, l, s}; + int temp[3] = {h, l, s}; int limit[3] = {180, 255, 255}; cv::cvtColor(res, res, CV_BGR2HLS); for (int i = 0; i < res.rows; ++i) { @@ -554,7 +603,7 @@ class DefaultImageDetAugmenter : public ImageAugmenter { for (int k = 0; k < 3; ++k) { int v = res.at(i, j)[k]; v += temp[k]; - v = std::max(0, std::min(limit[k], v)); + v = std::max(0, std::min(limit[k], v)); res.at(i, j)[k] = v; } } @@ -583,13 +632,19 @@ class DefaultImageDetAugmenter : public ImageAugmenter { if (pad_box.area() > 0) { if (det_label.TryPad(pad_box)) { // pad image - temp_ = res; - int left = static_cast(-pad_box.x * res.cols); - int top = static_cast(-pad_box.y * res.rows); + temp_ = res; + int left = static_cast(-pad_box.x * res.cols); + int top = static_cast(-pad_box.y * res.rows); int right = static_cast((pad_box.width + pad_box.x - 1) * res.cols); - int bot = static_cast((pad_box.height + pad_box.y - 1) * res.rows); - cv::copyMakeBorder(temp_, res, top, bot, left, right, cv::BORDER_ISOLATED, - cv::Scalar(param_.fill_value, param_.fill_value, param_.fill_value)); + int bot = static_cast((pad_box.height + pad_box.y - 1) * res.rows); + cv::copyMakeBorder(temp_, + res, + top, + bot, + left, + right, + cv::BORDER_ISOLATED, + cv::Scalar(param_.fill_value, param_.fill_value, param_.fill_value)); } } } @@ -608,24 +663,31 @@ class DefaultImageDetAugmenter : public ImageAugmenter { std::shuffle(indices.begin(), indices.end(), *prnd); int num_processed = 0; for (auto idx : indices) { - if (num_processed > 0) break; + if (num_processed > 0) + break; for (int t = 0; t < param_.max_crop_trials[idx]; ++t) { Rect crop_box = GenerateCropBox(param_.min_crop_scales[idx], - param_.max_crop_scales[idx], param_.min_crop_aspect_ratios[idx], - param_.max_crop_aspect_ratios[idx], prnd, - static_cast(res.cols) / res.rows); - if (det_label.TryCrop(crop_box, param_.min_crop_overlaps[idx], - param_.max_crop_overlaps[idx], param_.min_crop_sample_coverages[idx], - param_.max_crop_sample_coverages[idx], param_.min_crop_object_coverages[idx], - param_.max_crop_object_coverages[idx], param_.crop_emit_mode, - param_.emit_overlap_thresh)) { + param_.max_crop_scales[idx], + param_.min_crop_aspect_ratios[idx], + param_.max_crop_aspect_ratios[idx], + prnd, + static_cast(res.cols) / res.rows); + if (det_label.TryCrop(crop_box, + param_.min_crop_overlaps[idx], + param_.max_crop_overlaps[idx], + param_.min_crop_sample_coverages[idx], + param_.max_crop_sample_coverages[idx], + param_.min_crop_object_coverages[idx], + param_.max_crop_object_coverages[idx], + param_.crop_emit_mode, + param_.emit_overlap_thresh)) { ++num_processed; // crop image - int left = static_cast(crop_box.x * res.cols); - int top = static_cast(crop_box.y * res.rows); - int width = static_cast(crop_box.width * res.cols); + int left = static_cast(crop_box.x * res.cols); + int top = static_cast(crop_box.y * res.rows); + int width = static_cast(crop_box.width * res.cols); int height = static_cast(crop_box.height * res.rows); - res = res(cv::Rect(left, top, width, height)); + res = res(cv::Rect(left, top, width, height)); break; } } @@ -636,34 +698,31 @@ class DefaultImageDetAugmenter : public ImageAugmenter { if (image_det_aug_default_enum::kForce == param_.resize_mode) { // force resize to specified data_shape, regardless of aspect ratio int new_height = param_.data_shape[1]; - int new_width = param_.data_shape[2]; - int interpolation_method = GetInterMethod(param_.inter_method, - res.cols, res.rows, new_width, new_height, prnd); - cv::resize(res, res, cv::Size(new_width, new_height), - 0, 0, interpolation_method); + int new_width = param_.data_shape[2]; + int interpolation_method = + GetInterMethod(param_.inter_method, res.cols, res.rows, new_width, new_height, prnd); + cv::resize(res, res, cv::Size(new_width, new_height), 0, 0, interpolation_method); } else if (image_det_aug_default_enum::kShrink == param_.resize_mode) { // try to keep original size, shrink if too large float h = param_.data_shape[1]; float w = param_.data_shape[2]; if (res.rows > h || res.cols > w) { - float ratio = std::min(h / res.rows, w / res.cols); + float ratio = std::min(h / res.rows, w / res.cols); int new_height = ratio * res.rows; - int new_width = ratio * res.cols; - int interpolation_method = GetInterMethod(param_.inter_method, - res.cols, res.rows, new_width, new_height, prnd); - cv::resize(res, res, cv::Size(new_width, new_height), - 0, 0, interpolation_method); + int new_width = ratio * res.cols; + int interpolation_method = + GetInterMethod(param_.inter_method, res.cols, res.rows, new_width, new_height, prnd); + cv::resize(res, res, cv::Size(new_width, new_height), 0, 0, interpolation_method); } } else if (image_det_aug_default_enum::kFit == param_.resize_mode) { - float h = param_.data_shape[1]; - float w = param_.data_shape[2]; - float ratio = std::min(h / res.rows, w / res.cols); + float h = param_.data_shape[1]; + float w = param_.data_shape[2]; + float ratio = std::min(h / res.rows, w / res.cols); int new_height = ratio * res.rows; - int new_width = ratio * res.cols; - int interpolation_method = GetInterMethod(param_.inter_method, - res.cols, res.rows, new_width, new_height, prnd); - cv::resize(res, res, cv::Size(new_width, new_height), - 0, 0, interpolation_method); + int new_width = ratio * res.cols; + int interpolation_method = + GetInterMethod(param_.inter_method, res.cols, res.rows, new_width, new_height, prnd); + cv::resize(res, res, cv::Size(new_width, new_height), 0, 0, interpolation_method); } *label = det_label.ToArray(); // put back processed labels @@ -678,10 +737,8 @@ class DefaultImageDetAugmenter : public ImageAugmenter { }; MXNET_REGISTER_IMAGE_AUGMENTER(det_aug_default) -.describe("default detection augmenter") -.set_body([]() { - return new DefaultImageDetAugmenter(); - }); + .describe("default detection augmenter") + .set_body([]() { return new DefaultImageDetAugmenter(); }); #endif // MXNET_USE_OPENCV } // namespace io } // namespace mxnet diff --git a/src/io/image_io.cc b/src/io/image_io.cc index db9ac7682287..f737c2c25307 100644 --- a/src/io/image_io.cc +++ b/src/io/image_io.cc @@ -40,8 +40,8 @@ #include "../operator/image/resize-inl.h" #if MXNET_USE_OPENCV - #include - #include "./opencv_compatibility.h" +#include +#include "./opencv_compatibility.h" #endif // MXNET_USE_OPENCV namespace mxnet { @@ -50,32 +50,34 @@ namespace io { // http://www.64lines.com/jpeg-width-height // Gets the JPEG size from the array of data passed to the function, // file reference: http://www.obrador.com/essentialjpeg/headerinfo.htm -bool get_jpeg_size(const uint8_t* data, uint32_t data_size, int64_t *width, int64_t *height) { +bool get_jpeg_size(const uint8_t* data, uint32_t data_size, int64_t* width, int64_t* height) { // Check for valid JPEG image uint32_t i = 0; // Keeps track of the position within the file - if (data[i] == 0xFF && data[i+1] == 0xD8 && data[i+2] == 0xFF && data[i+3] == 0xE0) { + if (data[i] == 0xFF && data[i + 1] == 0xD8 && data[i + 2] == 0xFF && data[i + 3] == 0xE0) { i += 4; // Check for valid JPEG header (null terminated JFIF) - if (data[i+2] == 'J' && data[i+3] == 'F' && data[i+4] == 'I' - && data[i+5] == 'F' && data[i+6] == 0x00) { + if (data[i + 2] == 'J' && data[i + 3] == 'F' && data[i + 4] == 'I' && data[i + 5] == 'F' && + data[i + 6] == 0x00) { // Retrieve the block length of the first block since // the first block will not contain the size of file - uint16_t block_length = data[i] * 256 + data[i+1]; + uint16_t block_length = data[i] * 256 + data[i + 1]; while (i < data_size) { - i+=block_length; // Increase the file index to get to the next block - if (i >= data_size) return false; // Check to protect against segmentation faults - if (data[i] != 0xFF) return false; // Check that we are truly at the start of another block - uint8_t m = data[i+1]; + i += block_length; // Increase the file index to get to the next block + if (i >= data_size) + return false; // Check to protect against segmentation faults + if (data[i] != 0xFF) + return false; // Check that we are truly at the start of another block + uint8_t m = data[i + 1]; if (m == 0xC0 || (m >= 0xC1 && m <= 0xCF && m != 0xC4 && m != 0xC8 && m != 0xCC)) { // 0xFFC0 is the "Start of frame" marker which contains the file size // The structure of the 0xFFC0 block is quite simple // [0xFFC0][ushort length][uchar precision][ushort x][ushort y] - *height = data[i+5]*256 + data[i+6]; - *width = data[i+7]*256 + data[i+8]; + *height = data[i + 5] * 256 + data[i + 6]; + *width = data[i + 7] * 256 + data[i + 8]; return true; } else { - i+=2; // Skip the block marker - block_length = data[i] * 256 + data[i+1]; // Go to the next block + i += 2; // Skip the block marker + block_length = data[i] * 256 + data[i + 1]; // Go to the next block } } return false; // If this point is reached then no size was found @@ -87,12 +89,12 @@ bool get_jpeg_size(const uint8_t* data, uint32_t data_size, int64_t *width, int6 } } -bool get_png_size(const uint8_t* data, uint32_t data_size, int64_t *width, int64_t *height) { - if (data[0] == 0x89 && data[1] == 0x50 && data[2] ==0x4E && data[3] == 0x47) { +bool get_png_size(const uint8_t* data, uint32_t data_size, int64_t* width, int64_t* height) { + if (data[0] == 0x89 && data[1] == 0x50 && data[2] == 0x4E && data[3] == 0x47) { uint8_t const* p = data + 16; - *width = ((p[0]*256 + p[1])*256 + p[2])*256 + p[3]; + *width = ((p[0] * 256 + p[1]) * 256 + p[2]) * 256 + p[3]; p += 4; - *height = ((p[0]*256 + p[1])*256 + p[2])*256 + p[3]; + *height = ((p[0] * 256 + p[1]) * 256 + p[2]) * 256 + p[3]; return true; } else { return false; @@ -103,14 +105,11 @@ struct ImdecodeParam : public dmlc::Parameter { int flag; bool to_rgb; DMLC_DECLARE_PARAMETER(ImdecodeParam) { - DMLC_DECLARE_FIELD(flag) - .set_lower_bound(0) - .set_default(1) - .describe("Convert decoded image to grayscale (0) or color (1)."); - DMLC_DECLARE_FIELD(to_rgb) - .set_default(true) - .describe("Whether to convert decoded image to mxnet's default RGB format " - "(instead of opencv's default BGR)."); + DMLC_DECLARE_FIELD(flag).set_lower_bound(0).set_default(1).describe( + "Convert decoded image to grayscale (0) or color (1)."); + DMLC_DECLARE_FIELD(to_rgb).set_default(true).describe( + "Whether to convert decoded image to mxnet's default RGB format " + "(instead of opencv's default BGR)."); } }; @@ -121,25 +120,19 @@ struct ImreadParam : public dmlc::Parameter { int flag; bool to_rgb; DMLC_DECLARE_PARAMETER(ImreadParam) { - DMLC_DECLARE_FIELD(filename) - .describe("Name of the image file to be loaded."); - DMLC_DECLARE_FIELD(flag) - .set_lower_bound(0) - .set_default(1) - .describe("Convert decoded image to grayscale (0) or color (1)."); - DMLC_DECLARE_FIELD(to_rgb) - .set_default(true) - .describe("Whether to convert decoded image to mxnet's default RGB format " - "(instead of opencv's default BGR)."); + DMLC_DECLARE_FIELD(filename).describe("Name of the image file to be loaded."); + DMLC_DECLARE_FIELD(flag).set_lower_bound(0).set_default(1).describe( + "Convert decoded image to grayscale (0) or color (1)."); + DMLC_DECLARE_FIELD(to_rgb).set_default(true).describe( + "Whether to convert decoded image to mxnet's default RGB format " + "(instead of opencv's default BGR)."); } }; DMLC_REGISTER_PARAMETER(ImreadParam); - #if MXNET_USE_OPENCV -void ImdecodeImpl(int flag, bool to_rgb, void* data, size_t size, - NDArray* out) { +void ImdecodeImpl(int flag, bool to_rgb, void* data, size_t size, NDArray* out) { cv::Mat buf(1, size, CV_8U, data); cv::Mat dst; if (out->is_none()) { @@ -147,18 +140,18 @@ void ImdecodeImpl(int flag, bool to_rgb, void* data, size_t size, CHECK(!res.empty()) << "Decoding failed. Invalid image file."; *out = NDArray(mshadow::Shape3(res.rows, res.cols, flag == 0 ? 1 : 3), - Context::CPU(), false, mshadow::kUint8); - dst = cv::Mat(out->shape()[0], out->shape()[1], flag == 0 ? CV_8U : CV_8UC3, - out->data().dptr_); + Context::CPU(), + false, + mshadow::kUint8); + dst = cv::Mat(out->shape()[0], out->shape()[1], flag == 0 ? CV_8U : CV_8UC3, out->data().dptr_); res.copyTo(dst); CHECK(!dst.empty()) << "Failed copying buffer to output."; } else { - dst = cv::Mat(out->shape()[0], out->shape()[1], flag == 0 ? CV_8U : CV_8UC3, - out->data().dptr_); + dst = cv::Mat(out->shape()[0], out->shape()[1], flag == 0 ? CV_8U : CV_8UC3, out->data().dptr_); #if (CV_MAJOR_VERSION > 3 || (CV_MAJOR_VERSION == 3 && CV_MINOR_VERSION >= 3)) cv::imdecode(buf, flag | cv::IMREAD_IGNORE_ORIENTATION, &dst); CHECK(!dst.empty()) << "Decoding failed. Invalid image file."; -#elif(CV_MAJOR_VERSION > 2 || (CV_MAJOR_VERSION == 2 && CV_MINOR_VERSION >= 4)) +#elif (CV_MAJOR_VERSION > 2 || (CV_MAJOR_VERSION == 2 && CV_MINOR_VERSION >= 4)) // NOLINT cv::imdecode(buf, flag, &dst); CHECK(!dst.empty()) << "Decoding failed. Invalid image file."; #else @@ -186,7 +179,7 @@ void Imdecode(const nnvm::NodeAttrs& attrs, inputs[0].WaitToRead(); uint8_t* str_img = inputs[0].data().dptr(); - size_t len = inputs[0].shape().Size(); + size_t len = inputs[0].shape().Size(); CHECK(len > 0) << "Input cannot be an empty buffer"; mxnet::TShape oshape(3, 1); @@ -200,13 +193,18 @@ void Imdecode(const nnvm::NodeAttrs& attrs, } const NDArray& ndin = inputs[0]; - NDArray& ndout = (*outputs)[0]; - ndout = NDArray(oshape, Context::CPU(), true, mshadow::kUint8); - Engine::Get()->PushSync([ndin, ndout, str_img, len, param](RunContext ctx){ - ImdecodeImpl(param.flag, param.to_rgb, str_img, len, - const_cast(&ndout)); - }, ndout.ctx(), {ndin.var()}, {ndout.var()}, - FnProperty::kNormal, 0, "Imdecode"); + NDArray& ndout = (*outputs)[0]; + ndout = NDArray(oshape, Context::CPU(), true, mshadow::kUint8); + Engine::Get()->PushSync( + [ndin, ndout, str_img, len, param](RunContext ctx) { + ImdecodeImpl(param.flag, param.to_rgb, str_img, len, const_cast(&ndout)); + }, + ndout.ctx(), + {ndin.var()}, + {ndout.var()}, + FnProperty::kNormal, + 0, + "Imdecode"); #else LOG(FATAL) << "Build with USE_OPENCV=1 for image io."; #endif // MXNET_USE_OPENCV @@ -221,13 +219,12 @@ void Imread(const nnvm::NodeAttrs& attrs, std::ifstream file(param.filename, std::ios::binary | std::ios::ate); // if file is not open we get bad alloc after tellg CHECK(file.is_open()) << "Imread: '" << param.filename - << "' couldn't open file: " << strerror(errno); + << "' couldn't open file: " << strerror(errno); size_t fsize = file.tellg(); file.seekg(0, std::ios::beg); std::shared_ptr buff(new uint8_t[fsize], std::default_delete()); file.read(reinterpret_cast(buff.get()), fsize); - CHECK(file.good()) << "Failed reading image file: '" << param.filename << "' " - << strerror(errno); + CHECK(file.good()) << "Failed reading image file: '" << param.filename << "' " << strerror(errno); mxnet::TShape oshape(3, 1); oshape[2] = param.flag == 0 ? 1 : 3; @@ -240,41 +237,41 @@ void Imread(const nnvm::NodeAttrs& attrs, } NDArray& ndout = (*outputs)[0]; - ndout = NDArray(oshape, Context::CPU(), true, mshadow::kUint8); - Engine::Get()->PushSync([ndout, buff, fsize, param](RunContext ctx){ - ImdecodeImpl(param.flag, param.to_rgb, buff.get(), fsize, - const_cast(&ndout)); - }, ndout.ctx(), {}, {ndout.var()}, - FnProperty::kNormal, 0, "Imread"); + ndout = NDArray(oshape, Context::CPU(), true, mshadow::kUint8); + Engine::Get()->PushSync( + [ndout, buff, fsize, param](RunContext ctx) { + ImdecodeImpl(param.flag, param.to_rgb, buff.get(), fsize, const_cast(&ndout)); + }, + ndout.ctx(), + {}, + {ndout.var()}, + FnProperty::kNormal, + 0, + "Imread"); #else LOG(FATAL) << "Build with USE_OPENCV=1 for image io."; #endif // MXNET_USE_OPENCV } - struct ResizeParam : public dmlc::Parameter { int w; int h; int interp; DMLC_DECLARE_PARAMETER(ResizeParam) { - DMLC_DECLARE_FIELD(w) - .set_lower_bound(1) - .describe("Width of resized image."); - DMLC_DECLARE_FIELD(h) - .set_lower_bound(1) - .describe("Height of resized image."); - DMLC_DECLARE_FIELD(interp) - .set_default(1) - .describe("Interpolation method (default=cv2.INTER_LINEAR)."); + DMLC_DECLARE_FIELD(w).set_lower_bound(1).describe("Width of resized image."); + DMLC_DECLARE_FIELD(h).set_lower_bound(1).describe("Height of resized image."); + DMLC_DECLARE_FIELD(interp).set_default(1).describe( + "Interpolation method (default=cv2.INTER_LINEAR)."); } }; DMLC_REGISTER_PARAMETER(ResizeParam); inline bool ResizeShape(const nnvm::NodeAttrs& attrs, - mxnet::ShapeVector *ishape, - mxnet::ShapeVector *oshape) { + mxnet::ShapeVector* ishape, + mxnet::ShapeVector* oshape) { const auto& param = nnvm::get(attrs.parsed); - if (ishape->size() != 1 || (*ishape)[0].ndim() != 3) return false; + if (ishape->size() != 1 || (*ishape)[0].ndim() != 3) + return false; oshape->clear(); oshape->push_back(mshadow::Shape3(param.h, param.w, (*ishape)[0][2])); @@ -282,65 +279,56 @@ inline bool ResizeShape(const nnvm::NodeAttrs& attrs, } inline void Imresize(const nnvm::NodeAttrs& attrs, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { const auto& param = nnvm::get(attrs.parsed); op::image::ResizeImpl(inputs, outputs, param.h, param.w, param.interp); } - struct MakeBorderParam : public dmlc::Parameter { int top, bot, left, right; int type; double value; mxnet::Tuple values; DMLC_DECLARE_PARAMETER(MakeBorderParam) { - DMLC_DECLARE_FIELD(top) - .describe("Top margin."); - DMLC_DECLARE_FIELD(bot) - .describe("Bottom margin."); - DMLC_DECLARE_FIELD(left) - .describe("Left margin."); - DMLC_DECLARE_FIELD(right) - .describe("Right margin."); - DMLC_DECLARE_FIELD(type) - .set_default(0) - .describe("Filling type (default=cv2.BORDER_CONSTANT)."); - DMLC_DECLARE_FIELD(value) - .set_default(0.0) - .describe("(Deprecated! Use ``values`` instead.) Fill with single value."); - DMLC_DECLARE_FIELD(values) - .set_default({}) - .describe("Fill with value(RGB[A] or gray), up to 4 channels."); + DMLC_DECLARE_FIELD(top).describe("Top margin."); + DMLC_DECLARE_FIELD(bot).describe("Bottom margin."); + DMLC_DECLARE_FIELD(left).describe("Left margin."); + DMLC_DECLARE_FIELD(right).describe("Right margin."); + DMLC_DECLARE_FIELD(type).set_default(0).describe("Filling type (default=cv2.BORDER_CONSTANT)."); + DMLC_DECLARE_FIELD(value).set_default(0.0).describe( + "(Deprecated! Use ``values`` instead.) Fill with single value."); + DMLC_DECLARE_FIELD(values).set_default({}).describe( + "Fill with value(RGB[A] or gray), up to 4 channels."); } }; DMLC_REGISTER_PARAMETER(MakeBorderParam); inline bool MakeBorderShape(const nnvm::NodeAttrs& attrs, - mxnet::ShapeVector *ishape, - mxnet::ShapeVector *oshape) { + mxnet::ShapeVector* ishape, + mxnet::ShapeVector* oshape) { const auto& param = nnvm::get(attrs.parsed); - if (ishape->size() != 1 || (*ishape)[0].ndim() != 3) return false; + if (ishape->size() != 1 || (*ishape)[0].ndim() != 3) + return false; oshape->clear(); - oshape->push_back( - mshadow::Shape3((*ishape)[0][0]+param.top+param.bot, - (*ishape)[0][1]+param.left+param.right, - (*ishape)[0][2])); + oshape->push_back(mshadow::Shape3((*ishape)[0][0] + param.top + param.bot, + (*ishape)[0][1] + param.left + param.right, + (*ishape)[0][2])); return true; } inline void copyMakeBorder(const nnvm::NodeAttrs& attrs, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { #if MXNET_USE_OPENCV CHECK_NE(inputs[0].type_flag_, mshadow::kFloat16) << "imresize doesn't support fp16"; const int DTYPE[] = {CV_32F, CV_64F, -1, CV_8U, CV_32S}; - int cv_type = CV_MAKETYPE(DTYPE[inputs[0].type_flag_], inputs[0].shape_[2]); + int cv_type = CV_MAKETYPE(DTYPE[inputs[0].type_flag_], inputs[0].shape_[2]); const auto& param = nnvm::get(attrs.parsed); cv::Mat buf(inputs[0].shape_[0], inputs[0].shape_[1], cv_type, inputs[0].dptr_); cv::Mat dst(outputs[0].shape_[0], outputs[0].shape_[1], cv_type, outputs[0].dptr_); @@ -357,50 +345,52 @@ inline void copyMakeBorder(const nnvm::NodeAttrs& attrs, } NNVM_REGISTER_OP(_cvimdecode) -.add_alias("_npi_cvimdecode") -.describe("Decode image with OpenCV. \n" - "Note: return image in RGB by default, " - "instead of OpenCV's default BGR.") -.set_num_inputs(1) -.set_num_outputs(1) -.set_attr_parser(op::ParamParser) -.set_attr("FNDArrayFunction", Imdecode) -.add_argument("buf", "NDArray", "Buffer containing binary encoded image") -.add_arguments(ImdecodeParam::__FIELDS__()); + .add_alias("_npi_cvimdecode") + .describe( + "Decode image with OpenCV. \n" + "Note: return image in RGB by default, " + "instead of OpenCV's default BGR.") + .set_num_inputs(1) + .set_num_outputs(1) + .set_attr_parser(op::ParamParser) + .set_attr("FNDArrayFunction", Imdecode) + .add_argument("buf", "NDArray", "Buffer containing binary encoded image") + .add_arguments(ImdecodeParam::__FIELDS__()); NNVM_REGISTER_OP(_cvimread) -.add_alias("_npi_cvimread") -.describe("Read and decode image with OpenCV. \n" - "Note: return image in RGB by default, " - "instead of OpenCV's default BGR.") -.set_num_inputs(0) -.set_num_outputs(1) -.set_attr_parser(op::ParamParser) -.set_attr("FNDArrayFunction", Imread) -.add_arguments(ImreadParam::__FIELDS__()); + .add_alias("_npi_cvimread") + .describe( + "Read and decode image with OpenCV. \n" + "Note: return image in RGB by default, " + "instead of OpenCV's default BGR.") + .set_num_inputs(0) + .set_num_outputs(1) + .set_attr_parser(op::ParamParser) + .set_attr("FNDArrayFunction", Imread) + .add_arguments(ImreadParam::__FIELDS__()); NNVM_REGISTER_OP(_cvimresize) -.add_alias("_npi_cvimresize") -.describe("Resize image with OpenCV. \n") -.set_num_inputs(1) -.set_num_outputs(1) -.set_attr_parser(op::ParamParser) -.set_attr("FInferShape", ResizeShape) -.set_attr("FInferType", op::ElemwiseType<1, 1>) -.set_attr("FCompute", Imresize) -.add_argument("src", "NDArray", "source image") -.add_arguments(ResizeParam::__FIELDS__()); + .add_alias("_npi_cvimresize") + .describe("Resize image with OpenCV. \n") + .set_num_inputs(1) + .set_num_outputs(1) + .set_attr_parser(op::ParamParser) + .set_attr("FInferShape", ResizeShape) + .set_attr("FInferType", op::ElemwiseType<1, 1>) + .set_attr("FCompute", Imresize) + .add_argument("src", "NDArray", "source image") + .add_arguments(ResizeParam::__FIELDS__()); NNVM_REGISTER_OP(_cvcopyMakeBorder) -.describe("Pad image border with OpenCV. \n") -.set_num_inputs(1) -.set_num_outputs(1) -.set_attr_parser(op::ParamParser) -.set_attr("FInferShape", MakeBorderShape) -.set_attr("FInferType", op::ElemwiseType<1, 1>) -.set_attr("FCompute", copyMakeBorder) -.add_argument("src", "NDArray", "source image") -.add_arguments(MakeBorderParam::__FIELDS__()); + .describe("Pad image border with OpenCV. \n") + .set_num_inputs(1) + .set_num_outputs(1) + .set_attr_parser(op::ParamParser) + .set_attr("FInferShape", MakeBorderShape) + .set_attr("FInferType", op::ElemwiseType<1, 1>) + .set_attr("FCompute", copyMakeBorder) + .add_argument("src", "NDArray", "source image") + .add_arguments(MakeBorderParam::__FIELDS__()); } // namespace io } // namespace mxnet diff --git a/src/io/image_iter_common.h b/src/io/image_iter_common.h index edda3d9dfe97..2b6276aa7779 100644 --- a/src/io/image_iter_common.h +++ b/src/io/image_iter_common.h @@ -41,28 +41,28 @@ class ImageLabelMap { * \param path_imglist path to the image list * \param label_width predefined label_width */ - explicit ImageLabelMap(const char *path_imglist, - index_t label_width, - bool silent) { + explicit ImageLabelMap(const char* path_imglist, index_t label_width, bool silent) { this->label_width = label_width; image_index_.clear(); label_.clear(); idx2label_.clear(); - dmlc::InputSplit *fi = dmlc::InputSplit::Create - (path_imglist, 0, 1, "text"); + dmlc::InputSplit* fi = dmlc::InputSplit::Create(path_imglist, 0, 1, "text"); dmlc::InputSplit::Blob rec; while (fi->NextRecord(&rec)) { // quick manual parsing - char *p = reinterpret_cast(rec.dptr); - char *end = p + rec.size; + char* p = reinterpret_cast(rec.dptr); + char* end = p + rec.size; // skip space - while (isspace(*p) && p != end) ++p; + while (isspace(*p) && p != end) + ++p; image_index_.push_back(static_cast(atol(p))); for (index_t i = 0; i < label_width; ++i) { // skip till space - while (!isspace(*p) && p != end) ++p; + while (!isspace(*p) && p != end) + ++p; // skip space - while (isspace(*p) && p != end) ++p; + while (isspace(*p) && p != end) + ++p; CHECK(p != end) << "Bad ImageList format"; label_.push_back(static_cast(atof(p))); } @@ -74,23 +74,21 @@ class ImageLabelMap { idx2label_[image_index_[i]] = dmlc::BeginPtr(label_) + i * label_width; } if (!silent) { - LOG(INFO) << "Loaded ImageList from " << path_imglist << ' ' - << image_index_.size() << " Image records"; + LOG(INFO) << "Loaded ImageList from " << path_imglist << ' ' << image_index_.size() + << " Image records"; } } /*! \brief find a label for corresponding index */ inline mshadow::Tensor Find(size_t imid) const { - std::unordered_map::const_iterator it - = idx2label_.find(imid); + std::unordered_map::const_iterator it = idx2label_.find(imid); CHECK(it != idx2label_.end()) << "fail to find imagelabel for id " << imid; return mshadow::Tensor(it->second, mshadow::Shape1(label_width)); } /*! \brief find a label for corresponding index, return vector as copy */ inline std::vector FindCopy(size_t imid) const { - std::unordered_map::const_iterator it - = idx2label_.find(imid); + std::unordered_map::const_iterator it = idx2label_.find(imid); CHECK(it != idx2label_.end()) << "fail to find imagelabel for id " << imid; - const real_t *ptr = it->second; + const real_t* ptr = it->second; return std::vector(ptr, ptr + label_width); } @@ -138,43 +136,57 @@ struct ImageRecParserParam : public dmlc::Parameter { // declare parameters DMLC_DECLARE_PARAMETER(ImageRecParserParam) { - DMLC_DECLARE_FIELD(path_imglist).set_default("") - .describe("Path to the image list (.lst) file. Generally created with tools/im2rec.py. "\ - "Format (Tab separated): "\ - "\t\t."); - DMLC_DECLARE_FIELD(path_imgrec).set_default("") - .describe("Path to the image RecordIO (.rec) file or a directory path. "\ - "Created with tools/im2rec.py."); - DMLC_DECLARE_FIELD(path_imgidx).set_default("") - .describe("Path to the image RecordIO index (.idx) file. "\ - "Created with tools/im2rec.py."); - DMLC_DECLARE_FIELD(aug_seq).set_default("aug_default") - .describe("The augmenter names to represent"\ - " sequence of augmenters to be applied, seperated by comma." \ - " Additional keyword parameters will be seen by these augmenters."); - DMLC_DECLARE_FIELD(label_width).set_lower_bound(1).set_default(1) + DMLC_DECLARE_FIELD(path_imglist) + .set_default("") + .describe( + "Path to the image list (.lst) file. Generally created with tools/im2rec.py. " + "Format (Tab separated): " + "\t\t."); + DMLC_DECLARE_FIELD(path_imgrec) + .set_default("") + .describe( + "Path to the image RecordIO (.rec) file or a directory path. " + "Created with tools/im2rec.py."); + DMLC_DECLARE_FIELD(path_imgidx) + .set_default("") + .describe( + "Path to the image RecordIO index (.idx) file. " + "Created with tools/im2rec.py."); + DMLC_DECLARE_FIELD(aug_seq) + .set_default("aug_default") + .describe( + "The augmenter names to represent" + " sequence of augmenters to be applied, seperated by comma." + " Additional keyword parameters will be seen by these augmenters."); + DMLC_DECLARE_FIELD(label_width) + .set_lower_bound(1) + .set_default(1) .describe("The number of labels per image."); DMLC_DECLARE_FIELD(data_shape) - .set_expect_ndim(3).enforce_nonzero() + .set_expect_ndim(3) + .enforce_nonzero() .describe("The shape of one output image in (channels, height, width) format."); - DMLC_DECLARE_FIELD(preprocess_threads).set_lower_bound(1).set_default(4) + DMLC_DECLARE_FIELD(preprocess_threads) + .set_lower_bound(1) + .set_default(4) .describe("The number of threads to do preprocessing."); - DMLC_DECLARE_FIELD(verbose).set_default(true) - .describe("If or not output verbose information."); - DMLC_DECLARE_FIELD(num_parts).set_default(1) - .describe("Virtually partition the data into these many parts."); - DMLC_DECLARE_FIELD(part_index).set_default(0) + DMLC_DECLARE_FIELD(verbose).set_default(true).describe("If or not output verbose information."); + DMLC_DECLARE_FIELD(num_parts).set_default(1).describe( + "Virtually partition the data into these many parts."); + DMLC_DECLARE_FIELD(part_index) + .set_default(0) .describe("The *i*-th virtual partition to be read."); - DMLC_DECLARE_FIELD(device_id).set_default(0) - .describe("The device id used to create context for internal NDArray. "\ - "Setting device_id to -1 will create Context::CPU(0). Setting " - "device_id to valid positive device id will create " - "Context::CPUPinned(device_id). Default is 0."); - DMLC_DECLARE_FIELD(shuffle_chunk_size).set_default(0) + DMLC_DECLARE_FIELD(device_id).set_default(0).describe( + "The device id used to create context for internal NDArray. " + "Setting device_id to -1 will create Context::CPU(0). Setting " + "device_id to valid positive device id will create " + "Context::CPUPinned(device_id). Default is 0."); + DMLC_DECLARE_FIELD(shuffle_chunk_size) + .set_default(0) .describe("The data shuffle buffer size in MB. Only valid if shuffle is true."); - DMLC_DECLARE_FIELD(shuffle_chunk_seed).set_default(0) - .describe("The random seed for shuffling"); - DMLC_DECLARE_FIELD(seed_aug).set_default(dmlc::optional()) + DMLC_DECLARE_FIELD(shuffle_chunk_seed).set_default(0).describe("The random seed for shuffling"); + DMLC_DECLARE_FIELD(seed_aug) + .set_default(dmlc::optional()) .describe("Random seed for augmentations."); } }; @@ -187,9 +199,9 @@ struct BatchParam : public dmlc::Parameter { bool round_batch; // declare parameters DMLC_DECLARE_PARAMETER(BatchParam) { - DMLC_DECLARE_FIELD(batch_size) - .describe("Batch size."); - DMLC_DECLARE_FIELD(round_batch).set_default(true) + DMLC_DECLARE_FIELD(batch_size).describe("Batch size."); + DMLC_DECLARE_FIELD(round_batch) + .set_default(true) .describe("Whether to use round robin to handle overflow batch or not."); } }; @@ -211,26 +223,27 @@ struct BatchSamplerParam : public dmlc::Parameter { int last_batch; // declare parameters DMLC_DECLARE_PARAMETER(BatchSamplerParam) { - DMLC_DECLARE_FIELD(batch_size) - .describe("Batch size."); - DMLC_DECLARE_FIELD(last_batch).set_default(kKeep) + DMLC_DECLARE_FIELD(batch_size).describe("Batch size."); + DMLC_DECLARE_FIELD(last_batch) + .set_default(kKeep) .add_enum("keep", kKeep) .add_enum("rollover", kRollOver) .add_enum("discard", kDiscard) - .describe("Specifies how the last batch is handled if batch_size does not evenly " - "divide sequence length. " - "If 'keep', the last batch will be returned directly, but will contain " - "less element than `batch_size` requires. " - "If 'discard', the last batch will be discarded. " - "If 'rollover', the remaining elements will be rolled over to the next " - "iteration. Note: legacy batch param with round_batch will always round data " - "in order to always provide full batchs. Rollover behavior will instead result " - "in different iteration sizes for each epoch."); + .describe( + "Specifies how the last batch is handled if batch_size does not evenly " + "divide sequence length. " + "If 'keep', the last batch will be returned directly, but will contain " + "less element than `batch_size` requires. " + "If 'discard', the last batch will be discarded. " + "If 'rollover', the remaining elements will be rolled over to the next " + "iteration. Note: legacy batch param with round_batch will always round data " + "in order to always provide full batchs. Rollover behavior will instead result " + "in different iteration sizes for each epoch."); } }; // Define image record parameters -struct ImageRecordParam: public dmlc::Parameter { +struct ImageRecordParam : public dmlc::Parameter { /*! \brief whether to do shuffle */ bool shuffle; /*! \brief random seed */ @@ -239,17 +252,16 @@ struct ImageRecordParam: public dmlc::Parameter { bool verbose; // declare parameters DMLC_DECLARE_PARAMETER(ImageRecordParam) { - DMLC_DECLARE_FIELD(shuffle).set_default(false) - .describe("Whether to shuffle data randomly or not."); - DMLC_DECLARE_FIELD(seed).set_default(0) - .describe("The random seed."); - DMLC_DECLARE_FIELD(verbose).set_default(true) - .describe("Whether to output verbose information or not."); + DMLC_DECLARE_FIELD(shuffle).set_default(false).describe( + "Whether to shuffle data randomly or not."); + DMLC_DECLARE_FIELD(seed).set_default(0).describe("The random seed."); + DMLC_DECLARE_FIELD(verbose).set_default(true).describe( + "Whether to output verbose information or not."); } }; // normalize parameters -struct ImageNormalizeParam : public dmlc::Parameter { +struct ImageNormalizeParam : public dmlc::Parameter { /*! \brief random seed */ int seed; /*! \brief whether to mirror the image */ @@ -284,48 +296,50 @@ struct ImageNormalizeParam : public dmlc::Parameter { bool verbose; // declare parameters DMLC_DECLARE_PARAMETER(ImageNormalizeParam) { - DMLC_DECLARE_FIELD(seed).set_default(0) - .describe("The random seed."); - DMLC_DECLARE_FIELD(mirror).set_default(false) - .describe("Whether to mirror the image or not. If true, images are "\ - "flipped along the horizontal axis."); - DMLC_DECLARE_FIELD(rand_mirror).set_default(false) - .describe("Whether to randomly mirror images or not. If true, 50% of "\ - "the images will be randomly mirrored (flipped along the "\ - "horizontal axis)"); - DMLC_DECLARE_FIELD(mean_img).set_default("") - .describe("Filename of the mean image."); - DMLC_DECLARE_FIELD(mean_r).set_default(0.0f) - .describe("The mean value to be subtracted on the R channel"); - DMLC_DECLARE_FIELD(mean_g).set_default(0.0f) - .describe("The mean value to be subtracted on the G channel"); - DMLC_DECLARE_FIELD(mean_b).set_default(0.0f) - .describe("The mean value to be subtracted on the B channel"); - DMLC_DECLARE_FIELD(mean_a).set_default(0.0f) - .describe("The mean value to be subtracted on the alpha channel"); - DMLC_DECLARE_FIELD(std_r).set_default(1.0f) - .describe("Augmentation Param: Standard deviation on R channel."); - DMLC_DECLARE_FIELD(std_g).set_default(1.0f) - .describe("Augmentation Param: Standard deviation on G channel."); - DMLC_DECLARE_FIELD(std_b).set_default(1.0f) - .describe("Augmentation Param: Standard deviation on B channel."); - DMLC_DECLARE_FIELD(std_a).set_default(1.0f) - .describe("Augmentation Param: Standard deviation on Alpha channel."); - DMLC_DECLARE_FIELD(scale).set_default(1.0f) - .describe("Multiply the image with a scale value."); - DMLC_DECLARE_FIELD(max_random_contrast).set_default(0.0f) - .describe("Change the contrast with a value randomly chosen from " - "``[-max_random_contrast, max_random_contrast]``"); - DMLC_DECLARE_FIELD(max_random_illumination).set_default(0.0f) - .describe("Change the illumination with a value randomly chosen from " - "``[-max_random_illumination, max_random_illumination]``"); - DMLC_DECLARE_FIELD(verbose).set_default(true) - .describe("If or not output verbose information."); + DMLC_DECLARE_FIELD(seed).set_default(0).describe("The random seed."); + DMLC_DECLARE_FIELD(mirror).set_default(false).describe( + "Whether to mirror the image or not. If true, images are " + "flipped along the horizontal axis."); + DMLC_DECLARE_FIELD(rand_mirror) + .set_default(false) + .describe( + "Whether to randomly mirror images or not. If true, 50% of " + "the images will be randomly mirrored (flipped along the " + "horizontal axis)"); + DMLC_DECLARE_FIELD(mean_img).set_default("").describe("Filename of the mean image."); + DMLC_DECLARE_FIELD(mean_r).set_default(0.0f).describe( + "The mean value to be subtracted on the R channel"); + DMLC_DECLARE_FIELD(mean_g).set_default(0.0f).describe( + "The mean value to be subtracted on the G channel"); + DMLC_DECLARE_FIELD(mean_b).set_default(0.0f).describe( + "The mean value to be subtracted on the B channel"); + DMLC_DECLARE_FIELD(mean_a).set_default(0.0f).describe( + "The mean value to be subtracted on the alpha channel"); + DMLC_DECLARE_FIELD(std_r).set_default(1.0f).describe( + "Augmentation Param: Standard deviation on R channel."); + DMLC_DECLARE_FIELD(std_g).set_default(1.0f).describe( + "Augmentation Param: Standard deviation on G channel."); + DMLC_DECLARE_FIELD(std_b).set_default(1.0f).describe( + "Augmentation Param: Standard deviation on B channel."); + DMLC_DECLARE_FIELD(std_a).set_default(1.0f).describe( + "Augmentation Param: Standard deviation on Alpha channel."); + DMLC_DECLARE_FIELD(scale).set_default(1.0f).describe("Multiply the image with a scale value."); + DMLC_DECLARE_FIELD(max_random_contrast) + .set_default(0.0f) + .describe( + "Change the contrast with a value randomly chosen from " + "``[-max_random_contrast, max_random_contrast]``"); + DMLC_DECLARE_FIELD(max_random_illumination) + .set_default(0.0f) + .describe( + "Change the illumination with a value randomly chosen from " + "``[-max_random_illumination, max_random_illumination]``"); + DMLC_DECLARE_FIELD(verbose).set_default(true).describe("If or not output verbose information."); } }; // normalize det parameters -struct ImageDetNormalizeParam : public dmlc::Parameter { +struct ImageDetNormalizeParam : public dmlc::Parameter { /*! \brief random seed */ int seed; /*! \brief mean file string */ @@ -352,36 +366,35 @@ struct ImageDetNormalizeParam : public dmlc::Parameter bool verbose; // declare parameters DMLC_DECLARE_PARAMETER(ImageDetNormalizeParam) { - DMLC_DECLARE_FIELD(seed).set_default(0) - .describe("Augmentation Param: Random Seed."); - DMLC_DECLARE_FIELD(mean_img).set_default("") - .describe("Augmentation Param: Mean Image to be subtracted."); - DMLC_DECLARE_FIELD(mean_r).set_default(0.0f) - .describe("Augmentation Param: Mean value on R channel."); - DMLC_DECLARE_FIELD(mean_g).set_default(0.0f) - .describe("Augmentation Param: Mean value on G channel."); - DMLC_DECLARE_FIELD(mean_b).set_default(0.0f) - .describe("Augmentation Param: Mean value on B channel."); - DMLC_DECLARE_FIELD(mean_a).set_default(0.0f) - .describe("Augmentation Param: Mean value on Alpha channel."); - DMLC_DECLARE_FIELD(std_r).set_default(0.0f) - .describe("Augmentation Param: Standard deviation on R channel."); - DMLC_DECLARE_FIELD(std_g).set_default(0.0f) - .describe("Augmentation Param: Standard deviation on G channel."); - DMLC_DECLARE_FIELD(std_b).set_default(0.0f) - .describe("Augmentation Param: Standard deviation on B channel."); - DMLC_DECLARE_FIELD(std_a).set_default(0.0f) - .describe("Augmentation Param: Standard deviation on Alpha channel."); - DMLC_DECLARE_FIELD(scale).set_default(1.0f) - .describe("Augmentation Param: Scale in color space."); - DMLC_DECLARE_FIELD(verbose).set_default(true) - .describe("Augmentation Param: Whether to print augmentor info."); + DMLC_DECLARE_FIELD(seed).set_default(0).describe("Augmentation Param: Random Seed."); + DMLC_DECLARE_FIELD(mean_img).set_default("").describe( + "Augmentation Param: Mean Image to be subtracted."); + DMLC_DECLARE_FIELD(mean_r).set_default(0.0f).describe( + "Augmentation Param: Mean value on R channel."); + DMLC_DECLARE_FIELD(mean_g).set_default(0.0f).describe( + "Augmentation Param: Mean value on G channel."); + DMLC_DECLARE_FIELD(mean_b).set_default(0.0f).describe( + "Augmentation Param: Mean value on B channel."); + DMLC_DECLARE_FIELD(mean_a).set_default(0.0f).describe( + "Augmentation Param: Mean value on Alpha channel."); + DMLC_DECLARE_FIELD(std_r).set_default(0.0f).describe( + "Augmentation Param: Standard deviation on R channel."); + DMLC_DECLARE_FIELD(std_g).set_default(0.0f).describe( + "Augmentation Param: Standard deviation on G channel."); + DMLC_DECLARE_FIELD(std_b).set_default(0.0f).describe( + "Augmentation Param: Standard deviation on B channel."); + DMLC_DECLARE_FIELD(std_a).set_default(0.0f).describe( + "Augmentation Param: Standard deviation on Alpha channel."); + DMLC_DECLARE_FIELD(scale).set_default(1.0f).describe( + "Augmentation Param: Scale in color space."); + DMLC_DECLARE_FIELD(verbose).set_default(true).describe( + "Augmentation Param: Whether to print augmentor info."); } }; // Define prefetcher parameters struct PrefetcherParam : public dmlc::Parameter { - enum CtxType { kGPU = 0, kCPU, kCPUPinned, kCPUShared}; + enum CtxType { kGPU = 0, kCPU, kCPUPinned, kCPUShared }; /*! \brief number of prefetched batches */ size_t prefetch_buffer; @@ -393,30 +406,33 @@ struct PrefetcherParam : public dmlc::Parameter { // declare parameters DMLC_DECLARE_PARAMETER(PrefetcherParam) { - DMLC_DECLARE_FIELD(prefetch_buffer).set_default(4) + DMLC_DECLARE_FIELD(prefetch_buffer) + .set_default(4) .describe("Maximum number of batches to prefetch."); - DMLC_DECLARE_FIELD(ctx).set_default(kGPU) + DMLC_DECLARE_FIELD(ctx) + .set_default(kGPU) .add_enum("cpu", kCPU) .add_enum("gpu", kGPU) .add_enum("cpu_pinned", kCPUPinned) - .describe("Context data loader optimized for. " - "Note that it only indicates the optimization strategy for devices, " - "by no means the prefetcher will load data to GPUs. " - "If ctx is 'cpu_pinned' and device_id is not -1, " - "it will use cpu_pinned(device_id) as ctx"); - DMLC_DECLARE_FIELD(device_id).set_default(-1) - .describe("The default device id for context. -1 indicate it's on default device"); + .describe( + "Context data loader optimized for. " + "Note that it only indicates the optimization strategy for devices, " + "by no means the prefetcher will load data to GPUs. " + "If ctx is 'cpu_pinned' and device_id is not -1, " + "it will use cpu_pinned(device_id) as ctx"); + DMLC_DECLARE_FIELD(device_id).set_default(-1).describe( + "The default device id for context. -1 indicate it's on default device"); DMLC_DECLARE_FIELD(dtype) - .add_enum("float32", mshadow::kFloat32) - .add_enum("float64", mshadow::kFloat64) - .add_enum("float16", mshadow::kFloat16) - .add_enum("bfloat16", mshadow::kBfloat16) - .add_enum("int64", mshadow::kInt64) - .add_enum("int32", mshadow::kInt32) - .add_enum("uint8", mshadow::kUint8) - .add_enum("int8", mshadow::kInt8) - .set_default(dmlc::optional()) - .describe("Output data type. ``None`` means no change."); + .add_enum("float32", mshadow::kFloat32) + .add_enum("float64", mshadow::kFloat64) + .add_enum("float16", mshadow::kFloat16) + .add_enum("bfloat16", mshadow::kBfloat16) + .add_enum("int64", mshadow::kInt64) + .add_enum("int32", mshadow::kInt32) + .add_enum("uint8", mshadow::kUint8) + .add_enum("int8", mshadow::kInt8) + .set_default(dmlc::optional()) + .describe("Output data type. ``None`` means no change."); } }; diff --git a/src/io/image_recordio.h b/src/io/image_recordio.h index 131bfda905f9..d2ee44779ea2 100644 --- a/src/io/image_recordio.h +++ b/src/io/image_recordio.h @@ -60,16 +60,15 @@ struct ImageRecordIO { /*! \brief header of image recordio */ Header header; /*! \brief point to label */ - float *label; + float* label; /*! \brief number of float labels */ int num_label; /*! \brief pointer to data content */ - uint8_t *content; + uint8_t* content; /*! \brief size of the content */ size_t content_size; /*! \brief constructor */ - ImageRecordIO(void) - : label(nullptr), num_label(0), content(nullptr), content_size(0) { + ImageRecordIO(void) : label(nullptr), num_label(0), content(nullptr), content_size(0) { memset(&header, 0, sizeof(header)); } /*! \brief get image id from record */ @@ -81,26 +80,26 @@ struct ImageRecordIO { * \param buf the head of record * \param size the size of the entire record */ - inline void Load(void *buf, size_t size) { + inline void Load(void* buf, size_t size) { CHECK(size >= sizeof(header)); std::memcpy(&header, buf, sizeof(header)); - content = reinterpret_cast(buf) + sizeof(header); + content = reinterpret_cast(buf) + sizeof(header); content_size = size - sizeof(header); if (header.flag > 0) { - CHECK(content_size >= sizeof(float)*header.flag); - label = reinterpret_cast(content); + CHECK(content_size >= sizeof(float) * header.flag); + label = reinterpret_cast(content); num_label = header.flag; - content = reinterpret_cast(label + header.flag); - content_size -= sizeof(float)*header.flag; + content = reinterpret_cast(label + header.flag); + content_size -= sizeof(float) * header.flag; } else { - label = nullptr; + label = nullptr; num_label = 0; } } /*! * \brief save the record header */ - inline void SaveHeader(std::string *blob) const { + inline void SaveHeader(std::string* blob) const { blob->resize(sizeof(header)); std::memcpy(dmlc::BeginPtr(*blob), &header, sizeof(header)); } diff --git a/src/io/inst_vector.h b/src/io/inst_vector.h index 78630f3959f3..21d2fbed6a80 100644 --- a/src/io/inst_vector.h +++ b/src/io/inst_vector.h @@ -42,19 +42,19 @@ namespace io { * * data are stored in memory continuously */ -template +template class TensorVector { public: TensorVector(void) { this->Clear(); } /*! \brief get the buffer to the i-th tensor */ - inline mshadow::Tensor - operator[](size_t i) const { + inline mshadow::Tensor operator[](size_t i) const { CHECK_LT(i + 1, offset_.size()); CHECK_EQ(shape_[i].Size(), offset_[i + 1] - offset_[i]); - return mshadow::Tensor - ((DType*)dmlc::BeginPtr(content_) + offset_[i], shape_[i]); // NOLINT(*) + return mshadow::Tensor( + (DType*)dmlc::BeginPtr(content_) + offset_[i], // NOLINT(*) + shape_[i]); // NOLINT(*) } inline mshadow::Tensor Back() const { return (*this)[Size() - 1]; @@ -87,7 +87,7 @@ class TensorVector { /*! * \brief a list of (label, example) pairs, examples can have various shape */ -template +template class InstVector { public: /*! \brief return the number of (label, example) pairs */ @@ -124,9 +124,7 @@ class InstVector { * \brief push a (label, example) pair * only reserved the space, while the data is not copied */ - inline void Push(unsigned index, - mshadow::Shape<3> dshape, - mshadow::Shape<1> lshape) { + inline void Push(unsigned index, mshadow::Shape<3> dshape, mshadow::Shape<1> lshape) { index_.push_back(index); data_.Push(dshape); label_.Push(lshape); @@ -157,11 +155,12 @@ class InstVector { struct TBlobBatch { public: /*! \brief unique id for instance, can be NULL, sometimes is useful */ - unsigned *inst_index; + unsigned* inst_index; /*! \brief number of instance */ mshadow::index_t batch_size; /*! \brief number of padding elements in this batch, - this is used to indicate the last elements in the batch are only padded up to match the batch, and should be discarded */ + this is used to indicate the last elements in the batch are only padded up to match the + batch, and should be discarded */ mshadow::index_t num_batch_padd; /*! \brief content of dense data */ std::vector data; @@ -169,8 +168,9 @@ struct TBlobBatch { std::string extra_data; /*! \brief constructor */ TBlobBatch(void) { - inst_index = nullptr; - batch_size = 0; num_batch_padd = 0; + inst_index = nullptr; + batch_size = 0; + num_batch_padd = 0; } /*! \brief destructor */ ~TBlobBatch() { @@ -180,21 +180,20 @@ struct TBlobBatch { class TBlobContainer : public TBlob { public: - TBlobContainer(void) - : TBlob(), tensor_container_(nullptr) {} + TBlobContainer(void) : TBlob(), tensor_container_(nullptr) {} ~TBlobContainer() { if (tensor_container_) { release(); } } - void resize(const mxnet::TShape &shape, int type_flag) { + void resize(const mxnet::TShape& shape, int type_flag) { if (tensor_container_) { CHECK_EQ(this->type_flag_, type_flag); this->shape_ = shape; resize(); } else { this->type_flag_ = type_flag; - this->shape_ = shape; + this->shape_ = shape; create(); } } @@ -204,24 +203,22 @@ class TBlobContainer : public TBlob { CHECK(tensor_container_ == nullptr); CHECK_EQ(this->dev_mask(), mshadow::cpu::kDevMask); MSHADOW_TYPE_SWITCH(this->type_flag_, DType, { - auto tensor_container = new mshadow::TensorContainer(false); - tensor_container->Resize(mshadow::Shape1(shape_.Size())); - dptr_ = tensor_container->dptr_; - tensor_container_ = tensor_container; + auto tensor_container = new mshadow::TensorContainer(false); + tensor_container->Resize(mshadow::Shape1(shape_.Size())); + dptr_ = tensor_container->dptr_; + tensor_container_ = tensor_container; }); } void resize() { MSHADOW_TYPE_SWITCH(this->type_flag_, DType, { - auto tensor_container = - (mshadow::TensorContainer*) tensor_container_; - tensor_container->Resize(mshadow::Shape1(shape_.Size())); + auto tensor_container = (mshadow::TensorContainer*)tensor_container_; + tensor_container->Resize(mshadow::Shape1(shape_.Size())); }); } void release() { MSHADOW_TYPE_SWITCH(this->type_flag_, DType, { - auto tensor_container = - (mshadow::TensorContainer*) tensor_container_; - delete tensor_container; + auto tensor_container = (mshadow::TensorContainer*)tensor_container_; + delete tensor_container; }); } diff --git a/src/io/iter_batchloader.h b/src/io/iter_batchloader.h index 7532033fefdd..00b0642cbfe1 100644 --- a/src/io/iter_batchloader.h +++ b/src/io/iter_batchloader.h @@ -41,9 +41,7 @@ namespace io { /*! \brief create a batch iterator from single instance iterator */ class BatchLoader : public IIterator { public: - explicit BatchLoader(IIterator *base): - head_(1), num_overflow_(0), base_(base) { - } + explicit BatchLoader(IIterator* base) : head_(1), num_overflow_(0), base_(base) {} virtual ~BatchLoader(void) { delete base_; @@ -73,15 +71,16 @@ class BatchLoader : public IIterator { virtual bool Next(void) { out_.num_batch_padd = 0; - out_.batch_size = param_.batch_size; - this->head_ = 0; + out_.batch_size = param_.batch_size; + this->head_ = 0; // if overflow from previous round, directly return false, until before first is called - if (num_overflow_ != 0) return false; + if (num_overflow_ != 0) + return false; size_t top = 0; while (base_->Next()) { - const DataInst& d = base_->Value(); + const DataInst& d = base_->Value(); out_.inst_index[top] = d.index; if (data_.size() == 0) { this->InitData(d); @@ -89,11 +88,10 @@ class BatchLoader : public IIterator { for (size_t i = 0; i < d.data.size(); ++i) { CHECK_EQ(unit_size_[i], d.data[i].Size()); MSHADOW_TYPE_SWITCH(data_[i].type_flag_, DType, { - mshadow::Copy( - data_[i].get().Slice(top * unit_size_[i], - (top + 1) * unit_size_[i]), + mshadow::Copy( + data_[i].get().Slice(top * unit_size_[i], (top + 1) * unit_size_[i]), d.data[i].get_with_shape(mshadow::Shape1(unit_size_[i]))); - }); + }); } if (++top >= param_.batch_size) { return true; @@ -105,17 +103,17 @@ class BatchLoader : public IIterator { base_->BeforeFirst(); for (; top < param_.batch_size; ++top, ++num_overflow_) { CHECK(base_->Next()) << "number of input must be bigger than batch size"; - const DataInst& d = base_->Value(); + const DataInst& d = base_->Value(); out_.inst_index[top] = d.index; // copy data for (size_t i = 0; i < d.data.size(); ++i) { CHECK_EQ(unit_size_[i], d.data[i].Size()); MSHADOW_TYPE_SWITCH(data_[i].type_flag_, DType, { - mshadow::Copy( + mshadow::Copy( data_[i].get().Slice(top * unit_size_[i], (top + 1) * unit_size_[i]), d.data[i].get_with_shape(mshadow::Shape1(unit_size_[i]))); - }); + }); } } out_.num_batch_padd = num_overflow_; @@ -126,7 +124,7 @@ class BatchLoader : public IIterator { } return false; } - virtual const TBlobBatch &Value(void) const { + virtual const TBlobBatch& Value(void) const { return out_; } @@ -144,7 +142,7 @@ class BatchLoader : public IIterator { private: /*! \brief base iterator */ - IIterator *base_; + IIterator* base_; /*! \brief data shape */ mxnet::ShapeVector shape_; /*! \brief unit size */ @@ -156,7 +154,7 @@ class BatchLoader : public IIterator { unit_size_.resize(first_batch.data.size()); for (size_t i = 0; i < first_batch.data.size(); ++i) { mxnet::TShape src_shape = first_batch.data[i].shape_; - int src_type_flag = first_batch.data[i].type_flag_; + int src_type_flag = first_batch.data[i].type_flag_; // init object attributes std::vector shape_vec; shape_vec.push_back(param_.batch_size); @@ -177,9 +175,7 @@ class BatchLoader : public IIterator { */ class BatchSampler : public IIterator { public: - explicit BatchSampler(IIterator *base): - num_overflow_(0), base_(base) { - } + explicit BatchSampler(IIterator* base) : num_overflow_(0), base_(base) {} virtual ~BatchSampler(void) { delete base_; @@ -214,7 +210,7 @@ class BatchSampler : public IIterator { return (base_hint + num_overflow_) / param_.batch_size; } else { LOG(FATAL) << "last_batch must be one of 'keep', 'discard', or 'rollover'" - << " but got: " << param_.last_batch; + << " but got: " << param_.last_batch; } return -1; } @@ -233,11 +229,10 @@ class BatchSampler : public IIterator { for (size_t i = 0; i < d.data.size(); ++i) { CHECK_EQ(unit_size_[i], d.data[i].Size()); MSHADOW_TYPE_SWITCH(data_[i].type_flag_, DType, { - mshadow::Copy( - data_[i].get().Slice(top * unit_size_[i], - (top + 1) * unit_size_[i]), + mshadow::Copy( + data_[i].get().Slice(top * unit_size_[i], (top + 1) * unit_size_[i]), d.data[i].get_with_shape(mshadow::Shape1(unit_size_[i]))); - }); + }); } if (++top >= param_.batch_size) { num_overflow_ = 0; @@ -251,7 +246,7 @@ class BatchSampler : public IIterator { return false; } else if (param_.last_batch == param_.kKeep) { out_.num_batch_padd = param_.batch_size - top; - num_overflow_ = 0; + num_overflow_ = 0; return true; } else if (param_.last_batch == param_.kRollOver) { if (num_overflow_ > 0) { @@ -268,7 +263,7 @@ class BatchSampler : public IIterator { } return false; } - virtual const DataBatch &Value(void) const { + virtual const DataBatch& Value(void) const { return out_; } @@ -284,7 +279,7 @@ class BatchSampler : public IIterator { private: /*! \brief base iterator */ - IIterator *base_; + IIterator* base_; /*! \brief data shape */ mxnet::ShapeVector shape_; /*! \brief unit size */ @@ -296,7 +291,7 @@ class BatchSampler : public IIterator { unit_size_.resize(first_batch.data.size()); for (size_t i = 0; i < first_batch.data.size(); ++i) { mxnet::TShape src_shape = first_batch.data[i].shape_; - int src_type_flag = first_batch.data[i].type_flag_; + int src_type_flag = first_batch.data[i].type_flag_; // init object attributes std::vector shape_vec; shape_vec.push_back(param_.batch_size); @@ -307,8 +302,8 @@ class BatchSampler : public IIterator { shape_[i] = dst_shape; data_[i].resize(mshadow::Shape1(dst_shape.Size()), src_type_flag); unit_size_[i] = src_shape.Size(); - out_.data.push_back(NDArray(TBlob( - data_[i].dptr_, dst_shape, cpu::kDevMask, src_type_flag, 0), 0)); + out_.data.push_back( + NDArray(TBlob(data_[i].dptr_, dst_shape, cpu::kDevMask, src_type_flag, 0), 0)); } } }; // class BatchSampler diff --git a/src/io/iter_csv.cc b/src/io/iter_csv.cc index 87f295df544f..43560ffafc0c 100644 --- a/src/io/iter_csv.cc +++ b/src/io/iter_csv.cc @@ -44,20 +44,19 @@ struct CSVIterParam : public dmlc::Parameter { mxnet::TShape label_shape; // declare parameters DMLC_DECLARE_PARAMETER(CSVIterParam) { - DMLC_DECLARE_FIELD(data_csv) - .describe("The input CSV file or a directory path."); - DMLC_DECLARE_FIELD(data_shape) - .describe("The shape of one example."); - DMLC_DECLARE_FIELD(label_csv).set_default("NULL") - .describe("The input CSV file or a directory path. " - "If NULL, all labels will be returned as 0."); + DMLC_DECLARE_FIELD(data_csv).describe("The input CSV file or a directory path."); + DMLC_DECLARE_FIELD(data_shape).describe("The shape of one example."); + DMLC_DECLARE_FIELD(label_csv).set_default("NULL").describe( + "The input CSV file or a directory path. " + "If NULL, all labels will be returned as 0."); index_t shape1[] = {1}; - DMLC_DECLARE_FIELD(label_shape).set_default(mxnet::TShape(shape1, shape1 + 1)) + DMLC_DECLARE_FIELD(label_shape) + .set_default(mxnet::TShape(shape1, shape1 + 1)) .describe("The shape of one label."); } }; -class CSVIterBase: public IIterator { +class CSVIterBase : public IIterator { public: CSVIterBase() { out_.data.resize(2); @@ -71,7 +70,7 @@ class CSVIterBase: public IIterator { /*! \brief move to next item */ bool Next() override = 0; /*! \brief get current data */ - const DataInst &Value() const override { + const DataInst& Value() const override { return out_; } @@ -91,7 +90,7 @@ class CSVIterBase: public IIterator { }; template -class CSVIterTyped: public CSVIterBase { +class CSVIterTyped : public CSVIterBase { public: ~CSVIterTyped() override = default; // intialize iterator loads data in @@ -100,7 +99,7 @@ class CSVIterTyped: public CSVIterBase { data_parser_.reset(dmlc::Parser::Create(param_.data_csv.c_str(), 0, 1, "csv")); if (param_.label_csv != "NULL") { label_parser_.reset( - dmlc::Parser::Create(param_.label_csv.c_str(), 0, 1, "csv")); + dmlc::Parser::Create(param_.label_csv.c_str(), 0, 1, "csv")); } else { dummy_label.set_pad(false); dummy_label.Resize(mshadow::Shape1(1)); @@ -115,17 +114,19 @@ class CSVIterTyped: public CSVIterBase { } data_ptr_ = label_ptr_ = 0; data_size_ = label_size_ = 0; - inst_counter_ = 0; - end_ = false; + inst_counter_ = 0; + end_ = false; } bool Next() override { - if (end_) return false; + if (end_) + return false; while (data_ptr_ >= data_size_) { if (!data_parser_->Next()) { - end_ = true; return false; + end_ = true; + return false; } - data_ptr_ = 0; + data_ptr_ = 0; data_size_ = data_parser_->Value().size; } out_.index = inst_counter_++; @@ -136,7 +137,7 @@ class CSVIterTyped: public CSVIterBase { while (label_ptr_ >= label_size_) { CHECK(label_parser_->Next()) << "Data CSV's row is smaller than the number of rows in label_csv"; - label_ptr_ = 0; + label_ptr_ = 0; label_size_ = label_parser_->Value().size; } CHECK_LT(label_ptr_, label_size_); @@ -161,16 +162,16 @@ class CSVIterTyped: public CSVIterBase { std::unique_ptr > data_parser_; }; -class CSVIter: public IIterator { +class CSVIter : public IIterator { public: - CSVIter() = default; + CSVIter() = default; ~CSVIter() override = default; // intialize iterator loads data in void Init(const std::vector >& kwargs) override { param_.InitAllowUnknown(kwargs); bool dtype_has_value = false; - int target_dtype = -1; + int target_dtype = -1; for (const auto& arg : kwargs) { if (arg.first == "dtype") { dtype_has_value = true; @@ -203,7 +204,7 @@ class CSVIter: public IIterator { return iterator_->Next(); } - const DataInst &Value() const override { + const DataInst& Value() const override { return iterator_->Value(); } @@ -212,11 +213,10 @@ class CSVIter: public IIterator { std::unique_ptr iterator_; }; - DMLC_REGISTER_PARAMETER(CSVIterParam); MXNET_REGISTER_IO_ITER(CSVIter) -.describe(R"code(Returns the CSV file iterator. + .describe(R"code(Returns the CSV file iterator. In this function, the `data_shape` parameter is used to set the shape of each line of the input data. If a row in an input file is `1,2,3,4,5,6`` and `data_shape` is (3,2), that row @@ -306,14 +306,10 @@ Examples:: [3 4 5]] )code" ADD_FILELINE) -.add_arguments(CSVIterParam::__FIELDS__()) -.add_arguments(BatchParam::__FIELDS__()) -.add_arguments(PrefetcherParam::__FIELDS__()) -.set_body([]() { - return new PrefetcherIter( - new BatchLoader( - new CSVIter())); - }); + .add_arguments(CSVIterParam::__FIELDS__()) + .add_arguments(BatchParam::__FIELDS__()) + .add_arguments(PrefetcherParam::__FIELDS__()) + .set_body([]() { return new PrefetcherIter(new BatchLoader(new CSVIter())); }); } // namespace io } // namespace mxnet diff --git a/src/io/iter_image_det_recordio.cc b/src/io/iter_image_det_recordio.cc index 3fe0ec7f3e17..a6705cc09631 100644 --- a/src/io/iter_image_det_recordio.cc +++ b/src/io/iter_image_det_recordio.cc @@ -56,69 +56,75 @@ class ImageDetLabelMap { * \param path_imglist path to the image list * \param label_width predefined label_width, -1 for arbitrary width */ - explicit ImageDetLabelMap(const char *path_imglist, - int label_width, - bool silent) { + explicit ImageDetLabelMap(const char* path_imglist, int label_width, bool silent) { image_index_.clear(); label_.clear(); idx2label_.clear(); - dmlc::InputSplit *fi = dmlc::InputSplit::Create - (path_imglist, 0, 1, "text"); + dmlc::InputSplit* fi = dmlc::InputSplit::Create(path_imglist, 0, 1, "text"); dmlc::InputSplit::Blob rec; while (fi->NextRecord(&rec)) { // quick manual parsing - char *p = reinterpret_cast(rec.dptr); - char *end = p + rec.size; + char* p = reinterpret_cast(rec.dptr); + char* end = p + rec.size; // skip space - while (isspace(*p) && p != end) ++p; + while (isspace(*p) && p != end) + ++p; image_index_.push_back(static_cast(atol(p))); size_t start_pos = label_.size(); if (label_width > 0) { // provided label_width > 0, require width check for (int i = 0; i < label_width; ++i) { // skip till space - while (!isspace(*p) && p != end) ++p; + while (!isspace(*p) && p != end) + ++p; // skip space - while (isspace(*p) && p != end) ++p; + while (isspace(*p) && p != end) + ++p; CHECK(p != end) << "Bad ImageList format"; label_.push_back(static_cast(atof(p))); } CHECK_EQ(label_.size() - start_pos, label_width); } else { // arbitrary label width for each sample - while (!isspace(*p) && p != end) ++p; - while (isspace(*p) && p != end) ++p; - char *curr = p; + while (!isspace(*p) && p != end) + ++p; + while (isspace(*p) && p != end) + ++p; + char* curr = p; CHECK(curr != end) << "Bad ImageList format"; - while (!isspace(*p) && p != end) ++p; - while (isspace(*p) && p != end) ++p; - char *next = p; + while (!isspace(*p) && p != end) + ++p; + while (isspace(*p) && p != end) + ++p; + char* next = p; while (next != end) { label_.push_back(static_cast(atof(curr))); curr = next; - while (!isspace(*next) && next != end) ++next; - while (isspace(*next) && next != end) ++next; + while (!isspace(*next) && next != end) + ++next; + while (isspace(*next) && next != end) + ++next; } // skip the last one which should be the image_path CHECK_GT(label_.size(), start_pos) << "Bad ImageList format: empty label"; } // record label start_pos and width in map - idx2label_[image_index_.back()] = std::pair( - start_pos, label_.size() - start_pos); + idx2label_[image_index_.back()] = + std::pair(start_pos, label_.size() - start_pos); } delete fi; if (!silent) { - LOG(INFO) << "Loaded ImageList from " << path_imglist << ' ' - << image_index_.size() << " Image records"; + LOG(INFO) << "Loaded ImageList from " << path_imglist << ' ' << image_index_.size() + << " Image records"; } } /*! \brief find a label for corresponding index, return vector as copy */ inline std::vector FindCopy(size_t imid) const { - std::unordered_map >::const_iterator it - = idx2label_.find(imid); + std::unordered_map>::const_iterator it = + idx2label_.find(imid); CHECK(it != idx2label_.end()) << "fail to find imagelabel for id " << imid; - const real_t *ptr = dmlc::BeginPtr(label_) + it->second.first; + const real_t* ptr = dmlc::BeginPtr(label_) + it->second.first; return std::vector(ptr, ptr + it->second.second); } @@ -127,7 +133,8 @@ class ImageDetLabelMap { size_t max_width = 0; for (auto i : idx2label_) { size_t width = i.second.second; - if (width > max_width) max_width = width; + if (width > max_width) + max_width = width; } return max_width; } @@ -138,7 +145,7 @@ class ImageDetLabelMap { /*! \brief vectors storing raw labels in 1D */ std::vector label_; /*! \brief map storing image index to pair */ - std::unordered_map > idx2label_; + std::unordered_map> idx2label_; }; // class ImageDetLabelMap // Define image record parser parameters @@ -172,46 +179,53 @@ struct ImageDetRecParserParam : public dmlc::Parameter { // declare parameters DMLC_DECLARE_PARAMETER(ImageDetRecParserParam) { - DMLC_DECLARE_FIELD(path_imglist).set_default("") - .describe("Dataset Param: Path to image list."); - DMLC_DECLARE_FIELD(path_imgrec).set_default("./data/imgrec.rec") + DMLC_DECLARE_FIELD(path_imglist).set_default("").describe("Dataset Param: Path to image list."); + DMLC_DECLARE_FIELD(path_imgrec) + .set_default("./data/imgrec.rec") .describe("Dataset Param: Path to image record file."); - DMLC_DECLARE_FIELD(aug_seq).set_default("det_aug_default") - .describe("Augmentation Param: the augmenter names to represent"\ - " sequence of augmenters to be applied, seperated by comma." \ - " Additional keyword parameters will be seen by these augmenters." - " Make sure you don't use normal augmenters for detection tasks."); - DMLC_DECLARE_FIELD(label_width).set_default(-1) + DMLC_DECLARE_FIELD(aug_seq) + .set_default("det_aug_default") + .describe( + "Augmentation Param: the augmenter names to represent" + " sequence of augmenters to be applied, seperated by comma." + " Additional keyword parameters will be seen by these augmenters." + " Make sure you don't use normal augmenters for detection tasks."); + DMLC_DECLARE_FIELD(label_width) + .set_default(-1) .describe("Dataset Param: How many labels for an image, -1 for variable label size."); DMLC_DECLARE_FIELD(data_shape) - .set_expect_ndim(3).enforce_nonzero() + .set_expect_ndim(3) + .enforce_nonzero() .describe("Dataset Param: Shape of each instance generated by the DataIter."); - DMLC_DECLARE_FIELD(preprocess_threads).set_lower_bound(1).set_default(4) + DMLC_DECLARE_FIELD(preprocess_threads) + .set_lower_bound(1) + .set_default(4) .describe("Backend Param: Number of thread to do preprocessing."); - DMLC_DECLARE_FIELD(verbose).set_default(true) - .describe("Auxiliary Param: Whether to output parser information."); - DMLC_DECLARE_FIELD(num_parts).set_default(1) - .describe("partition the data into multiple parts"); - DMLC_DECLARE_FIELD(part_index).set_default(0) - .describe("the index of the part will read"); - DMLC_DECLARE_FIELD(shuffle_chunk_size).set_default(0) - .describe("the size(MB) of the shuffle chunk, used with shuffle=True,"\ - " it can enable global shuffling"); - DMLC_DECLARE_FIELD(shuffle_chunk_seed).set_default(0) - .describe("the seed for chunk shuffling"); - DMLC_DECLARE_FIELD(label_pad_width).set_default(0) + DMLC_DECLARE_FIELD(verbose).set_default(true).describe( + "Auxiliary Param: Whether to output parser information."); + DMLC_DECLARE_FIELD(num_parts).set_default(1).describe("partition the data into multiple parts"); + DMLC_DECLARE_FIELD(part_index).set_default(0).describe("the index of the part will read"); + DMLC_DECLARE_FIELD(shuffle_chunk_size) + .set_default(0) + .describe( + "the size(MB) of the shuffle chunk, used with shuffle=True," + " it can enable global shuffling"); + DMLC_DECLARE_FIELD(shuffle_chunk_seed).set_default(0).describe("the seed for chunk shuffling"); + DMLC_DECLARE_FIELD(label_pad_width) + .set_default(0) .describe("pad output label width if set larger than 0, -1 for auto estimate"); - DMLC_DECLARE_FIELD(label_pad_value).set_default(-1.f) + DMLC_DECLARE_FIELD(label_pad_value) + .set_default(-1.f) .describe("label padding value if enabled"); } }; // parser to parse image recordio -template +template class ImageDetRecordIOParser { public: // initialize the parser - inline void Init(const std::vector >& kwargs); + inline void Init(const std::vector>& kwargs); // set record to the head inline void BeforeFirst() { @@ -219,19 +233,19 @@ class ImageDetRecordIOParser { } // parse next set of records, return an array of // instance vector to the user - virtual inline bool ParseNext(std::vector> *out); + virtual inline bool ParseNext(std::vector>* out); protected: // magic number to see prng static const int kRandMagic = 233; /*! \brief parameters */ ImageDetRecParserParam param_; - #if MXNET_USE_OPENCV +#if MXNET_USE_OPENCV /*! \brief augmenters */ - std::vector > > augmenters_; - #endif + std::vector>> augmenters_; +#endif /*! \brief random samplers */ - std::vector > prnds_; + std::vector> prnds_; /*! \brief data source */ std::unique_ptr source_; /*! \brief label information, if any */ @@ -244,22 +258,20 @@ class ImageDetRecordIOParser { template inline void ImageDetRecordIOParser::Init( - const std::vector >& kwargs) { + const std::vector>& kwargs) { #if MXNET_USE_OPENCV // initialize parameter // init image rec param param_.InitAllowUnknown(kwargs); int maxthread, threadget; - #pragma omp parallel +#pragma omp parallel { // be conservative, set number of real cores - 1 maxthread = std::max(omp_get_num_procs() - 1, 1); } param_.preprocess_threads = std::min(maxthread, param_.preprocess_threads); - #pragma omp parallel num_threads(param_.preprocess_threads) - { - threadget = omp_get_num_threads(); - } +#pragma omp parallel num_threads(param_.preprocess_threads) + { threadget = omp_get_num_threads(); } param_.preprocess_threads = threadget; std::vector aug_names = dmlc::Split(param_.aug_seq, ','); @@ -274,20 +286,17 @@ inline void ImageDetRecordIOParser::Init( prnds_.emplace_back(new common::RANDOM_ENGINE((i + 1) * kRandMagic)); } if (param_.path_imglist.length() != 0) { - label_map_ = std::make_unique(param_.path_imglist.c_str(), - param_.label_width, !param_.verbose); + label_map_ = std::make_unique( + param_.path_imglist.c_str(), param_.label_width, !param_.verbose); } - CHECK(param_.path_imgrec.length() != 0) - << "ImageDetRecordIOIterator: must specify image_rec"; + CHECK(param_.path_imgrec.length() != 0) << "ImageDetRecordIOIterator: must specify image_rec"; if (param_.verbose) { - LOG(INFO) << "ImageDetRecordIOParser: " << param_.path_imgrec - << ", use " << threadget << " threads for decoding.."; + LOG(INFO) << "ImageDetRecordIOParser: " << param_.path_imgrec << ", use " << threadget + << " threads for decoding.."; } source_.reset(dmlc::InputSplit::Create( - param_.path_imgrec.c_str(), - param_.part_index, param_.num_parts, - "recordio")); + param_.path_imgrec.c_str(), param_.part_index, param_.num_parts, "recordio")); // estimate padding width for labels int max_label_width = 0; @@ -297,23 +306,23 @@ inline void ImageDetRecordIOParser::Init( // iterate through recordio dmlc::InputSplit::Blob chunk; while (source_->NextChunk(&chunk)) { - #pragma omp parallel num_threads(param_.preprocess_threads) +#pragma omp parallel num_threads(param_.preprocess_threads) { omp_exc_.Run([&] { CHECK(omp_get_num_threads() == param_.preprocess_threads); int max_width = 0; - int tid = omp_get_thread_num(); - dmlc::RecordIOChunkReader reader(chunk, tid, - param_.preprocess_threads); + int tid = omp_get_thread_num(); + dmlc::RecordIOChunkReader reader(chunk, tid, param_.preprocess_threads); ImageRecordIO rec; dmlc::InputSplit::Blob blob; while (reader.NextRecord(&blob)) { rec.Load(blob.dptr, blob.size); if (rec.label != nullptr) { if (param_.label_width > 0) { - CHECK_EQ(param_.label_width, rec.num_label) - << "rec file provide " << rec.num_label << "-dimensional label " - "but label_width is set to " << param_.label_width; + CHECK_EQ(param_.label_width, rec.num_label) << "rec file provide " << rec.num_label + << "-dimensional label " + "but label_width is set to " + << param_.label_width; } // update max value max_width = std::max(max_width, rec.num_label); @@ -321,10 +330,8 @@ inline void ImageDetRecordIOParser::Init( LOG(FATAL) << "Not enough label packed in img_list or rec file."; } } - #pragma omp critical - { - max_label_width = std::max(max_label_width, max_width); - } +#pragma omp critical + { max_label_width = std::max(max_label_width, max_width); } }); } omp_exc_.Rethrow(); @@ -332,8 +339,8 @@ inline void ImageDetRecordIOParser::Init( } if (max_label_width > param_.label_pad_width) { if (param_.label_pad_width > 0) { - LOG(FATAL) << "ImageDetRecordIOParser: label_pad_width: " - << param_.label_pad_width << " smaller than estimated width: " << max_label_width; + LOG(FATAL) << "ImageDetRecordIOParser: label_pad_width: " << param_.label_pad_width + << " smaller than estimated width: " << max_label_width; } param_.label_pad_width = max_label_width; } @@ -343,9 +350,7 @@ inline void ImageDetRecordIOParser::Init( } source_.reset(dmlc::InputSplit::Create( - param_.path_imgrec.c_str(), - param_.part_index, param_.num_parts, - "recordio")); + param_.path_imgrec.c_str(), param_.part_index, param_.num_parts, "recordio")); if (param_.shuffle_chunk_size > 0) { if (param_.shuffle_chunk_size > 4096) { @@ -359,14 +364,16 @@ inline void ImageDetRecordIOParser::Init( "larger chunk size"; } // 1.1 ratio is for a bit more shuffle parts to avoid boundary issue - unsigned num_shuffle_parts = - std::ceil(source_->GetTotalSize() * 1.1 / - (param_.num_parts * (param_.shuffle_chunk_size << 20UL))); + unsigned num_shuffle_parts = std::ceil( + source_->GetTotalSize() * 1.1 / (param_.num_parts * (param_.shuffle_chunk_size << 20UL))); if (num_shuffle_parts > 1) { - source_.reset(dmlc::InputSplitShuffle::Create( - param_.path_imgrec.c_str(), param_.part_index, - param_.num_parts, "recordio", num_shuffle_parts, param_.shuffle_chunk_seed)); + source_.reset(dmlc::InputSplitShuffle::Create(param_.path_imgrec.c_str(), + param_.part_index, + param_.num_parts, + "recordio", + num_shuffle_parts, + param_.shuffle_chunk_seed)); } source_->HintChunkSize(param_.shuffle_chunk_size << 17UL); } else { @@ -378,16 +385,16 @@ inline void ImageDetRecordIOParser::Init( #endif } -template -inline bool ImageDetRecordIOParser:: -ParseNext(std::vector> *out_vec) { +template +inline bool ImageDetRecordIOParser::ParseNext(std::vector>* out_vec) { CHECK(source_ != nullptr); dmlc::InputSplit::Blob chunk; - if (!source_->NextChunk(&chunk)) return false; + if (!source_->NextChunk(&chunk)) + return false; #if MXNET_USE_OPENCV // save opencv out out_vec->resize(param_.preprocess_threads); - #pragma omp parallel num_threads(param_.preprocess_threads) +#pragma omp parallel num_threads(param_.preprocess_threads) { omp_exc_.Run([&] { CHECK(omp_get_num_threads() == param_.preprocess_threads); @@ -396,7 +403,7 @@ ParseNext(std::vector> *out_vec) { ImageRecordIO rec; dmlc::InputSplit::Blob blob; // image data - InstVector &out = (*out_vec)[tid]; + InstVector& out = (*out_vec)[tid]; out.Clear(); while (reader.NextRecord(&blob)) { // Opencv decode and augments @@ -414,9 +421,8 @@ ParseNext(std::vector> *out_vec) { // -1 to keep the number of channel of the encoded image, and not // force gray or color. res = cv::imdecode(buf, -1); - CHECK_EQ(res.channels(), 4) - << "Invalid image with index " << rec.image_index() - << ". Expected 4 channels, got " << res.channels(); + CHECK_EQ(res.channels(), 4) << "Invalid image with index " << rec.image_index() + << ". Expected 4 channels, got " << res.channels(); break; default: LOG(FATAL) << "Invalid output shape " << param_.data_shape; @@ -428,22 +434,20 @@ ParseNext(std::vector> *out_vec) { label_buf = label_map_->FindCopy(rec.image_index()); } else if (rec.label != nullptr) { if (param_.label_width > 0) { - CHECK_EQ(param_.label_width, rec.num_label) - << "rec file provide " << rec.num_label - << "-dimensional label " - "but label_width is set to " - << param_.label_width; + CHECK_EQ(param_.label_width, rec.num_label) << "rec file provide " << rec.num_label + << "-dimensional label " + "but label_width is set to " + << param_.label_width; } label_buf.assign(rec.label, rec.label + rec.num_label); } else { LOG(FATAL) << "Not enough label packed in img_list or rec file."; } - for (auto &aug : this->augmenters_[tid]) { + for (auto& aug : this->augmenters_[tid]) { res = aug->Process(res, &label_buf, this->prnds_[tid].get()); } out.Push(static_cast(rec.image_index()), - mshadow::Shape3(n_channels, param_.data_shape[1], - param_.data_shape[2]), + mshadow::Shape3(n_channels, param_.data_shape[1], param_.data_shape[2]), mshadow::Shape1(param_.label_pad_width + 4)); mshadow::Tensor data = out.data().Back(); @@ -451,12 +455,15 @@ ParseNext(std::vector> *out_vec) { // For RGB or RGBA data, swap the B and R channel: // OpenCV store as BGR (or BGRA) and we want RGB (or RGBA) std::vector swap_indices; - if (n_channels == 1) swap_indices = {0}; - if (n_channels == 3) swap_indices = {2, 1, 0}; - if (n_channels == 4) swap_indices = {2, 1, 0, 3}; + if (n_channels == 1) + swap_indices = {0}; + if (n_channels == 3) + swap_indices = {2, 1, 0}; + if (n_channels == 4) + swap_indices = {2, 1, 0, 3}; for (int i = 0; i < res.rows; ++i) { - uchar *im_data = res.ptr(i); + uchar* im_data = res.ptr(i); for (int j = 0; j < res.cols; ++j) { for (int k = 0; k < n_channels; ++k) { data[k][i][j] = im_data[swap_indices[k]]; @@ -465,7 +472,7 @@ ParseNext(std::vector> *out_vec) { } } mshadow::Tensor label = out.label().Back(); - label = param_.label_pad_value; + label = param_.label_pad_value; // store info for real data_shape and label_width label[0] = res.channels(); label[1] = res.rows; @@ -473,21 +480,20 @@ ParseNext(std::vector> *out_vec) { label[3] = label_buf.size(); mshadow::Copy( label.Slice(4, 4 + label_buf.size()), - mshadow::Tensor(dmlc::BeginPtr(label_buf), - mshadow::Shape1(label_buf.size()))); + mshadow::Tensor(dmlc::BeginPtr(label_buf), mshadow::Shape1(label_buf.size()))); res.release(); } }); } #else - LOG(FATAL) << "Opencv is needed for image decoding and augmenting."; + LOG(FATAL) << "Opencv is needed for image decoding and augmenting."; #endif omp_exc_.Rethrow(); return true; } // Define image record parameters -struct ImageDetRecordParam: public dmlc::Parameter { +struct ImageDetRecordParam : public dmlc::Parameter { /*! \brief whether to do shuffle */ bool shuffle; /*! \brief random seed */ @@ -496,40 +502,40 @@ struct ImageDetRecordParam: public dmlc::Parameter { bool verbose; // declare parameters DMLC_DECLARE_PARAMETER(ImageDetRecordParam) { - DMLC_DECLARE_FIELD(shuffle).set_default(false) - .describe("Augmentation Param: Whether to shuffle data."); - DMLC_DECLARE_FIELD(seed).set_default(0) - .describe("Augmentation Param: Random Seed."); - DMLC_DECLARE_FIELD(verbose).set_default(true) - .describe("Auxiliary Param: Whether to output information."); + DMLC_DECLARE_FIELD(shuffle).set_default(false).describe( + "Augmentation Param: Whether to shuffle data."); + DMLC_DECLARE_FIELD(seed).set_default(0).describe("Augmentation Param: Random Seed."); + DMLC_DECLARE_FIELD(verbose).set_default(true).describe( + "Auxiliary Param: Whether to output information."); } }; // iterator on image recordio -template +template class ImageDetRecordIter : public IIterator { public: - ImageDetRecordIter() : data_(nullptr) { } + ImageDetRecordIter() : data_(nullptr) {} // destructor ~ImageDetRecordIter() override { iter_.Destroy(); delete data_; } // constructor - void Init(const std::vector >& kwargs) override { + void Init(const std::vector>& kwargs) override { param_.InitAllowUnknown(kwargs); // use the kwarg to init parser parser_.Init(kwargs); // prefetch at most 4 minbatches iter_.set_max_capacity(4); // init thread iter - iter_.Init([this](std::vector> **dptr) { - if (*dptr == nullptr) { - *dptr = new std::vector>(); - } - return parser_.ParseNext(*dptr); - }, - [this]() { parser_.BeforeFirst(); }); + iter_.Init( + [this](std::vector>** dptr) { + if (*dptr == nullptr) { + *dptr = new std::vector>(); + } + return parser_.ParseNext(*dptr); + }, + [this]() { parser_.BeforeFirst(); }); inst_ptr_ = 0; rnd_.seed(kRandMagic + param_.seed); } @@ -544,12 +550,14 @@ class ImageDetRecordIter : public IIterator { while (true) { if (inst_ptr_ < inst_order_.size()) { std::pair p = inst_order_[inst_ptr_]; - out_ = (*data_)[p.first][p.second]; + out_ = (*data_)[p.first][p.second]; ++inst_ptr_; return true; } else { - if (data_ != nullptr) iter_.Recycle(&data_); - if (!iter_.Next(&data_)) return false; + if (data_ != nullptr) + iter_.Recycle(&data_); + if (!iter_.Next(&data_)) + return false; inst_order_.clear(); for (unsigned i = 0; i < data_->size(); ++i) { const InstVector& tmp = (*data_)[i]; @@ -567,7 +575,7 @@ class ImageDetRecordIter : public IIterator { return false; } - const DataInst &Value() const override { + const DataInst& Value() const override { return out_; } @@ -579,13 +587,13 @@ class ImageDetRecordIter : public IIterator { // data ptr size_t inst_ptr_; // internal instance order - std::vector > inst_order_; + std::vector> inst_order_; // data - std::vector> *data_; + std::vector>* data_; // internal parser ImageDetRecordIOParser parser_; // backend thread - dmlc::ThreadedIter> > iter_; + dmlc::ThreadedIter>> iter_; // parameters ImageDetRecordParam param_; // random number generator @@ -596,18 +604,16 @@ DMLC_REGISTER_PARAMETER(ImageDetRecParserParam); DMLC_REGISTER_PARAMETER(ImageDetRecordParam); MXNET_REGISTER_IO_ITER(ImageDetRecordIter) -.describe("Create iterator for image detection dataset packed in recordio.") -.add_arguments(ImageDetRecParserParam::__FIELDS__()) -.add_arguments(ImageDetRecordParam::__FIELDS__()) -.add_arguments(BatchParam::__FIELDS__()) -.add_arguments(PrefetcherParam::__FIELDS__()) -.add_arguments(ListDefaultDetAugParams()) -.add_arguments(ImageDetNormalizeParam::__FIELDS__()) -.set_body([]() { - return new PrefetcherIter( - new BatchLoader( - new ImageDetNormalizeIter( - new ImageDetRecordIter()))); -}); + .describe("Create iterator for image detection dataset packed in recordio.") + .add_arguments(ImageDetRecParserParam::__FIELDS__()) + .add_arguments(ImageDetRecordParam::__FIELDS__()) + .add_arguments(BatchParam::__FIELDS__()) + .add_arguments(PrefetcherParam::__FIELDS__()) + .add_arguments(ListDefaultDetAugParams()) + .add_arguments(ImageDetNormalizeParam::__FIELDS__()) + .set_body([]() { + return new PrefetcherIter( + new BatchLoader(new ImageDetNormalizeIter(new ImageDetRecordIter()))); + }); } // namespace io } // namespace mxnet diff --git a/src/io/iter_image_recordio.cc b/src/io/iter_image_recordio.cc index 23008050ec28..41955b2a34be 100644 --- a/src/io/iter_image_recordio.cc +++ b/src/io/iter_image_recordio.cc @@ -47,11 +47,11 @@ namespace mxnet { namespace io { // parser to parse image recordio -template +template class ImageRecordIOParser { public: // initialize the parser - inline void Init(const std::vector >& kwargs); + inline void Init(const std::vector>& kwargs); // set record to the head inline void BeforeFirst() { @@ -59,19 +59,19 @@ class ImageRecordIOParser { } // parse next set of records, return an array of // instance vector to the user - inline bool ParseNext(std::vector> *out); + inline bool ParseNext(std::vector>* out); private: // magic number to see prng static const int kRandMagic = 111; /*! \brief parameters */ ImageRecParserParam param_; - #if MXNET_USE_OPENCV +#if MXNET_USE_OPENCV /*! \brief augmenters */ - std::vector > > augmenters_; - #endif + std::vector>> augmenters_; +#endif /*! \brief random samplers */ - std::vector > prnds_; + std::vector> prnds_; /*! \brief data source */ std::unique_ptr source_; /*! \brief label information, if any */ @@ -80,24 +80,22 @@ class ImageRecordIOParser { mshadow::TensorContainer img_; }; -template +template inline void ImageRecordIOParser::Init( - const std::vector >& kwargs) { + const std::vector>& kwargs) { #if MXNET_USE_OPENCV // initialize parameter // init image rec param param_.InitAllowUnknown(kwargs); int maxthread, threadget; - #pragma omp parallel +#pragma omp parallel { // be conservative, set number of real cores maxthread = std::max(omp_get_num_procs() / 2 - 1, 1); } param_.preprocess_threads = std::min(maxthread, param_.preprocess_threads); - #pragma omp parallel num_threads(param_.preprocess_threads) - { - threadget = omp_get_num_threads(); - } +#pragma omp parallel num_threads(param_.preprocess_threads) + { threadget = omp_get_num_threads(); } param_.preprocess_threads = threadget; std::vector aug_names = dmlc::Split(param_.aug_seq, ','); @@ -112,39 +110,39 @@ inline void ImageRecordIOParser::Init( prnds_.emplace_back(new common::RANDOM_ENGINE((i + 1) * kRandMagic)); } if (param_.path_imglist.length() != 0) { - label_map_ = std::make_unique(param_.path_imglist.c_str(), - param_.label_width, !param_.verbose); + label_map_ = std::make_unique( + param_.path_imglist.c_str(), param_.label_width, !param_.verbose); } - CHECK(param_.path_imgrec.length() != 0) - << "ImageRecordIOIterator: must specify image_rec"; + CHECK(param_.path_imgrec.length() != 0) << "ImageRecordIOIterator: must specify image_rec"; if (param_.verbose) { - LOG(INFO) << "ImageRecordIOParser: " << param_.path_imgrec - << ", use " << threadget << " threads for decoding.."; + LOG(INFO) << "ImageRecordIOParser: " << param_.path_imgrec << ", use " << threadget + << " threads for decoding.."; } source_.reset(dmlc::InputSplit::Create( - param_.path_imgrec.c_str(), param_.part_index, - param_.num_parts, "recordio")); + param_.path_imgrec.c_str(), param_.part_index, param_.num_parts, "recordio")); if (param_.shuffle_chunk_size > 0) { if (param_.shuffle_chunk_size > 4096) { LOG(INFO) << "Chunk size: " << param_.shuffle_chunk_size - << " MB which is larger than 4096 MB, please set " - "smaller chunk size"; + << " MB which is larger than 4096 MB, please set " + "smaller chunk size"; } if (param_.shuffle_chunk_size < 4) { LOG(INFO) << "Chunk size: " << param_.shuffle_chunk_size - << " MB which is less than 4 MB, please set " - "larger chunk size"; + << " MB which is less than 4 MB, please set " + "larger chunk size"; } // 1.1 ratio is for a bit more shuffle parts to avoid boundary issue - unsigned num_shuffle_parts = - std::ceil(source_->GetTotalSize() * 1.1 / - (param_.num_parts * (param_.shuffle_chunk_size << 20UL))); + unsigned num_shuffle_parts = std::ceil( + source_->GetTotalSize() * 1.1 / (param_.num_parts * (param_.shuffle_chunk_size << 20UL))); if (num_shuffle_parts > 1) { - source_.reset(dmlc::InputSplitShuffle::Create( - param_.path_imgrec.c_str(), param_.part_index, - param_.num_parts, "recordio", num_shuffle_parts, param_.shuffle_chunk_seed)); + source_.reset(dmlc::InputSplitShuffle::Create(param_.path_imgrec.c_str(), + param_.part_index, + param_.num_parts, + "recordio", + num_shuffle_parts, + param_.shuffle_chunk_seed)); } source_->HintChunkSize(param_.shuffle_chunk_size << 17UL); } else { @@ -156,16 +154,16 @@ inline void ImageRecordIOParser::Init( #endif } -template -inline bool ImageRecordIOParser:: -ParseNext(std::vector> *out_vec) { +template +inline bool ImageRecordIOParser::ParseNext(std::vector>* out_vec) { CHECK(source_ != nullptr); dmlc::InputSplit::Blob chunk; - if (!source_->NextChunk(&chunk)) return false; + if (!source_->NextChunk(&chunk)) + return false; #if MXNET_USE_OPENCV // save opencv out out_vec->resize(param_.preprocess_threads); - #pragma omp parallel num_threads(param_.preprocess_threads) +#pragma omp parallel num_threads(param_.preprocess_threads) { CHECK(omp_get_num_threads() == param_.preprocess_threads); int tid = omp_get_thread_num(); @@ -173,7 +171,7 @@ ParseNext(std::vector> *out_vec) { ImageRecordIO rec; dmlc::InputSplit::Blob blob; // image data - InstVector &out = (*out_vec)[tid]; + InstVector& out = (*out_vec)[tid]; out.Clear(); while (reader.NextRecord(&blob)) { // Opencv decode and augments @@ -181,21 +179,20 @@ ParseNext(std::vector> *out_vec) { rec.Load(blob.dptr, blob.size); cv::Mat buf(1, rec.content_size, CV_8U, rec.content); switch (param_.data_shape[0]) { - case 1: - res = cv::imdecode(buf, 0); - break; - case 3: - res = cv::imdecode(buf, 1); - break; - case 4: - // -1 to keep the number of channel of the encoded image, and not force gray or color. - res = cv::imdecode(buf, -1); - CHECK_EQ(res.channels(), 4) - << "Invalid image with index " << rec.image_index() - << ". Expected 4 channels, got " << res.channels(); - break; - default: - LOG(FATAL) << "Invalid output shape " << param_.data_shape; + case 1: + res = cv::imdecode(buf, 0); + break; + case 3: + res = cv::imdecode(buf, 1); + break; + case 4: + // -1 to keep the number of channel of the encoded image, and not force gray or color. + res = cv::imdecode(buf, -1); + CHECK_EQ(res.channels(), 4) << "Invalid image with index " << rec.image_index() + << ". Expected 4 channels, got " << res.channels(); + break; + default: + LOG(FATAL) << "Invalid output shape " << param_.data_shape; } const int n_channels = res.channels(); for (auto& aug : augmenters_[tid]) { @@ -210,15 +207,18 @@ ParseNext(std::vector> *out_vec) { // For RGB or RGBA data, swap the B and R channel: // OpenCV store as BGR (or BGRA) and we want RGB (or RGBA) std::vector swap_indices; - if (n_channels == 1) swap_indices = {0}; - if (n_channels == 3) swap_indices = {2, 1, 0}; - if (n_channels == 4) swap_indices = {2, 1, 0, 3}; + if (n_channels == 1) + swap_indices = {0}; + if (n_channels == 3) + swap_indices = {2, 1, 0}; + if (n_channels == 4) + swap_indices = {2, 1, 0, 3}; for (int i = 0; i < res.rows; ++i) { uchar* im_data = res.ptr(i); for (int j = 0; j < res.cols; ++j) { for (int k = 0; k < n_channels; ++k) { - data[k][i][j] = im_data[swap_indices[k]]; + data[k][i][j] = im_data[swap_indices[k]]; } im_data += n_channels; } @@ -228,15 +228,14 @@ ParseNext(std::vector> *out_vec) { if (label_map_ != nullptr) { mshadow::Copy(label, label_map_->Find(rec.image_index())); } else if (rec.label != nullptr) { - CHECK_EQ(param_.label_width, rec.num_label) - << "rec file provide " << rec.num_label << "-dimensional label " - "but label_width is set to " << param_.label_width; - mshadow::Copy(label, mshadow::Tensor(rec.label, - mshadow::Shape1(rec.num_label))); + CHECK_EQ(param_.label_width, rec.num_label) << "rec file provide " << rec.num_label + << "-dimensional label " + "but label_width is set to " + << param_.label_width; + mshadow::Copy(label, mshadow::Tensor(rec.label, mshadow::Shape1(rec.num_label))); } else { - CHECK_EQ(param_.label_width, 1) - << "label_width must be 1 unless an imglist is provided " - "or the rec file is packed with multi dimensional label"; + CHECK_EQ(param_.label_width, 1) << "label_width must be 1 unless an imglist is provided " + "or the rec file is packed with multi dimensional label"; label[0] = rec.header.label; } res.release(); @@ -249,30 +248,31 @@ ParseNext(std::vector> *out_vec) { } // iterator on image recordio -template +template class ImageRecordIter : public IIterator { public: - ImageRecordIter() : data_(nullptr) { } + ImageRecordIter() : data_(nullptr) {} // destructor ~ImageRecordIter() override { iter_.Destroy(); delete data_; } // constructor - void Init(const std::vector >& kwargs) override { + void Init(const std::vector>& kwargs) override { param_.InitAllowUnknown(kwargs); // use the kwarg to init parser parser_.Init(kwargs); // prefetch at most 4 minbatches iter_.set_max_capacity(4); // init thread iter - iter_.Init([this](std::vector> **dptr) { - if (*dptr == nullptr) { - *dptr = new std::vector>(); - } - return parser_.ParseNext(*dptr); - }, - [this]() { parser_.BeforeFirst(); }); + iter_.Init( + [this](std::vector>** dptr) { + if (*dptr == nullptr) { + *dptr = new std::vector>(); + } + return parser_.ParseNext(*dptr); + }, + [this]() { parser_.BeforeFirst(); }); inst_ptr_ = 0; rnd_.seed(kRandMagic + param_.seed); } @@ -287,12 +287,14 @@ class ImageRecordIter : public IIterator { while (true) { if (inst_ptr_ < inst_order_.size()) { std::pair p = inst_order_[inst_ptr_]; - out_ = (*data_)[p.first][p.second]; + out_ = (*data_)[p.first][p.second]; ++inst_ptr_; return true; } else { - if (data_ != nullptr) iter_.Recycle(&data_); - if (!iter_.Next(&data_)) return false; + if (data_ != nullptr) + iter_.Recycle(&data_); + if (!iter_.Next(&data_)) + return false; inst_order_.clear(); for (unsigned i = 0; i < data_->size(); ++i) { const InstVector& tmp = (*data_)[i]; @@ -310,7 +312,7 @@ class ImageRecordIter : public IIterator { return false; } - const DataInst &Value() const override { + const DataInst& Value() const override { return out_; } @@ -322,13 +324,13 @@ class ImageRecordIter : public IIterator { // data ptr size_t inst_ptr_; // internal instance order - std::vector > inst_order_; + std::vector> inst_order_; // data - std::vector> *data_; + std::vector>* data_; // internal parser ImageRecordIOParser parser_; // backend thread - dmlc::ThreadedIter> > iter_; + dmlc::ThreadedIter>> iter_; // parameters ImageRecordParam param_; // random number generator @@ -337,7 +339,7 @@ class ImageRecordIter : public IIterator { // OLD VERSION - DEPRECATED MXNET_REGISTER_IO_ITER(ImageRecordIter_v1) -.describe(R"code(Iterating on image RecordIO files + .describe(R"code(Iterating on image RecordIO files .. note:: @@ -351,22 +353,20 @@ One can use ``tools/im2rec.py`` to pack individual image files into RecordIO files. )code" ADD_FILELINE) -.add_arguments(ImageRecParserParam::__FIELDS__()) -.add_arguments(ImageRecordParam::__FIELDS__()) -.add_arguments(BatchParam::__FIELDS__()) -.add_arguments(PrefetcherParam::__FIELDS__()) -.add_arguments(ListDefaultAugParams()) -.add_arguments(ImageNormalizeParam::__FIELDS__()) -.set_body([]() { - return new PrefetcherIter( - new BatchLoader( - new ImageNormalizeIter( - new ImageRecordIter()))); - }); + .add_arguments(ImageRecParserParam::__FIELDS__()) + .add_arguments(ImageRecordParam::__FIELDS__()) + .add_arguments(BatchParam::__FIELDS__()) + .add_arguments(PrefetcherParam::__FIELDS__()) + .add_arguments(ListDefaultAugParams()) + .add_arguments(ImageNormalizeParam::__FIELDS__()) + .set_body([]() { + return new PrefetcherIter( + new BatchLoader(new ImageNormalizeIter(new ImageRecordIter()))); + }); // OLD VERSION - DEPRECATED MXNET_REGISTER_IO_ITER(ImageRecordUInt8Iter_v1) -.describe(R"code(Iterating on image RecordIO files + .describe(R"code(Iterating on image RecordIO files .. note:: @@ -376,15 +376,11 @@ This iterator is identical to ``ImageRecordIter`` except for using ``uint8`` as the data type instead of ``float``. )code" ADD_FILELINE) -.add_arguments(ImageRecParserParam::__FIELDS__()) -.add_arguments(ImageRecordParam::__FIELDS__()) -.add_arguments(BatchParam::__FIELDS__()) -.add_arguments(PrefetcherParam::__FIELDS__()) -.add_arguments(ListDefaultAugParams()) -.set_body([]() { - return new PrefetcherIter( - new BatchLoader( - new ImageRecordIter())); - }); + .add_arguments(ImageRecParserParam::__FIELDS__()) + .add_arguments(ImageRecordParam::__FIELDS__()) + .add_arguments(BatchParam::__FIELDS__()) + .add_arguments(PrefetcherParam::__FIELDS__()) + .add_arguments(ListDefaultAugParams()) + .set_body([]() { return new PrefetcherIter(new BatchLoader(new ImageRecordIter())); }); } // namespace io } // namespace mxnet diff --git a/src/io/iter_image_recordio_2.cc b/src/io/iter_image_recordio_2.cc index f4f88b76d65c..375fdaabe7f2 100644 --- a/src/io/iter_image_recordio_2.cc +++ b/src/io/iter_image_recordio_2.cc @@ -49,11 +49,11 @@ namespace mxnet { namespace io { // parser to parse image recordio -template +template class ImageRecordIOParser2 { public: // initialize the parser - inline void Init(const std::vector >& kwargs); + inline void Init(const std::vector>& kwargs); // set record to the head inline void BeforeFirst() { @@ -66,24 +66,28 @@ class ImageRecordIOParser2 { } // parse next set of records, return an array of // instance vector to the user - inline bool ParseNext(DataBatch *out); + inline bool ParseNext(DataBatch* out); private: #if MXNET_USE_OPENCV - template + template void ProcessImage(const cv::Mat& res, - mshadow::Tensor* data_ptr, const bool is_mirrored, const float contrast_scaled, - const float illumination_scaled); + mshadow::Tensor* data_ptr, + const bool is_mirrored, + const float contrast_scaled, + const float illumination_scaled); #if MXNET_USE_LIBJPEG_TURBO cv::Mat TJimdecode(cv::Mat buf, int color); #endif #endif - inline size_t ParseChunk(DType* data_dptr, real_t* label_dptr, const size_t current_size, - dmlc::InputSplit::Blob * chunk); + inline size_t ParseChunk(DType* data_dptr, + real_t* label_dptr, + const size_t current_size, + dmlc::InputSplit::Blob* chunk); inline void CreateMeanImg(); // magic number to seed prng - static const int kRandMagic = 111; + static const int kRandMagic = 111; static const int kRandMagicNormalize = 0; /*! \brief parameters */ ImageRecParserParam param_; @@ -91,12 +95,12 @@ class ImageRecordIOParser2 { BatchParam batch_param_; ImageNormalizeParam normalize_param_; - #if MXNET_USE_OPENCV +#if MXNET_USE_OPENCV /*! \brief augmenters */ - std::vector > > augmenters_; - #endif + std::vector>> augmenters_; +#endif /*! \brief random samplers */ - std::vector > prnds_; + std::vector> prnds_; common::RANDOM_ENGINE rnd_; /*! \brief data source */ std::unique_ptr source_; @@ -107,7 +111,7 @@ class ImageRecordIOParser2 { /*! \brief temp space */ mshadow::TensorContainer img_; /*! \brief internal instance order */ - std::vector > inst_order_; + std::vector> inst_order_; size_t inst_index_; /*! \brief internal counter tracking number of already parsed entries */ size_t n_parsed_; @@ -126,9 +130,9 @@ class ImageRecordIOParser2 { dmlc::OMPException omp_exc_; }; -template +template inline void ImageRecordIOParser2::Init( - const std::vector >& kwargs) { + const std::vector>& kwargs) { #if MXNET_USE_OPENCV // initialize parameter // init image rec param @@ -139,22 +143,20 @@ inline void ImageRecordIOParser2::Init( PrefetcherParam prefetch_param; prefetch_param.InitAllowUnknown(kwargs); n_parsed_ = 0; - overflow = false; + overflow = false; rnd_.seed(kRandMagic + record_param_.seed); int maxthread, threadget; if (prefetch_param.ctx == PrefetcherParam::CtxType::kCPU) { threadget = engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); } else { - #pragma omp parallel +#pragma omp parallel { // be conservative, set number of real cores maxthread = std::max(omp_get_num_procs() / 2, 1); } param_.preprocess_threads = std::min(maxthread, param_.preprocess_threads); - #pragma omp parallel num_threads(param_.preprocess_threads) - { - threadget = omp_get_num_threads(); - } +#pragma omp parallel num_threads(param_.preprocess_threads) + { threadget = omp_get_num_threads(); } } param_.preprocess_threads = threadget; @@ -170,52 +172,52 @@ inline void ImageRecordIOParser2::Init( prnds_.emplace_back(new common::RANDOM_ENGINE((i + 1) * kRandMagic)); } if (param_.path_imglist.length() != 0) { - label_map_ = std::make_unique(param_.path_imglist.c_str(), - param_.label_width, !param_.verbose); + label_map_ = std::make_unique( + param_.path_imglist.c_str(), param_.label_width, !param_.verbose); } - CHECK(param_.path_imgrec.length() != 0) - << "ImageRecordIter2: must specify image_rec"; + CHECK(param_.path_imgrec.length() != 0) << "ImageRecordIter2: must specify image_rec"; if (param_.verbose) { - LOG(INFO) << "ImageRecordIOParser2: " << param_.path_imgrec - << ", use " << threadget << " threads for decoding.."; + LOG(INFO) << "ImageRecordIOParser2: " << param_.path_imgrec << ", use " << threadget + << " threads for decoding.."; } legacy_shuffle_ = false; if (param_.path_imgidx.length() != 0) { - source_.reset(dmlc::InputSplit::Create( - param_.path_imgrec.c_str(), - param_.path_imgidx.c_str(), - param_.part_index, - param_.num_parts, "indexed_recordio", - record_param_.shuffle, - record_param_.seed, - batch_param_.batch_size)); + source_.reset(dmlc::InputSplit::Create(param_.path_imgrec.c_str(), + param_.path_imgidx.c_str(), + param_.part_index, + param_.num_parts, + "indexed_recordio", + record_param_.shuffle, + record_param_.seed, + batch_param_.batch_size)); } else { source_.reset(dmlc::InputSplit::Create( - param_.path_imgrec.c_str(), param_.part_index, - param_.num_parts, "recordio")); + param_.path_imgrec.c_str(), param_.part_index, param_.num_parts, "recordio")); if (record_param_.shuffle) legacy_shuffle_ = true; if (param_.shuffle_chunk_size > 0) { if (param_.shuffle_chunk_size > 4096) { LOG(INFO) << "Chunk size: " << param_.shuffle_chunk_size - << " MB which is larger than 4096 MB, please set " - "smaller chunk size"; + << " MB which is larger than 4096 MB, please set " + "smaller chunk size"; } if (param_.shuffle_chunk_size < 4) { LOG(INFO) << "Chunk size: " << param_.shuffle_chunk_size - << " MB which is less than 4 MB, please set " - "larger chunk size"; + << " MB which is less than 4 MB, please set " + "larger chunk size"; } // 1.1 ratio is for a bit more shuffle parts to avoid boundary issue - size_t num_shuffle_parts = - std::ceil(source_->GetTotalSize() * 1.1 / - (param_.num_parts * (param_.shuffle_chunk_size << 20UL))); + size_t num_shuffle_parts = std::ceil( + source_->GetTotalSize() * 1.1 / (param_.num_parts * (param_.shuffle_chunk_size << 20UL))); if (num_shuffle_parts > 1) { - source_.reset(dmlc::InputSplitShuffle::Create( - param_.path_imgrec.c_str(), param_.part_index, - param_.num_parts, "recordio", num_shuffle_parts, param_.shuffle_chunk_seed)); + source_.reset(dmlc::InputSplitShuffle::Create(param_.path_imgrec.c_str(), + param_.part_index, + param_.num_parts, + "recordio", + num_shuffle_parts, + param_.shuffle_chunk_seed)); } source_->HintChunkSize(param_.shuffle_chunk_size << 17UL); } else { @@ -241,12 +243,11 @@ inline void ImageRecordIOParser2::Init( std::vector data; std::vector keys; { - std::unique_ptr fi(dmlc::Stream::Create(normalize_param_.mean_img.c_str(), - "r")); + std::unique_ptr fi( + dmlc::Stream::Create(normalize_param_.mean_img.c_str(), "r")); NDArray::Load(fi.get(), &data, &keys); } - CHECK_EQ(data.size(), 1) - << "Invalid mean image file format"; + CHECK_EQ(data.size(), 1) << "Invalid mean image file format"; data[0].WaitToRead(); mshadow::Tensor src = data[0].data().get(); meanimg_.Resize(src.shape_); @@ -263,8 +264,8 @@ inline void ImageRecordIOParser2::Init( #endif } -template -inline bool ImageRecordIOParser2::ParseNext(DataBatch *out) { +template +inline bool ImageRecordIOParser2::ParseNext(DataBatch* out) { if (overflow) { return false; } @@ -293,7 +294,7 @@ inline bool ImageRecordIOParser2::ParseNext(DataBatch *out) { shape_vec.push_back(param_.label_width); mxnet::TShape label_shape(shape_vec.begin(), shape_vec.end()); - auto ctx = Context::CPU(0); + auto ctx = Context::CPU(0); auto dev_id = param_.device_id; if (dev_id != -1) { ctx = Context::CPUPinned(dev_id); @@ -302,11 +303,9 @@ inline bool ImageRecordIOParser2::ParseNext(DataBatch *out) { const std::string profiler_scope = profiler::ProfilerScope::Get()->GetCurrentProfilerScope() + "image_io:"; - out->data.at(0) = NDArray(data_shape, ctx, false, - mshadow::DataType::kFlag); + out->data.at(0) = NDArray(data_shape, ctx, false, mshadow::DataType::kFlag); out->data.at(0).AssignStorageInfo(profiler_scope, "data"); - out->data.at(1) = NDArray(label_shape, ctx, false, - mshadow::DataType::kFlag); + out->data.at(1) = NDArray(label_shape, ctx, false, mshadow::DataType::kFlag); out->data.at(1).AssignStorageInfo(profiler_scope, "label"); unit_size_[0] = param_.data_shape.Size(); unit_size_[1] = param_.label_width; @@ -318,8 +317,8 @@ inline bool ImageRecordIOParser2::ParseNext(DataBatch *out) { if (n_parsed_ == 0) { if (source_->NextBatch(&chunk, batch_param_.batch_size)) { inst_order_.clear(); - inst_index_ = 0; - DType* data_dptr = static_cast(out->data[0].data().dptr_); + inst_index_ = 0; + DType* data_dptr = static_cast(out->data[0].data().dptr_); real_t* label_dptr = static_cast(out->data[1].data().dptr_); if (!legacy_shuffle_) { n_to_out = ParseChunk(data_dptr, label_dptr, current_size, &chunk); @@ -344,27 +343,27 @@ inline bool ImageRecordIOParser2::ParseNext(DataBatch *out) { current_size = batch_param_.batch_size; } out->num_batch_padd = batch_param_.batch_size - current_size; - n_to_out = 0; + n_to_out = 0; } } else { - size_t n_to_copy = std::min(n_parsed_, - static_cast(batch_param_.batch_size) - current_size); + size_t n_to_copy = + std::min(n_parsed_, static_cast(batch_param_.batch_size) - current_size); n_parsed_ -= n_to_copy; - // Copy - #pragma omp parallel for num_threads(param_.preprocess_threads) +// Copy +#pragma omp parallel for num_threads(param_.preprocess_threads) for (int i = 0; i < static_cast(n_to_copy); ++i) { omp_exc_.Run([&] { - std::pair place = inst_order_[inst_index_ + i]; - const DataInst& batch = temp_[place.first][place.second]; - for (size_t j = 0; j < batch.data.size(); ++j) { - CHECK_EQ(unit_size_[j], batch.data[j].Size()); - MSHADOW_TYPE_SWITCH(out->data[j].data().type_flag_, dtype, { - mshadow::Copy( - out->data[j].data().FlatTo1D().Slice((current_size + i) * unit_size_[j], - (current_size + i + 1) * unit_size_[j]), - batch.data[j].get_with_shape(mshadow::Shape1(unit_size_[j]))); - }); - } + std::pair place = inst_order_[inst_index_ + i]; + const DataInst& batch = temp_[place.first][place.second]; + for (size_t j = 0; j < batch.data.size(); ++j) { + CHECK_EQ(unit_size_[j], batch.data[j].Size()); + MSHADOW_TYPE_SWITCH(out->data[j].data().type_flag_, dtype, { + mshadow::Copy( + out->data[j].data().FlatTo1D().Slice( + (current_size + i) * unit_size_[j], (current_size + i + 1) * unit_size_[j]), + batch.data[j].get_with_shape(mshadow::Shape1(unit_size_[j]))); + }); + } }); } omp_exc_.Rethrow(); @@ -378,15 +377,17 @@ inline bool ImageRecordIOParser2::ParseNext(DataBatch *out) { } #if MXNET_USE_OPENCV -template -template +template +template void ImageRecordIOParser2::ProcessImage(const cv::Mat& res, - mshadow::Tensor* data_ptr, const bool is_mirrored, const float contrast_scaled, - const float illumination_scaled) { - float RGBA_MULT[4] = { 0 }; - float RGBA_BIAS[4] = { 0 }; - float RGBA_MEAN[4] = { 0 }; - int16_t RGBA_MEAN_INT[4] = {0}; + mshadow::Tensor* data_ptr, + const bool is_mirrored, + const float contrast_scaled, + const float illumination_scaled) { + float RGBA_MULT[4] = {0}; + float RGBA_BIAS[4] = {0}; + float RGBA_MEAN[4] = {0}; + int16_t RGBA_MEAN_INT[4] = {0}; mshadow::Tensor& data = (*data_ptr); if (!std::is_same::value) { RGBA_MULT[0] = contrast_scaled / normalize_param_.std_r; @@ -398,10 +399,10 @@ void ImageRecordIOParser2::ProcessImage(const cv::Mat& res, RGBA_BIAS[2] = illumination_scaled / normalize_param_.std_b; RGBA_BIAS[3] = illumination_scaled / normalize_param_.std_a; if (!meanfile_ready_) { - RGBA_MEAN[0] = normalize_param_.mean_r; - RGBA_MEAN[1] = normalize_param_.mean_g; - RGBA_MEAN[2] = normalize_param_.mean_b; - RGBA_MEAN[3] = normalize_param_.mean_a; + RGBA_MEAN[0] = normalize_param_.mean_r; + RGBA_MEAN[1] = normalize_param_.mean_g; + RGBA_MEAN[2] = normalize_param_.mean_b; + RGBA_MEAN[3] = normalize_param_.mean_a; RGBA_MEAN_INT[0] = std::round(normalize_param_.mean_r); RGBA_MEAN_INT[1] = std::round(normalize_param_.mean_g); RGBA_MEAN_INT[2] = std::round(normalize_param_.mean_b); @@ -409,7 +410,7 @@ void ImageRecordIOParser2::ProcessImage(const cv::Mat& res, } } - int swap_indices[n_channels]; // NOLINT(*) + int swap_indices[n_channels]; // NOLINT(*) if (n_channels == 1) { swap_indices[0] = 0; } else if (n_channels == 3) { @@ -430,8 +431,8 @@ void ImageRecordIOParser2::ProcessImage(const cv::Mat& res, if (std::is_same::value) { if (meanfile_ready_) { for (int k = 0; k < n_channels; ++k) { - RGBA[k] = cv::saturate_cast(im_data[swap_indices[k]] - - static_cast(std::round(meanimg_[k][i][j]))); + RGBA[k] = cv::saturate_cast( + im_data[swap_indices[k]] - static_cast(std::round(meanimg_[k][i][j]))); } } else { for (int k = 0; k < n_channels; ++k) { @@ -470,7 +471,7 @@ void ImageRecordIOParser2::ProcessImage(const cv::Mat& res, #if MXNET_USE_LIBJPEG_TURBO -bool is_jpeg(unsigned char * file) { +bool is_jpeg(unsigned char* file) { if ((file[0] == 255) && (file[1] == 216)) { return true; } else { @@ -478,10 +479,10 @@ bool is_jpeg(unsigned char * file) { } } -template +template cv::Mat ImageRecordIOParser2::TJimdecode(cv::Mat image, int color) { unsigned char* jpeg = image.ptr(); - size_t jpeg_size = image.rows * image.cols; + size_t jpeg_size = image.rows * image.cols; if (!is_jpeg(jpeg)) { // If it is not JPEG then fall back to OpenCV @@ -490,24 +491,13 @@ cv::Mat ImageRecordIOParser2::TJimdecode(cv::Mat image, int color) { tjhandle handle = tjInitDecompress(); int h, w, subsamp; - int err = tjDecompressHeader2(handle, - jpeg, - jpeg_size, - &w, &h, &subsamp); + int err = tjDecompressHeader2(handle, jpeg, jpeg_size, &w, &h, &subsamp); if (err != 0) { // If it is a malformed JPEG then fall back to OpenCV return cv::imdecode(image, color); } cv::Mat ret = cv::Mat(h, w, color ? CV_8UC3 : CV_8UC1); - err = tjDecompress2(handle, - jpeg, - jpeg_size, - ret.ptr(), - w, - 0, - h, - color ? TJPF_BGR : TJPF_GRAY, - 0); + err = tjDecompress2(handle, jpeg, jpeg_size, ret.ptr(), w, 0, h, color ? TJPF_BGR : TJPF_GRAY, 0); if (err != 0) { // If it is a malformed JPEG then fall back to OpenCV return cv::imdecode(image, color); @@ -519,142 +509,148 @@ cv::Mat ImageRecordIOParser2::TJimdecode(cv::Mat image, int color) { #endif // Returns the number of images that are put into output -template -inline size_t ImageRecordIOParser2::ParseChunk(DType* data_dptr, real_t* label_dptr, - const size_t current_size, dmlc::InputSplit::Blob * chunk) { +template +inline size_t ImageRecordIOParser2::ParseChunk(DType* data_dptr, + real_t* label_dptr, + const size_t current_size, + dmlc::InputSplit::Blob* chunk) { temp_.resize(param_.preprocess_threads); #if MXNET_USE_OPENCV // save opencv out dmlc::RecordIOChunkReader reader(*chunk, 0, 1); size_t gl_idx = current_size; - #pragma omp parallel num_threads(param_.preprocess_threads) +#pragma omp parallel num_threads(param_.preprocess_threads) { omp_exc_.Run([&] { - CHECK(omp_get_num_threads() == param_.preprocess_threads); - int tid = omp_get_thread_num(); - // dmlc::RecordIOChunkReader reader(*chunk, tid, param_.preprocess_threads); - ImageRecordIO rec; - dmlc::InputSplit::Blob blob; - // image data - InstVector &out_tmp = temp_[tid]; - out_tmp.Clear(); - while (true) { - bool reader_has_data; - size_t idx; - #pragma omp critical - { - reader_has_data = reader.NextRecord(&blob); - if (reader_has_data) { - idx = gl_idx++; - if (idx >= batch_param_.batch_size) { - inst_order_.push_back(std::make_pair(tid, out_tmp.Size())); + CHECK(omp_get_num_threads() == param_.preprocess_threads); + int tid = omp_get_thread_num(); + // dmlc::RecordIOChunkReader reader(*chunk, tid, param_.preprocess_threads); + ImageRecordIO rec; + dmlc::InputSplit::Blob blob; + // image data + InstVector& out_tmp = temp_[tid]; + out_tmp.Clear(); + while (true) { + bool reader_has_data; + size_t idx; +#pragma omp critical + { + reader_has_data = reader.NextRecord(&blob); + if (reader_has_data) { + idx = gl_idx++; + if (idx >= batch_param_.batch_size) { + inst_order_.push_back(std::make_pair(tid, out_tmp.Size())); + } } } - } - if (!reader_has_data) break; - // Opencv decode and augments - cv::Mat res; - rec.Load(blob.dptr, blob.size); - cv::Mat buf(1, rec.content_size, CV_8U, rec.content); - - // If augmentation seed is supplied - // Re-seed RNG to guarantee reproducible results - if (param_.seed_aug.has_value()) { - prnds_[tid]->seed(idx + param_.seed_aug.value() + kRandMagic); - } + if (!reader_has_data) + break; + // Opencv decode and augments + cv::Mat res; + rec.Load(blob.dptr, blob.size); + cv::Mat buf(1, rec.content_size, CV_8U, rec.content); + + // If augmentation seed is supplied + // Re-seed RNG to guarantee reproducible results + if (param_.seed_aug.has_value()) { + prnds_[tid]->seed(idx + param_.seed_aug.value() + kRandMagic); + } - switch (param_.data_shape[0]) { - case 1: + switch (param_.data_shape[0]) { + case 1: #if MXNET_USE_LIBJPEG_TURBO - res = TJimdecode(buf, 0); + res = TJimdecode(buf, 0); #else - res = cv::imdecode(buf, 0); + res = cv::imdecode(buf, 0); #endif - break; - case 3: + break; + case 3: #if MXNET_USE_LIBJPEG_TURBO - res = TJimdecode(buf, 1); + res = TJimdecode(buf, 1); #else - res = cv::imdecode(buf, 1); + res = cv::imdecode(buf, 1); #endif - break; - case 4: - // -1 to keep the number of channel of the encoded image, and not force gray or color. - res = cv::imdecode(buf, -1); - CHECK_EQ(res.channels(), 4) - << "Invalid image with index " << rec.image_index() - << ". Expected 4 channels, got " << res.channels(); - break; - default: - LOG(FATAL) << "Invalid output shape " << param_.data_shape; - } - const int n_channels = res.channels(); - // load label before augmentations - std::vector label_buf; - if (label_map_ != nullptr) { - label_buf = label_map_->FindCopy(rec.image_index()); - } else if (rec.label != nullptr) { - CHECK_EQ(param_.label_width, rec.num_label) - << "rec file provide " << rec.num_label << "-dimensional label " - "but label_width is set to " << param_.label_width; - label_buf.assign(rec.label, rec.label + rec.num_label); - } else { - CHECK_EQ(param_.label_width, 1) - << "label_width must be 1 unless an imglist is provided " - "or the rec file is packed with multi dimensional label"; - label_buf.assign(&rec.header.label, &rec.header.label + 1); - } - for (auto& aug : augmenters_[tid]) { - res = aug->Process(res, &label_buf, prnds_[tid].get()); - } - mshadow::Tensor data; - if (idx < batch_param_.batch_size) { - data = mshadow::Tensor(data_dptr + idx*unit_size_[0], - mshadow::Shape3(n_channels, res.rows, res.cols)); - } else { - out_tmp.Push(static_cast(rec.image_index()), - mshadow::Shape3(n_channels, res.rows, res.cols), - mshadow::Shape1(param_.label_width)); - data = out_tmp.data().Back(); - } + break; + case 4: + // -1 to keep the number of channel of the encoded image, and not force gray or color. + res = cv::imdecode(buf, -1); + CHECK_EQ(res.channels(), 4) << "Invalid image with index " << rec.image_index() + << ". Expected 4 channels, got " << res.channels(); + break; + default: + LOG(FATAL) << "Invalid output shape " << param_.data_shape; + } + const int n_channels = res.channels(); + // load label before augmentations + std::vector label_buf; + if (label_map_ != nullptr) { + label_buf = label_map_->FindCopy(rec.image_index()); + } else if (rec.label != nullptr) { + CHECK_EQ(param_.label_width, rec.num_label) << "rec file provide " << rec.num_label + << "-dimensional label " + "but label_width is set to " + << param_.label_width; + label_buf.assign(rec.label, rec.label + rec.num_label); + } else { + CHECK_EQ(param_.label_width, 1) + << "label_width must be 1 unless an imglist is provided " + "or the rec file is packed with multi dimensional label"; + label_buf.assign(&rec.header.label, &rec.header.label + 1); + } + for (auto& aug : augmenters_[tid]) { + res = aug->Process(res, &label_buf, prnds_[tid].get()); + } + mshadow::Tensor data; + if (idx < batch_param_.batch_size) { + data = mshadow::Tensor(data_dptr + idx * unit_size_[0], + mshadow::Shape3(n_channels, res.rows, res.cols)); + } else { + out_tmp.Push(static_cast(rec.image_index()), + mshadow::Shape3(n_channels, res.rows, res.cols), + mshadow::Shape1(param_.label_width)); + data = out_tmp.data().Back(); + } - std::uniform_real_distribution rand_uniform(0, 1); - std::bernoulli_distribution coin_flip(0.5); - bool is_mirrored = (normalize_param_.rand_mirror && coin_flip(*(prnds_[tid]))) - || normalize_param_.mirror; - float contrast_scaled = 1; - float illumination_scaled = 0; - if (!std::is_same::value) { - contrast_scaled = - (rand_uniform(*(prnds_[tid])) * normalize_param_.max_random_contrast * 2 - - normalize_param_.max_random_contrast + 1)*normalize_param_.scale; - illumination_scaled = - (rand_uniform(*(prnds_[tid])) * normalize_param_.max_random_illumination * 2 - - normalize_param_.max_random_illumination) * normalize_param_.scale; - } - // For RGB or RGBA data, swap the B and R channel: - // OpenCV store as BGR (or BGRA) and we want RGB (or RGBA) - if (n_channels == 1) { - ProcessImage<1>(res, &data, is_mirrored, contrast_scaled, illumination_scaled); - } else if (n_channels == 3) { - ProcessImage<3>(res, &data, is_mirrored, contrast_scaled, illumination_scaled); - } else if (n_channels == 4) { - ProcessImage<4>(res, &data, is_mirrored, contrast_scaled, illumination_scaled); - } + std::uniform_real_distribution rand_uniform(0, 1); + std::bernoulli_distribution coin_flip(0.5); + bool is_mirrored = + (normalize_param_.rand_mirror && coin_flip(*(prnds_[tid]))) || normalize_param_.mirror; + float contrast_scaled = 1; + float illumination_scaled = 0; + if (!std::is_same::value) { + contrast_scaled = + (rand_uniform(*(prnds_[tid])) * normalize_param_.max_random_contrast * 2 - + normalize_param_.max_random_contrast + 1) * + normalize_param_.scale; + illumination_scaled = + (rand_uniform(*(prnds_[tid])) * normalize_param_.max_random_illumination * 2 - + normalize_param_.max_random_illumination) * + normalize_param_.scale; + } + // For RGB or RGBA data, swap the B and R channel: + // OpenCV store as BGR (or BGRA) and we want RGB (or RGBA) + if (n_channels == 1) { + ProcessImage<1>(res, &data, is_mirrored, contrast_scaled, illumination_scaled); + } else if (n_channels == 3) { + ProcessImage<3>(res, &data, is_mirrored, contrast_scaled, illumination_scaled); + } else if (n_channels == 4) { + ProcessImage<4>(res, &data, is_mirrored, contrast_scaled, illumination_scaled); + } - mshadow::Tensor label; - if (idx < batch_param_.batch_size) { - label = mshadow::Tensor(label_dptr + idx*unit_size_[1], - mshadow::Shape1(param_.label_width)); - } else { - label = out_tmp.label().Back(); - } + mshadow::Tensor label; + if (idx < batch_param_.batch_size) { + label = mshadow::Tensor(label_dptr + idx * unit_size_[1], + mshadow::Shape1(param_.label_width)); + } else { + label = out_tmp.label().Back(); + } - mshadow::Copy(label, mshadow::Tensor(dmlc::BeginPtr(label_buf), - mshadow::Shape1(label_buf.size()))); - res.release(); - } - }); + mshadow::Copy( + label, + mshadow::Tensor(dmlc::BeginPtr(label_buf), mshadow::Shape1(label_buf.size()))); + res.release(); + } + }); } omp_exc_.Rethrow(); return (std::min(static_cast(batch_param_.batch_size), gl_idx) - current_size); @@ -665,118 +661,117 @@ inline size_t ImageRecordIOParser2::ParseChunk(DType* data_dptr, real_t* } // create mean image. -template +template inline void ImageRecordIOParser2::CreateMeanImg() { - if (param_.verbose) { - LOG(INFO) << "Cannot find " << normalize_param_.mean_img - << ": create mean image, this will take some time..."; - } - double start = dmlc::GetTime(); - dmlc::InputSplit::Blob chunk; - size_t imcnt = 0; // NOLINT(*) - while (source_->NextChunk(&chunk)) { - inst_order_.clear(); - // Parse chunk w/o putting anything in out - ParseChunk(nullptr, nullptr, batch_param_.batch_size, &chunk); - for (auto place : inst_order_) { - mshadow::Tensor outimg = + if (param_.verbose) { + LOG(INFO) << "Cannot find " << normalize_param_.mean_img + << ": create mean image, this will take some time..."; + } + double start = dmlc::GetTime(); + dmlc::InputSplit::Blob chunk; + size_t imcnt = 0; // NOLINT(*) + while (source_->NextChunk(&chunk)) { + inst_order_.clear(); + // Parse chunk w/o putting anything in out + ParseChunk(nullptr, nullptr, batch_param_.batch_size, &chunk); + for (auto place : inst_order_) { + mshadow::Tensor outimg = temp_[place.first][place.second].data[0].template get(); - if (imcnt == 0) { - meanimg_.Resize(outimg.shape_); - mshadow::Copy(meanimg_, outimg); - } else { - meanimg_ += outimg; - } - imcnt += 1; - double elapsed = dmlc::GetTime() - start; - if (imcnt % 10000L == 0 && param_.verbose) { - LOG(INFO) << imcnt << " images processed, " << elapsed << " sec elapsed"; - } + if (imcnt == 0) { + meanimg_.Resize(outimg.shape_); + mshadow::Copy(meanimg_, outimg); + } else { + meanimg_ += outimg; + } + imcnt += 1; + double elapsed = dmlc::GetTime() - start; + if (imcnt % 10000L == 0 && param_.verbose) { + LOG(INFO) << imcnt << " images processed, " << elapsed << " sec elapsed"; } } - meanimg_ *= (1.0f / imcnt); - // save as mxnet python compatible format. - TBlob tmp = meanimg_; - { - std::unique_ptr fo( - dmlc::Stream::Create(normalize_param_.mean_img.c_str(), "w")); - NDArray::Save(fo.get(), - {NDArray(tmp, 0)}, - {"mean_img"}); - } - if (param_.verbose) { - LOG(INFO) << "Save mean image to " << normalize_param_.mean_img << ".."; - } - meanfile_ready_ = true; - this->BeforeFirst(); + } + meanimg_ *= (1.0f / imcnt); + // save as mxnet python compatible format. + TBlob tmp = meanimg_; + { + std::unique_ptr fo(dmlc::Stream::Create(normalize_param_.mean_img.c_str(), "w")); + NDArray::Save(fo.get(), {NDArray(tmp, 0)}, {"mean_img"}); + } + if (param_.verbose) { + LOG(INFO) << "Save mean image to " << normalize_param_.mean_img << ".."; + } + meanfile_ready_ = true; + this->BeforeFirst(); } -template +template class ImageRecordIter2 : public IIterator { public: - ImageRecordIter2() = default; + ImageRecordIter2() = default; - ~ImageRecordIter2() override { - iter_.Destroy(); - } + ~ImageRecordIter2() override { + iter_.Destroy(); + } - void Init(const std::vector >& kwargs) override { - prefetch_param_.InitAllowUnknown(kwargs); - parser_.Init(kwargs); - // maximum prefetch threaded iter internal size - const int kMaxPrefetchBuffer = 16; - // init thread iter - iter_.set_max_capacity(kMaxPrefetchBuffer); - // init thread iter - iter_.Init([this](DataBatch **dptr) { + void Init(const std::vector>& kwargs) override { + prefetch_param_.InitAllowUnknown(kwargs); + parser_.Init(kwargs); + // maximum prefetch threaded iter internal size + const int kMaxPrefetchBuffer = 16; + // init thread iter + iter_.set_max_capacity(kMaxPrefetchBuffer); + // init thread iter + iter_.Init( + [this](DataBatch** dptr) { if (*dptr == nullptr) { *dptr = new DataBatch(); } return parser_.ParseNext(*dptr); - }, - [this]() { parser_.BeforeFirst(); }); - } + }, + [this]() { parser_.BeforeFirst(); }); + } - void BeforeFirst() override { - iter_.BeforeFirst(); - } + void BeforeFirst() override { + iter_.BeforeFirst(); + } - // From iter_prefetcher.h - bool Next() override { - if (out_ != nullptr) { - recycle_queue_.push(out_); out_ = nullptr; - } - // do recycle - if (recycle_queue_.size() == prefetch_param_.prefetch_buffer) { - DataBatch *old_batch = recycle_queue_.front(); - // can be more efficient on engine - for (NDArray& arr : old_batch->data) { - arr.WaitToWrite(); - } - recycle_queue_.pop(); - iter_.Recycle(&old_batch); + // From iter_prefetcher.h + bool Next() override { + if (out_ != nullptr) { + recycle_queue_.push(out_); + out_ = nullptr; + } + // do recycle + if (recycle_queue_.size() == prefetch_param_.prefetch_buffer) { + DataBatch* old_batch = recycle_queue_.front(); + // can be more efficient on engine + for (NDArray& arr : old_batch->data) { + arr.WaitToWrite(); } - return iter_.Next(&out_); + recycle_queue_.pop(); + iter_.Recycle(&old_batch); } + return iter_.Next(&out_); + } - const DataBatch &Value() const override { - return *out_; - } + const DataBatch& Value() const override { + return *out_; + } private: - /*! \brief Backend thread */ - dmlc::ThreadedIter iter_; - /*! \brief Parameters */ - PrefetcherParam prefetch_param_; - /*! \brief output data */ - DataBatch *out_{nullptr}; - /*! \brief queue to be recycled */ - std::queue recycle_queue_; - /* \brief parser */ - ImageRecordIOParser2 parser_; + /*! \brief Backend thread */ + dmlc::ThreadedIter iter_; + /*! \brief Parameters */ + PrefetcherParam prefetch_param_; + /*! \brief output data */ + DataBatch* out_{nullptr}; + /*! \brief queue to be recycled */ + std::queue recycle_queue_; + /* \brief parser */ + ImageRecordIOParser2 parser_; }; -template +template class ImageRecordIter2CPU : public IIterator { public: ImageRecordIter2CPU() { @@ -793,22 +788,28 @@ class ImageRecordIter2CPU : public IIterator { parser_.Init(kwargs); } - void BeforeFirst() override { parser_.BeforeFirst(); } + void BeforeFirst() override { + parser_.BeforeFirst(); + } // From iter_prefetcher.h bool Next() override { - bool result = false; + bool result = false; const auto engine = Engine::Get(); - engine->PushSync( - [this, &result](RunContext ctx) { - result = this->parser_.ParseNext(out_); - }, - Context::CPU(), {}, {var_}, FnProperty::kNormal, 0, "DataLoader"); + engine->PushSync([this, &result](RunContext ctx) { result = this->parser_.ParseNext(out_); }, + Context::CPU(), + {}, + {var_}, + FnProperty::kNormal, + 0, + "DataLoader"); engine->WaitForVar(var_); return result; } - const DataBatch& Value() const override { return *out_; } + const DataBatch& Value() const override { + return *out_; + } private: /*! \brief Backend thread */ @@ -825,7 +826,8 @@ class ImageRecordIter2CPU : public IIterator { class ImageRecordIter2Wrapper : public IIterator { public: ~ImageRecordIter2Wrapper() override { - if (record_iter_) delete record_iter_; + if (record_iter_) + delete record_iter_; } void Init(const std::vector>& kwargs) override { PrefetcherParam prefetch_param; @@ -867,25 +869,27 @@ class ImageRecordIter2Wrapper : public IIterator { } } record_iter_->Init(kwargs); - } + } - void BeforeFirst() override { - record_iter_->BeforeFirst(); - } + void BeforeFirst() override { + record_iter_->BeforeFirst(); + } - // From iter_prefetcher.h - bool Next() override { return record_iter_->Next(); } + // From iter_prefetcher.h + bool Next() override { + return record_iter_->Next(); + } - const DataBatch &Value() const override { - return record_iter_->Value(); - } + const DataBatch& Value() const override { + return record_iter_->Value(); + } private: IIterator* record_iter_ = nullptr; }; MXNET_REGISTER_IO_ITER(ImageRecordIter) -.describe(R"code(Iterates on image RecordIO files + .describe(R"code(Iterates on image RecordIO files Reads batches of images from .rec RecordIO files. One can use ``im2rec.py`` tool (in tools/) to pack raw image files into RecordIO files. This iterator is less @@ -909,18 +913,16 @@ Example:: data_iter.reset() # To restart the iterator from the beginning. )code" ADD_FILELINE) -.add_arguments(ImageRecParserParam::__FIELDS__()) -.add_arguments(ImageRecordParam::__FIELDS__()) -.add_arguments(BatchParam::__FIELDS__()) -.add_arguments(PrefetcherParam::__FIELDS__()) -.add_arguments(ListDefaultAugParams()) -.add_arguments(ImageNormalizeParam::__FIELDS__()) -.set_body([]() { - return new ImageRecordIter2Wrapper(); - }); + .add_arguments(ImageRecParserParam::__FIELDS__()) + .add_arguments(ImageRecordParam::__FIELDS__()) + .add_arguments(BatchParam::__FIELDS__()) + .add_arguments(PrefetcherParam::__FIELDS__()) + .add_arguments(ListDefaultAugParams()) + .add_arguments(ImageNormalizeParam::__FIELDS__()) + .set_body([]() { return new ImageRecordIter2Wrapper(); }); MXNET_REGISTER_IO_ITER(ImageRecordUInt8Iter) -.describe(R"code(Iterating on image RecordIO files + .describe(R"code(Iterating on image RecordIO files .. note:: ImageRecordUInt8Iter is deprecated. Use ImageRecordIter(dtype='uint8') instead. @@ -928,17 +930,15 @@ This iterator is identical to ``ImageRecordIter`` except for using ``uint8`` as the data type instead of ``float``. )code" ADD_FILELINE) -.add_arguments(ImageRecParserParam::__FIELDS__()) -.add_arguments(ImageRecordParam::__FIELDS__()) -.add_arguments(BatchParam::__FIELDS__()) -.add_arguments(PrefetcherParam::__FIELDS__()) -.add_arguments(ListDefaultAugParams()) -.set_body([]() { - return new ImageRecordIter2(); - }); + .add_arguments(ImageRecParserParam::__FIELDS__()) + .add_arguments(ImageRecordParam::__FIELDS__()) + .add_arguments(BatchParam::__FIELDS__()) + .add_arguments(PrefetcherParam::__FIELDS__()) + .add_arguments(ListDefaultAugParams()) + .set_body([]() { return new ImageRecordIter2(); }); MXNET_REGISTER_IO_ITER(ImageRecordInt8Iter) -.describe(R"code(Iterating on image RecordIO files + .describe(R"code(Iterating on image RecordIO files .. note:: ``ImageRecordInt8Iter`` is deprecated. Use ImageRecordIter(dtype='int8') instead. @@ -946,14 +946,12 @@ This iterator is identical to ``ImageRecordIter`` except for using ``int8`` as the data type instead of ``float``. )code" ADD_FILELINE) -.add_arguments(ImageRecParserParam::__FIELDS__()) -.add_arguments(ImageRecordParam::__FIELDS__()) -.add_arguments(BatchParam::__FIELDS__()) -.add_arguments(PrefetcherParam::__FIELDS__()) -.add_arguments(ListDefaultAugParams()) -.set_body([]() { - return new ImageRecordIter2(); - }); + .add_arguments(ImageRecParserParam::__FIELDS__()) + .add_arguments(ImageRecordParam::__FIELDS__()) + .add_arguments(BatchParam::__FIELDS__()) + .add_arguments(PrefetcherParam::__FIELDS__()) + .add_arguments(ListDefaultAugParams()) + .set_body([]() { return new ImageRecordIter2(); }); } // namespace io } // namespace mxnet diff --git a/src/io/iter_libsvm.cc b/src/io/iter_libsvm.cc index 0965bfc5192e..3e0f079c9138 100644 --- a/src/io/iter_libsvm.cc +++ b/src/io/iter_libsvm.cc @@ -49,24 +49,24 @@ struct LibSVMIterParam : public dmlc::Parameter { DMLC_DECLARE_PARAMETER(LibSVMIterParam) { DMLC_DECLARE_FIELD(data_libsvm) .describe("The input zero-base indexed LibSVM data file or a directory path."); - DMLC_DECLARE_FIELD(data_shape) - .describe("The shape of one example."); - DMLC_DECLARE_FIELD(label_libsvm).set_default("NULL") - .describe("The input LibSVM label file or a directory path. " - "If NULL, all labels will be read from ``data_libsvm``."); + DMLC_DECLARE_FIELD(data_shape).describe("The shape of one example."); + DMLC_DECLARE_FIELD(label_libsvm) + .set_default("NULL") + .describe( + "The input LibSVM label file or a directory path. " + "If NULL, all labels will be read from ``data_libsvm``."); index_t shape1[] = {1}; - DMLC_DECLARE_FIELD(label_shape).set_default(mxnet::TShape(shape1, shape1 + 1)) + DMLC_DECLARE_FIELD(label_shape) + .set_default(mxnet::TShape(shape1, shape1 + 1)) .describe("The shape of one label."); - DMLC_DECLARE_FIELD(num_parts).set_default(1) - .describe("partition the data into multiple parts"); - DMLC_DECLARE_FIELD(part_index).set_default(0) - .describe("the index of the part will read"); + DMLC_DECLARE_FIELD(num_parts).set_default(1).describe("partition the data into multiple parts"); + DMLC_DECLARE_FIELD(part_index).set_default(0).describe("the index of the part will read"); } }; -class LibSVMIter: public SparseIIterator { +class LibSVMIter : public SparseIIterator { public: - LibSVMIter() = default; + LibSVMIter() = default; ~LibSVMIter() override = default; // intialize iterator loads data in @@ -75,18 +75,16 @@ class LibSVMIter: public SparseIIterator { CHECK_EQ(param_.data_shape.ndim(), 1) << "dimension of data_shape is expected to be 1"; CHECK_GT(param_.num_parts, 0) << "number of parts should be positive"; CHECK_GE(param_.part_index, 0) << "part index should be non-negative"; - data_parser_.reset(dmlc::Parser::Create(param_.data_libsvm.c_str(), - param_.part_index, - param_.num_parts, "libsvm")); + data_parser_.reset(dmlc::Parser::Create( + param_.data_libsvm.c_str(), param_.part_index, param_.num_parts, "libsvm")); if (param_.label_libsvm != "NULL") { - label_parser_.reset(dmlc::Parser::Create(param_.label_libsvm.c_str(), - param_.part_index, - param_.num_parts, "libsvm")); + label_parser_.reset(dmlc::Parser::Create( + param_.label_libsvm.c_str(), param_.part_index, param_.num_parts, "libsvm")); CHECK_GT(param_.label_shape.Size(), 1) - << "label_shape is not expected to be (1,) when param_.label_libsvm is set."; + << "label_shape is not expected to be (1,) when param_.label_libsvm is set."; } else { CHECK_EQ(param_.label_shape.Size(), 1) - << "label_shape is expected to be (1,) when param_.label_libsvm is NULL"; + << "label_shape is expected to be (1,) when param_.label_libsvm is NULL"; } // both data and label are of CSRStorage in libsvm format if (param_.label_shape.Size() > 1) { @@ -104,17 +102,19 @@ class LibSVMIter: public SparseIIterator { } data_ptr_ = label_ptr_ = 0; data_size_ = label_size_ = 0; - inst_counter_ = 0; - end_ = false; + inst_counter_ = 0; + end_ = false; } bool Next() override { - if (end_) return false; + if (end_) + return false; while (data_ptr_ >= data_size_) { if (!data_parser_->Next()) { - end_ = true; return false; + end_ = true; + return false; } - data_ptr_ = 0; + data_ptr_ = 0; data_size_ = data_parser_->Value().size; } out_.index = inst_counter_++; @@ -129,7 +129,7 @@ class LibSVMIter: public SparseIIterator { while (label_ptr_ >= label_size_) { CHECK(label_parser_->Next()) << "Data LibSVM's row is smaller than the number of rows in label_libsvm"; - label_ptr_ = 0; + label_ptr_ = 0; label_size_ = label_parser_->Value().size; } CHECK_LT(label_ptr_, label_size_); @@ -144,17 +144,19 @@ class LibSVMIter: public SparseIIterator { return true; } - const DataInst &Value() const override { + const DataInst& Value() const override { return out_; } const NDArrayStorageType GetStorageType(bool is_data) const override { - if (is_data) return kCSRStorage; + if (is_data) + return kCSRStorage; return param_.label_shape.Size() > 1 ? kCSRStorage : kDefaultStorage; } const mxnet::TShape GetShape(bool is_data) const override { - if (is_data) return param_.data_shape; + if (is_data) + return param_.data_shape; return param_.label_shape; } @@ -162,13 +164,13 @@ class LibSVMIter: public SparseIIterator { inline TBlob AsDataBlob(const dmlc::Row& row) { const real_t* ptr = row.value; mxnet::TShape shape(mshadow::Shape1(row.length)); - return TBlob((real_t*) ptr, shape, cpu::kDevMask); // NOLINT(*) + return TBlob((real_t*)ptr, shape, cpu::kDevMask); // NOLINT(*) } inline TBlob AsIdxBlob(const dmlc::Row& row) { const uint64_t* ptr = row.index; mxnet::TShape shape(mshadow::Shape1(row.length)); - return TBlob((int64_t*) ptr, shape, cpu::kDevMask, mshadow::kInt64); // NOLINT(*) + return TBlob((int64_t*)ptr, shape, cpu::kDevMask, mshadow::kInt64); // NOLINT(*) } inline TBlob AsIndPtrPlaceholder(const dmlc::Row& row) { @@ -177,7 +179,7 @@ class LibSVMIter: public SparseIIterator { inline TBlob AsScalarLabelBlob(const dmlc::Row& row) { const real_t* ptr = row.label; - return TBlob((real_t*) ptr, mshadow::Shape1(1), cpu::kDevMask); // NOLINT(*) + return TBlob((real_t*)ptr, mshadow::Shape1(1), cpu::kDevMask); // NOLINT(*) } LibSVMIterParam param_; @@ -194,11 +196,10 @@ class LibSVMIter: public SparseIIterator { std::unique_ptr > data_parser_; }; - DMLC_REGISTER_PARAMETER(LibSVMIterParam); MXNET_REGISTER_IO_ITER(LibSVMIter) -.describe(R"code(Returns the LibSVM iterator which returns data with `csr` + .describe(R"code(Returns the LibSVM iterator which returns data with `csr` storage type. This iterator is experimental and should be used with care. The input data is stored in a format similar to LibSVM file format, except that the **indices @@ -296,14 +297,10 @@ Example:: [ 0. 0. 1.2 ]] )code" ADD_FILELINE) -.add_arguments(LibSVMIterParam::__FIELDS__()) -.add_arguments(BatchParam::__FIELDS__()) -.add_arguments(PrefetcherParam::__FIELDS__()) -.set_body([]() { - return new SparsePrefetcherIter( - new SparseBatchLoader( - new LibSVMIter())); - }); + .add_arguments(LibSVMIterParam::__FIELDS__()) + .add_arguments(BatchParam::__FIELDS__()) + .add_arguments(PrefetcherParam::__FIELDS__()) + .set_body([]() { return new SparsePrefetcherIter(new SparseBatchLoader(new LibSVMIter())); }); } // namespace io } // namespace mxnet diff --git a/src/io/iter_mnist.cc b/src/io/iter_mnist.cc index bf590170ca9f..20c28fbd9021 100644 --- a/src/io/iter_mnist.cc +++ b/src/io/iter_mnist.cc @@ -21,7 +21,7 @@ * Copyright (c) 2015 by Contributors * \file iter_mnist.cc * \brief register mnist iterator -*/ + */ #include #include #include @@ -56,35 +56,36 @@ struct MNISTParam : public dmlc::Parameter { int part_index; // declare parameters DMLC_DECLARE_PARAMETER(MNISTParam) { - DMLC_DECLARE_FIELD(image).set_default("./train-images-idx3-ubyte") + DMLC_DECLARE_FIELD(image) + .set_default("./train-images-idx3-ubyte") .describe("Dataset Param: Mnist image path."); - DMLC_DECLARE_FIELD(label).set_default("./train-labels-idx1-ubyte") + DMLC_DECLARE_FIELD(label) + .set_default("./train-labels-idx1-ubyte") .describe("Dataset Param: Mnist label path."); - DMLC_DECLARE_FIELD(batch_size).set_lower_bound(1).set_default(128) + DMLC_DECLARE_FIELD(batch_size) + .set_lower_bound(1) + .set_default(128) .describe("Batch Param: Batch Size."); - DMLC_DECLARE_FIELD(shuffle).set_default(true) - .describe("Augmentation Param: Whether to shuffle data."); - DMLC_DECLARE_FIELD(flat).set_default(false) - .describe("Augmentation Param: Whether to flat the data into 1D."); - DMLC_DECLARE_FIELD(seed).set_default(0) - .describe("Augmentation Param: Random Seed."); - DMLC_DECLARE_FIELD(silent).set_default(false) - .describe("Auxiliary Param: Whether to print out data info."); - DMLC_DECLARE_FIELD(num_parts).set_default(1) - .describe("partition the data into multiple parts"); - DMLC_DECLARE_FIELD(part_index).set_default(0) - .describe("the index of the part will read"); + DMLC_DECLARE_FIELD(shuffle).set_default(true).describe( + "Augmentation Param: Whether to shuffle data."); + DMLC_DECLARE_FIELD(flat).set_default(false).describe( + "Augmentation Param: Whether to flat the data into 1D."); + DMLC_DECLARE_FIELD(seed).set_default(0).describe("Augmentation Param: Random Seed."); + DMLC_DECLARE_FIELD(silent).set_default(false).describe( + "Auxiliary Param: Whether to print out data info."); + DMLC_DECLARE_FIELD(num_parts).set_default(1).describe("partition the data into multiple parts"); + DMLC_DECLARE_FIELD(part_index).set_default(0).describe("the index of the part will read"); } }; -class MNISTIter: public IIterator { +class MNISTIter : public IIterator { public: - MNISTIter() { + MNISTIter() { img_.dptr_ = nullptr; out_.data.resize(2); } ~MNISTIter() override { - delete []img_.dptr_; + delete[] img_.dptr_; } // intialize iterator loads data in void Init(const std::vector >& kwargs) override { @@ -98,20 +99,21 @@ class MNISTIter: public IIterator { batch_data_.shape_ = mshadow::Shape4(param_.batch_size, 1, img_.size(1), img_.size(2)); } out_.data.clear(); - batch_label_.shape_ = mshadow::Shape2(param_.batch_size, 1); + batch_label_.shape_ = mshadow::Shape2(param_.batch_size, 1); batch_label_.stride_ = 1; - batch_data_.stride_ = batch_data_.size(3); - out_.batch_size = param_.batch_size; - if (param_.shuffle) this->Shuffle(); + batch_data_.stride_ = batch_data_.size(3); + out_.batch_size = param_.batch_size; + if (param_.shuffle) + this->Shuffle(); if (param_.silent == 0) { mxnet::TShape s; s = batch_data_.shape_; if (param_.flat) { - LOG(INFO) << "MNISTIter: load " << (unsigned)img_.size(0) << " images, shuffle=" - << param_.shuffle << ", shape=" << s.FlatTo2D(); + LOG(INFO) << "MNISTIter: load " << (unsigned)img_.size(0) + << " images, shuffle=" << param_.shuffle << ", shape=" << s.FlatTo2D(); } else { - LOG(INFO) << "MNISTIter: load " << (unsigned)img_.size(0) << " images, shuffle=" - << param_.shuffle << ", shape=" << s; + LOG(INFO) << "MNISTIter: load " << (unsigned)img_.size(0) + << " images, shuffle=" << param_.shuffle << ", shape=" << s; } } } @@ -120,13 +122,13 @@ class MNISTIter: public IIterator { } bool Next() override { if (loc_ + param_.batch_size <= img_.size(0)) { - batch_data_.dptr_ = img_[loc_].dptr_; + batch_data_.dptr_ = img_[loc_].dptr_; batch_label_.dptr_ = &labels_[loc_]; out_.data.clear(); if (param_.flat) { - out_.data.emplace_back(batch_data_.FlatTo2D()); + out_.data.emplace_back(batch_data_.FlatTo2D()); } else { - out_.data.emplace_back(batch_data_); + out_.data.emplace_back(batch_data_); } out_.data.emplace_back(batch_label_); loc_ += param_.batch_size; @@ -135,25 +137,23 @@ class MNISTIter: public IIterator { return false; } } - const TBlobBatch &Value() const override { + const TBlobBatch& Value() const override { return out_; } private: - inline void GetPart(int count, int* start, int *end) { + inline void GetPart(int count, int* start, int* end) { CHECK_GE(param_.part_index, 0); CHECK_GT(param_.num_parts, 0); CHECK_GT(param_.num_parts, param_.part_index); - *start = static_cast( - static_cast(count) / param_.num_parts * param_.part_index); - *end = static_cast( - static_cast(count) / param_.num_parts * (param_.part_index+1)); + *start = static_cast(static_cast(count) / param_.num_parts * param_.part_index); + *end = + static_cast(static_cast(count) / param_.num_parts * (param_.part_index + 1)); } inline void LoadImage() { - dmlc::SeekStream* stdimg - = dmlc::SeekStream::CreateForRead(param_.image.c_str()); + dmlc::SeekStream* stdimg = dmlc::SeekStream::CreateForRead(param_.image.c_str()); ReadInt(stdimg); int image_count = ReadInt(stdimg); int image_rows = ReadInt(stdimg); @@ -166,7 +166,7 @@ class MNISTIter: public IIterator { stdimg->Seek(stdimg->Tell() + start * image_rows * image_cols); } - img_.shape_ = mshadow::Shape3(image_count, image_rows, image_cols); + img_.shape_ = mshadow::Shape3(image_count, image_rows, image_cols); img_.stride_ = img_.size(2); // allocate continuous memory @@ -185,8 +185,7 @@ class MNISTIter: public IIterator { delete stdimg; } inline void LoadLabel() { - dmlc::SeekStream* stdlabel - = dmlc::SeekStream::CreateForRead(param_.label.c_str()); + dmlc::SeekStream* stdlabel = dmlc::SeekStream::CreateForRead(param_.label.c_str()); ReadInt(stdlabel); int labels_count = ReadInt(stdlabel); @@ -221,10 +220,9 @@ class MNISTIter: public IIterator { } private: - inline static int ReadInt(dmlc::Stream *fi) { + inline static int ReadInt(dmlc::Stream* fi) { unsigned char buf[4]; - CHECK(fi->Read(buf, sizeof(buf)) == sizeof(buf)) - << "invalid mnist format"; + CHECK(fi->Read(buf, sizeof(buf)) == sizeof(buf)) << "invalid mnist format"; #ifdef _MSC_VER return (buf[0] << 24 | buf[1] << 16 | buf[2] << 8 | buf[3]); #else @@ -258,12 +256,10 @@ class MNISTIter: public IIterator { DMLC_REGISTER_PARAMETER(MNISTParam); MXNET_REGISTER_IO_ITER(MNISTIter) -.describe("Iterating on the MNIST dataset." ADD_FILELINE) -.add_arguments(MNISTParam::__FIELDS__()) -.add_arguments(PrefetcherParam::__FIELDS__()) -.set_body([]() { - return new PrefetcherIter(new MNISTIter()); - }); + .describe("Iterating on the MNIST dataset." ADD_FILELINE) + .add_arguments(MNISTParam::__FIELDS__()) + .add_arguments(PrefetcherParam::__FIELDS__()) + .set_body([]() { return new PrefetcherIter(new MNISTIter()); }); } // namespace io } // namespace mxnet diff --git a/src/io/iter_normalize.h b/src/io/iter_normalize.h index 4bc7d53d2b76..eda7af445e7e 100644 --- a/src/io/iter_normalize.h +++ b/src/io/iter_normalize.h @@ -47,9 +47,7 @@ namespace io { */ class ImageNormalizeIter : public IIterator { public: - explicit ImageNormalizeIter(IIterator *base) - : base_(base), meanfile_ready_(false) { - } + explicit ImageNormalizeIter(IIterator* base) : base_(base), meanfile_ready_(false) {} virtual void Init(const std::vector >& kwargs) { param_.InitAllowUnknown(kwargs); @@ -58,8 +56,7 @@ class ImageNormalizeIter : public IIterator { outimg_.set_pad(false); meanimg_.set_pad(false); if (param_.mean_img.length() != 0) { - std::unique_ptr fi( - dmlc::Stream::Create(param_.mean_img.c_str(), "r", true)); + std::unique_ptr fi(dmlc::Stream::Create(param_.mean_img.c_str(), "r", true)); if (fi.get() == nullptr) { this->CreateMeanImg(); } else { @@ -74,8 +71,7 @@ class ImageNormalizeIter : public IIterator { std::unique_ptr fi(dmlc::Stream::Create(param_.mean_img.c_str(), "r")); NDArray::Load(fi.get(), &data, &keys); } - CHECK_EQ(data.size(), 1U) - << "Invalid mean image file format"; + CHECK_EQ(data.size(), 1U) << "Invalid mean image file format"; data[0].WaitToRead(); mshadow::Tensor src = data[0].data().get(); meanimg_.Resize(src.shape_); @@ -94,7 +90,8 @@ class ImageNormalizeIter : public IIterator { } virtual bool Next(void) { - if (!this->Next_()) return false; + if (!this->Next_()) + return false; return true; } @@ -118,13 +115,14 @@ class ImageNormalizeIter : public IIterator { /*! \brief internal next function, inlined for fater processing. */ inline bool Next_(void) { - if (!base_->Next()) return false; - const DataInst &src = base_->Value(); + if (!base_->Next()) + return false; + const DataInst& src = base_->Value(); this->SetOutImg(src); out_.data.resize(2); - out_.data[0] = outimg_; - out_.data[1] = src.data[1]; - out_.index = src.index; + out_.data[0] = outimg_; + out_.data[1] = src.data[1]; + out_.index = src.index; out_.extra_data = src.extra_data; return true; } @@ -132,7 +130,7 @@ class ImageNormalizeIter : public IIterator { * \brief Set the output image, after augmentation and normalization. * \param src The source image. */ - inline void SetOutImg(const DataInst &src) { + inline void SetOutImg(const DataInst& src) { using namespace mshadow::expr; // NOLINT(*) std::uniform_real_distribution rand_uniform(0, 1); @@ -150,59 +148,59 @@ class ImageNormalizeIter : public IIterator { switch (data.shape_[0]) { case 4: if (meanfile_ready_ && flip) { - outimg_[3] = mirror((data[3] - meanimg_[3]) * contrast + illumination) - * param_.scale / param_.std_a; + outimg_[3] = mirror((data[3] - meanimg_[3]) * contrast + illumination) * param_.scale / + param_.std_a; } else if (meanfile_ready_ && (!flip)) { - outimg_[3] = ((data[3] - meanimg_[3]) * contrast + illumination) - * param_.scale / param_.std_a; + outimg_[3] = + ((data[3] - meanimg_[3]) * contrast + illumination) * param_.scale / param_.std_a; } else if (!meanfile_ready_ && flip) { - outimg_[3] = mirror((data[3] - param_.mean_a) * contrast + illumination) - * param_.scale / param_.std_a; + outimg_[3] = mirror((data[3] - param_.mean_a) * contrast + illumination) * param_.scale / + param_.std_a; } else { - outimg_[3] = ((data[3] - param_.mean_a) * contrast + illumination) - * param_.scale / param_.std_a; + outimg_[3] = + ((data[3] - param_.mean_a) * contrast + illumination) * param_.scale / param_.std_a; } case 3: if (meanfile_ready_ && flip) { - outimg_[2] = mirror((data[2] - meanimg_[2]) * contrast + illumination) - * param_.scale / param_.std_b; + outimg_[2] = mirror((data[2] - meanimg_[2]) * contrast + illumination) * param_.scale / + param_.std_b; } else if (meanfile_ready_ && (!flip)) { - outimg_[2] = ((data[2] - meanimg_[2]) * contrast + illumination) - * param_.scale / param_.std_b; + outimg_[2] = + ((data[2] - meanimg_[2]) * contrast + illumination) * param_.scale / param_.std_b; } else if (!meanfile_ready_ && flip) { - outimg_[2] = mirror((data[2] - param_.mean_b) * contrast + illumination) - * param_.scale / param_.std_b; + outimg_[2] = mirror((data[2] - param_.mean_b) * contrast + illumination) * param_.scale / + param_.std_b; } else { - outimg_[2] = ((data[2] - param_.mean_b) * contrast + illumination) - * param_.scale / param_.std_b; + outimg_[2] = + ((data[2] - param_.mean_b) * contrast + illumination) * param_.scale / param_.std_b; } case 2: if (meanfile_ready_ && flip) { - outimg_[1] = mirror((data[1] - meanimg_[1]) * contrast + illumination) - * param_.scale / param_.std_g; + outimg_[1] = mirror((data[1] - meanimg_[1]) * contrast + illumination) * param_.scale / + param_.std_g; } else if (meanfile_ready_ && (!flip)) { - outimg_[1] = ((data[1] - meanimg_[1]) * contrast + illumination) - * param_.scale / param_.std_g; + outimg_[1] = + ((data[1] - meanimg_[1]) * contrast + illumination) * param_.scale / param_.std_g; } else if (!meanfile_ready_ && flip) { - outimg_[1] = mirror((data[1] - param_.mean_g) * contrast + illumination) - * param_.scale / param_.std_g; + outimg_[1] = mirror((data[1] - param_.mean_g) * contrast + illumination) * param_.scale / + param_.std_g; } else { - outimg_[1] = ((data[1] - param_.mean_g) * contrast + illumination) - * param_.scale / param_.std_g; + outimg_[1] = + ((data[1] - param_.mean_g) * contrast + illumination) * param_.scale / param_.std_g; } case 1: if (meanfile_ready_ && flip) { - outimg_[0] = mirror((data[0] - meanimg_[0]) * contrast + illumination) - * param_.scale / param_.std_r; + outimg_[0] = mirror((data[0] - meanimg_[0]) * contrast + illumination) * param_.scale / + param_.std_r; } else if (meanfile_ready_ && (!flip)) { - outimg_[0] = ((data[0] - meanimg_[0]) * contrast + illumination) - * param_.scale / param_.std_r; + outimg_[0] = + ((data[0] - meanimg_[0]) * contrast + illumination) * param_.scale / param_.std_r; } else if (!meanfile_ready_ && flip) { - outimg_[0] = mirror((data[0] - param_.mean_r) * contrast + illumination) - * param_.scale / param_.std_r; + outimg_[0] = mirror((data[0] - param_.mean_r) * contrast + illumination) * param_.scale / + param_.std_r; } else { - outimg_[0] = ((data[0] - param_.mean_r) * contrast + illumination) - * param_.scale / param_.std_r; + outimg_[0] = + ((data[0] - param_.mean_r) * contrast + illumination) * param_.scale / param_.std_r; } break; default: @@ -234,9 +232,7 @@ class ImageNormalizeIter : public IIterator { TBlob tmp = meanimg_; { std::unique_ptr fo(dmlc::Stream::Create(param_.mean_img.c_str(), "w")); - NDArray::Save(fo.get(), - {NDArray(tmp, 0)}, - {"mean_img"}); + NDArray::Save(fo.get(), {NDArray(tmp, 0)}, {"mean_img"}); } if (param_.verbose) { LOG(INFO) << "Save mean image to " << param_.mean_img << ".."; @@ -252,9 +248,7 @@ class ImageNormalizeIter : public IIterator { */ class ImageDetNormalizeIter : public IIterator { public: - explicit ImageDetNormalizeIter(IIterator *base) - : base_(base), meanfile_ready_(false) { - } + explicit ImageDetNormalizeIter(IIterator* base) : base_(base), meanfile_ready_(false) {} virtual void Init(const std::vector >& kwargs) { param_.InitAllowUnknown(kwargs); @@ -263,8 +257,7 @@ class ImageDetNormalizeIter : public IIterator { outimg_.set_pad(false); meanimg_.set_pad(false); if (param_.mean_img.length() != 0) { - std::unique_ptr fi( - dmlc::Stream::Create(param_.mean_img.c_str(), "r", true)); + std::unique_ptr fi(dmlc::Stream::Create(param_.mean_img.c_str(), "r", true)); if (fi.get() == nullptr) { this->CreateMeanImg(); } else { @@ -279,8 +272,7 @@ class ImageDetNormalizeIter : public IIterator { std::unique_ptr fi(dmlc::Stream::Create(param_.mean_img.c_str(), "r")); NDArray::Load(fi.get(), &data, &keys); } - CHECK_EQ(data.size(), 1) - << "Invalid mean image file format"; + CHECK_EQ(data.size(), 1) << "Invalid mean image file format"; data[0].WaitToRead(); mshadow::Tensor src = data[0].data().get(); meanimg_.Resize(src.shape_); @@ -299,7 +291,8 @@ class ImageDetNormalizeIter : public IIterator { } virtual bool Next(void) { - if (!this->Next_()) return false; + if (!this->Next_()) + return false; return true; } @@ -323,13 +316,14 @@ class ImageDetNormalizeIter : public IIterator { /*! \brief internal next function, inlined for fater processing. */ inline bool Next_(void) { - if (!base_->Next()) return false; - const DataInst &src = base_->Value(); + if (!base_->Next()) + return false; + const DataInst& src = base_->Value(); this->SetOutImg(src); out_.data.resize(2); - out_.data[0] = outimg_; - out_.data[1] = src.data[1]; - out_.index = src.index; + out_.data[0] = outimg_; + out_.data[1] = src.data[1]; + out_.index = src.index; out_.extra_data = src.extra_data; return true; } @@ -337,14 +331,14 @@ class ImageDetNormalizeIter : public IIterator { * \brief Set the output image, after augmentation and normalization. * \param src The source image. */ - inline void SetOutImg(const DataInst &src) { + inline void SetOutImg(const DataInst& src) { using namespace mshadow::expr; // NOLINT(*) mshadow::Tensor data = src.data[0].get(); outimg_.Resize(data.shape_); - if (param_.mean_r > 0.0f || param_.mean_g > 0.0f || - param_.mean_b > 0.0f || param_.mean_a > 0.0f) { + if (param_.mean_r > 0.0f || param_.mean_g > 0.0f || param_.mean_b > 0.0f || + param_.mean_a > 0.0f) { // subtract mean per channel data[0] -= param_.mean_r; if (data.shape_[0] >= 3) { @@ -401,9 +395,7 @@ class ImageDetNormalizeIter : public IIterator { TBlob tmp = meanimg_; { std::unique_ptr fo(dmlc::Stream::Create(param_.mean_img.c_str(), "w")); - NDArray::Save(fo.get(), - {NDArray(tmp, 0)}, - {"mean_img"}); + NDArray::Save(fo.get(), {NDArray(tmp, 0)}, {"mean_img"}); } if (param_.verbose) { LOG(INFO) << "Save mean image to " << param_.mean_img << ".."; diff --git a/src/io/iter_prefetcher.h b/src/io/iter_prefetcher.h index c416c9d7b9be..66e18e5d3e32 100644 --- a/src/io/iter_prefetcher.h +++ b/src/io/iter_prefetcher.h @@ -51,7 +51,7 @@ class PrefetcherIter : public IIterator { ~PrefetcherIter() { while (recycle_queue_.size() != 0) { - DataBatch *batch = recycle_queue_.front(); + DataBatch* batch = recycle_queue_.front(); recycle_queue_.pop(); delete batch; } @@ -75,48 +75,49 @@ class PrefetcherIter : public IIterator { // use the kwarg to init batch loader loader_->Init(kwargs); length_hint_ = loader_->GetLenHint(); - iter.Init([this](DataBatch **dptr) { - if (!loader_->Next()) return false; - const TBlobBatch& batch = loader_->Value(); - if (*dptr == nullptr) { - // allocate databatch - *dptr = new DataBatch(); - (*dptr)->num_batch_padd = batch.num_batch_padd; - (*dptr)->data.resize(batch.data.size()); - (*dptr)->index.resize(batch.batch_size); + iter.Init( + [this](DataBatch** dptr) { + if (!loader_->Next()) + return false; + const TBlobBatch& batch = loader_->Value(); + if (*dptr == nullptr) { + // allocate databatch + *dptr = new DataBatch(); + (*dptr)->num_batch_padd = batch.num_batch_padd; + (*dptr)->data.resize(batch.data.size()); + (*dptr)->index.resize(batch.batch_size); + for (size_t i = 0; i < batch.data.size(); ++i) { + auto dtype = param_.dtype ? param_.dtype.value() : batch.data[i].type_flag_; + auto ctx = ((param_.ctx == PrefetcherParam::kCPUPinned) && (param_.device_id >= 0)) + ? Context::CPUPinned(param_.device_id) + : Context::CPU(); + (*dptr)->data.at(i) = NDArray(batch.data[i].shape_, ctx, false, dtype); + } + } + CHECK(batch.data.size() == (*dptr)->data.size()); + // copy data over for (size_t i = 0; i < batch.data.size(); ++i) { - auto dtype = param_.dtype - ? param_.dtype.value() - : batch.data[i].type_flag_; - auto ctx = ((param_.ctx == PrefetcherParam::kCPUPinned) && (param_.device_id >= 0)) ? - Context::CPUPinned(param_.device_id) : Context::CPU(); - (*dptr)->data.at(i) = NDArray(batch.data[i].shape_, - ctx, false, - dtype); + if ((*dptr)->data.at(i).shape() != batch.data[i].shape_) { + // TODO(zhreshold): memory pool for dynamic shaped data + (*dptr)->data.at(i).ReshapeAndAlloc(batch.data[i].shape_); + } + CHECK_EQ((*dptr)->data.at(i).shape(), batch.data[i].shape_); + MSHADOW_TYPE_SWITCH(batch.data[i].type_flag_, DType, { + mshadow::Copy(((*dptr)->data)[i].data().FlatTo2D(), + batch.data[i].FlatTo2D()); + }); + (*dptr)->num_batch_padd = batch.num_batch_padd; } - } - CHECK(batch.data.size() == (*dptr)->data.size()); - // copy data over - for (size_t i = 0; i < batch.data.size(); ++i) { - if ((*dptr)->data.at(i).shape() != batch.data[i].shape_) { - // TODO(zhreshold): memory pool for dynamic shaped data - (*dptr)->data.at(i).ReshapeAndAlloc(batch.data[i].shape_); + if (batch.inst_index) { + std::copy( + batch.inst_index, batch.inst_index + batch.batch_size, (*dptr)->index.begin()); } - CHECK_EQ((*dptr)->data.at(i).shape(), batch.data[i].shape_); - MSHADOW_TYPE_SWITCH(batch.data[i].type_flag_, DType, { - mshadow::Copy(((*dptr)->data)[i].data().FlatTo2D(), - batch.data[i].FlatTo2D()); - }); - (*dptr)->num_batch_padd = batch.num_batch_padd; - } - if (batch.inst_index) { - std::copy(batch.inst_index, - batch.inst_index + batch.batch_size, - (*dptr)->index.begin()); - } - return true; - }, - [this]() { loader_->BeforeFirst(); length_hint_ = loader_->GetLenHint();}); + return true; + }, + [this]() { + loader_->BeforeFirst(); + length_hint_ = loader_->GetLenHint(); + }); } virtual void BeforeFirst(void) { @@ -129,11 +130,12 @@ class PrefetcherIter : public IIterator { virtual bool Next(void) { if (out_ != nullptr) { - recycle_queue_.push(out_); out_ = nullptr; + recycle_queue_.push(out_); + out_ = nullptr; } // do recycle if (recycle_queue_.size() == param_.prefetch_buffer) { - DataBatch *old_batch = recycle_queue_.front(); + DataBatch* old_batch = recycle_queue_.front(); // can be more efficient on engine for (NDArray& arr : old_batch->data) { arr.WaitToWrite(); @@ -143,7 +145,7 @@ class PrefetcherIter : public IIterator { } return iter.Next(&out_); } - virtual const DataBatch &Value(void) const { + virtual const DataBatch& Value(void) const { return *out_; } @@ -157,7 +159,7 @@ class PrefetcherIter : public IIterator { private: /*! \brief output data */ - DataBatch *out_; + DataBatch* out_; /*! \brief queue to be recycled */ std::queue recycle_queue_; /*! \brief size hint cache */ diff --git a/src/io/iter_sampler.cc b/src/io/iter_sampler.cc index 049347dfd9cf..d38b8244128a 100644 --- a/src/io/iter_sampler.cc +++ b/src/io/iter_sampler.cc @@ -42,10 +42,8 @@ struct SequentialSamplerParam : public dmlc::Parameter { int start; // declare parameters DMLC_DECLARE_PARAMETER(SequentialSamplerParam) { - DMLC_DECLARE_FIELD(length) - .describe("Length of the sequence."); - DMLC_DECLARE_FIELD(start).set_default(0) - .describe("Start of the index."); + DMLC_DECLARE_FIELD(length).describe("Length of the sequence."); + DMLC_DECLARE_FIELD(start).set_default(0).describe("Start of the index."); } }; // struct SequentialSamplerParam @@ -70,15 +68,20 @@ class SequentialSampler : public IIterator { bool Next() override { if (pos_ < indices_.size()) { - int64_t *ptr = indices_.data() + pos_; - out_.data[0] = TBlob(ptr, TShape({1, }), cpu::kDevMask, 0); + int64_t* ptr = indices_.data() + pos_; + out_.data[0] = TBlob(ptr, + TShape({ + 1, + }), + cpu::kDevMask, + 0); ++pos_; return true; } return false; } - const DataInst &Value() const override { + const DataInst& Value() const override { return out_; } @@ -94,23 +97,18 @@ class SequentialSampler : public IIterator { }; // class SequentialSampler MXNET_REGISTER_IO_ITER(SequentialSampler) -.describe(R"code(Returns the sequential sampler iterator. + .describe(R"code(Returns the sequential sampler iterator. )code" ADD_FILELINE) -.add_arguments(SequentialSamplerParam::__FIELDS__()) -.add_arguments(BatchSamplerParam::__FIELDS__()) -.set_body([]() { - return - new BatchSampler( - new SequentialSampler()); - }); + .add_arguments(SequentialSamplerParam::__FIELDS__()) + .add_arguments(BatchSamplerParam::__FIELDS__()) + .set_body([]() { return new BatchSampler(new SequentialSampler()); }); struct RandomSamplerParam : public dmlc::Parameter { /*! \brief Length of the sequence. */ size_t length; // declare parameters DMLC_DECLARE_PARAMETER(RandomSamplerParam) { - DMLC_DECLARE_FIELD(length) - .describe("Length of the sequence."); + DMLC_DECLARE_FIELD(length).describe("Length of the sequence."); } }; // struct RandomSamplerParam @@ -122,8 +120,9 @@ class RandomSampler : public IIterator { param_.InitAllowUnknown(kwargs); indices_.resize(param_.length); std::iota(std::begin(indices_), std::end(indices_), 0); // fill like arange - mshadow::Random *ctx_rng = ResourceManager::Get()->Request( - Context::CPU(), ResourceRequest::kRandom).get_random(nullptr); + mshadow::Random* ctx_rng = ResourceManager::Get() + ->Request(Context::CPU(), ResourceRequest::kRandom) + .get_random(nullptr); rng_ = std::make_unique(ctx_rng->GetSeed()); out_.data.resize(1); BeforeFirst(); @@ -140,17 +139,23 @@ class RandomSampler : public IIterator { bool Next() override { if (pos_ < indices_.size()) { - int64_t *ptr = indices_.data() + pos_; - out_.data[0] = TBlob(ptr, TShape({1, }), cpu::kDevMask, 0); + int64_t* ptr = indices_.data() + pos_; + out_.data[0] = TBlob(ptr, + TShape({ + 1, + }), + cpu::kDevMask, + 0); ++pos_; return true; } return false; } - const DataInst &Value() const override { + const DataInst& Value() const override { return out_; } + private: /*! \brief Stored integer indices */ std::vector indices_; @@ -165,14 +170,11 @@ class RandomSampler : public IIterator { }; // class RandomSampler MXNET_REGISTER_IO_ITER(RandomSampler) -.describe(R"code(Returns the random sampler iterator. + .describe(R"code(Returns the random sampler iterator. )code" ADD_FILELINE) -.add_arguments(RandomSamplerParam::__FIELDS__()) -.add_arguments(BatchSamplerParam::__FIELDS__()) -.set_body([]() { - return new BatchSampler( - new RandomSampler()); - }); + .add_arguments(RandomSamplerParam::__FIELDS__()) + .add_arguments(BatchSamplerParam::__FIELDS__()) + .set_body([]() { return new BatchSampler(new RandomSampler()); }); } // namespace io } // namespace mxnet diff --git a/src/io/iter_sparse.h b/src/io/iter_sparse.h index 22b1836be419..e7f2bc526b25 100644 --- a/src/io/iter_sparse.h +++ b/src/io/iter_sparse.h @@ -32,7 +32,7 @@ namespace mxnet { * \brief iterator type * \param DType data type */ -template +template class SparseIIterator : public IIterator { public: /*! \brief storage type of the data or label */ diff --git a/src/io/iter_sparse_batchloader.h b/src/io/iter_sparse_batchloader.h index c0d856df89ec..02e38fca4154 100644 --- a/src/io/iter_sparse_batchloader.h +++ b/src/io/iter_sparse_batchloader.h @@ -42,15 +42,14 @@ namespace io { /*! \brief create a batch iterator from single instance iterator */ class SparseBatchLoader : public BatchLoader, public SparseIIterator { public: - explicit SparseBatchLoader(SparseIIterator *base): - BatchLoader(base), sparse_base_(base) { - } + explicit SparseBatchLoader(SparseIIterator* base) + : BatchLoader(base), sparse_base_(base) {} virtual ~SparseBatchLoader(void) {} inline void Init(const std::vector >& kwargs) { BatchLoader::Init(kwargs); - data_stype_ = sparse_base_->GetStorageType(true); + data_stype_ = sparse_base_->GetStorageType(true); label_stype_ = sparse_base_->GetStorageType(false); if (param_.round_batch == 0) { LOG(FATAL) << "sparse batch loader doesn't support round_batch == false yet"; @@ -63,18 +62,21 @@ class SparseBatchLoader : public BatchLoader, public SparseIIterator virtual bool Next(void) { out_.num_batch_padd = 0; - out_.batch_size = param_.batch_size; - this->head_ = 0; + out_.batch_size = param_.batch_size; + this->head_ = 0; // if overflown from previous round, directly return false, until before first is called - if (num_overflow_ != 0) return false; + if (num_overflow_ != 0) + return false; size_t top = 0; offsets_.clear(); while (sparse_base_->Next()) { const DataInst& inst = sparse_base_->Value(); // initialize the data buffer, only called once - if (data_.size() == 0) this->InitData(inst); + if (data_.size() == 0) + this->InitData(inst); // initialize the number of elements in each buffer, called once per batch - if (offsets_.size() == 0) offsets_.resize(inst.data.size(), 0); + if (offsets_.size() == 0) + offsets_.resize(inst.data.size(), 0); CopyData(inst, top); if (++top >= param_.batch_size) { SetOutputShape(); @@ -83,7 +85,7 @@ class SparseBatchLoader : public BatchLoader, public SparseIIterator } if (top != 0) { CHECK_NE(param_.round_batch, 0) - << "round_batch = False is not supported for sparse data iterator"; + << "round_batch = False is not supported for sparse data iterator"; num_overflow_ = 0; sparse_base_->BeforeFirst(); for (; top < param_.batch_size; ++top, ++num_overflow_) { @@ -100,7 +102,7 @@ class SparseBatchLoader : public BatchLoader, public SparseIIterator return false; } - virtual const TBlobBatch &Value(void) const { + virtual const TBlobBatch& Value(void) const { return BatchLoader::Value(); } @@ -120,7 +122,7 @@ class SparseBatchLoader : public BatchLoader, public SparseIIterator private: /*! \brief base sparse iterator */ - SparseIIterator *sparse_base_; + SparseIIterator* sparse_base_; /*! \brief data storage type */ NDArrayStorageType data_stype_; /*! \brief data label type */ @@ -134,16 +136,15 @@ class SparseBatchLoader : public BatchLoader, public SparseIIterator // check whether ith position is the indptr tensor for a CSR tensor inline bool IsIndPtr(size_t i) { - auto data_num_aux = num_aux_data(data_stype_); - auto label_num_aux = num_aux_data(label_stype_); + auto data_num_aux = num_aux_data(data_stype_); + auto label_num_aux = num_aux_data(label_stype_); auto label_indptr_offset = data_num_aux + 1 + label_num_aux; // data indptr if (i == data_num_aux && data_stype_ == kCSRStorage) { return true; } // label indptr - if (i == label_indptr_offset && label_stype_ == kCSRStorage && - data_stype_ == kCSRStorage) { + if (i == label_indptr_offset && label_stype_ == kCSRStorage && data_stype_ == kCSRStorage) { return true; } return false; @@ -173,11 +174,11 @@ class SparseBatchLoader : public BatchLoader, public SparseIIterator // shape for indptr if (IsIndPtr(i)) { buff_sizes[i] = param_.batch_size + 1; - indptr_[i] = true; + indptr_[i] = true; } else { // estimated the size for the whole batch based on the first instance buff_sizes[i] = first_inst.data[i].Size() * param_.batch_size; - indptr_[i] = false; + indptr_[i] = false; } dtypes_[i] = first_inst.data[i].type_flag_; } @@ -195,8 +196,7 @@ class SparseBatchLoader : public BatchLoader, public SparseIIterator /* \brief set the shape of the outputs based on actual shapes */ inline void SetOutputShape() { for (size_t i = 0; i < out_.data.size(); i++) { - out_.data[i] = TBlob(data_[i].dptr_, mshadow::Shape1(offsets_[i]), - Context::kCPU, dtypes_[i]); + out_.data[i] = TBlob(data_[i].dptr_, mshadow::Shape1(offsets_[i]), Context::kCPU, dtypes_[i]); } } @@ -208,8 +208,8 @@ class SparseBatchLoader : public BatchLoader, public SparseIIterator mshadow::Copy(temp.get(), data_[i].get().Slice(0, src_size)); // increase the size of space exponentially size_t capacity = data_[i].Size(); - capacity = capacity * 2 + 1; - data_[i] = TBlobContainer(); + capacity = capacity * 2 + 1; + data_[i] = TBlobContainer(); data_[i].resize(mshadow::Shape1(capacity), dtypes_[i]); // copy back mshadow::Copy(data_[i].get().Slice(0, src_size), temp.get()); @@ -218,7 +218,7 @@ class SparseBatchLoader : public BatchLoader, public SparseIIterator /* \brief copy the data instance to data buffer */ void CopyData(const DataInst& inst, const size_t top) { - int64_t unit_size = 0; + int64_t unit_size = 0; out_.inst_index[top] = inst.index; for (size_t i = 0; i < inst.data.size(); ++i) { if (!indptr_[i]) { @@ -226,8 +226,8 @@ class SparseBatchLoader : public BatchLoader, public SparseIIterator unit_size = inst.data[i].shape_.Size(); MSHADOW_TYPE_SWITCH(data_[i].type_flag_, DType, { const size_t begin = offsets_[i]; - const size_t end = offsets_[i] + unit_size; - size_t capacity = data_[i].Size(); + const size_t end = offsets_[i] + unit_size; + size_t capacity = data_[i].Size(); // resize the data buffer if estimated space is not sufficient while (capacity < end) { ResizeBuffer(begin, i); @@ -241,9 +241,10 @@ class SparseBatchLoader : public BatchLoader, public SparseIIterator // indptr placeholder auto indptr = data_[i].get(); // initialize the first indptr, which is always 0 - if (top == 0) indptr[0] = 0; + if (top == 0) + indptr[0] = 0; indptr[top + 1] = indptr[top] + unit_size; - offsets_[i] = top + 2; + offsets_[i] = top + 2; } } } diff --git a/src/io/iter_sparse_prefetcher.h b/src/io/iter_sparse_prefetcher.h index 3f06052b0292..d6e258112c81 100644 --- a/src/io/iter_sparse_prefetcher.h +++ b/src/io/iter_sparse_prefetcher.h @@ -56,67 +56,68 @@ class SparsePrefetcherIter : public PrefetcherIter { PrefetcherIter::InitParams(kwargs); // use the kwarg to init batch loader sparse_loader_->Init(kwargs); - iter.Init([this](DataBatch **dptr) { - if (!sparse_loader_->Next()) return false; - const TBlobBatch& batch = sparse_loader_->Value(); - if (*dptr == nullptr) { - // allocate databatch - *dptr = new DataBatch(); - (*dptr)->num_batch_padd = batch.num_batch_padd; - // (*dptr)->data.at(0) => data - // (*dptr)->data.at(1) => label - (*dptr)->data.resize(2); - (*dptr)->index.resize(batch.batch_size); + iter.Init( + [this](DataBatch** dptr) { + if (!sparse_loader_->Next()) + return false; + const TBlobBatch& batch = sparse_loader_->Value(); + if (*dptr == nullptr) { + // allocate databatch + *dptr = new DataBatch(); + (*dptr)->num_batch_padd = batch.num_batch_padd; + // (*dptr)->data.at(0) => data + // (*dptr)->data.at(1) => label + (*dptr)->data.resize(2); + (*dptr)->index.resize(batch.batch_size); + size_t data_iter = 0; + for (size_t i = 0; i < (*dptr)->data.size(); ++i) { + bool is_data = i == 0; + auto stype = this->GetStorageType(is_data); + auto dtype = param_.dtype ? param_.dtype.value() : batch.data[data_iter].type_flag_; + if (stype == kDefaultStorage) { + (*dptr)->data.at(i) = + NDArray(batch.data[data_iter].shape_, Context::CPU(), false, dtype); + } else { + (*dptr)->data.at(i) = + NDArray(stype, this->GetShape(is_data), Context::CPU(), false, dtype); + } + data_iter += num_aux_data(stype) + 1; + } + } + // copy data over size_t data_iter = 0; for (size_t i = 0; i < (*dptr)->data.size(); ++i) { - bool is_data = i == 0; - auto stype = this->GetStorageType(is_data); - auto dtype = param_.dtype ? param_.dtype.value() : batch.data[data_iter].type_flag_; + auto& nd = ((*dptr)->data)[i]; + auto stype = nd.storage_type(); + auto& data_i = ((*dptr)->data)[i]; if (stype == kDefaultStorage) { - (*dptr)->data.at(i) = NDArray(batch.data[data_iter].shape_, - Context::CPU(), false, dtype); + CopyFromTo(data_i.data(), batch.data[data_iter]); + } else if (stype == kCSRStorage) { + auto& values = batch.data[data_iter]; + auto& indices = batch.data[data_iter + 1]; + auto& indptr = batch.data[data_iter + 2]; + // allocate memory + CHECK_EQ(indices.shape_.Size(), values.shape_.Size()); + nd.CheckAndAllocAuxData(csr::kIdx, indices.shape_); + nd.CheckAndAllocData(values.shape_); + nd.CheckAndAllocAuxData(csr::kIndPtr, indptr.shape_); + // copy values, indices and indptr + CopyFromTo(data_i.data(), values); + CopyFromTo(data_i.aux_data(csr::kIdx), indices); + CopyFromTo(data_i.aux_data(csr::kIndPtr), indptr); } else { - (*dptr)->data.at(i) = NDArray(stype, this->GetShape(is_data), - Context::CPU(), false, dtype); + LOG(FATAL) << "Storage type not implemented: " << stype; } data_iter += num_aux_data(stype) + 1; + (*dptr)->num_batch_padd = batch.num_batch_padd; } - } - // copy data over - size_t data_iter = 0; - for (size_t i = 0; i < (*dptr)->data.size(); ++i) { - auto& nd = ((*dptr)->data)[i]; - auto stype = nd.storage_type(); - auto& data_i = ((*dptr)->data)[i]; - if (stype == kDefaultStorage) { - CopyFromTo(data_i.data(), batch.data[data_iter]); - } else if (stype == kCSRStorage) { - auto& values = batch.data[data_iter]; - auto& indices = batch.data[data_iter + 1]; - auto& indptr = batch.data[data_iter + 2]; - // allocate memory - CHECK_EQ(indices.shape_.Size(), values.shape_.Size()); - nd.CheckAndAllocAuxData(csr::kIdx, indices.shape_); - nd.CheckAndAllocData(values.shape_); - nd.CheckAndAllocAuxData(csr::kIndPtr, indptr.shape_); - // copy values, indices and indptr - CopyFromTo(data_i.data(), values); - CopyFromTo(data_i.aux_data(csr::kIdx), indices); - CopyFromTo(data_i.aux_data(csr::kIndPtr), indptr); - } else { - LOG(FATAL) << "Storage type not implemented: " << stype; + if (batch.inst_index) { + std::copy( + batch.inst_index, batch.inst_index + batch.batch_size, (*dptr)->index.begin()); } - data_iter += num_aux_data(stype) + 1; - (*dptr)->num_batch_padd = batch.num_batch_padd; - } - if (batch.inst_index) { - std::copy(batch.inst_index, - batch.inst_index + batch.batch_size, - (*dptr)->index.begin()); - } - return true; - }, - [this]() { sparse_loader_->BeforeFirst(); }); + return true; + }, + [this]() { sparse_loader_->BeforeFirst(); }); } virtual void BeforeFirst(void) { @@ -126,7 +127,7 @@ class SparsePrefetcherIter : public PrefetcherIter { virtual bool Next(void) { return PrefetcherIter::Next(); } - virtual const DataBatch &Value(void) const { + virtual const DataBatch& Value(void) const { return PrefetcherIter::Value(); } diff --git a/src/io/opencv_compatibility.h b/src/io/opencv_compatibility.h index 7f42328497ed..76512258dff8 100644 --- a/src/io/opencv_compatibility.h +++ b/src/io/opencv_compatibility.h @@ -45,12 +45,12 @@ #define CV_RGB2BGR cv::COLOR_RGB2BGR #define CV_BGR2RGB cv::COLOR_BGR2RGB -#define CV_INTER_LINEAR cv::INTER_LINEAR +#define CV_INTER_LINEAR cv::INTER_LINEAR #define CV_INTER_NEAREST cv::INTER_NEAREST -#define CV_LOAD_IMAGE_COLOR cv::IMREAD_COLOR +#define CV_LOAD_IMAGE_COLOR cv::IMREAD_COLOR #define CV_IMWRITE_PNG_COMPRESSION cv::IMWRITE_PNG_COMPRESSION -#define CV_IMWRITE_JPEG_QUALITY cv::IMWRITE_JPEG_QUALITY +#define CV_IMWRITE_JPEG_QUALITY cv::IMWRITE_JPEG_QUALITY #endif // CV_VERSION_MAJOR >= 4 diff --git a/src/ir/expr.cc b/src/ir/expr.cc index 75d76edfff5d..1e71daa50b0f 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -28,26 +28,23 @@ namespace mxnet { IntImm::IntImm(MXNetDataType dtype, int64_t value) { - CHECK(dtype.is_scalar()) - << "ValueError: IntImm can only take scalar."; - CHECK(dtype.is_int() || dtype.is_uint()) - << "ValueError: IntImm can only take scalar."; + CHECK(dtype.is_scalar()) << "ValueError: IntImm can only take scalar."; + CHECK(dtype.is_int() || dtype.is_uint()) << "ValueError: IntImm can only take scalar."; if (dtype.is_uint()) { CHECK_GE(value, 0U); } runtime::ObjectPtr node = make_object(); - node->dtype = dtype; - node->value = value; - data_ = std::move(node); + node->dtype = dtype; + node->value = value; + data_ = std::move(node); } FloatImm::FloatImm(MXNetDataType dtype, double value) { - CHECK_EQ(dtype.lanes(), 1) - << "ValueError: FloatImm can only take scalar."; + CHECK_EQ(dtype.lanes(), 1) << "ValueError: FloatImm can only take scalar."; runtime::ObjectPtr node = make_object(); - node->dtype = dtype; - node->value = value; - data_ = std::move(node); + node->dtype = dtype; + node->value = value; + data_ = std::move(node); } } // namespace mxnet diff --git a/src/kvstore/comm.h b/src/kvstore/comm.h index b03b74c73bad..4f4edcafac80 100644 --- a/src/kvstore/comm.h +++ b/src/kvstore/comm.h @@ -46,23 +46,25 @@ class Comm { Comm() { pinned_ctx_ = Context::CPUPinned(0); } - virtual ~Comm() { } + virtual ~Comm() {} /** * \brief init key with the data shape and storage shape */ - virtual void Init(int key, const NDArrayStorageType stype, - const mxnet::TShape& shape, int dtype = mshadow::kFloat32) = 0; + virtual void Init(int key, + const NDArrayStorageType stype, + const mxnet::TShape& shape, + int dtype = mshadow::kFloat32) = 0; /** * \brief returns src[0] + .. + src[src.size()-1] */ - virtual const NDArray& Reduce( - int key, const std::vector& src, int priority) = 0; + virtual const NDArray& Reduce(int key, const std::vector& src, int priority) = 0; /** * \brief copy from src to dst[i] for every i */ - virtual void Broadcast( - int key, const NDArray& src, - const std::vector dst, int priority) = 0; + virtual void Broadcast(int key, + const NDArray& src, + const std::vector dst, + int priority) = 0; /** * \brief broadcast src to dst[i] with target row_ids for every i @@ -72,7 +74,8 @@ class Comm { where the row_ids are expected to be unique and sorted in row_id.data() * \param priority the priority of the operation */ - virtual void BroadcastRowSparse(int key, const NDArray& src, + virtual void BroadcastRowSparse(int key, + const NDArray& src, const std::vector>& dst, const int priority) = 0; @@ -105,23 +108,24 @@ class CommCPU : public Comm { public: CommCPU() { nthread_reduction_ = dmlc::GetEnv("MXNET_KVSTORE_REDUCTION_NTHREADS", 4); - bigarray_bound_ = dmlc::GetEnv("MXNET_KVSTORE_BIGARRAY_BOUND", 1000 * 1000); + bigarray_bound_ = dmlc::GetEnv("MXNET_KVSTORE_BIGARRAY_BOUND", 1000 * 1000); // TODO(junwu) delete the following data member, now for benchmark only is_serial_push_ = dmlc::GetEnv("MXNET_KVSTORE_SERIAL_PUSH", 0); } - virtual ~CommCPU() { } + virtual ~CommCPU() {} - void Init(int key, const NDArrayStorageType stype, const mxnet::TShape& shape, + void Init(int key, + const NDArrayStorageType stype, + const mxnet::TShape& shape, int type = mshadow::kFloat32) override { // Delayed allocation - the dense merged buffer might not be used at all if push() // only sees sparse arrays - bool delay_alloc = true; + bool delay_alloc = true; merge_buf_[key].merged = NDArray(shape, pinned_ctx_, delay_alloc, type); } - const NDArray& Reduce(int key, const std::vector& src, - int priority) override { - auto& buf = merge_buf_[key]; + const NDArray& Reduce(int key, const std::vector& src, int priority) override { + auto& buf = merge_buf_[key]; const auto stype = src[0].storage_type(); // avoid extra copy for single device, but it may bring problems for // abnormal usage of kvstore @@ -148,28 +152,32 @@ class CommCPU : public Comm { reduce[0] = buf_merged; if (buf.copy_buf.empty()) { - buf.copy_buf.resize(src.size()-1); + buf.copy_buf.resize(src.size() - 1); for (size_t j = 0; j < src.size() - 1; ++j) { // allocate copy buffer - buf.copy_buf[j] = NDArray( - src[0].shape(), pinned_ctx_, false, src[0].dtype()); + buf.copy_buf[j] = NDArray(src[0].shape(), pinned_ctx_, false, src[0].dtype()); } } CHECK(stype == buf.copy_buf[0].storage_type()) - << "Storage type mismatch detected. " << stype << "(src) vs. " - << buf.copy_buf[0].storage_type() << "(buf.copy_buf)"; + << "Storage type mismatch detected. " << stype << "(src) vs. " + << buf.copy_buf[0].storage_type() << "(buf.copy_buf)"; for (size_t i = 1; i < src.size(); ++i) { - CopyFromTo(src[i], &(buf.copy_buf[i-1]), priority); - reduce[i] = buf.copy_buf[i-1]; - const_vars[i-1] = reduce[i].var(); + CopyFromTo(src[i], &(buf.copy_buf[i - 1]), priority); + reduce[i] = buf.copy_buf[i - 1]; + const_vars[i - 1] = reduce[i].var(); } Engine::Get()->PushAsync( - [reduce, this](RunContext rctx, Engine::CallbackOnComplete on_complete) { - ReduceSumCPU(reduce); - on_complete(); - }, Context::CPU(), const_vars, {reduce[0].var()}, - FnProperty::kCPUPrioritized, priority, "KVStoreReduce"); + [reduce, this](RunContext rctx, Engine::CallbackOnComplete on_complete) { + ReduceSumCPU(reduce); + on_complete(); + }, + Context::CPU(), + const_vars, + {reduce[0].var()}, + FnProperty::kCPUPrioritized, + priority, + "KVStoreReduce"); } else { // sparse reduce @@ -179,39 +187,47 @@ class CommCPU : public Comm { if (buf.copy_buf.empty()) { buf.copy_buf.resize(src.size()); for (size_t j = 0; j < src.size(); ++j) { - buf.copy_buf[j] = NDArray( - src[0].storage_type(), src[0].shape(), pinned_ctx_, true, src[0].dtype()); + buf.copy_buf[j] = + NDArray(src[0].storage_type(), src[0].shape(), pinned_ctx_, true, src[0].dtype()); } } CHECK(stype == buf.copy_buf[0].storage_type()) - << "Storage type mismatch detected. " << stype << "(src) vs. " - << buf.copy_buf[0].storage_type() << "(buf.copy_buf)"; + << "Storage type mismatch detected. " << stype << "(src) vs. " + << buf.copy_buf[0].storage_type() << "(buf.copy_buf)"; for (size_t i = 0; i < src.size(); ++i) { CopyFromTo(src[i], &(buf.copy_buf[i]), priority); - reduce[i] = buf.copy_buf[i]; + reduce[i] = buf.copy_buf[i]; const_vars[i] = reduce[i].var(); } Resource rsc = ResourceManager::Get()->Request(buf_merged.ctx(), - ResourceRequest(ResourceRequest::kTempSpace)); + ResourceRequest(ResourceRequest::kTempSpace)); Engine::Get()->PushAsync( - [reduce, buf_merged, rsc, this](RunContext rctx, Engine::CallbackOnComplete on_complete) { - NDArray out = buf_merged; - is_serial_push_? - ReduceSumCPUExSerial(reduce, &out) - : mxnet::ndarray::ElementwiseSum(rctx.get_stream(), rsc, reduce, &out); - on_complete(); - }, Context::CPU(), const_vars, {buf_merged.var(), rsc.var}, - FnProperty::kCPUPrioritized, priority, "KVStoreReduce"); + [reduce, buf_merged, rsc, this](RunContext rctx, Engine::CallbackOnComplete on_complete) { + NDArray out = buf_merged; + is_serial_push_ + ? ReduceSumCPUExSerial(reduce, &out) + : mxnet::ndarray::ElementwiseSum(rctx.get_stream(), rsc, reduce, &out); + on_complete(); + }, + Context::CPU(), + const_vars, + {buf_merged.var(), rsc.var}, + FnProperty::kCPUPrioritized, + priority, + "KVStoreReduce"); } return buf_merged; } - void Broadcast(int key, const NDArray& src, - const std::vector dst, int priority) override { + void Broadcast(int key, + const NDArray& src, + const std::vector dst, + int priority) override { int mask = src.ctx().dev_mask(); if (mask == Context::kCPU) { - for (auto d : dst) CopyFromTo(src, d, priority); + for (auto d : dst) + CopyFromTo(src, d, priority); } else { // First copy data to pinned_ctx, then broadcast. // Note that kv.init initializes the data on pinned_ctx. @@ -220,31 +236,35 @@ class CommCPU : public Comm { // Also indicates that buffers are already initialized during push(). auto& buf = merge_buf_[key].merged_buf(src.storage_type()); CopyFromTo(src, &buf, priority); - for (auto d : dst) CopyFromTo(buf, d, priority); + for (auto d : dst) + CopyFromTo(buf, d, priority); } } - void BroadcastRowSparse(int key, const NDArray& src, + void BroadcastRowSparse(int key, + const NDArray& src, const std::vector>& dst, const int priority) override { using namespace mshadow; CHECK_EQ(src.storage_type(), kRowSparseStorage) - << "BroadcastRowSparse expects row-sparse src NDArray"; + << "BroadcastRowSparse expects row-sparse src NDArray"; CHECK_EQ(src.ctx().dev_mask(), Context::kCPU) - << "BroadcastRowSparse with src on gpu context not supported"; + << "BroadcastRowSparse with src on gpu context not supported"; for (const auto& dst_kv : dst) { - NDArray* out = dst_kv.first; + NDArray* out = dst_kv.first; NDArray row_id = dst_kv.second; CHECK_EQ(out->storage_type(), kRowSparseStorage) - << "BroadcastRowSparse expects row_sparse dst NDArray"; + << "BroadcastRowSparse expects row_sparse dst NDArray"; CHECK_EQ(row_id.ctx().dev_mask(), Context::kCPU) - << "BroadcastRowSparse with row_indices on gpu context not supported"; + << "BroadcastRowSparse with row_indices on gpu context not supported"; // retain according to unique indices const bool is_same_ctx = out->ctx() == src.ctx(); const bool is_diff_var = out->var() != src.var(); - NDArray retained_cpu = (is_same_ctx && is_diff_var) ? *out : - NDArray(kRowSparseStorage, src.shape(), src.ctx(), true, - src.dtype(), src.aux_types()); + NDArray retained_cpu = + (is_same_ctx && is_diff_var) + ? *out + : NDArray( + kRowSparseStorage, src.shape(), src.ctx(), true, src.dtype(), src.aux_types()); if (!is_diff_var) { common::LogOnce("The output of row_sparse_pull() on key " + std::to_string(key) + "refers to the same NDArray as the one stored in KVStore." @@ -254,15 +274,19 @@ class CommCPU : public Comm { "consider create a new NDArray buffer to store the output."); } Engine::Get()->PushAsync( - [=](RunContext rctx, Engine::CallbackOnComplete on_complete) { - const TBlob& indices = row_id.data(); - NDArray temp = retained_cpu; // get rid the of const qualifier - op::SparseRetainOpForwardRspImpl(rctx.get_stream(), - src, indices, kWriteTo, - &temp); - on_complete(); - }, Context::CPU(), {src.var(), row_id.var()}, {retained_cpu.var()}, - FnProperty::kNormal, priority, "KVStoreSparseRetain"); + [=](RunContext rctx, Engine::CallbackOnComplete on_complete) { + const TBlob& indices = row_id.data(); + NDArray temp = retained_cpu; // get rid the of const qualifier + op::SparseRetainOpForwardRspImpl( + rctx.get_stream(), src, indices, kWriteTo, &temp); + on_complete(); + }, + Context::CPU(), + {src.var(), row_id.var()}, + {retained_cpu.var()}, + FnProperty::kNormal, + priority, + "KVStoreSparseRetain"); // if retained_cpu == out, CopyFromTo will ignore the copy operation CopyFromTo(retained_cpu, out, priority); } @@ -270,7 +294,7 @@ class CommCPU : public Comm { private: // reduce sum into val[0] - inline void ReduceSumCPU(const std::vector &in_data) { + inline void ReduceSumCPU(const std::vector& in_data) { MSHADOW_TYPE_SWITCH(in_data[0].dtype(), DType, { std::vector dptr(in_data.size()); for (size_t i = 0; i < in_data.size(); ++i) { @@ -284,13 +308,13 @@ class CommCPU : public Comm { } // serial implementation of reduce sum for row sparse NDArray. - inline void ReduceSumCPUExSerial(const std::vector &in, NDArray *out) { + inline void ReduceSumCPUExSerial(const std::vector& in, NDArray* out) { using namespace rowsparse; using namespace mshadow; auto stype = out->storage_type(); CHECK_EQ(stype, kRowSparseStorage) << "Unexpected storage type " << stype; size_t total_num_rows = 0; - size_t num_in = in.size(); + size_t num_in = in.size(); // skip the ones with empty indices and values std::vector skip(num_in, false); // the values tensor of the inputs @@ -306,10 +330,10 @@ class CommCPU : public Comm { skip[i] = true; continue; } - auto size = in[i].aux_shape(kIdx).Size(); + auto size = in[i].aux_shape(kIdx).Size(); num_rows[i] = size; total_num_rows += size; - in_vals[i] = in[i].data().FlatTo2D(); + in_vals[i] = in[i].data().FlatTo2D(); in_indices[i] = in[i].aux_data(kIdx).FlatTo1D(); } std::vector indices; @@ -334,9 +358,10 @@ class CommCPU : public Comm { for (size_t i = 0; i < nnr; i++) { // copy indices back idx_data[i] = indices[i]; - bool zeros = true; + bool zeros = true; for (size_t j = 0; j < num_in; j++) { - if (skip[j]) continue; + if (skip[j]) + continue; size_t offset = offsets[j]; if (offset < num_rows[j]) { if (indices[i] == in_indices[j][offset]) { @@ -355,12 +380,11 @@ class CommCPU : public Comm { }); } - template - inline static void ReduceSumCPU( - const std::vector &dptr, size_t offset, index_t size) { + template + inline static void ReduceSumCPU(const std::vector& dptr, size_t offset, index_t size) { using namespace mshadow; // NOLINT(*) Tensor in_0(dptr[0] + offset, Shape1(size)); - for (size_t i = 1; i < dptr.size(); i+=4) { + for (size_t i = 1; i < dptr.size(); i += 4) { switch (dptr.size() - i) { case 1: { Tensor in_1(dptr[i] + offset, Shape1(size)); @@ -369,22 +393,22 @@ class CommCPU : public Comm { } case 2: { Tensor in_1(dptr[i] + offset, Shape1(size)); - Tensor in_2(dptr[i+1] + offset, Shape1(size)); + Tensor in_2(dptr[i + 1] + offset, Shape1(size)); in_0 += in_1 + in_2; break; } case 3: { Tensor in_1(dptr[i] + offset, Shape1(size)); - Tensor in_2(dptr[i+1] + offset, Shape1(size)); - Tensor in_3(dptr[i+2] + offset, Shape1(size)); + Tensor in_2(dptr[i + 1] + offset, Shape1(size)); + Tensor in_3(dptr[i + 2] + offset, Shape1(size)); in_0 += in_1 + in_2 + in_3; break; } default: { Tensor in_1(dptr[i] + offset, Shape1(size)); - Tensor in_2(dptr[i+1] + offset, Shape1(size)); - Tensor in_3(dptr[i+2] + offset, Shape1(size)); - Tensor in_4(dptr[i+3] + offset, Shape1(size)); + Tensor in_2(dptr[i + 1] + offset, Shape1(size)); + Tensor in_3(dptr[i + 2] + offset, Shape1(size)); + Tensor in_4(dptr[i + 3] + offset, Shape1(size)); in_0 += in_1 + in_2 + in_3 + in_4; break; } @@ -392,19 +416,20 @@ class CommCPU : public Comm { } } - template + template inline void ReduceSumCPUImpl(std::vector dptr, size_t total) { const size_t step = std::min(bigarray_bound_, static_cast(4 << 10)); - long ntask = (total + step - 1) / step; // NOLINT(*) + long ntask = (total + step - 1) / step; // NOLINT(*) if (total < bigarray_bound_ || nthread_reduction_ <= 1) { ReduceSumCPU(dptr, 0, total); } else { - #pragma omp parallel for schedule(static) num_threads(nthread_reduction_) - for (long j = 0; j < ntask; ++j) { // NOLINT(*) - size_t k = static_cast(j); +#pragma omp parallel for schedule(static) num_threads(nthread_reduction_) + for (long j = 0; j < ntask; ++j) { // NOLINT(*) + size_t k = static_cast(j); size_t begin = std::min(k * step, total); - size_t end = std::min((k + 1) * step, total); - if (j == ntask - 1) CHECK_EQ(end, total); + size_t end = std::min((k + 1) * step, total); + if (j == ntask - 1) + CHECK_EQ(end, total); ReduceSumCPU(dptr, begin, static_cast(end - begin)); } } @@ -425,8 +450,8 @@ class CommCPU : public Comm { // check if sparse_merged is initialized if (sparse_merged.is_none()) { CHECK(!merged.is_none()); - sparse_merged = NDArray(kRowSparseStorage, merged.shape(), merged.ctx(), - true, merged.dtype()); + sparse_merged = + NDArray(kRowSparseStorage, merged.shape(), merged.ctx(), true, merged.dtype()); } return sparse_merged; } @@ -455,9 +480,11 @@ class CommDevice : public Comm { inited_ = false; } - virtual ~CommDevice() { } + virtual ~CommDevice() {} - void Init(int key, const NDArrayStorageType stype, const mxnet::TShape& shape, + void Init(int key, + const NDArrayStorageType stype, + const mxnet::TShape& shape, int dtype = mshadow::kFloat32) override { sorted_key_attrs_.emplace_back(key, shape, dtype); inited_ = false; @@ -476,13 +503,12 @@ class CommDevice : public Comm { } } - const NDArray& ReduceRowSparse(int key, const std::vector& src, - int priority) { + const NDArray& ReduceRowSparse(int key, const std::vector& src, int priority) { auto& buf = merge_buf_[key]; std::vector reduce(src.size()); const NDArrayStorageType stype = src[0].storage_type(); - NDArray& buf_merged = buf.merged_buf(stype); + NDArray& buf_merged = buf.merged_buf(stype); if (buf.copy_buf.empty()) { // initialize buffer for copying during reduce buf.copy_buf.resize(src.size()); @@ -491,8 +517,8 @@ class CommDevice : public Comm { } } CHECK(src[0].storage_type() == buf.copy_buf[0].storage_type()) - << "Storage type mismatch detected. " << src[0].storage_type() << "(src) vs. " - << buf.copy_buf[0].storage_type() << "(buf.copy_buf)"; + << "Storage type mismatch detected. " << src[0].storage_type() << "(src) vs. " + << buf.copy_buf[0].storage_type() << "(buf.copy_buf)"; for (size_t i = 0; i < src.size(); ++i) { CopyFromTo(src[i], &(buf.copy_buf[i]), priority); reduce[i] = buf.copy_buf[i]; @@ -501,8 +527,7 @@ class CommDevice : public Comm { return buf_merged; } - const NDArray& Reduce(int key, const std::vector& src, - int priority) override { + const NDArray& Reduce(int key, const std::vector& src, int priority) override { // when this reduce is called from kvstore_dist, gc is not set // we don't do compression twice in dist_sync_device if ((gc_ != nullptr) && (gc_->get_type() != CompressionType::kNone)) { @@ -519,7 +544,7 @@ class CommDevice : public Comm { auto& buf = merge_buf_[key]; const NDArrayStorageType stype = src[0].storage_type(); - NDArray& buf_merged = buf.merged_buf(stype); + NDArray& buf_merged = buf.merged_buf(stype); // normal dense reduce if (stype == kDefaultStorage) { CopyFromTo(src[0], &buf_merged, priority); @@ -532,18 +557,18 @@ class CommDevice : public Comm { // such as the largest fullc in VGG. consider to do segment reduce with // NDArray.Slice or gpu direct memory access. for the latter, we need to // remove some ctx check, and also it reduces 20% perf - buf.copy_buf.resize(src.size()-1); + buf.copy_buf.resize(src.size() - 1); const std::string profiler_scope = profiler::ProfilerScope::Get()->GetCurrentProfilerScope() + "comm_dev:"; - for (size_t i = 0; i < src.size()-1; ++i) { - buf.copy_buf[i] = NDArray( - buf_merged.shape(), buf_merged.ctx(), false, buf_merged.dtype()); + for (size_t i = 0; i < src.size() - 1; ++i) { + buf.copy_buf[i] = + NDArray(buf_merged.shape(), buf_merged.ctx(), false, buf_merged.dtype()); buf.copy_buf[i].AssignStorageInfo(profiler_scope, "copy_buf"); } } - for (size_t i = 0; i < src.size()-1; ++i) { - CopyFromTo(src[i+1], &(buf.copy_buf[i]), priority); - reduce[i+1] = buf.copy_buf[i]; + for (size_t i = 0; i < src.size() - 1; ++i) { + CopyFromTo(src[i + 1], &(buf.copy_buf[i]), priority); + reduce[i + 1] = buf.copy_buf[i]; } ElementwiseSum(reduce, &buf_merged, priority); } else { @@ -553,8 +578,7 @@ class CommDevice : public Comm { return buf_merged; } - const NDArray& ReduceCompressed(int key, const std::vector& src, - int priority) { + const NDArray& ReduceCompressed(int key, const std::vector& src, int priority) { InitBuffersAndComm(src); auto& buf = merge_buf_[key]; std::vector reduce(src.size()); @@ -567,19 +591,17 @@ class CommDevice : public Comm { const std::string profiler_scope = profiler::ProfilerScope::Get()->GetCurrentProfilerScope() + "comm_dev:"; for (size_t i = 0; i < src.size(); ++i) { - buf.copy_buf[i] = NDArray(buf.merged.shape(), buf.merged.ctx(), - false, buf.merged.dtype()); + buf.copy_buf[i] = NDArray(buf.merged.shape(), buf.merged.ctx(), false, buf.merged.dtype()); buf.copy_buf[i].AssignStorageInfo(profiler_scope, "copy_buf"); - buf.residual[i] = NDArray(buf.merged.shape(), src[i].ctx(), - false, buf.merged.dtype()); + buf.residual[i] = NDArray(buf.merged.shape(), src[i].ctx(), false, buf.merged.dtype()); buf.residual[i].AssignStorageInfo(profiler_scope, "residual"); - buf.residual[i] = 0; + buf.residual[i] = 0; int64_t small_size = gc_->GetCompressedSize(buf.merged.shape().Size()); - buf.compressed_recv_buf[i] = NDArray(mxnet::TShape{small_size}, buf.merged.ctx(), - false, buf.merged.dtype()); + buf.compressed_recv_buf[i] = + NDArray(mxnet::TShape{small_size}, buf.merged.ctx(), false, buf.merged.dtype()); buf.compressed_recv_buf[i].AssignStorageInfo(profiler_scope, "compressed_recv_buf"); - buf.compressed_send_buf[i] = NDArray(mxnet::TShape{small_size}, src[i].ctx(), - false, buf.merged.dtype()); + buf.compressed_send_buf[i] = + NDArray(mxnet::TShape{small_size}, src[i].ctx(), false, buf.merged.dtype()); buf.compressed_send_buf[i].AssignStorageInfo(profiler_scope, "compressed_send_buf"); } } @@ -604,8 +626,10 @@ class CommDevice : public Comm { return buf.merged; } - void Broadcast(int key, const NDArray& src, - const std::vector dst, int priority) override { + void Broadcast(int key, + const NDArray& src, + const std::vector dst, + int priority) override { if (!inited_) { // copy to a random device first int dev_id = key % dst.size(); @@ -624,26 +648,30 @@ class CommDevice : public Comm { } } - void BroadcastRowSparse(int key, const NDArray& src, + void BroadcastRowSparse(int key, + const NDArray& src, const std::vector>& dst, const int priority) override { CHECK_EQ(src.storage_type(), kRowSparseStorage) - << "BroadcastRowSparse expects row-sparse src NDArray"; + << "BroadcastRowSparse expects row-sparse src NDArray"; for (const auto& dst_kv : dst) { - NDArray* out = dst_kv.first; + NDArray* out = dst_kv.first; NDArray row_id = dst_kv.second; CHECK_EQ(out->storage_type(), kRowSparseStorage) - << "BroadcastRowSparse expects row_sparse dst NDArray"; - CHECK_EQ(row_id.ctx(), src.ctx()) - << "row_id and src are expected to be on the same context"; + << "BroadcastRowSparse expects row_sparse dst NDArray"; + CHECK_EQ(row_id.ctx(), src.ctx()) << "row_id and src are expected to be on the same context"; // retain according to indices const bool is_same_ctx = out->ctx() == src.ctx(); const bool is_diff_var = out->var() != src.var(); - NDArray retained_gpu = (is_same_ctx && is_diff_var) ? *out : - NDArray(kRowSparseStorage, out->shape(), src.ctx(), true, - out->dtype(), out->aux_types()); + NDArray retained_gpu = (is_same_ctx && is_diff_var) ? *out + : NDArray(kRowSparseStorage, + out->shape(), + src.ctx(), + true, + out->dtype(), + out->aux_types()); if (!is_diff_var) { common::LogOnce("The output of row_sparse_pull() on key " + std::to_string(key) + "refers to the same NDArray as the one stored in KVStore." @@ -653,31 +681,37 @@ class CommDevice : public Comm { "consider create a new NDArray buffer to store the output."); } bool is_gpu = retained_gpu.ctx().dev_mask() == gpu::kDevMask; - Engine::Get()->PushAsync([=](RunContext rctx, Engine::CallbackOnComplete on_complete) { - const TBlob& indices = row_id.data(); - using namespace mxnet::common; - NDArray temp = retained_gpu; - switch (temp.ctx().dev_mask()) { - case cpu::kDevMask: { - SparseRetainOpForwardRspWrapper(rctx.get_stream(), - src, indices, kWriteTo, &temp); - break; - } + Engine::Get()->PushAsync( + [=](RunContext rctx, Engine::CallbackOnComplete on_complete) { + const TBlob& indices = row_id.data(); + using namespace mxnet::common; + NDArray temp = retained_gpu; + switch (temp.ctx().dev_mask()) { + case cpu::kDevMask: { + SparseRetainOpForwardRspWrapper( + rctx.get_stream(), src, indices, kWriteTo, &temp); + break; + } #if MXNET_USE_CUDA - case gpu::kDevMask: { - SparseRetainOpForwardRspWrapper(rctx.get_stream(), - src, indices, kWriteTo, &temp); - // wait for GPU operations to complete - rctx.get_stream()->Wait(); - break; - } + case gpu::kDevMask: { + SparseRetainOpForwardRspWrapper( + rctx.get_stream(), src, indices, kWriteTo, &temp); + // wait for GPU operations to complete + rctx.get_stream()->Wait(); + break; + } #endif - default: LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; - } - on_complete(); - }, retained_gpu.ctx(), {src.var(), row_id.var()}, {retained_gpu.var()}, - is_gpu ? FnProperty::kGPUPrioritized : FnProperty::kCPUPrioritized, - priority, "KVStoreSparseRetain"); + default: + LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; + } + on_complete(); + }, + retained_gpu.ctx(), + {src.var(), row_id.var()}, + {retained_gpu.var()}, + is_gpu ? FnProperty::kGPUPrioritized : FnProperty::kCPUPrioritized, + priority, + "KVStoreSparseRetain"); CopyFromTo(retained_gpu, out, priority); } } @@ -685,10 +719,11 @@ class CommDevice : public Comm { using KeyAttrs = std::tuple; // try to allocate buff on device evenly void InitMergeBuffer(const std::vector& devs) { - std::sort(sorted_key_attrs_.begin(), sorted_key_attrs_.end(), []( - const KeyAttrs& a, const KeyAttrs& b) { - return std::get<1>(a).Size() > std::get<1>(b).Size(); - }); + std::sort(sorted_key_attrs_.begin(), + sorted_key_attrs_.end(), + [](const KeyAttrs& a, const KeyAttrs& b) { + return std::get<1>(a).Size() > std::get<1>(b).Size(); + }); std::unordered_map> ctx_info; for (auto d : devs) { @@ -699,16 +734,16 @@ class CommDevice : public Comm { profiler::ProfilerScope::Get()->GetCurrentProfilerScope() + "kvstore:comm_dev:"; for (auto& sorted_key_attr : sorted_key_attrs_) { - const int key = std::get<0>(sorted_key_attr); + const int key = std::get<0>(sorted_key_attr); const mxnet::TShape& shape = std::get<1>(sorted_key_attr); - const int type = std::get<2>(sorted_key_attr); - auto& buf = merge_buf_[key]; + const int type = std::get<2>(sorted_key_attr); + auto& buf = merge_buf_[key]; Context ctx; size_t min_size = std::numeric_limits::max(); for (auto& ctx_info_kv : ctx_info) { size_t size = ctx_info_kv.second.second; if (size <= min_size) { - ctx = ctx_info_kv.second.first; + ctx = ctx_info_kv.second.first; min_size = size; } } @@ -716,7 +751,7 @@ class CommDevice : public Comm { // only sees sparse arrays if (buf.merged.is_none()) { bool delay_alloc = true; - buf.merged = NDArray(shape, ctx, delay_alloc, type); + buf.merged = NDArray(shape, ctx, delay_alloc, type); buf.merged.AssignStorageInfo(profiler_scope, "merge_buf_" + std::to_string(key)); } ctx_info[ctx.dev_id].second += shape.Size(); @@ -733,9 +768,9 @@ class CommDevice : public Comm { gpus.push_back(d.dev_id); } } - int n = static_cast(gpus.size()); + int n = static_cast(gpus.size()); int enabled = 0; - std::vector p2p(n*n); + std::vector p2p(n * n); for (int i = 0; i < n; ++i) { // Restores active device to what it was before EnableP2P @@ -747,21 +782,21 @@ class CommDevice : public Comm { cudaError_t e = cudaDeviceEnablePeerAccess(gpus[j], 0); if (e == cudaSuccess || e == cudaErrorPeerAccessAlreadyEnabled) { ++enabled; - p2p[i*n+j] = 1; + p2p[i * n + j] = 1; } } } } - if (enabled != n*(n-1)) { + if (enabled != n * (n - 1)) { // print warning info if not fully enabled - LOG(WARNING) << "only " << enabled << " out of " - << n*(n-1) << " GPU pairs are enabled direct access. " + LOG(WARNING) << "only " << enabled << " out of " << n * (n - 1) + << " GPU pairs are enabled direct access. " << "It may affect the performance. " << "You can set MXNET_ENABLE_GPU_P2P=0 to turn it off"; std::string access(n, '.'); for (int i = 0; i < n; ++i) { for (int j = 0; j < n; ++j) { - access[j] = p2p[i*n+j] ? 'v' : '.'; + access[j] = p2p[i * n + j] ? 'v' : '.'; } LOG(WARNING) << access; } @@ -792,8 +827,8 @@ class CommDevice : public Comm { // check if sparse_merged is initialized if (sparse_merged.is_none()) { CHECK(!merged.is_none()); - sparse_merged = NDArray(kRowSparseStorage, merged.shape(), merged.ctx(), - true, merged.dtype()); + sparse_merged = + NDArray(kRowSparseStorage, merged.shape(), merged.ctx(), true, merged.dtype()); } return sparse_merged; } diff --git a/src/kvstore/comm_tree.h b/src/kvstore/comm_tree.h index 07ffba53cb8b..f855845b297f 100644 --- a/src/kvstore/comm_tree.h +++ b/src/kvstore/comm_tree.h @@ -50,15 +50,17 @@ namespace kvstore { class CommDeviceTree : public CommDevice { public: CommDeviceTree() { - inited_ = false; - gpuarray_bound_ = dmlc::GetEnv("MXNET_KVSTORE_TREE_ARRAY_BOUND", 10000000); - backtrack_ = dmlc::GetEnv("MXNET_KVSTORE_TREE_BACKTRACK", 0); + inited_ = false; + gpuarray_bound_ = dmlc::GetEnv("MXNET_KVSTORE_TREE_ARRAY_BOUND", 10000000); + backtrack_ = dmlc::GetEnv("MXNET_KVSTORE_TREE_BACKTRACK", 0); link_usage_penalty_ = dmlc::GetEnv("MXNET_KVSTORE_TREE_LINK_USAGE_PENALTY", 0.7); } - virtual ~CommDeviceTree() { } + virtual ~CommDeviceTree() {} - void Init(int key, const NDArrayStorageType stype, const mxnet::TShape& shape, + void Init(int key, + const NDArrayStorageType stype, + const mxnet::TShape& shape, int dtype = mshadow::kFloat32) override { tree_sorted_key_attrs_.emplace_back(key, shape, dtype); sorted_key_attrs_.emplace_back(key, shape, dtype); @@ -88,23 +90,26 @@ class CommDeviceTree : public CommDevice { * \param merged_row is the id of the slice we are taking * \param priority the priority of the operation */ - const NDArray& ReduceInner(int key, const std::vector& src, int root, - int merged_row, int priority) { + const NDArray& ReduceInner(int key, + const std::vector& src, + int root, + int merged_row, + int priority) { std::vector> reduce(devs_.size()); - TreeBufferEntry& random_buf = tree_merge_buf_[0][key]; + TreeBufferEntry& random_buf = tree_merge_buf_[0][key]; const NDArrayStorageType stype = random_buf.merged[0].storage_type(); - std::vector& topology = topology_[root]; + std::vector& topology = topology_[root]; NDArray buf_slice; if (stype == kDefaultStorage) { // Copy everything into buf.merged for each gpu for (const auto& src_gpu_value : src) { int start = scan_[root][depth_]; - int end = scan_[root][depth_+1]; + int end = scan_[root][depth_ + 1]; for (int j = start; j < end; ++j) { - int topo_id = topology[j]; + int topo_id = topology[j]; TreeBufferEntry& buf = tree_merge_buf_[topo_id][key]; if (devs_[topo_id] == src_gpu_value.ctx()) { @@ -114,14 +119,14 @@ class CommDeviceTree : public CommDevice { } for (int level = depth_; level > 0; --level) { - int start = scan_[root][level ]; - int end = scan_[root][level+1]; + int start = scan_[root][level]; + int end = scan_[root][level + 1]; unsigned is_dest = 0; - int dest_id = 0; + int dest_id = 0; for (int j = start; j < end; ++j) { int topo_id = topology[j]; - dest_id = (is_dest == 0) ? topo_id : dest_id; + dest_id = (is_dest == 0) ? topo_id : dest_id; TreeBufferEntry& buf_dest = tree_merge_buf_[dest_id][key]; TreeBufferEntry& buf_from = tree_merge_buf_[topo_id][key]; @@ -133,26 +138,24 @@ class CommDeviceTree : public CommDevice { } else { if (dest_id != topo_id) { CopyFromTo(buf_from.merged[merged_row], - &(buf_dest.copy_buf[merged_row][is_dest-1]), + &(buf_dest.copy_buf[merged_row][is_dest - 1]), priority); - reduce[dest_id].push_back( - buf_dest.copy_buf[merged_row][is_dest-1]); + reduce[dest_id].push_back(buf_dest.copy_buf[merged_row][is_dest - 1]); } } - is_dest = (is_dest == static_cast(kBranch)-1) ? - 0 : is_dest+1; + is_dest = (is_dest == static_cast(kBranch) - 1) ? 0 : is_dest + 1; } - start = scan_[root][level-1]; - end = scan_[root][level]; + start = scan_[root][level - 1]; + end = scan_[root][level]; int source = end; for (int i = start; i < end; ++i) { int gpu_id = topology[i]; // source keeps track of 2 leaf nodes, while start keeps track of parent int dest_id = topology[source]; - int from_id = topology[source+1]; + int from_id = topology[source + 1]; source += 2; // conditional to detect whether operation must be done @@ -171,13 +174,12 @@ class CommDeviceTree : public CommDevice { LOG(FATAL) << "Only dense input supported for now"; } - int topo_id = topology[0]; + int topo_id = topology[0]; TreeBufferEntry& buf = tree_merge_buf_[topo_id][key]; return buf.merged[merged_row]; } - const NDArray& Reduce(int key, const std::vector& src, - int priority) override { + const NDArray& Reduce(int key, const std::vector& src, int priority) override { // when this reduce is called from kvstore_dist, gc is not set // we don't do compression twice in dist_sync_device if ((gc_ != nullptr) && (gc_->get_type() != CompressionType::kNone)) { @@ -191,22 +193,22 @@ class CommDeviceTree : public CommDevice { } InitBuffersAndComm(src); - std::vector> slice(devs_.size()); + std::vector> slice(devs_.size()); std::vector> broadcast_slice(devs_.size()); - std::vector slice_scan(devs_.size()+1); + std::vector slice_scan(devs_.size() + 1); - int total_size = src[0].shape().Size(); + int total_size = src[0].shape().Size(); unsigned first_size = src[0].shape()[0]; const NDArrayStorageType stype = src[0].storage_type(); // normal dense reduce if (stype == kDefaultStorage) { - if (total_size > gpuarray_bound_ && first_size >= 2*devs_.size()) { + if (total_size > gpuarray_bound_ && first_size >= 2 * devs_.size()) { // Find slice bounds - slice_scan[0] = 0; - int slice_size = first_size/devs_.size(); + slice_scan[0] = 0; + int slice_size = first_size / devs_.size(); for (unsigned i = 1; i < devs_.size(); ++i) { - slice_scan[i] = slice_scan[i-1] + slice_size; + slice_scan[i] = slice_scan[i - 1] + slice_size; } slice_scan[devs_.size()] = src[0].shape()[0]; @@ -215,8 +217,7 @@ class CommDeviceTree : public CommDevice { for (unsigned row = 0; row < devs_.size(); ++row) { for (unsigned col = 0; col < devs_.size(); ++col) { TreeBufferEntry& buf = tree_merge_buf_[col][key]; - NDArray curr_slice = src[col].Slice(slice_scan[row], - slice_scan[row+1]); + NDArray curr_slice = src[col].Slice(slice_scan[row], slice_scan[row + 1]); slice[row].push_back(curr_slice); broadcast_slice[row].push_back(&(buf.merged[row])); } @@ -249,9 +250,12 @@ class CommDeviceTree : public CommDevice { } } - void BroadcastInner(int key, const NDArray& src, - const std::vector& dst, int root, - int merged_row, int priority) { + void BroadcastInner(int key, + const NDArray& src, + const std::vector& dst, + int root, + int merged_row, + int priority) { // copy to root of tree std::vector& topology = topology_[root]; std::vector temp(devs_.size()); @@ -262,26 +266,28 @@ class CommDeviceTree : public CommDevice { for (int level = 1; level <= depth_; ++level) { int start = scan_[root][level]; - int end = scan_[root][level+1]; + int end = scan_[root][level + 1]; unsigned is_src = 0; - int src_id = 0; + int src_id = 0; for (int j = start; j < end; ++j) { int topo_id = topology[j]; - src_id = (is_src == 0) ? topo_id : src_id; + src_id = (is_src == 0) ? topo_id : src_id; if (is_src && src_id != topo_id) { CopyFromTo(temp[src_id], dst[topo_id], priority); temp[topo_id] = *dst[topo_id]; } - is_src = (is_src == static_cast(kBranch)-1) ? 0 : is_src+1; + is_src = (is_src == static_cast(kBranch) - 1) ? 0 : is_src + 1; } } } - void Broadcast(int key, const NDArray& src, - const std::vector dst, int priority) override { + void Broadcast(int key, + const NDArray& src, + const std::vector dst, + int priority) override { if (!inited_) { // copy to a random device first int dev_id = key % dst.size(); @@ -292,33 +298,34 @@ class CommDeviceTree : public CommDevice { } } } else { - int total_size = src.shape().Size(); - unsigned first_size = src.shape()[0]; + int total_size = src.shape().Size(); + unsigned first_size = src.shape()[0]; const NDArrayStorageType stype = src.storage_type(); // normal dense reduce if (stype == kDefaultStorage) { - if (total_size > gpuarray_bound_ && first_size >= 2*devs_.size()) { - std::vector slice_scan(devs_.size()+1); - slice_scan[0] = 0; - int slice_size = (dst[0]->shape()[0])/devs_.size(); - for (unsigned i = 1; i < devs_.size(); ++i) { - slice_scan[i] = slice_scan[i-1] + slice_size; - } - slice_scan[devs_.size()] = dst[0]->shape()[0]; - - for (unsigned gpu_id = 0; gpu_id < dst.size(); ++gpu_id) { - TreeBufferEntry& buf = tree_merge_buf_[gpu_id][key]; - for (unsigned i = 0; i < devs_.size(); ++i) { - if (devs_[gpu_id] == dst[gpu_id]->ctx()) { - NDArray curr_slice = dst[gpu_id]->Slice(slice_scan[i], slice_scan[i+1]); - CopyFromTo(buf.merged[i], &curr_slice, priority); + if (total_size > gpuarray_bound_ && first_size >= 2 * devs_.size()) { + std::vector slice_scan(devs_.size() + 1); + slice_scan[0] = 0; + int slice_size = (dst[0]->shape()[0]) / devs_.size(); + for (unsigned i = 1; i < devs_.size(); ++i) { + slice_scan[i] = slice_scan[i - 1] + slice_size; + } + slice_scan[devs_.size()] = dst[0]->shape()[0]; + + for (unsigned gpu_id = 0; gpu_id < dst.size(); ++gpu_id) { + TreeBufferEntry& buf = tree_merge_buf_[gpu_id][key]; + for (unsigned i = 0; i < devs_.size(); ++i) { + if (devs_[gpu_id] == dst[gpu_id]->ctx()) { + NDArray curr_slice = dst[gpu_id]->Slice(slice_scan[i], slice_scan[i + 1]); + CopyFromTo(buf.merged[i], &curr_slice, priority); + } } } + } else { + int root = 0; + BroadcastInner(key, src, dst, root, -1, priority); } } else { - int root = 0; - BroadcastInner(key, src, dst, root, -1, priority); - }} else { LOG(FATAL) << "Only dense input supported for now"; } } @@ -333,10 +340,10 @@ class CommDeviceTree : public CommDevice { gpus.push_back(d.dev_id); } } - int n = static_cast(gpus.size()); + int n = static_cast(gpus.size()); int enabled = 0; p2p->clear(); - p2p->resize(n*n, 0); + p2p->resize(n * n, 0); for (int i = 0; i < n; ++i) { mxnet::common::cuda::DeviceStore device_store(gpus[i]); for (int j = 0; j < n; j++) { @@ -346,21 +353,21 @@ class CommDeviceTree : public CommDevice { cudaError_t e = cudaDeviceEnablePeerAccess(gpus[j], 0); if (e == cudaSuccess || e == cudaErrorPeerAccessAlreadyEnabled) { ++enabled; - (*p2p)[i*n+j] = 1; + (*p2p)[i * n + j] = 1; } } } } - if (enabled != n*(n-1)) { + if (enabled != n * (n - 1)) { // print warning info if not fully enabled - LOG(WARNING) << "only " << enabled << " out of " - << n*(n-1) << " GPU pairs are enabled direct access. " + LOG(WARNING) << "only " << enabled << " out of " << n * (n - 1) + << " GPU pairs are enabled direct access. " << "It may affect the performance. " << "You can set MXNET_ENABLE_GPU_P2P=0 to turn it off"; std::string access(n, '.'); for (int i = 0; i < n; ++i) { for (int j = 0; j < n; ++j) { - access[j] = (*p2p)[i*n+j] ? 'v' : '.'; + access[j] = (*p2p)[i * n + j] ? 'v' : '.'; } LOG(WARNING) << access; } @@ -370,16 +377,15 @@ class CommDeviceTree : public CommDevice { void QueryTopology() { #if MXNET_USE_CUDA - std::vector link_matrix(devs_.size()*devs_.size()); - std::vector p2p_matrix(devs_.size()*devs_.size()); + std::vector link_matrix(devs_.size() * devs_.size()); + std::vector p2p_matrix(devs_.size() * devs_.size()); EnableP2P(&p2p_matrix); GetP2PWeight(devs_, p2p_matrix, &link_matrix); if (backtrack_) LOG(INFO) << "Using Backtracking to generate trees"; else LOG(INFO) << "Using Kernighan-Lin to generate trees"; - ComputeTrees(link_matrix, devs_.size(), link_usage_penalty_, backtrack_, - &topology_, &scan_); + ComputeTrees(link_matrix, devs_.size(), link_usage_penalty_, backtrack_, &topology_, &scan_); depth_ = ComputeDepth(devs_.size()); #endif @@ -404,9 +410,9 @@ class CommDeviceTree : public CommDevice { profiler::ProfilerScope::Get()->GetCurrentProfilerScope() + "comm_dev_tree:"; for (auto& tree_sorted_key_attr : tree_sorted_key_attrs_) { - const int key = std::get<0>(tree_sorted_key_attr); + const int key = std::get<0>(tree_sorted_key_attr); const mxnet::TShape& shape = std::get<1>(tree_sorted_key_attr); - const int type = std::get<2>(tree_sorted_key_attr); + const int type = std::get<2>(tree_sorted_key_attr); if (key_dist.find(shape.Size()) == key_dist.end()) key_dist[shape.Size()] = 1; @@ -414,7 +420,7 @@ class CommDeviceTree : public CommDevice { key_dist[shape.Size()]++; int start = scan_[0][depth_]; - int end = scan_[0][depth_+1]; + int end = scan_[0][depth_ + 1]; // In order to generalize to any number of GPUs in arbitrary order, we use // strategy of having found the mapping from 0, 1, ..., n_gpus to dev_id. @@ -442,29 +448,28 @@ class CommDeviceTree : public CommDevice { // 3) We use the mapping (devs_) to retrieve dev_id and device context for (int j = start; j < end; ++j) { int topo_id = topology_[0][j]; - auto& buf = tree_merge_buf_[topo_id][key]; + auto& buf = tree_merge_buf_[topo_id][key]; Context ctx = devs_[topo_id]; // buf.merged enforces that we only visit each GPU once if (buf.merged.empty()) { mxnet::TShape shape_copy = shape; - int total_size = shape.Size(); - unsigned first_size = shape[0]; - if (total_size > gpuarray_bound_ && first_size >= 2*devs_.size()) { + int total_size = shape.Size(); + unsigned first_size = shape[0]; + if (total_size > gpuarray_bound_ && first_size >= 2 * devs_.size()) { // Find slice bounds - int slice_size = first_size/devs_.size(); - int last_slice = first_size-(devs_.size()-1)*slice_size; - shape_copy[0] = slice_size; + int slice_size = first_size / devs_.size(); + int last_slice = first_size - (devs_.size() - 1) * slice_size; + shape_copy[0] = slice_size; buf.merged.resize(devs_.size()); for (unsigned row = 0; row < devs_.size(); ++row) { - if (row == devs_.size()-1) + if (row == devs_.size() - 1) shape_copy[0] = last_slice; buf.merged[row] = NDArray(shape_copy, ctx, delay_alloc, type); - buf.merged[row].AssignStorageInfo( - profiler_scope, "merged_" + std::to_string(key)); + buf.merged[row].AssignStorageInfo(profiler_scope, "merged_" + std::to_string(key)); buf.copy_buf.emplace_back(); if (buf.copy_buf[row].empty()) { - buf.copy_buf[row].resize(kBranch-1); + buf.copy_buf[row].resize(kBranch - 1); for (size_t col = 0; col < buf.copy_buf[0].size(); ++col) { buf.copy_buf[row][col] = NDArray(buf.merged[row].shape(), buf.merged[row].ctx(), @@ -476,16 +481,13 @@ class CommDeviceTree : public CommDevice { } } else { buf.merged.emplace_back(shape, ctx, false, type); - buf.merged.back().AssignStorageInfo( - profiler_scope, "merged_" + std::to_string(key)); + buf.merged.back().AssignStorageInfo(profiler_scope, "merged_" + std::to_string(key)); if (buf.copy_buf.empty()) { buf.copy_buf.emplace_back(); - buf.copy_buf[0].resize(kBranch-1); + buf.copy_buf[0].resize(kBranch - 1); for (size_t col = 0; col < buf.copy_buf[0].size(); ++col) { - buf.copy_buf[0][col] = NDArray(buf.merged[0].shape(), - buf.merged[0].ctx(), - delay_alloc, - buf.merged[0].dtype()); + buf.copy_buf[0][col] = NDArray( + buf.merged[0].shape(), buf.merged[0].ctx(), delay_alloc, buf.merged[0].dtype()); buf.copy_buf[0][col].AssignStorageInfo(profiler_scope, "copy_buf"); } } diff --git a/src/kvstore/gpu_topology.h b/src/kvstore/gpu_topology.h index 777fb47f9945..70a4f555008c 100644 --- a/src/kvstore/gpu_topology.h +++ b/src/kvstore/gpu_topology.h @@ -23,8 +23,8 @@ #ifndef MXNET_KVSTORE_GPU_TOPOLOGY_H_ #define MXNET_KVSTORE_GPU_TOPOLOGY_H_ #if MXNET_USE_CUDA - #include - #include +#include +#include #endif #include #include @@ -55,8 +55,10 @@ inline void PrintVector(const std::string& str, const std::vector& vec) { } template -inline void PrintMatrix(const std::string& str, const std::vector& matrix, - int num_rows, int num_cols) { +inline void PrintMatrix(const std::string& str, + const std::vector& matrix, + int num_rows, + int num_cols) { LOG(INFO) << str << ":"; int count = 0; for (int row = 0; row < num_rows; ++row) { @@ -68,16 +70,17 @@ inline void PrintMatrix(const std::string& str, const std::vector& matrix, } } -inline void PrintTopo(const std::string& str, const std::vector& topo_row, - std::vector scan_row) { +inline void PrintTopo(const std::string& str, + const std::vector& topo_row, + std::vector scan_row) { LOG(INFO) << str << ":"; - int depth = scan_row.size()-1; + int depth = scan_row.size() - 1; for (int row = 0; row < depth; ++row) { int start = scan_row[row]; - int end = scan_row[row+1]; + int end = scan_row[row + 1]; std::string output; for (; start < end; start++) { - for (int i = 0; i < (2 << (depth-row-2))+1; ++i) { + for (int i = 0; i < (2 << (depth - row - 2)) + 1; ++i) { output += " "; } output += std::to_string(topo_row[start]); @@ -86,7 +89,7 @@ inline void PrintTopo(const std::string& str, const std::vector& topo_ro } } -/** +/** * \brief Uses BFS to find whether undirected graph is connected or not given its * adjacency matrix * Note: only consider matrix values > 1, because we care about whether it is @@ -105,7 +108,7 @@ inline bool IsConnected(const std::vector& matrix, int num_gpus) { work_list.pop(); for (int i = 0; i < num_gpus; ++i) { - int neighbour = matrix[curr*num_gpus + i]; + int neighbour = matrix[curr * num_gpus + i]; if (i != curr && neighbour > 1 && visited[i] == false) { visited[i] = true; work_list.push(i); @@ -153,7 +156,7 @@ inline void GetP2PWeight(const std::vector& devs, for (int row = 0; row < num_gpus; ++row) { for (int col = 0; col < num_gpus; ++col) { if (row == col) { - (*matrix)[row*num_gpus+col] = 0; + (*matrix)[row * num_gpus + col] = 0; } else { int value; int row_gpu = zero_dev_id[row]; @@ -161,7 +164,7 @@ inline void GetP2PWeight(const std::vector& devs, cudaDeviceGetP2PAttribute(&value, attr, row_gpu, col_gpu); if (value > max[row]) max[row] = value; - (*matrix)[row*num_gpus+col] = static_cast(value)+1; + (*matrix)[row * num_gpus + col] = static_cast(value) + 1; } } } @@ -240,7 +243,7 @@ inline void GetP2PWeight(const std::vector& devs, } } else { for (auto& matrix_value : *matrix) { - matrix_value = (matrix_value == 1) ? 1./num_gpus : matrix_value; + matrix_value = (matrix_value == 1) ? 1. / num_gpus : matrix_value; } } if (kLogTree) @@ -257,14 +260,13 @@ inline void GetP2PWeight(const std::vector& devs, * y = A*x (no accumulate) */ template -inline void gemv(const std::vector& A, const std::vector& x, - std::vector* y) { +inline void gemv(const std::vector& A, const std::vector& x, std::vector* y) { int nrows = x.size(); int count = 0; - for (int row=0; row < nrows; ++row) { + for (int row = 0; row < nrows; ++row) { (*y)[row] = 0; - for (int col=0; col < nrows; ++col) { - (*y)[row] += A[count]*static_cast(x[col]); + for (int col = 0; col < nrows; ++col) { + (*y)[row] += A[count] * static_cast(x[col]); count++; } } @@ -277,8 +279,8 @@ inline void gemv(const std::vector& A, const std::vector& x, template inline void ewisemult(const std::vector& u, T alpha, std::vector* w) { int nelem = u.size(); - for (int i=0; i < nelem; ++i) { - (*w)[i] *= alpha*static_cast(u[i]); + for (int i = 0; i < nelem; ++i) { + (*w)[i] *= alpha * static_cast(u[i]); } } @@ -294,17 +296,21 @@ inline void FindBestMove(const std::vector& W, const std::vector& P_temp, const std::vector& D, const std::unordered_set& used, - int* a, int* b, T* g) { + int* a, + int* b, + T* g) { int nrows = P_temp.size(); - *g = 0; - *a = -1; - *b = -1; - for (int row=0; row < nrows; ++row) { - if (P_temp[row] == 0 || used.find(row) != used.end()) continue; - for (int col=row+1; col < nrows; ++col) { - if (P_temp[col] == 0 || P_temp[row] == P_temp[col]) continue; - - T cost = D[row]+D[col]-2*W[row*nrows+col]; + *g = 0; + *a = -1; + *b = -1; + for (int row = 0; row < nrows; ++row) { + if (P_temp[row] == 0 || used.find(row) != used.end()) + continue; + for (int col = row + 1; col < nrows; ++col) { + if (P_temp[col] == 0 || P_temp[row] == P_temp[col]) + continue; + + T cost = D[row] + D[col] - 2 * W[row * nrows + col]; if (cost > *g) { *g = cost; *a = row; @@ -323,7 +329,8 @@ inline void FindBestMove(const std::vector& W, * the output of partitioning one large cluster */ template -inline bool KernighanLin(const std::vector& W, std::vector* P, +inline bool KernighanLin(const std::vector& W, + std::vector* P, int* num_partitions, std::vector>* cluster_pairs, std::mt19937* gen) { @@ -340,14 +347,13 @@ inline bool KernighanLin(const std::vector& W, std::vector* P, } bool stop = true; - for (unsigned color=0; color < histogram.size(); ++color) { + for (unsigned color = 0; color < histogram.size(); ++color) { int partition_size = histogram[color]; // Save cluster in preparation for push to topo in GenerateBinaryTree() if (partition_size <= 2) { - cluster_pairs->push_back( - std::pair(static_cast(color), -partition_size)); + cluster_pairs->push_back(std::pair(static_cast(color), -partition_size)); - // Do Kernighan-Lin if clustering is necessary + // Do Kernighan-Lin if clustering is necessary } else { stop = false; @@ -355,8 +361,8 @@ inline bool KernighanLin(const std::vector& W, std::vector* P, // Assign random balanced partition of it // -balanced is more important than random, so allocate first half to A // and rest to B - int first_partition = 0; - int target_partition = partition_size/2; + int first_partition = 0; + int target_partition = partition_size / 2; std::vector cluster_list; for (unsigned i = 0; i < P->size(); ++i) { @@ -374,18 +380,18 @@ inline bool KernighanLin(const std::vector& W, std::vector* P, std::shuffle(cluster_list.begin(), cluster_list.end(), *gen); for (int cluster : cluster_list) { if (first_partition < target_partition) { - int dest = cluster; + int dest = cluster; P_temp[dest] = 1; first_partition++; } else { - int dest = cluster; + int dest = cluster; P_temp[dest] = -1; } } // 2) Do iterations of Kernighan-Lin until convergence - T g_max = 0; - int g_k = -1; + T g_max = 0; + int g_k = -1; unsigned count = 0; do { count++; @@ -404,7 +410,7 @@ inline bool KernighanLin(const std::vector& W, std::vector* P, std::unordered_set used; - for (int iter=0; iter < partition_size/2; ++iter) { + for (int iter = 0; iter < partition_size / 2; ++iter) { // b) Find best move by looking through upper triangular of W matrix int a, b; T g; @@ -437,10 +443,10 @@ inline bool KernighanLin(const std::vector& W, std::vector* P, // Recompute score g_max for (unsigned k = 0; k < gv.size(); ++k) { if (k > 0) - gv[k] += gv[k-1]; + gv[k] += gv[k - 1]; if (gv[k] > g_max) { g_max = gv[k]; - g_k = k + 1; + g_k = k + 1; } } @@ -448,9 +454,9 @@ inline bool KernighanLin(const std::vector& W, std::vector* P, // Otherwise, rollback changes to P_temp2 if (g_max > 0) { for (int i = 0; i < g_k; i++) { - int a = av[i]; - int b = bv[i]; - int temp = P_temp2[a]; + int a = av[i]; + int b = bv[i]; + int temp = P_temp2[a]; P_temp2[a] = P_temp2[b]; P_temp2[b] = temp; @@ -463,14 +469,14 @@ inline bool KernighanLin(const std::vector& W, std::vector* P, // 5) Update P using P_temp int moves = 0; - for (unsigned i=0; i < P->size(); ++i) { + for (unsigned i = 0; i < P->size(); ++i) { if (P_temp[i] == -1) { (*P)[i] = *num_partitions; moves++; } } - cluster_pairs->push_back(std::pair(static_cast(color), - static_cast(*num_partitions))); + cluster_pairs->push_back( + std::pair(static_cast(color), static_cast(*num_partitions))); (*num_partitions)++; } @@ -483,8 +489,7 @@ inline bool KernighanLin(const std::vector& W, std::vector* P, * \brief Returns root of a given color if found in roots * Returns -1 if it is not found */ -inline int GetRoot(const std::vector& P, int color, - const std::unordered_set& roots) { +inline int GetRoot(const std::vector& P, int color, const std::unordered_set& roots) { for (auto root : roots) { if (P[root] == color) return root; @@ -514,16 +519,21 @@ inline int GetChild(const std::vector& P, int color, int parent) { // g is weight of edge // Optimization: Only need to look at row a in matrix template -inline void FindBestEdge(const std::vector& W, const std::vector& P, - int parent, int dest_cluster, std::vector* b, T* g) { +inline void FindBestEdge(const std::vector& W, + const std::vector& P, + int parent, + int dest_cluster, + std::vector* b, + T* g) { int nrows = P.size(); - int row = parent; - *g = 0; + int row = parent; + *g = 0; b->push_back(-1); - for (int col=0; col < nrows; ++col) { - if (col == row || P[col] != dest_cluster) continue; + for (int col = 0; col < nrows; ++col) { + if (col == row || P[col] != dest_cluster) + continue; - T cost = W[row*nrows+col]; + T cost = W[row * nrows + col]; if (cost > *g) { b->clear(); } @@ -562,25 +572,27 @@ inline int KLGenerateBinaryTree(const std::vector& W, int parent, child = -1; if ((*cluster_pairs)[i].second == -2) { // Root must be color of pair.first - int color = (*cluster_pairs)[i].first; - parent = GetRoot(P, color, *roots); - if (parent == -1) return 1; + int color = (*cluster_pairs)[i].first; + parent = GetRoot(P, color, *roots); + if (parent == -1) + return 1; child = GetChild(P, color, parent); } else if ((*cluster_pairs)[i].second == -1) { int color = (*cluster_pairs)[i].first; - parent = GetRoot(P, color, *roots); - if (parent == -1) return 1; + parent = GetRoot(P, color, *roots); + if (parent == -1) + return 1; child = parent; } else { // Root must exist in either first or second element of pair int color = (*cluster_pairs)[i].first; - parent = GetRoot(P, color, *roots); - color = (parent == -1) ? (*cluster_pairs)[i].second : color; - parent = (parent == -1) ? GetRoot(P, color, *roots) : parent; + parent = GetRoot(P, color, *roots); + color = (parent == -1) ? (*cluster_pairs)[i].second : color; + parent = (parent == -1) ? GetRoot(P, color, *roots) : parent; int from_cluster = color; - int dest_cluster = (from_cluster == (*cluster_pairs)[i].first) ? - (*cluster_pairs)[i].second : (*cluster_pairs)[i].first; + int dest_cluster = (from_cluster == (*cluster_pairs)[i].first) ? (*cluster_pairs)[i].second + : (*cluster_pairs)[i].first; std::vector candidates; T weight; @@ -605,8 +617,8 @@ inline int KLGenerateBinaryTree(const std::vector& W, } int depth = scan_row->size(); - int start = (*scan_row)[depth-2]; - int end = (*scan_row)[depth-1]; + int start = (*scan_row)[depth - 2]; + int end = (*scan_row)[depth - 1]; for (int i = start; i < end; ++i) { int parent = (*topo_row)[i]; @@ -614,7 +626,7 @@ inline int KLGenerateBinaryTree(const std::vector& W, // If not first, check previous level whether or not we are encountering // this root for the first time in this level of the tree - if (i != start && parent == static_cast((*topo_row)[i-1])) + if (i != start && parent == static_cast((*topo_row)[i - 1])) child = parent; else child = new_topo[parent]; @@ -635,7 +647,7 @@ inline int ComputeDepth(int n) { for (int depth = 0; depth < MXNET_KVSTORE_MAXDEPTH; ++depth) { int num = 2 << depth; if (n <= num) - return depth+1; + return depth + 1; } return 0; } @@ -646,17 +658,20 @@ inline int ComputeDepth(int n) { // -each edge in tree corresponds to link in network topology // -each edge in tree does not form self-loop template -inline bool IsValid(const std::vector& W, const std::vector& state, - int num_elements, int row, int depth) { +inline bool IsValid(const std::vector& W, + const std::vector& state, + int num_elements, + int row, + int depth) { // At each level of tree, check whether edge: // -corresponds to link in network topology // -corresponds to self-loop for (int i = 0; i < depth; ++i) { int stride = 1 << i; - for (int j = 0; j+stride < row; j += 2*stride) { + for (int j = 0; j + stride < row; j += 2 * stride) { int from = state[j]; - int dest = state[j+stride]; - if (W[from*num_elements + dest] == static_cast(0) && from != dest) { + int dest = state[j + stride]; + if (W[from * num_elements + dest] == static_cast(0) && from != dest) { return false; } } @@ -682,7 +697,7 @@ inline bool IsValid(const std::vector& W, const std::vector& state, // modifier is maximum number of repeats a single GPU can take // e.g. 5 GPUs in 3-level binary tree => one GPU can repeat 3x // GPU0 GPU0 GPU0 GPU0 GPU1 GPU2 GPU3 GPU4 - int modifier = (1 << depth) - num_elements; + int modifier = (1 << depth) - num_elements; int num_found = found.size(); // So we know we have an invalid state if we find: @@ -693,8 +708,8 @@ inline bool IsValid(const std::vector& W, const std::vector& state, return false; } - // If we are at last recursive level, we can apply a more stringent check: - // -if some GPU is not found, then we are in invalid state + // If we are at last recursive level, we can apply a more stringent check: + // -if some GPU is not found, then we are in invalid state } else if (row == static_cast(state.size())) { for (int i = 0; i < num_elements; ++i) { if (found_vec[i] == 0) { @@ -729,7 +744,7 @@ inline void Postprocess(std::vector* result, int num_elements, int depth) { for (int level = depth - 1; level >= 0; --level) { int stride = 1 << level; std::vector histogram_above(num_elements, 0); - for (unsigned i = 0; i < result->size(); i += 2*stride) { + for (unsigned i = 0; i < result->size(); i += 2 * stride) { int val = (*result)[i]; histogram_above[val]++; } @@ -739,9 +754,9 @@ inline void Postprocess(std::vector* result, int num_elements, int depth) { histogram[val]++; } - for (int i = result->size()-stride; i-stride >= 0; i -= 2*stride) { + for (int i = result->size() - stride; i - stride >= 0; i -= 2 * stride) { int from = (*result)[i]; - int dest = (*result)[i-stride]; + int dest = (*result)[i - stride]; if ((histogram[from] > 1 || histogram_above[from] >= 1) && from != dest) { (*result)[i] = dest; histogram[from]--; @@ -756,29 +771,31 @@ inline void Postprocess(std::vector* result, int num_elements, int depth) { // -usually turned on when backtracking to get better solutions // -usually turned off when outside the penalty to get weight of tree template -inline T ComputeTreeWeight(const std::vector& W, const std::vector& result, - int num_elements, int depth, bool penalty) { +inline T ComputeTreeWeight(const std::vector& W, + const std::vector& result, + int num_elements, + int depth, + bool penalty) { T weight = 0.f; std::unordered_set links_used; for (int i = 0; i < depth; ++i) { int stride = 1 << i; std::vector nodes_used(num_elements, false); - for (unsigned j = 0; j+stride < result.size(); j += 2*stride) { + for (unsigned j = 0; j + stride < result.size(); j += 2 * stride) { int from = result[j]; - int dest = result[j+stride]; + int dest = result[j + stride]; if (from != dest) { - weight += W[from*num_elements+dest]; + weight += W[from * num_elements + dest]; // Penalize: (1) use of redundant edges in a single tree // (2) repeated use of a GPU in a single tree at the same // level above the leaf level - if (links_used.find(from*num_elements+dest) != links_used.end() - && penalty) { + if (links_used.find(from * num_elements + dest) != links_used.end() && penalty) { weight -= 100; } - links_used.insert(from*num_elements+dest); - links_used.insert(dest*num_elements+from); + links_used.insert(from * num_elements + dest); + links_used.insert(dest * num_elements + from); } nodes_used[from] = true; @@ -868,9 +885,9 @@ inline bool RecursiveBacktrack(const std::vector& W, bool stop = false; for (int j = 0; j < num_elements; ++j) { (*state)[row] = j; - if (IsValid(W, state, num_elements, row+1, depth)) - stop = RecursiveBacktrack(W, state, best_result, best_result_weight, - row+1, num_elements, depth, optimal); + if (IsValid(W, state, num_elements, row + 1, depth)) + stop = RecursiveBacktrack( + W, state, best_result, best_result_weight, row + 1, num_elements, depth, optimal); (*state)[row] = -1; if (stop) return stop; @@ -888,7 +905,7 @@ inline void IterativeBacktrack(const std::vector& W, int depth, bool optimal) { std::stack state_stack; - row = 1; + row = 1; int pos = 0; state_stack.push(pos); @@ -901,15 +918,16 @@ inline void IterativeBacktrack(const std::vector& W, pos = state_stack.top(); pos++; state_stack.pop(); - (*state)[state_stack.size()+1] = -1; + (*state)[state_stack.size() + 1] = -1; row--; } - if (state_stack.empty()) break; + if (state_stack.empty()) + break; (*state)[row] = pos; // If there is a valid position push the position to stack, set current // position to 0 and move to next row - if (IsValid(W, *state, num_elements, row+1, depth)) { + if (IsValid(W, *state, num_elements, row + 1, depth)) { state_stack.push(pos); pos = 0; row++; @@ -931,7 +949,8 @@ inline void IterativeBacktrack(const std::vector& W, std::swap(*best_result_weight, weight); *best_result = result; } - if (!optimal) break; + if (!optimal) + break; pos = state_stack.top(); pos++; @@ -947,20 +966,22 @@ inline void IterativeBacktrack(const std::vector& W, * by the spanning tree */ template -inline void UpdateWeight(std::vector* W, const std::vector& topo_row, - int num_elements, float alpha) { +inline void UpdateWeight(std::vector* W, + const std::vector& topo_row, + int num_elements, + float alpha) { for (unsigned i = 1; i < topo_row.size() - 1; i += 2) { unsigned parent = topo_row[i]; - unsigned child = topo_row[i+1]; - if (!(parent >= num_elements*num_elements || - child >= num_elements*num_elements) && (parent != child)) { - (*W)[parent*num_elements+child] *= alpha; - (*W)[child*num_elements+parent] *= alpha; + unsigned child = topo_row[i + 1]; + if (!(parent >= num_elements * num_elements || child >= num_elements * num_elements) && + (parent != child)) { + (*W)[parent * num_elements + child] *= alpha; + (*W)[child * num_elements + parent] *= alpha; } } } -/** +/** * \brief Do brute-force backtracking approach if Kernighan-Lin fails to find a binary * tree of height Log P. * @@ -986,7 +1007,7 @@ inline bool BacktrackGenerateBinaryTree(std::vector* W, // 7: 3 8 // 8: 3 8 // 9: 4 16 - int depth = ComputeDepth(num_elements); + int depth = ComputeDepth(num_elements); int depth_leaves = 1 << depth; // State vector @@ -1002,11 +1023,9 @@ inline bool BacktrackGenerateBinaryTree(std::vector* W, // For larger numbers of GPUs, settle for first tree found (non-optimal), but // this saves a lot of runtime, because Backtrack is exponential time if (depth <= 3) { - IterativeBacktrack(*W, &state, &result, &result_weight, 1, num_elements, - depth, true); + IterativeBacktrack(*W, &state, &result, &result_weight, 1, num_elements, depth, true); } else { - IterativeBacktrack(*W, &state, &result, &result_weight, 1, num_elements, - depth, false); + IterativeBacktrack(*W, &state, &result, &result_weight, 1, num_elements, depth, false); } return FormTopology(result, topo_row, scan_row, depth); } @@ -1060,26 +1079,26 @@ inline void ComputeTreesFromRoot(std::vector* W, while (!backtrack && (!stop || reset)) { if (reset == 1) { cluster_pairs.clear(); - P_temp = P; + P_temp = P; num_partitions_temp = num_partitions; - roots_temp = roots; - topo_temp = *topo; - scan_temp = *scan; + roots_temp = roots; + topo_temp = *topo; + scan_temp = *scan; } // Run Kernighan-Lin to generate partition - stop = KernighanLin(*W, &P_temp, &num_partitions_temp, &cluster_pairs, - &gen); + stop = KernighanLin(*W, &P_temp, &num_partitions_temp, &cluster_pairs, &gen); // Use partitions found and a given root to find best inter-cluster edge for // each pair of clusters, and returns them as roots of next cluster // If reset is true, then rewind back to previous clustering - reset = KLGenerateBinaryTree(*W, P_temp, &cluster_pairs, &roots_temp, - &topo_temp, &scan_temp, &gen); + reset = + KLGenerateBinaryTree(*W, P_temp, &cluster_pairs, &roots_temp, &topo_temp, &scan_temp, &gen); if (reset) level++; - if (level > 10) break; + if (level > 10) + break; } bool success = true; @@ -1123,8 +1142,7 @@ inline void ComputeTrees(const std::vector& W, scan->push_back(std::vector()); (*topo)[i].push_back(i); (*scan)[i].push_back(0); - ComputeTreesFromRoot(&W_copy, num_elements, i, alpha, backtrack, - &((*topo)[i]), &((*scan)[i])); + ComputeTreesFromRoot(&W_copy, num_elements, i, alpha, backtrack, &((*topo)[i]), &((*scan)[i])); } // Note: must sum up adj matrix to show link usage before we readjust topo @@ -1132,21 +1150,20 @@ inline void ComputeTrees(const std::vector& W, std::vector adj(W.size(), 0); for (int row = 0; row < num_elements; ++row) { for (unsigned col = 1; col < (*topo)[0].size(); col += 2) { - int from = std::min((*topo)[row][col], (*topo)[row][col+1]); - int dest = std::max((*topo)[row][col], (*topo)[row][col+1]); + int from = std::min((*topo)[row][col], (*topo)[row][col + 1]); + int dest = std::max((*topo)[row][col], (*topo)[row][col + 1]); if (from != dest) { - adj.at(from*num_elements+dest) += 1; - adj.at(dest*num_elements+from) += 1; + adj.at(from * num_elements + dest) += 1; + adj.at(dest * num_elements + from) += 1; } } } - std::vector> topo_temp(num_elements, - std::vector()); + std::vector> topo_temp(num_elements, std::vector()); if (kLogTree) { for (int i = 0; i < num_elements; ++i) - PrintTopo("Tree "+std::to_string(i), (*topo)[i], (*scan)[i]); + PrintTopo("Tree " + std::to_string(i), (*topo)[i], (*scan)[i]); PrintMatrix("W", W, num_elements, num_elements); PrintMatrix("Links", adj, num_elements, num_elements); diff --git a/src/kvstore/gradient_compression-inl.h b/src/kvstore/gradient_compression-inl.h index 7b906fd9f234..f50cb71c5912 100644 --- a/src/kvstore/gradient_compression-inl.h +++ b/src/kvstore/gradient_compression-inl.h @@ -32,31 +32,35 @@ namespace mxnet { namespace kvstore { // these gpu functions are defined in gradient_compression.cu -void Quantize1BitImpl(mshadow::Stream *s, const std::vector &inputs, +void Quantize1BitImpl(mshadow::Stream* s, + const std::vector& inputs, const float threshold); -void Dequantize1BitImpl(mshadow::Stream *s, const std::vector &inputs, +void Dequantize1BitImpl(mshadow::Stream* s, + const std::vector& inputs, const float threshold); -void Quantize2BitImpl(mshadow::Stream *s, const std::vector &inputs, +void Quantize2BitImpl(mshadow::Stream* s, + const std::vector& inputs, const float threshold); -void Dequantize2BitImpl(mshadow::Stream *s, const std::vector &inputs, +void Dequantize2BitImpl(mshadow::Stream* s, + const std::vector& inputs, const float threshold); struct quantize_1bit { MSHADOW_XINLINE static void Map(int out_byte_id, int original_size, - float *out, - float *grad, - float *residual, + float* out, + float* grad, + float* residual, const float threshold) { // this byte contains the compressed representation of // upto 8 values starting from (char*)out + out_byte_id - char *compr_byte = reinterpret_cast(out) + out_byte_id; + char* compr_byte = reinterpret_cast(out) + out_byte_id; // init to 0 *compr_byte = 0; // start and end are indices in original grad array const int start = out_byte_id << 3; - const int end = (start + 8 <= original_size) ? start + 8 : original_size; + const int end = (start + 8 <= original_size) ? start + 8 : original_size; // masks used to quantize data const uint8_t bits[] = {0x80, 0x40, 0x20, 0x10, 0x08, 0x04, 0x02, 0x01}; @@ -78,35 +82,33 @@ struct quantize_1bit { } }; -template -void Quantize1BitKernelLaunch(mshadow::Stream *s, const std::vector &inputs, +template +void Quantize1BitKernelLaunch(mshadow::Stream* s, + const std::vector& inputs, const float threshold) { - mxnet::op::mxnet_op::Kernel - ::Launch(s, - inputs[2].Size() * 4, // compressed array byte size - inputs[0].Size(), // original size - inputs[2].dptr(), // compressed array - inputs[0].dptr(), // original array - inputs[1].dptr(), // residual array - threshold); // threshold + mxnet::op::mxnet_op::Kernel::Launch( + s, + inputs[2].Size() * 4, // compressed array byte size + inputs[0].Size(), // original size + inputs[2].dptr(), // compressed array + inputs[0].dptr(), // original array + inputs[1].dptr(), // residual array + threshold); // threshold } struct dequantize_1bit { - MSHADOW_XINLINE static void Map(int i, - float *out, - float *in, - const float threshold) { + MSHADOW_XINLINE static void Map(int i, float* out, float* in, const float threshold) { // get position of dequantized value to fill - float *outval = out + i; + float* outval = out + i; // gets byte which holds quantized value for this position - char *ch_ptr = reinterpret_cast < char * > (in + (i >> 5)); + char* ch_ptr = reinterpret_cast(in + (i >> 5)); ch_ptr += ((i & 31) >> 3); // masks used to quantize data const uint8_t bits[] = {0x80, 0x40, 0x20, 0x10, 0x08, 0x04, 0x02, 0x01}; // col denotes which bit of a byte is set for this value // col=0 implies the first bit, col=1 implies the second bit,... - const int col = i & 7; - const uint8_t mask = bits[col]; + const int col = i & 7; + const uint8_t mask = bits[col]; const uint8_t masked = *ch_ptr & mask; if (masked == mask) { *outval = +1; @@ -118,33 +120,34 @@ struct dequantize_1bit { } }; -template -void Dequantize1BitKernelLaunch(mshadow::Stream *s, const std::vector &inputs, +template +void Dequantize1BitKernelLaunch(mshadow::Stream* s, + const std::vector& inputs, const float threshold) { - mxnet::op::mxnet_op::Kernel - ::Launch(s, - inputs[1].Size(), // original size - inputs[1].dptr(), // out array - inputs[0].dptr(), // compressed array - threshold); // threshold + mxnet::op::mxnet_op::Kernel::Launch( + s, + inputs[1].Size(), // original size + inputs[1].dptr(), // out array + inputs[0].dptr(), // compressed array + threshold); // threshold } struct quantize_2bit { MSHADOW_XINLINE static void Map(int out_byte_id, int original_size, - float *out, - float *grad, - float *residual, + float* out, + float* grad, + float* residual, const float neg_threshold, const float pos_threshold) { // this block contains the compressed representation of // upto 4 values starting from (char*)out + out_byte_id - char *compr_byte = reinterpret_cast(out) + out_byte_id; + char* compr_byte = reinterpret_cast(out) + out_byte_id; // init to 0 *compr_byte = 0; // start and end are indices in original grad array const int start = out_byte_id << 2; - const int end = (start + 4 <= original_size) ? start + 4 : original_size; + const int end = (start + 4 <= original_size) ? start + 4 : original_size; // masks to set bits when value meets pos_threshold // 0xc0 is mask when value is to be represented by the first two bits in a char* @@ -169,40 +172,41 @@ struct quantize_2bit { } }; -template -void Quantize2BitKernelLaunch(mshadow::Stream *s, const std::vector &inputs, +template +void Quantize2BitKernelLaunch(mshadow::Stream* s, + const std::vector& inputs, const float threshold) { - mxnet::op::mxnet_op::Kernel - ::Launch(s, - inputs[2].Size() * 4, // compressed array byte size - inputs[0].Size(), // original size - inputs[2].dptr(), // compressed array - inputs[0].dptr(), // original array - inputs[1].dptr(), // residual array - -1 *threshold, // negative threshold - threshold); // positive threshold + mxnet::op::mxnet_op::Kernel::Launch( + s, + inputs[2].Size() * 4, // compressed array byte size + inputs[0].Size(), // original size + inputs[2].dptr(), // compressed array + inputs[0].dptr(), // original array + inputs[1].dptr(), // residual array + -1 * threshold, // negative threshold + threshold); // positive threshold } struct dequantize_2bit { MSHADOW_XINLINE static void Map(int i, - float *out, - float *in, + float* out, + float* in, const float neg_threshold, const float pos_threshold) { // get position of dequantized value to fill - float *outval = out + i; + float* outval = out + i; // gets byte which holds quantized value for this position - char *ch_ptr = reinterpret_cast(in + (i >> 4)); + char* ch_ptr = reinterpret_cast(in + (i >> 4)); ch_ptr += ((i & 15) >> 2); // masks used to quantize data const uint8_t posbits[] = {0xc0, 0x30, 0x0c, 0x03}; const uint8_t negbits[] = {0x80, 0x20, 0x08, 0x02}; // col denotes which two bits of a byte are set for this value // col=0 implies first two bits, col=3 implies last two bits,... - const int col = i & 3; - const uint8_t mask = posbits[col]; + const int col = i & 3; + const uint8_t mask = posbits[col]; const uint8_t negmask = negbits[col]; - const uint8_t masked = *ch_ptr & mask; + const uint8_t masked = *ch_ptr & mask; if (masked == mask) { *outval = pos_threshold; } else if (masked == negmask) { @@ -215,38 +219,39 @@ struct dequantize_2bit { } }; -template -void Dequantize2BitKernelLaunch(mshadow::Stream *s, const std::vector &inputs, +template +void Dequantize2BitKernelLaunch(mshadow::Stream* s, + const std::vector& inputs, const float threshold) { - mxnet::op::mxnet_op::Kernel - ::Launch(s, - inputs[1].Size(), // original size - inputs[1].dptr(), // out array - inputs[0].dptr(), // compressed array - -1 *threshold, // negative threshold - threshold); // positive threshold + mxnet::op::mxnet_op::Kernel::Launch( + s, + inputs[1].Size(), // original size + inputs[1].dptr(), // out array + inputs[0].dptr(), // compressed array + -1 * threshold, // negative threshold + threshold); // positive threshold } -inline void Quantize1BitImpl(mshadow::Stream *s, - const std::vector &inputs, +inline void Quantize1BitImpl(mshadow::Stream* s, + const std::vector& inputs, const float threshold) { Quantize1BitKernelLaunch(s, inputs, threshold); } -inline void Dequantize1BitImpl(mshadow::Stream *s, - const std::vector &inputs, +inline void Dequantize1BitImpl(mshadow::Stream* s, + const std::vector& inputs, const float threshold) { Dequantize1BitKernelLaunch(s, inputs, threshold); } -inline void Quantize2BitImpl(mshadow::Stream *s, - const std::vector &inputs, +inline void Quantize2BitImpl(mshadow::Stream* s, + const std::vector& inputs, const float threshold) { Quantize2BitKernelLaunch(s, inputs, threshold); } -inline void Dequantize2BitImpl(mshadow::Stream *s, - const std::vector &inputs, +inline void Dequantize2BitImpl(mshadow::Stream* s, + const std::vector& inputs, const float threshold) { Dequantize2BitKernelLaunch(s, inputs, threshold); } diff --git a/src/kvstore/gradient_compression.cc b/src/kvstore/gradient_compression.cc index d59035a77470..bfec4fe16ee1 100644 --- a/src/kvstore/gradient_compression.cc +++ b/src/kvstore/gradient_compression.cc @@ -37,8 +37,8 @@ GradientCompression::GradientCompression() { type_ = CompressionType::kNone; } -void GradientCompression::SetParams(const std::vector > - & kwargs) { +void GradientCompression::SetParams( + const std::vector >& kwargs) { GradientCompressionParam params; params.InitAllowUnknown(kwargs); if (params.type == "1bit") { @@ -60,12 +60,12 @@ std::string GradientCompression::get_type_str() { } void GradientCompression::SetOneBitCompression(const float threshold) { - type_ = CompressionType::kOneBit; + type_ = CompressionType::kOneBit; threshold_ = threshold; } void GradientCompression::SetTwoBitCompression(const float threshold) { - type_ = CompressionType::kTwoBit; + type_ = CompressionType::kTwoBit; threshold_ = threshold; } @@ -78,7 +78,7 @@ std::string GradientCompression::EncodeParams() { return rval; } -void GradientCompression::DecodeParams(const std::string &s) { +void GradientCompression::DecodeParams(const std::string& s) { std::vector elems; mxnet::kvstore::split(s, ',', std::back_inserter(elems)); type_ = static_cast(stoi(elems[0])); @@ -102,32 +102,44 @@ int GradientCompression::GetCompressionFactor() { int64_t GradientCompression::GetCompressedSize(const int64_t original_size) { const int bits = GetCompressionFactor(); - return ((original_size % bits == 0) ? - original_size / bits : - original_size / bits + 1); + return ((original_size % bits == 0) ? original_size / bits : original_size / bits + 1); } -void GradientCompression::Quantize(const mxnet::NDArray &from, mxnet::NDArray *to, - mxnet::NDArray *residual, const int priority) { +void GradientCompression::Quantize(const mxnet::NDArray& from, + mxnet::NDArray* to, + mxnet::NDArray* residual, + const int priority) { CHECK(shape_is_known(from.shape())) << "source operand has undefined shape"; CHECK(shape_is_known(to->shape())) << "destination operand has undefined shape"; CHECK(shape_is_known(residual->shape())) << "residual operand has undefined shape"; - const int a = from.ctx().dev_mask(); - const int b = to->ctx().dev_mask(); + const int a = from.ctx().dev_mask(); + const int b = to->ctx().dev_mask(); const float threshold = threshold_; if (a == mshadow::cpu::kDevMask && b == mshadow::cpu::kDevMask) { if (type_ == CompressionType::kOneBit) { - mxnet::Engine::Get()->PushSync([from, to, residual, threshold](mxnet::RunContext ctx) { - std::vector inputs = {from.data(), residual->data(), to->data()}; - Quantize1BitImpl(ctx.get_stream(), inputs, threshold); - }, from.ctx(), {from.var()}, {to->var(), residual->var()}, - mxnet::FnProperty::kNormal, priority, "QuantizeCPU"); + mxnet::Engine::Get()->PushSync( + [from, to, residual, threshold](mxnet::RunContext ctx) { + std::vector inputs = {from.data(), residual->data(), to->data()}; + Quantize1BitImpl(ctx.get_stream(), inputs, threshold); + }, + from.ctx(), + {from.var()}, + {to->var(), residual->var()}, + mxnet::FnProperty::kNormal, + priority, + "QuantizeCPU"); } else if (type_ == CompressionType::kTwoBit) { - mxnet::Engine::Get()->PushSync([from, to, residual, threshold](mxnet::RunContext ctx) { - std::vector inputs = {from.data(), residual->data(), to->data()}; - Quantize2BitImpl(ctx.get_stream(), inputs, threshold); - }, from.ctx(), {from.var()}, {to->var(), residual->var()}, - mxnet::FnProperty::kNormal, priority, "QuantizeCPU"); + mxnet::Engine::Get()->PushSync( + [from, to, residual, threshold](mxnet::RunContext ctx) { + std::vector inputs = {from.data(), residual->data(), to->data()}; + Quantize2BitImpl(ctx.get_stream(), inputs, threshold); + }, + from.ctx(), + {from.var()}, + {to->var(), residual->var()}, + mxnet::FnProperty::kNormal, + priority, + "QuantizeCPU"); } else { LOG(FATAL) << "Unsupported quantization of type " << get_type_str(); } @@ -135,26 +147,38 @@ void GradientCompression::Quantize(const mxnet::NDArray &from, mxnet::NDArray *t if (a == mshadow::gpu::kDevMask && b == mshadow::gpu::kDevMask) { #if MXNET_USE_CUDA if (type_ == CompressionType::kOneBit) { - mxnet::Engine::Get()->PushSync([from, to, residual, threshold](mxnet::RunContext ctx) { - std::vector inputs = {from.data(), residual->data(), to->data()}; - Quantize1BitImpl(ctx.get_stream(), inputs, threshold); - // Wait GPU kernel to complete - ctx.get_stream()->Wait(); - }, from.ctx(), {from.var()}, {to->var(), residual->var()}, - mxnet::FnProperty::kNormal, priority, "QuantizeGPU"); + mxnet::Engine::Get()->PushSync( + [from, to, residual, threshold](mxnet::RunContext ctx) { + std::vector inputs = {from.data(), residual->data(), to->data()}; + Quantize1BitImpl(ctx.get_stream(), inputs, threshold); + // Wait GPU kernel to complete + ctx.get_stream()->Wait(); + }, + from.ctx(), + {from.var()}, + {to->var(), residual->var()}, + mxnet::FnProperty::kNormal, + priority, + "QuantizeGPU"); } else if (type_ == CompressionType::kTwoBit) { - mxnet::Engine::Get()->PushSync([from, to, residual, threshold](mxnet::RunContext ctx) { - std::vector inputs = {from.data(), residual->data(), to->data()}; - Quantize2BitImpl(ctx.get_stream(), inputs, threshold); - // Wait GPU kernel to complete - ctx.get_stream()->Wait(); - }, from.ctx(), {from.var()}, {to->var(), residual->var()}, - mxnet::FnProperty::kNormal, priority, "QuantizeGPU"); + mxnet::Engine::Get()->PushSync( + [from, to, residual, threshold](mxnet::RunContext ctx) { + std::vector inputs = {from.data(), residual->data(), to->data()}; + Quantize2BitImpl(ctx.get_stream(), inputs, threshold); + // Wait GPU kernel to complete + ctx.get_stream()->Wait(); + }, + from.ctx(), + {from.var()}, + {to->var(), residual->var()}, + mxnet::FnProperty::kNormal, + priority, + "QuantizeGPU"); } else { LOG(FATAL) << "Unsupported quantization of type " << get_type_str(); } #else - LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; + LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; #endif } else { LOG(FATAL) << "Unknown device mask, from device mask " << a << " to device mask " << b; @@ -162,26 +186,39 @@ void GradientCompression::Quantize(const mxnet::NDArray &from, mxnet::NDArray *t } } -void GradientCompression::Dequantize(const mxnet::NDArray &from, mxnet::NDArray *to, +void GradientCompression::Dequantize(const mxnet::NDArray& from, + mxnet::NDArray* to, const int priority) { CHECK(shape_is_known(from.shape())) << "source operand has undefined shape"; CHECK(shape_is_known(to->shape())) << "destination operand has undefined shape"; - const int a = from.ctx().dev_mask(); - const int b = to->ctx().dev_mask(); + const int a = from.ctx().dev_mask(); + const int b = to->ctx().dev_mask(); const float threshold = threshold_; if (a == mshadow::cpu::kDevMask && b == mshadow::cpu::kDevMask) { if (type_ == CompressionType::kOneBit) { - mxnet::Engine::Get()->PushSync([from, to, threshold](mxnet::RunContext ctx) { - std::vector inputs = {from.data(), to->data()}; - Dequantize1BitImpl(ctx.get_stream(), inputs, threshold); - }, from.ctx(), {from.var()}, {to->var()}, - mxnet::FnProperty::kNormal, priority, "DequantizeCPU"); + mxnet::Engine::Get()->PushSync( + [from, to, threshold](mxnet::RunContext ctx) { + std::vector inputs = {from.data(), to->data()}; + Dequantize1BitImpl(ctx.get_stream(), inputs, threshold); + }, + from.ctx(), + {from.var()}, + {to->var()}, + mxnet::FnProperty::kNormal, + priority, + "DequantizeCPU"); } else if (type_ == CompressionType::kTwoBit) { - mxnet::Engine::Get()->PushSync([from, to, threshold](mxnet::RunContext ctx) { - std::vector inputs = {from.data(), to->data()}; - Dequantize2BitImpl(ctx.get_stream(), inputs, threshold); - }, from.ctx(), {from.var()}, {to->var()}, - mxnet::FnProperty::kNormal, priority, "DequantizeCPU"); + mxnet::Engine::Get()->PushSync( + [from, to, threshold](mxnet::RunContext ctx) { + std::vector inputs = {from.data(), to->data()}; + Dequantize2BitImpl(ctx.get_stream(), inputs, threshold); + }, + from.ctx(), + {from.var()}, + {to->var()}, + mxnet::FnProperty::kNormal, + priority, + "DequantizeCPU"); } else { LOG(FATAL) << "Unsupported dequantization of type " << get_type_str(); } @@ -189,26 +226,38 @@ void GradientCompression::Dequantize(const mxnet::NDArray &from, mxnet::NDArray if (a == mshadow::gpu::kDevMask && b == mshadow::gpu::kDevMask) { #if MXNET_USE_CUDA if (type_ == CompressionType::kOneBit) { - mxnet::Engine::Get()->PushSync([from, to, threshold](mxnet::RunContext ctx) { - std::vector inputs = {from.data(), to->data()}; - Dequantize1BitImpl(ctx.get_stream(), inputs, threshold); - // Wait GPU kernel to complete - ctx.get_stream()->Wait(); - }, from.ctx(), {from.var()}, {to->var()}, - mxnet::FnProperty::kNormal, priority, "DequantizeGPU"); + mxnet::Engine::Get()->PushSync( + [from, to, threshold](mxnet::RunContext ctx) { + std::vector inputs = {from.data(), to->data()}; + Dequantize1BitImpl(ctx.get_stream(), inputs, threshold); + // Wait GPU kernel to complete + ctx.get_stream()->Wait(); + }, + from.ctx(), + {from.var()}, + {to->var()}, + mxnet::FnProperty::kNormal, + priority, + "DequantizeGPU"); } else if (type_ == CompressionType::kTwoBit) { - mxnet::Engine::Get()->PushSync([from, to, threshold](mxnet::RunContext ctx) { - std::vector inputs = {from.data(), to->data()}; - Dequantize2BitImpl(ctx.get_stream(), inputs, threshold); - // Wait GPU kernel to completes - ctx.get_stream()->Wait(); - }, from.ctx(), {from.var()}, {to->var()}, - mxnet::FnProperty::kNormal, priority, "DequantizeGPU"); + mxnet::Engine::Get()->PushSync( + [from, to, threshold](mxnet::RunContext ctx) { + std::vector inputs = {from.data(), to->data()}; + Dequantize2BitImpl(ctx.get_stream(), inputs, threshold); + // Wait GPU kernel to completes + ctx.get_stream()->Wait(); + }, + from.ctx(), + {from.var()}, + {to->var()}, + mxnet::FnProperty::kNormal, + priority, + "DequantizeGPU"); } else { LOG(FATAL) << "Unsupported dequantization of type " << get_type_str(); } #else - LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; + LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; #endif } else { LOG(FATAL) << "Unknown device mask, from device mask " << a << " to device mask " << b; @@ -217,4 +266,3 @@ void GradientCompression::Dequantize(const mxnet::NDArray &from, mxnet::NDArray } } // namespace kvstore } // namespace mxnet - diff --git a/src/kvstore/gradient_compression.cu b/src/kvstore/gradient_compression.cu index c5bacc227306..389623570665 100644 --- a/src/kvstore/gradient_compression.cu +++ b/src/kvstore/gradient_compression.cu @@ -27,22 +27,26 @@ namespace mxnet { namespace kvstore { -void Quantize1BitImpl(mshadow::Stream* s, const std::vector& inputs, +void Quantize1BitImpl(mshadow::Stream* s, + const std::vector& inputs, const float threshold) { Quantize1BitKernelLaunch(s, inputs, threshold); } -void Dequantize1BitImpl(mshadow::Stream* s, const std::vector& inputs, +void Dequantize1BitImpl(mshadow::Stream* s, + const std::vector& inputs, const float threshold) { Dequantize1BitKernelLaunch(s, inputs, threshold); } -void Quantize2BitImpl(mshadow::Stream* s, const std::vector& inputs, +void Quantize2BitImpl(mshadow::Stream* s, + const std::vector& inputs, const float threshold) { Quantize2BitKernelLaunch(s, inputs, threshold); } -void Dequantize2BitImpl(mshadow::Stream* s, const std::vector& inputs, +void Dequantize2BitImpl(mshadow::Stream* s, + const std::vector& inputs, const float threshold) { Dequantize2BitKernelLaunch(s, inputs, threshold); } diff --git a/src/kvstore/gradient_compression.h b/src/kvstore/gradient_compression.h index 5496ada31bba..efd449f7eda1 100644 --- a/src/kvstore/gradient_compression.h +++ b/src/kvstore/gradient_compression.h @@ -34,18 +34,16 @@ namespace mxnet { namespace kvstore { -enum class CompressionType { - kNone, kOneBit, kTwoBit -}; +enum class CompressionType { kNone, kOneBit, kTwoBit }; struct GradientCompressionParam : public dmlc::Parameter { std::string type; float threshold; DMLC_DECLARE_PARAMETER(GradientCompressionParam) { - DMLC_DECLARE_FIELD(type) - .describe("Type of gradient compression to use, like `2bit` for example"); - DMLC_DECLARE_FIELD(threshold).set_default(0.5) - .describe("Threshold to use for 2bit gradient compression"); + DMLC_DECLARE_FIELD(type).describe( + "Type of gradient compression to use, like `2bit` for example"); + DMLC_DECLARE_FIELD(threshold).set_default(0.5).describe( + "Threshold to use for 2bit gradient compression"); } }; @@ -92,7 +90,7 @@ class GradientCompression { /*! * \brief decodes parameters of gc from a string and assigns them to member variables */ - void DecodeParams(const std::string &s); + void DecodeParams(const std::string& s); /*! * \brief returns compression factor, which is the factor by which size of gradient @@ -106,25 +104,27 @@ class GradientCompression { int64_t GetCompressedSize(const int64_t original_size); /*! - * \brief Issues quantize operation to be scheduled by the engine - * Compresses `from` into `to` and accumulates the quantization error - * into 'residual', using the quantization of type `type_` - * \param from the ndarray containing original data to be quantized - * \param to the target ndarray which contains quantized data - * \param residual the ndarray which accumulates quantization error - * \param priority Priority of the action. - */ - void Quantize(const mxnet::NDArray &from, mxnet::NDArray *to, - mxnet::NDArray *residual, const int priority); + * \brief Issues quantize operation to be scheduled by the engine + * Compresses `from` into `to` and accumulates the quantization error + * into 'residual', using the quantization of type `type_` + * \param from the ndarray containing original data to be quantized + * \param to the target ndarray which contains quantized data + * \param residual the ndarray which accumulates quantization error + * \param priority Priority of the action. + */ + void Quantize(const mxnet::NDArray& from, + mxnet::NDArray* to, + mxnet::NDArray* residual, + const int priority); /*! - * \brief Issues dequantize operation to be scheduled by the engine - * Decompresses `from` into `to` using current parameters of `type` and `threshold` - * \param from the ndarray containing quantized data - * \param to the target ndarray which contains final dequantized data - * \param priority Priority of the action. - */ - void Dequantize(const mxnet::NDArray &from, mxnet::NDArray *to, const int priority); + * \brief Issues dequantize operation to be scheduled by the engine + * Decompresses `from` into `to` using current parameters of `type` and `threshold` + * \param from the ndarray containing quantized data + * \param to the target ndarray which contains final dequantized data + * \param priority Priority of the action. + */ + void Dequantize(const mxnet::NDArray& from, mxnet::NDArray* to, const int priority); private: /*! diff --git a/src/kvstore/kvstore.cc b/src/kvstore/kvstore.cc index 87daad60ccc2..fb8d31f7ae26 100644 --- a/src/kvstore/kvstore.cc +++ b/src/kvstore/kvstore.cc @@ -39,12 +39,12 @@ std::atomic mxnet::kvstore::KVStoreDist::customer_id_{0}; namespace mxnet { -KVStore* KVStore::Create(const char *type_name) { +KVStore* KVStore::Create(const char* type_name) { std::string tname = type_name; std::transform(tname.begin(), tname.end(), tname.begin(), ::tolower); - KVStore* kv = nullptr; + KVStore* kv = nullptr; bool use_device_comm = false; - auto has = [tname](const std::string& pattern) { + auto has = [tname](const std::string& pattern) { return tname.find(pattern) != std::string::npos; }; if (has("device")) { @@ -77,7 +77,7 @@ KVStore* KVStore::Create(const char *type_name) { return nullptr; #endif } else { - kv = new kvstore::KVStoreLocal(use_device_comm); + kv = new kvstore::KVStoreLocal(use_device_comm); } } kv->type_ = tname; diff --git a/src/kvstore/kvstore_dist.h b/src/kvstore/kvstore_dist.h index 28bf19be561d..876888611b9a 100644 --- a/src/kvstore/kvstore_dist.h +++ b/src/kvstore/kvstore_dist.h @@ -47,16 +47,15 @@ class KVStoreDist : public KVStoreLocal { : KVStoreLocal(use_device_comm), ps_worker_(nullptr), server_(nullptr) { if (IsWorkerNode()) { int new_customer_id = GetNewCustomerId(); - ps_worker_ = new ps::KVWorker(0, new_customer_id); + ps_worker_ = new ps::KVWorker(0, new_customer_id); ps::StartAsync(new_customer_id, "mxnet\0"); if (!ps::Postoffice::Get()->is_recovery()) { - ps::Postoffice::Get()->Barrier( - new_customer_id, - ps::kWorkerGroup + ps::kServerGroup + ps::kScheduler); + ps::Postoffice::Get()->Barrier(new_customer_id, + ps::kWorkerGroup + ps::kServerGroup + ps::kScheduler); } } bigarray_bound_ = dmlc::GetEnv("MXNET_KVSTORE_BIGARRAY_BOUND", 1000 * 1000); - log_verbose_ = dmlc::GetEnv("MXNET_KVSTORE_DIST_ROW_SPARSE_VERBOSE", false); + log_verbose_ = dmlc::GetEnv("MXNET_KVSTORE_DIST_ROW_SPARSE_VERBOSE", false); } virtual ~KVStoreDist() { @@ -84,8 +83,8 @@ class KVStoreDist : public KVStoreLocal { } } - void SetGradientCompression(const std::vector > - & kwargs) override { + void SetGradientCompression( + const std::vector>& kwargs) override { KVStoreLocal::SetGradientCompression(kwargs); if (get_rank() == 0) { SendCommandToServers(static_cast(CommandType::kSetGradientCompression), @@ -101,28 +100,31 @@ class KVStoreDist : public KVStoreLocal { } } - void Barrier() override { ps::Postoffice::Get()->Barrier(ps_worker_->get_customer()->customer_id(), ps::kWorkerGroup); } - void SendCommandToServers(int cmd_id, - const std::string& cmd_body) override { + void SendCommandToServers(int cmd_id, const std::string& cmd_body) override { CHECK_NOTNULL(ps_worker_); ps_worker_->Wait(ps_worker_->Request(cmd_id, cmd_body, ps::kServerGroup)); } - int get_group_size() const override { return ps::NumWorkers(); } + int get_group_size() const override { + return ps::NumWorkers(); + } - int get_rank() const override { return ps::MyRank(); } + int get_rank() const override { + return ps::MyRank(); + } int get_num_dead_node(int node_id, int timeout) const override { - int number = 0; - auto dead_nodes = ps::Postoffice::Get()->GetDeadNodes(timeout); + int number = 0; + auto dead_nodes = ps::Postoffice::Get()->GetDeadNodes(timeout); const auto& watch_nodes = ps::Postoffice::Get()->GetNodeIDs(node_id); std::unordered_set watch_set(watch_nodes.begin(), watch_nodes.end()); for (int r : dead_nodes) { - if (watch_set.find(r) != watch_set.end()) number++; + if (watch_set.find(r) != watch_set.end()) + number++; } return number; } @@ -136,10 +138,10 @@ class KVStoreDist : public KVStoreLocal { ps::StartAsync(0, "mxnet_server\0"); if (!ps::Postoffice::Get()->is_recovery()) { - ps::Postoffice::Get()->Barrier(0, - ps::kWorkerGroup + ps::kServerGroup + ps::kScheduler); + ps::Postoffice::Get()->Barrier(0, ps::kWorkerGroup + ps::kServerGroup + ps::kScheduler); } - if (server_) server_->Run(); + if (server_) + server_->Run(); ps::Finalize(0, true); delete server_; server_ = nullptr; @@ -161,7 +163,7 @@ class KVStoreDist : public KVStoreLocal { */ struct PSKV { ps::SArray keys; // n keys - ps::SArray lens; // the length of the i-th value + ps::SArray lens; // the length of the i-th value int size; }; @@ -190,8 +192,7 @@ class KVStoreDist : public KVStoreLocal { return customer_id_++; } - void InitImpl(const std::vector& keys, - const std::vector& values) override { + void InitImpl(const std::vector& keys, const std::vector& values) override { CheckUnique(keys); for (size_t i = 0; i < keys.size(); ++i) { InitKV(keys[i], values[i]); @@ -227,13 +228,11 @@ class KVStoreDist : public KVStoreLocal { GroupKVPairsPush(vkeys, values, &uniq_vkeys, &grouped_vals, false); GroupKVPairsPull(okeys, outputs, &uniq_okeys, &grouped_outs, true); - CHECK_EQ(uniq_vkeys.size(), uniq_okeys.size()) - << "List of push and pull keys are different"; + CHECK_EQ(uniq_vkeys.size(), uniq_okeys.size()) << "List of push and pull keys are different"; for (size_t i = 0; i < uniq_vkeys.size(); ++i) { - CHECK_EQ(uniq_vkeys[i], uniq_okeys[i]) - << "Mismatch in push and pull key"; - int key = uniq_vkeys[i]; + CHECK_EQ(uniq_vkeys[i], uniq_okeys[i]) << "Mismatch in push and pull key"; + int key = uniq_vkeys[i]; const auto& vals = grouped_vals[i]; const auto& outs = grouped_outs[i]; @@ -241,16 +240,14 @@ class KVStoreDist : public KVStoreLocal { const auto push_stype = merged.storage_type(); const auto pull_stype = outs[0]->storage_type(); - CHECK_EQ(push_stype, kDefaultStorage) - << "Expected push_stype of value to be kDefaultStorage"; - CHECK_EQ(pull_stype, kDefaultStorage) - << "Expected pull_stype of value to be kDefaultStorage"; + CHECK_EQ(push_stype, kDefaultStorage) << "Expected push_stype of value to be kDefaultStorage"; + CHECK_EQ(pull_stype, kDefaultStorage) << "Expected pull_stype of value to be kDefaultStorage"; const int push_dtype = merged.dtype(); const int pull_dtype = outs[0]->dtype(); CHECK_EQ(push_dtype, pull_dtype) << "Output buffer dtype is different"; - auto &comm_buf = comm_buf_[key]; + auto& comm_buf = comm_buf_[key]; if (merged.ctx().dev_mask() == cpu::kDevMask) { comm_buf = merged; // avoid memory copy } else { @@ -261,7 +258,7 @@ class KVStoreDist : public KVStoreLocal { } CHECK(gradient_compression_->get_type() == CompressionType::kNone) - << "Compression not supported with PushPull"; + << "Compression not supported with PushPull"; PushPullDefault(key, comm_buf, priority); comm_->Broadcast(key, comm_buf, outs, priority); } @@ -275,24 +272,24 @@ class KVStoreDist : public KVStoreLocal { void PullImpl(const std::vector& keys, const std::vector& values, - int priority, bool ignore_sparse) override { + int priority, + bool ignore_sparse) override { CHECK(ignore_sparse) << "dist kvstore pull doesn't support ignore_sparse=False"; std::vector uniq_keys; - std::vector > grouped_vals; + std::vector> grouped_vals; GroupKVPairsPull(keys, values, &uniq_keys, &grouped_vals, true); for (size_t i = 0; i < uniq_keys.size(); ++i) { int key = uniq_keys[i]; // use the same array for merging to guarantee that pull always happens // after the previous push on this key - auto& recv_buf = comm_buf_[key]; + auto& recv_buf = comm_buf_[key]; const auto storage_type = grouped_vals[i][0]->storage_type(); - CHECK_EQ(storage_type, kDefaultStorage) - << "Expected stype of value to be kDefaultStorage"; + CHECK_EQ(storage_type, kDefaultStorage) << "Expected stype of value to be kDefaultStorage"; if (recv_buf.is_none()) { // it may happen for the first time a no-rank-0 worker pull the weight. - recv_buf = NDArray(grouped_vals[i][0]->shape(), pinned_ctx_, - true, grouped_vals[i][0]->dtype()); + recv_buf = + NDArray(grouped_vals[i][0]->shape(), pinned_ctx_, true, grouped_vals[i][0]->dtype()); } PullDefault(key, recv_buf, priority); @@ -311,20 +308,23 @@ class KVStoreDist : public KVStoreLocal { int key = uniq_keys[i]; // use the same array for merging to guarantee that pull always happens // after the previous push on this key - auto& recv_buf = comm_buf_[key]; + auto& recv_buf = comm_buf_[key]; auto& grouped_val_rowid = grouped_val_rowids[i]; const auto storage_type = grouped_val_rowid[0].first->storage_type(); CHECK_EQ(storage_type, kRowSparseStorage) - << "expected kRowSparseStorage, but got " << storage_type; + << "expected kRowSparseStorage, but got " << storage_type; if (recv_buf.is_none()) { // it may happen for the first time a no-rank-0 worker pull the weight. - recv_buf = NDArray(storage_type, grouped_val_rowid[0].first->shape(), - pinned_ctx_, true, grouped_val_rowid[0].first->dtype()); + recv_buf = NDArray(storage_type, + grouped_val_rowid[0].first->shape(), + pinned_ctx_, + true, + grouped_val_rowid[0].first->dtype()); } - auto &target_val_rowids = grouped_val_rowids[i]; - const size_t num_vals = target_val_rowids.size(); + auto& target_val_rowids = grouped_val_rowids[i]; + const size_t num_vals = target_val_rowids.size(); for (size_t i = 0; i < num_vals; i++) { - auto &row_id = target_val_rowids[i].second; + auto& row_id = target_val_rowids[i].second; target_val_rowids[i].second = Unique(row_id, pinned_ctx_, 0); } CHECK_EQ(num_vals, 1) << "RowSparsePull with multiple values is not supported yet"; @@ -334,8 +334,8 @@ class KVStoreDist : public KVStoreLocal { // Directly broadcast w/o rowids if num_vals == 1 auto get_val = [](const std::pair& p) { return p.first; }; std::vector grouped_val(grouped_val_rowid.size()); - std::transform(grouped_val_rowid.begin(), grouped_val_rowid.end(), - grouped_val.begin(), get_val); + std::transform( + grouped_val_rowid.begin(), grouped_val_rowid.end(), grouped_val.begin(), get_val); comm_->Broadcast(key, recv_buf, grouped_val, priority); } } @@ -346,17 +346,17 @@ class KVStoreDist : public KVStoreLocal { bool do_merge) { // first aggregate the values over keys std::vector uniq_keys; - std::vector > grouped_vals; + std::vector> grouped_vals; GroupKVPairsPush(keys, values, &uniq_keys, &grouped_vals, false); for (size_t i = 0; i < uniq_keys.size(); ++i) { // merge over devices - int key = uniq_keys[i]; + int key = uniq_keys[i]; const auto& vals = grouped_vals[i]; - NDArray merged = do_merge ? comm_->Reduce(key, vals, priority) : vals[0]; + NDArray merged = do_merge ? comm_->Reduce(key, vals, priority) : vals[0]; const auto storage_type = merged.storage_type(); - auto &comm_buf = comm_buf_[key]; + auto& comm_buf = comm_buf_[key]; if (merged.ctx().dev_mask() == cpu::kDevMask) { // Start of a push doesn't guarantee that the previous pushes are completed. // This shouldn't affect training of networks though because training involves @@ -373,7 +373,7 @@ class KVStoreDist : public KVStoreLocal { } CopyFromTo(merged, &comm_buf); } - const int dtype = merged.dtype(); + const int dtype = merged.dtype(); const int num_bytes = mshadow::mshadow_sizeof(dtype); // push to servers if (storage_type == kDefaultStorage) { @@ -387,7 +387,7 @@ class KVStoreDist : public KVStoreLocal { // detect whether the push is initialization of a key or not. // is_active is false when push is initialization of key bool is_active = do_merge; - PSKV &pskv = EncodeCompressedKey(key, comm_buf.shape().Size(), is_active, num_bytes); + PSKV& pskv = EncodeCompressedKey(key, comm_buf.shape().Size(), is_active, num_bytes); // Returns push_pskv if active, else pull_pskv // we want inactive gc to send uncompressed gradients, // but sharded in the same way as later pushes would when gc becomes active @@ -399,7 +399,7 @@ class KVStoreDist : public KVStoreLocal { } } else if (storage_type == kRowSparseStorage) { CHECK(gradient_compression_->get_type() == CompressionType::kNone) - << "Gradient compression for row sparse storage type is not supported"; + << "Gradient compression for row sparse storage type is not supported"; PushRowSparse(key, comm_buf, priority); } else { LOG(FATAL) << "unknown storage type"; @@ -408,197 +408,196 @@ class KVStoreDist : public KVStoreLocal { } virtual void PushCompressed(int key, const NDArray& comm_buf, const PSKV& pskv, int priority) { - auto &small_buf = compr_buf_[key]; - auto &res_buf = residual_[key]; + auto& small_buf = compr_buf_[key]; + auto& res_buf = residual_[key]; const size_t original_size = comm_buf.shape().Size(); - const int dtype = comm_buf.dtype(); + const int dtype = comm_buf.dtype(); // Init the small buffer and residual_ buffer for quantize if (small_buf.is_none()) { small_buf = NDArray(mxnet::TShape{pskv.size}, comm_buf.ctx(), false, dtype); - res_buf = NDArray(mxnet::TShape{static_cast(original_size)}, - comm_buf.ctx(), false, dtype); + res_buf = + NDArray(mxnet::TShape{static_cast(original_size)}, comm_buf.ctx(), false, dtype); res_buf = 0; } gradient_compression_->Quantize(comm_buf, &small_buf, &res_buf, priority); - auto push_to_servers = - [this, key, dtype, pskv, small_buf](RunContext rctx, Engine::CallbackOnComplete cb) { - size_t size = small_buf.shape().Size() * mshadow::mshadow_sizeof(dtype); - char* data = static_cast (small_buf.data().dptr_); - // do push. false means no delete - ps::SArray vals(data, size, false); - int cmd = GetCommandType(RequestType::kCompressedPushPull, dtype); - CHECK_NOTNULL(ps_worker_)->ZPush(pskv.keys, vals, pskv.lens, cmd, [cb]() { cb(); }); - }; + auto push_to_servers = [this, key, dtype, pskv, small_buf](RunContext rctx, + Engine::CallbackOnComplete cb) { + size_t size = small_buf.shape().Size() * mshadow::mshadow_sizeof(dtype); + char* data = static_cast(small_buf.data().dptr_); + // do push. false means no delete + ps::SArray vals(data, size, false); + int cmd = GetCommandType(RequestType::kCompressedPushPull, dtype); + CHECK_NOTNULL(ps_worker_)->ZPush(pskv.keys, vals, pskv.lens, cmd, [cb]() { cb(); }); + }; // acquire locks on both comm_buf and small_buf so that // pull (which uses comm_buf) for the same key waits till push finishes - Engine::Get()->PushAsync( - push_to_servers, - pinned_ctx_, - {small_buf.var(), comm_buf.var()}, - {}, - FnProperty::kNormal, - priority, - "KVStoreDistCompressedPush"); + Engine::Get()->PushAsync(push_to_servers, + pinned_ctx_, + {small_buf.var(), comm_buf.var()}, + {}, + FnProperty::kNormal, + priority, + "KVStoreDistCompressedPush"); } - virtual void PushDefault(int key, const NDArray &send_buf, const PSKV& pskv, int priority) { - auto push_to_servers = - [this, key, pskv, send_buf](RunContext rctx, Engine::CallbackOnComplete cb) { - const int dtype = send_buf.dtype(); - // convert to ps keys - const size_t size = send_buf.shape().Size() * mshadow::mshadow_sizeof(dtype); - char* data = static_cast(send_buf.data().dptr_); - // do push. false means no delete - ps::SArray vals(data, size, false); - int cmd = GetCommandType(RequestType::kDefaultPushPull, dtype); - CHECK_NOTNULL(ps_worker_)->ZPush( - pskv.keys, vals, pskv.lens, - cmd, [cb]() { cb(); }); - }; - Engine::Get()->PushAsync( - push_to_servers, - pinned_ctx_, - {send_buf.var()}, - {}, - FnProperty::kNormal, - priority, - "KVStoreDistDefaultPush"); + virtual void PushDefault(int key, const NDArray& send_buf, const PSKV& pskv, int priority) { + auto push_to_servers = [this, key, pskv, send_buf](RunContext rctx, + Engine::CallbackOnComplete cb) { + const int dtype = send_buf.dtype(); + // convert to ps keys + const size_t size = send_buf.shape().Size() * mshadow::mshadow_sizeof(dtype); + char* data = static_cast(send_buf.data().dptr_); + // do push. false means no delete + ps::SArray vals(data, size, false); + int cmd = GetCommandType(RequestType::kDefaultPushPull, dtype); + CHECK_NOTNULL(ps_worker_)->ZPush(pskv.keys, vals, pskv.lens, cmd, [cb]() { cb(); }); + }; + Engine::Get()->PushAsync(push_to_servers, + pinned_ctx_, + {send_buf.var()}, + {}, + FnProperty::kNormal, + priority, + "KVStoreDistDefaultPush"); } // push row sparse gradient - virtual void PushRowSparse(int key, const NDArray &send_buf, int priority) { + virtual void PushRowSparse(int key, const NDArray& send_buf, int priority) { using namespace rowsparse; - auto push_to_servers = [this, key, send_buf] - (RunContext rctx, Engine::CallbackOnComplete cb) { - char* data = static_cast(send_buf.data().dptr_); + auto push_to_servers = [this, key, send_buf](RunContext rctx, Engine::CallbackOnComplete cb) { + char* data = static_cast(send_buf.data().dptr_); const int64_t num_rows = send_buf.aux_shape(kIdx)[0]; - const auto offsets = send_buf.aux_data(kIdx).dptr(); - const auto unit_len = send_buf.shape().ProdShape(1, send_buf.shape().ndim()); - const int num_bytes = mshadow::mshadow_sizeof(send_buf.dtype()); - const int64_t size = num_rows * unit_len; - // convert to ps keys in row sparse format - PSKV& pskv = EncodeRowSparseKey(key, size, num_rows, offsets, - unit_len, send_buf.shape()[0], num_bytes); + const auto offsets = send_buf.aux_data(kIdx).dptr(); + const auto unit_len = send_buf.shape().ProdShape(1, send_buf.shape().ndim()); + const int num_bytes = mshadow::mshadow_sizeof(send_buf.dtype()); + const int64_t size = num_rows * unit_len; + // convert to ps keys in row sparse format + PSKV& pskv = EncodeRowSparseKey( + key, size, num_rows, offsets, unit_len, send_buf.shape()[0], num_bytes); if (this->log_verbose_) { - LOG(INFO) << "worker " << get_rank() << " push lens: " << pskv.lens << " keys: " - << pskv.keys << " size: " << size; + LOG(INFO) << "worker " << get_rank() << " push lens: " << pskv.lens + << " keys: " << pskv.keys << " size: " << size; } ps::SArray vals(data, size * num_bytes, false); const int cmd = GetCommandType(RequestType::kRowSparsePushPull, send_buf.dtype()); CHECK_NOTNULL(ps_worker_)->ZPush(pskv.keys, vals, pskv.lens, cmd, [cb]() { cb(); }); }; - Engine::Get()->PushAsync( - push_to_servers, - pinned_ctx_, - {send_buf.var()}, - {}, - FnProperty::kNormal, - priority, - "KVStoreDistRowSparsePush"); + Engine::Get()->PushAsync(push_to_servers, + pinned_ctx_, + {send_buf.var()}, + {}, + FnProperty::kNormal, + priority, + "KVStoreDistRowSparsePush"); } - virtual void PullDefault(int key, const NDArray &recv_buf, int priority) { - auto pull_from_servers = [this, key, recv_buf]( - RunContext rctx, Engine::CallbackOnComplete cb) { + virtual void PullDefault(int key, const NDArray& recv_buf, int priority) { + auto pull_from_servers = [this, key, recv_buf](RunContext rctx, Engine::CallbackOnComplete cb) { // convert to ps keys - size_t size = recv_buf.shape().Size(); - const int dtype = recv_buf.dtype(); + size_t size = recv_buf.shape().Size(); + const int dtype = recv_buf.dtype(); const int num_bytes = mshadow::mshadow_sizeof(dtype); - PSKV& pskv = (gradient_compression_->get_type() == CompressionType::kNone) ? - EncodeDefaultKey(key, size, num_bytes) : - EncodeCompressedKey(key, size, false, num_bytes); - char* data = static_cast (recv_buf.data().dptr_); + PSKV& pskv = (gradient_compression_->get_type() == CompressionType::kNone) + ? EncodeDefaultKey(key, size, num_bytes) + : EncodeCompressedKey(key, size, false, num_bytes); + char* data = static_cast(recv_buf.data().dptr_); // false means not to delete data when SArray is deleted auto vals = new ps::SArray(data, size * num_bytes, false); // issue pull - RequestType mode = (gradient_compression_->get_type() != CompressionType::kNone) ? - RequestType::kCompressedPushPull : RequestType::kDefaultPushPull; - const int cmd = GetCommandType(mode, dtype); - CHECK_NOTNULL(ps_worker_)->ZPull( - pskv.keys, vals, &pskv.lens, cmd, [vals, cb](){ delete vals; cb(); }); + RequestType mode = (gradient_compression_->get_type() != CompressionType::kNone) + ? RequestType::kCompressedPushPull + : RequestType::kDefaultPushPull; + const int cmd = GetCommandType(mode, dtype); + CHECK_NOTNULL(ps_worker_)->ZPull(pskv.keys, vals, &pskv.lens, cmd, [vals, cb]() { + delete vals; + cb(); + }); }; - CHECK_NOTNULL(Engine::Get())->PushAsync( - pull_from_servers, - pinned_ctx_, - {}, - {recv_buf.var()}, - FnProperty::kNormal, - priority, - "KVStoreDistDefaultStoragePull"); + CHECK_NOTNULL(Engine::Get()) + ->PushAsync(pull_from_servers, + pinned_ctx_, + {}, + {recv_buf.var()}, + FnProperty::kNormal, + priority, + "KVStoreDistDefaultStoragePull"); } // pull row sparse weight into `recv_buf` based on indices given by `indices` - virtual void PullRowSparse_(const int key, const NDArray& recv_buf, - const NDArray& indices, int priority) { + virtual void PullRowSparse_(const int key, + const NDArray& recv_buf, + const NDArray& indices, + int priority) { using namespace rowsparse; - auto pull_from_servers = [this, key, recv_buf, indices] - (RunContext rctx, Engine::CallbackOnComplete cb) { + auto pull_from_servers = [this, key, recv_buf, indices](RunContext rctx, + Engine::CallbackOnComplete cb) { // allocate memory for the buffer CHECK_EQ(indices.dtype(), mshadow::kInt64); - const TBlob idx_data = indices.data(); + const TBlob idx_data = indices.data(); const size_t num_rows = idx_data.shape_.Size(); recv_buf.CheckAndAlloc({mshadow::Shape1(num_rows)}); - const int dtype = recv_buf.dtype(); - char* data = static_cast(recv_buf.data().dptr_); - const auto offsets = idx_data.dptr(); + const int dtype = recv_buf.dtype(); + char* data = static_cast(recv_buf.data().dptr_); + const auto offsets = idx_data.dptr(); const auto unit_len = recv_buf.shape().ProdShape(1, recv_buf.shape().ndim()); - const int64_t size = num_rows * unit_len; + const int64_t size = num_rows * unit_len; const int num_bytes = mshadow::mshadow_sizeof(dtype); // convert to ps keys in row sparse format - PSKV& pskv = EncodeRowSparseKey(key, size, num_rows, offsets, - unit_len, recv_buf.shape()[0], - num_bytes); + PSKV& pskv = EncodeRowSparseKey( + key, size, num_rows, offsets, unit_len, recv_buf.shape()[0], num_bytes); if (this->log_verbose_) { - LOG(INFO) << "worker " << get_rank() << " pull lens: " << pskv.lens << " keys: " - << pskv.keys << " size: " << size; + LOG(INFO) << "worker " << get_rank() << " pull lens: " << pskv.lens + << " keys: " << pskv.keys << " size: " << size; } - auto vals = new ps::SArray(data, size * num_bytes, false); + auto vals = new ps::SArray(data, size * num_bytes, false); const int cmd = GetCommandType(RequestType::kRowSparsePushPull, recv_buf.dtype()); // copy indices to recv_buf. this needs to be done before ZPull // because after pull is done, the callback function returns and locks are released. // at this point, later functions may access the indices variable while copy happens mshadow::Copy(recv_buf.aux_data(kIdx).FlatTo1D(), idx_data.FlatTo1D()); - CHECK_NOTNULL(ps_worker_)->ZPull(pskv.keys, vals, &pskv.lens, - cmd, - [vals, cb]() { delete vals; cb(); }); + CHECK_NOTNULL(ps_worker_)->ZPull(pskv.keys, vals, &pskv.lens, cmd, [vals, cb]() { + delete vals; + cb(); + }); }; - CHECK_NOTNULL(Engine::Get())->PushAsync( - pull_from_servers, - pinned_ctx_, - {indices.var()}, - {recv_buf.var()}, - FnProperty::kNormal, - priority, - "KVStoreDistRowSparsePull"); + CHECK_NOTNULL(Engine::Get()) + ->PushAsync(pull_from_servers, + pinned_ctx_, + {indices.var()}, + {recv_buf.var()}, + FnProperty::kNormal, + priority, + "KVStoreDistRowSparsePull"); } - virtual void PushPullDefault(int key, const NDArray &comm_buf, int priority) { - auto pushpull = [this, key, comm_buf]( - RunContext rctx, Engine::CallbackOnComplete cb) { - size_t size = comm_buf.shape().Size(); - const int dtype = comm_buf.dtype(); + virtual void PushPullDefault(int key, const NDArray& comm_buf, int priority) { + auto pushpull = [this, key, comm_buf](RunContext rctx, Engine::CallbackOnComplete cb) { + size_t size = comm_buf.shape().Size(); + const int dtype = comm_buf.dtype(); const int num_bytes = mshadow::mshadow_sizeof(dtype); - const int cmd = GetCommandType(RequestType::kDefaultPushPull, dtype); + const int cmd = GetCommandType(RequestType::kDefaultPushPull, dtype); PSKV& pskv = EncodeDefaultKey(key, size, num_bytes); char* data = static_cast(comm_buf.data().dptr_); - auto vals = new ps::SArray(data, size * num_bytes, false); + auto vals = new ps::SArray(data, size * num_bytes, false); - CHECK_NOTNULL(ps_worker_)->ZPushPull( - pskv.keys, *vals, vals, &pskv.lens, cmd, [vals, cb](){ delete vals; cb(); }); + CHECK_NOTNULL(ps_worker_)->ZPushPull(pskv.keys, *vals, vals, &pskv.lens, cmd, [vals, cb]() { + delete vals; + cb(); + }); }; - CHECK_NOTNULL(Engine::Get())->PushAsync( - pushpull, - pinned_ctx_, - {}, - {comm_buf.var()}, - FnProperty::kNormal, - priority, - "KVStoreDistDefaultStoragePushPull"); + CHECK_NOTNULL(Engine::Get()) + ->PushAsync(pushpull, + pinned_ctx_, + {}, + {comm_buf.var()}, + FnProperty::kNormal, + priority, + "KVStoreDistDefaultStoragePushPull"); } /** @@ -606,7 +605,7 @@ class KVStoreDist : public KVStoreLocal { */ void CheckUnique(const std::vector& keys) { auto keys_copy = keys; - auto last = std::unique(keys_copy.begin(), keys_copy.end()); + auto last = std::unique(keys_copy.begin(), keys_copy.end()); CHECK_EQ(static_cast(std::distance(keys_copy.begin(), last)), static_cast(keys.size())); } @@ -618,7 +617,8 @@ class KVStoreDist : public KVStoreLocal { * \param num_bytes size of each element in number of bytes * \return PSKV used for both push and pull */ - virtual inline PSKV& EncodeDefaultKey(const int key, const size_t num_arr_elems, + virtual inline PSKV& EncodeDefaultKey(const int key, + const size_t num_arr_elems, const int num_bytes) { mu_.lock(); PSKV& pskv = ps_kv_[key]; @@ -626,16 +626,16 @@ class KVStoreDist : public KVStoreLocal { size_t pskv_size = num_arr_elems * num_bytes; if (!pskv.keys.empty()) { CHECK_EQ(static_cast(pskv.size), pskv_size) - << "The value size cannot be changed " << pskv_size << ". Key is " << key; + << "The value size cannot be changed " << pskv_size << ". Key is " << key; } else { - auto krs = ps::Postoffice::Get()->GetServerKeyRanges(); + auto krs = ps::Postoffice::Get()->GetServerKeyRanges(); const int num_servers = krs.size(); CHECK_GT(num_servers, 0); // a simple heuristic for load balance if (num_arr_elems < bigarray_bound_) { // send it to a single random picked server - int server = (key * 9973) % num_servers; + int server = (key * 9973) % num_servers; ps::Key ps_key = krs[server].begin() + key; CHECK_LT(ps_key, krs[server].end()); pskv.keys.push_back(ps_key); @@ -647,8 +647,9 @@ class KVStoreDist : public KVStoreLocal { pskv.size = 0; for (int i = 0; i < num_servers; ++i) { size_t part_size = - static_cast(round(static_cast(num_arr_elems)/num_servers*(i+1))) - - static_cast(round(static_cast(num_arr_elems)/num_servers*i)); + static_cast( + round(static_cast(num_arr_elems) / num_servers * (i + 1))) - + static_cast(round(static_cast(num_arr_elems) / num_servers * i)); ps::Key ps_key = krs[i].begin() + key; CHECK_LT(ps_key, krs[i].end()); pskv.keys.push_back(ps_key); @@ -672,9 +673,11 @@ class KVStoreDist : public KVStoreLocal { * \param num_bytes size of each element in number of bytes * \return PSKV used for both push and pull */ - virtual inline PSKV& EncodeCompressedKey(const int key, const size_t original_num_elem, - const bool is_push, const int num_bytes) { - auto krs = ps::Postoffice::Get()->GetServerKeyRanges(); + virtual inline PSKV& EncodeCompressedKey(const int key, + const size_t original_num_elem, + const bool is_push, + const int num_bytes) { + auto krs = ps::Postoffice::Get()->GetServerKeyRanges(); const int num_servers = krs.size(); CHECK_GT(num_servers, 0); @@ -686,8 +689,8 @@ class KVStoreDist : public KVStoreLocal { if (!pskv.keys.empty()) { const size_t num_elem = (is_push) ? compr_num_elem : original_num_elem; - CHECK_EQ(static_cast(pskv.size), num_elem * num_bytes) - << "The value size can't be changed. For key " << key; + CHECK_EQ(static_cast(pskv.size), num_elem * num_bytes) + << "The value size can't be changed. For key " << key; } else { // populate both pull and push pskvs // push pskv has sizes corresponding to compressed data @@ -701,7 +704,7 @@ class KVStoreDist : public KVStoreLocal { // a simple heuristic for load balancing // send it to a single random picked server const int server = (key * 9973) % num_servers; - ps::Key ps_key = krs[server].begin() + key; + ps::Key ps_key = krs[server].begin() + key; CHECK_LT(ps_key, krs[server].end()); // meta info push_pskv.keys.push_back(krs[server].begin() + original_num_elem); @@ -709,7 +712,7 @@ class KVStoreDist : public KVStoreLocal { // data push_pskv.keys.push_back(ps_key); pull_pskv.keys.push_back(ps_key); - const int compr_size = compr_num_elem * num_bytes; + const int compr_size = compr_num_elem * num_bytes; const int original_size = original_num_elem * num_bytes; push_pskv.lens.push_back(compr_size); pull_pskv.lens.push_back(original_size); @@ -722,13 +725,14 @@ class KVStoreDist : public KVStoreLocal { for (int i = 0; i < num_servers; ++i) { size_t part_compr, part_orig; - if (i == num_servers-1) { + if (i == num_servers - 1) { part_compr = compr_num_elem - push_pskv.size; - part_orig = original_num_elem - pull_pskv.size; + part_orig = original_num_elem - pull_pskv.size; } else { part_compr = - static_cast (round(static_cast(compr_num_elem)/num_servers*(i+1))) - - static_cast (round(static_cast(compr_num_elem)/num_servers*(i))); + static_cast( + round(static_cast(compr_num_elem) / num_servers * (i + 1))) - + static_cast(round(static_cast(compr_num_elem) / num_servers * (i))); part_orig = part_compr * gradient_compression_->GetCompressionFactor(); } @@ -755,15 +759,18 @@ class KVStoreDist : public KVStoreLocal { push_pskv.size *= num_bytes; pull_pskv.size *= num_bytes; CHECK_EQ(push_pskv.lens.size(), num_servers * 2); - } } + } return pskv; } // Note: this encoding method for row sparse keys doesn't allow cross-layer batching - virtual inline PSKV& EncodeRowSparseKey(const int key, const int64_t num_elem, - const int64_t num_rows, const int64_t *offsets, - const size_t unit_len, const int64_t total_num_rows, + virtual inline PSKV& EncodeRowSparseKey(const int key, + const int64_t num_elem, + const int64_t num_rows, + const int64_t* offsets, + const size_t unit_len, + const int64_t total_num_rows, const int num_bytes) { using namespace common; mu_.lock(); @@ -772,12 +779,12 @@ class KVStoreDist : public KVStoreLocal { pskv.keys.clear(); pskv.lens.clear(); // TODO(haibin) cache this information - auto krs = ps::Postoffice::Get()->GetServerKeyRanges(); + auto krs = ps::Postoffice::Get()->GetServerKeyRanges(); const int num_servers = krs.size(); CHECK_GT(num_servers, 0); if (total_num_rows * unit_len >= bigarray_bound_) { - pskv.size = 0; + pskv.size = 0; int64_t start_row = 0; // parition it to all servers for (int i = 0; i < num_servers; ++i) { @@ -787,8 +794,8 @@ class KVStoreDist : public KVStoreLocal { if (offsets && num_elem > 0) { // calculate partition ranges int64_t part_num_rows = - llround(static_cast(total_num_rows) / num_servers * (i + 1)) - - llround(static_cast(total_num_rows) / num_servers * i); + llround(static_cast(total_num_rows) / num_servers * (i + 1)) - + llround(static_cast(total_num_rows) / num_servers * i); auto end_row = start_row + part_num_rows; // search for offsets in [start_row, end_row) auto lb = std::lower_bound(offsets, offsets + num_rows, start_row); @@ -807,7 +814,7 @@ class KVStoreDist : public KVStoreLocal { CHECK_EQ(static_cast(pskv.size), num_elem * num_bytes); } else { // send it to a single random picked server - const int server = (key * 9973) % num_servers; + const int server = (key * 9973) % num_servers; ps::Key master_key = krs[server].begin() + key; pskv.keys.push_back(master_key); pskv.lens.push_back(0); @@ -853,5 +860,4 @@ class KVStoreDist : public KVStoreLocal { } // namespace kvstore } // namespace mxnet - #endif // MXNET_KVSTORE_KVSTORE_DIST_H_ diff --git a/src/kvstore/kvstore_dist_server.h b/src/kvstore/kvstore_dist_server.h index 1dc222c0d7da..caa094dadd6a 100644 --- a/src/kvstore/kvstore_dist_server.h +++ b/src/kvstore/kvstore_dist_server.h @@ -44,13 +44,15 @@ namespace kvstore { // maintain same order in frontend. enum class CommandType { - kController, kSetMultiPrecision, kStopServer, kSyncMode, - kSetGradientCompression, kSetProfilerParams + kController, + kSetMultiPrecision, + kStopServer, + kSyncMode, + kSetGradientCompression, + kSetProfilerParams }; -enum class RequestType { - kDefaultPushPull, kRowSparsePushPull, kCompressedPushPull -}; +enum class RequestType { kDefaultPushPull, kRowSparsePushPull, kCompressedPushPull }; struct DataHandleType { RequestType requestType; @@ -77,7 +79,7 @@ static int GetCommandType(RequestType requestType, int d) { * \return DataHandleType */ static DataHandleType DepairDataHandleType(int cmd) { - int w = std::floor((std::sqrt(8 * cmd + 1) - 1)/2); + int w = std::floor((std::sqrt(8 * cmd + 1) - 1) / 2); int t = ((w * w) + w) / 2; int y = cmd - t; int x = w - y; @@ -85,7 +87,7 @@ static DataHandleType DepairDataHandleType(int cmd) { CHECK_GE(y, 0); DataHandleType type; type.requestType = static_cast(x); - type.dtype = y; + type.dtype = y; return type; } @@ -100,7 +102,7 @@ class Executor { void Start() { std::unique_lock lk(mu_); while (true) { - cond_.wait(lk, [this]{return !queue_.empty();}); + cond_.wait(lk, [this] { return !queue_.empty(); }); Block blk = std::move(queue_.front()); queue_.pop(); lk.unlock(); @@ -109,7 +111,8 @@ class Executor { blk.f(); blk.p->set_value(); } else { - blk.p->set_value(); break; + blk.p->set_value(); + break; } lk.lock(); } @@ -143,7 +146,7 @@ class Executor { private: struct Block { - explicit Block(const Func& func) : f(func), p(std::make_shared>()) { } + explicit Block(const Func& func) : f(func), p(std::make_shared>()) {} Func f; std::shared_ptr> p; }; @@ -157,13 +160,12 @@ class KVStoreDistServer { KVStoreDistServer() { using namespace std::placeholders; ps_server_ = new ps::KVServer(0); - static_cast(ps_server_)->set_request_handle( - std::bind(&KVStoreDistServer::CommandHandle, this, _1, _2)); - ps_server_->set_request_handle( - std::bind(&KVStoreDistServer::DataHandleEx, this, _1, _2, _3)); - sync_mode_ = false; + static_cast(ps_server_) + ->set_request_handle(std::bind(&KVStoreDistServer::CommandHandle, this, _1, _2)); + ps_server_->set_request_handle(std::bind(&KVStoreDistServer::DataHandleEx, this, _1, _2, _3)); + sync_mode_ = false; gradient_compression_ = std::make_shared(); - log_verbose_ = dmlc::GetEnv("MXNET_KVSTORE_DIST_ROW_SPARSE_VERBOSE", false); + log_verbose_ = dmlc::GetEnv("MXNET_KVSTORE_DIST_ROW_SPARSE_VERBOSE", false); } ~KVStoreDistServer() { @@ -176,7 +178,7 @@ class KVStoreDistServer { controller_ = controller; } - void set_updater(const KVStore::Updater& updater) { + void set_updater(const KVStore::Updater& updater) { CHECK(updater); updater_ = updater; } @@ -210,9 +212,8 @@ class KVStoreDistServer { break; case CommandType::kSetProfilerParams: // last char is the type of profiler command - ProcessServerProfilerCommands(static_cast - (recved.body.back() - '0'), - recved.body); + ProcessServerProfilerCommands( + static_cast(recved.body.back() - '0'), recved.body); break; case CommandType::kSetMultiPrecision: // uses value 1 for message id from frontend @@ -225,9 +226,9 @@ class KVStoreDistServer { // this uses value 0 for message id from frontend // let the main thread to execute ctrl, which is necessary for python exec_.Exec([this, recved]() { - CHECK(controller_); - controller_(recved.head, recved.body); - }); + CHECK(controller_); + controller_(recved.head, recved.body); + }); break; } app->Response(recved); @@ -239,36 +240,39 @@ class KVStoreDistServer { * some keys are initialized before optimizer is set. */ void CreateMultiPrecisionCopies() { - for (auto const &stored_entry : store_) { - const int key = stored_entry.first; - const NDArray &stored = stored_entry.second; + for (auto const& stored_entry : store_) { + const int key = stored_entry.first; + const NDArray& stored = stored_entry.second; if (stored.dtype() != mshadow::kFloat32) { - auto &stored_realt = store_realt_[key]; + auto& stored_realt = store_realt_[key]; if (stored.storage_type() == kRowSparseStorage) { - stored_realt = NDArray(kRowSparseStorage, stored.shape(), stored.ctx(), - true, mshadow::kFloat32); + stored_realt = + NDArray(kRowSparseStorage, stored.shape(), stored.ctx(), true, mshadow::kFloat32); } else { stored_realt = NDArray(stored.shape(), stored.ctx(), false, mshadow::kFloat32); } - auto &update = update_buf_[key]; + auto& update = update_buf_[key]; if (!update.merged.is_none()) { if (update.merged.storage_type() == kRowSparseStorage) { - update.merged = NDArray(kRowSparseStorage, update.merged.shape(), update.merged.ctx(), - true, mshadow::kFloat32); - } else { - update.merged = NDArray(update.merged.shape(), update.merged.ctx(), false, + update.merged = NDArray(kRowSparseStorage, + update.merged.shape(), + update.merged.ctx(), + true, mshadow::kFloat32); + } else { + update.merged = + NDArray(update.merged.shape(), update.merged.ctx(), false, mshadow::kFloat32); } } CHECK(update.request.size() == 0) - << ps::MyRank() << "Multiprecision mode can not be set while pushes are underway." - << "Please set optimizer before pushing keys." << key << " " << update.request.size(); + << ps::MyRank() << "Multiprecision mode can not be set while pushes are underway." + << "Please set optimizer before pushing keys." << key << " " << update.request.size(); CopyFromTo(stored, stored_realt); } } - for (auto const &stored_realt_entry : store_realt_) { + for (auto const& stored_realt_entry : store_realt_) { stored_realt_entry.second.WaitToRead(); } } @@ -298,12 +302,12 @@ class KVStoreDistServer { ckeys.reserve(elems.size()); cvals.reserve(elems.size()); - for (size_t i=0; i < elems.size(); i++) { + for (size_t i = 0; i < elems.size(); i++) { std::vector parts; mxnet::kvstore::split(elems[i], ':', std::back_inserter(parts)); CHECK_EQ(parts.size(), 2) << "Improper profiler config passed from worker"; CHECK(!parts[0].empty()) << "ProfilerConfig parameter is empty"; - CHECK(!parts[1].empty()) << "ProfilerConfig value is empty for parameter "<< parts[0]; + CHECK(!parts[1].empty()) << "ProfilerConfig value is empty for parameter " << parts[0]; if (parts[0] == "filename") { parts[1] = "rank" + std::to_string(ps::MyRank()) + "_" + parts[1]; } @@ -316,7 +320,7 @@ class KVStoreDistServer { cvals.push_back(cval); } MXSetProfilerConfig(elems.size(), &ckeys[0], &cvals[0]); - for (size_t i=0; i < ckeys.size(); i++) { + for (size_t i = 0; i < ckeys.size(); i++) { delete[] ckeys[i]; delete[] cvals[i]; } @@ -343,15 +347,17 @@ class KVStoreDistServer { return multi_precision_ && type.dtype != mshadow::kFloat32; } - inline void ApplyUpdates(const DataHandleType type, const int key, - const ps::KVPairs& req_data, UpdateBuf *update_buf, + inline void ApplyUpdates(const DataHandleType type, + const int key, + const ps::KVPairs& req_data, + UpdateBuf* update_buf, ps::KVServer* server) { - if (!sync_mode_ || update_buf->request.size() == (size_t) ps::NumWorkers()) { + if (!sync_mode_ || update_buf->request.size() == (size_t)ps::NumWorkers()) { // let the main thread to execute updater_, which is necessary for python auto& stored = has_multi_precision_copy(type) ? store_realt_[key] : store_[key]; - auto& update = sync_mode_ ? update_buf->merged : update_buf->temp_array; + auto& update = sync_mode_ ? update_buf->merged : update_buf->temp_array; if (updater_) { - exec_.Exec([this, key, &update, &stored](){ + exec_.Exec([this, key, &update, &stored]() { CHECK(updater_); updater_(key, update, &stored); }); @@ -361,7 +367,7 @@ class KVStoreDistServer { CopyFromTo(update_buf->merged, &stored); } - if (log_verbose_) { + if (log_verbose_) { LOG(INFO) << "sent response to " << update_buf->request.size() << " workers"; } /** @@ -375,7 +381,8 @@ class KVStoreDistServer { } if (has_pull) { // if there is a pull request, perform WaitToRead() once before DefaultStorageResponse - if (has_multi_precision_copy(type)) CopyFromTo(stored, store_[key]); + if (has_multi_precision_copy(type)) + CopyFromTo(stored, store_[key]); stored.WaitToRead(); for (const auto& req : update_buf->request) { if (req.pull) { @@ -389,7 +396,8 @@ class KVStoreDistServer { server->Response(req); } update_buf->request.clear(); - if (has_multi_precision_copy(type)) CopyFromTo(stored, store_[key]); + if (has_multi_precision_copy(type)) + CopyFromTo(stored, store_[key]); stored.WaitToRead(); } } else { @@ -397,12 +405,14 @@ class KVStoreDistServer { } } - void DecodeRowIds(const ps::SArray &keys, int64_t *indices, - const int64_t master_key, const int64_t num_rows) { + void DecodeRowIds(const ps::SArray& keys, + int64_t* indices, + const int64_t master_key, + const int64_t num_rows) { indices[0] = 0; for (int64_t i = 1; i <= num_rows; i++) { - int key = DecodeKey(keys[i]); - auto row_id = key - master_key; + int key = DecodeKey(keys[i]); + auto row_id = key - master_key; indices[i - 1] = row_id; } } @@ -410,19 +420,28 @@ class KVStoreDistServer { void AccumulateRowSparseGrads(const DataHandleType type, const NDArray& recved, UpdateBuf* updateBuf) { - NDArray out(kRowSparseStorage, updateBuf->merged.shape(), Context(), true, + NDArray out(kRowSparseStorage, + updateBuf->merged.shape(), + Context(), + true, has_multi_precision_copy(type) ? mshadow::kFloat32 : type.dtype); - if (has_multi_precision_copy(type)) CopyFromTo(recved, updateBuf->temp_array); + if (has_multi_precision_copy(type)) + CopyFromTo(recved, updateBuf->temp_array); const NDArray& to_merge = has_multi_precision_copy(type) ? updateBuf->temp_array : recved; // accumulate row_sparse gradients using namespace mshadow; Engine::Get()->PushAsync( - [to_merge, updateBuf, out](RunContext ctx, Engine::CallbackOnComplete on_complete) { - op::ElemwiseBinaryOp::ComputeEx( - {}, {}, {to_merge, updateBuf->merged}, {kWriteTo}, {out}); - on_complete(); - }, to_merge.ctx(), {to_merge.var(), updateBuf->merged.var()}, {out.var()}, - FnProperty::kNormal, 0, PROFILER_MESSAGE_FUNCNAME); + [to_merge, updateBuf, out](RunContext ctx, Engine::CallbackOnComplete on_complete) { + op::ElemwiseBinaryOp::ComputeEx( + {}, {}, {to_merge, updateBuf->merged}, {kWriteTo}, {out}); + on_complete(); + }, + to_merge.ctx(), + {to_merge.var(), updateBuf->merged.var()}, + {out.var()}, + FnProperty::kNormal, + 0, + PROFILER_MESSAGE_FUNCNAME); CopyFromTo(out, &(updateBuf->merged), 0); updateBuf->merged.WaitToRead(); } @@ -433,7 +452,8 @@ class KVStoreDistServer { const ps::KVMeta& req_meta, const ps::KVPairs& req_data, ps::KVServer* server) { - if (log_verbose_) LOG(INFO) << "pull: " << master_key; + if (log_verbose_) + LOG(INFO) << "pull: " << master_key; ps::KVPairs response; if (num_rows == 0) { std::vector lens(req_data.keys.size(), 0); @@ -443,23 +463,24 @@ class KVStoreDistServer { return; } const NDArray& stored = store_[master_key]; - if (has_multi_precision_copy(type)) stored.WaitToRead(); + if (has_multi_precision_copy(type)) + stored.WaitToRead(); CHECK(!stored.is_none()) << "init " << master_key << " first"; - auto shape = stored.shape(); - auto unit_len = shape.ProdShape(1, shape.ndim()); + auto shape = stored.shape(); + auto unit_len = shape.ProdShape(1, shape.ndim()); const int num_bytes = mshadow::mshadow_sizeof(type.dtype); const int unit_size = unit_len * num_bytes; - const char* data = static_cast (stored.data().dptr_); - auto len = num_rows * unit_size; + const char* data = static_cast(stored.data().dptr_); + auto len = num_rows * unit_size; // concat values response.vals.resize(len); - #pragma omp parallel for +#pragma omp parallel for for (size_t i = 1; i <= num_rows; i++) { - int key = DecodeKey(req_data.keys[i]); + int key = DecodeKey(req_data.keys[i]); int64_t row_id = key - master_key; const auto src = data + row_id * unit_size; - auto begin = (i - 1) * unit_size; - auto end = i * unit_size; + auto begin = (i - 1) * unit_size; + auto end = i * unit_size; response.vals.segment(begin, end).CopyFrom(src, unit_size); } // setup response @@ -476,12 +497,12 @@ class KVStoreDistServer { const ps::KVMeta& req_meta, const ps::KVPairs& req_data, ps::KVServer* server) { - auto& stored = has_multi_precision_copy(type) ? store_realt_[master_key] : store_[master_key]; - int dtype = type.dtype; + auto& stored = has_multi_precision_copy(type) ? store_realt_[master_key] : store_[master_key]; + int dtype = type.dtype; int num_bytes = mshadow::mshadow_sizeof(dtype); auto unit_len = req_data.lens[1] / num_bytes; CHECK_GT(unit_len, 0); - size_t ds[] = {num_rows, (size_t) unit_len}; + size_t ds[] = {num_rows, (size_t)unit_len}; mxnet::TShape dshape(ds, ds + 2); CHECK_EQ(req_data.vals.size(), num_rows * unit_len * num_bytes); TBlob recv_blob; @@ -489,28 +510,36 @@ class KVStoreDistServer { recv_blob = TBlob(reinterpret_cast(req_data.vals.data()), dshape, cpu::kDevMask); }) NDArray recved = NDArray(recv_blob, 0); - stored = NDArray(kRowSparseStorage, dshape, Context(), true, + stored = NDArray(kRowSparseStorage, + dshape, + Context(), + true, has_multi_precision_copy(type) ? mshadow::kFloat32 : type.dtype); if (has_multi_precision_copy(type)) { store_[master_key] = NDArray(kRowSparseStorage, dshape, Context(), true, type.dtype); } Engine::Get()->PushAsync( - [this, recved, stored, type](RunContext ctx, Engine::CallbackOnComplete on_complete) { - NDArray rsp = stored; - stored.CheckAndAlloc({mshadow::Shape1(recved.shape()[0])}); - mshadow::Stream *s = ctx.get_stream(); - using namespace mxnet::op; - nnvm::dim_t nnr = rsp.shape()[0]; - MSHADOW_IDX_TYPE_SWITCH(rsp.aux_type(rowsparse::kIdx), IType, { - IType* idx = rsp.aux_data(rowsparse::kIdx).dptr(); - mxnet_op::Kernel::Launch(s, nnr, idx); - }); - TBlob rsp_data = rsp.data(); - // copies or casts as appropriate - ndarray::Copy(recved.data(), &rsp_data, Context(), Context(), RunContext()); - on_complete(); - }, recved.ctx(), {recved.var()}, {stored.var()}, - FnProperty::kNormal, 0, PROFILER_MESSAGE_FUNCNAME); + [this, recved, stored, type](RunContext ctx, Engine::CallbackOnComplete on_complete) { + NDArray rsp = stored; + stored.CheckAndAlloc({mshadow::Shape1(recved.shape()[0])}); + mshadow::Stream* s = ctx.get_stream(); + using namespace mxnet::op; + nnvm::dim_t nnr = rsp.shape()[0]; + MSHADOW_IDX_TYPE_SWITCH(rsp.aux_type(rowsparse::kIdx), IType, { + IType* idx = rsp.aux_data(rowsparse::kIdx).dptr(); + mxnet_op::Kernel::Launch(s, nnr, idx); + }); + TBlob rsp_data = rsp.data(); + // copies or casts as appropriate + ndarray::Copy(recved.data(), &rsp_data, Context(), Context(), RunContext()); + on_complete(); + }, + recved.ctx(), + {recved.var()}, + {stored.var()}, + FnProperty::kNormal, + 0, + PROFILER_MESSAGE_FUNCNAME); if (has_multi_precision_copy(type)) { CopyFromTo(stored, store_[master_key]); store_[master_key].WaitToRead(); @@ -519,31 +548,37 @@ class KVStoreDistServer { server->Response(req_meta); } - void DataHandleRowSparse(const DataHandleType type, const ps::KVMeta& req_meta, + void DataHandleRowSparse(const DataHandleType type, + const ps::KVMeta& req_meta, const ps::KVPairs& req_data, ps::KVServer* server) { int master_key = DecodeKey(req_data.keys[0]); - auto num_rows = req_data.keys.size() - 1; - auto& stored = store_[master_key]; + auto num_rows = req_data.keys.size() - 1; + auto& stored = store_[master_key]; if (req_meta.push) { CHECK_GT(req_data.lens.size(), 0) << "req_data.lens cannot be empty"; CHECK_EQ(req_data.lens[0], 0); if (stored.is_none()) { - if (log_verbose_) LOG(INFO) << "initial push: " << master_key; + if (log_verbose_) + LOG(INFO) << "initial push: " << master_key; // initialization CHECK_GT(num_rows, 0) << "init with empty data is not supported"; InitRowSparseStored(type, master_key, num_rows, req_meta, req_data, server); return; } else { - if (log_verbose_) LOG(INFO) << "push: " << master_key << " " << req_data.keys; + if (log_verbose_) + LOG(INFO) << "push: " << master_key << " " << req_data.keys; auto& updates = update_buf_[master_key]; if (sync_mode_ && updates.merged.is_none()) { - updates.merged = NDArray(kRowSparseStorage, stored.shape(), Context(), true, + updates.merged = NDArray(kRowSparseStorage, + stored.shape(), + Context(), + true, has_multi_precision_copy(type) ? mshadow::kFloat32 : type.dtype); } if (has_multi_precision_copy(type) && updates.temp_array.is_none()) { - updates.temp_array = NDArray(kRowSparseStorage, stored.shape(), Context(), false, - mshadow::kFloat32); + updates.temp_array = + NDArray(kRowSparseStorage, stored.shape(), Context(), false, mshadow::kFloat32); } if (num_rows == 0) { @@ -551,8 +586,8 @@ class KVStoreDistServer { if (updates.request.empty()) { // reset to zeros int merged_dtype = has_multi_precision_copy(type) ? mshadow::kFloat32 : type.dtype; - updates.merged = NDArray(kRowSparseStorage, stored.shape(), Context(), - true, merged_dtype); + updates.merged = + NDArray(kRowSparseStorage, stored.shape(), Context(), true, merged_dtype); } // else nothing to aggregate updates.request.push_back(req_meta); ApplyUpdates(type, master_key, req_data, &updates, server); @@ -568,12 +603,12 @@ class KVStoreDistServer { // data TBlob idx_blob(indices.data(), mshadow::Shape1(num_rows), cpu::kDevMask); - size_t ds[] = {(size_t) num_rows, (size_t) unit_len}; + size_t ds[] = {(size_t)num_rows, (size_t)unit_len}; mxnet::TShape dshape(ds, ds + 2); TBlob recv_blob; MSHADOW_REAL_TYPE_SWITCH(type.dtype, DType, { - recv_blob = TBlob(reinterpret_cast(req_data.vals.data()), - dshape, cpu::kDevMask); + recv_blob = + TBlob(reinterpret_cast(req_data.vals.data()), dshape, cpu::kDevMask); }) // row_sparse NDArray NDArray recved(kRowSparseStorage, stored.shape(), recv_blob, {idx_blob}, 0); @@ -605,16 +640,17 @@ class KVStoreDistServer { void DefaultStorageResponse(const DataHandleType type, const int key, const ps::KVMeta& req_meta, - const ps::KVPairs &req_data, + const ps::KVPairs& req_data, ps::KVServer* server) { ps::KVPairs response; const NDArray& stored = store_[key]; CHECK(!stored.is_none()) << "init " << key << " first"; // as server returns when store_realt is ready in this case - if (has_multi_precision_copy(type)) stored.WaitToRead(); + if (has_multi_precision_copy(type)) + stored.WaitToRead(); - auto len = stored.shape().Size() * mshadow::mshadow_sizeof(stored.dtype()); + auto len = stored.shape().Size() * mshadow::mshadow_sizeof(stored.dtype()); response.keys = req_data.keys; response.lens = {len}; // TODO(mli) try to remove this CopyFrom @@ -624,10 +660,10 @@ class KVStoreDistServer { void DataHandleCompressed(const DataHandleType type, const ps::KVMeta& req_meta, - const ps::KVPairs &req_data, + const ps::KVPairs& req_data, ps::KVServer* server) { CHECK_EQ(type.dtype, mshadow::kFloat32) - << "Gradient compression is currently supported for fp32 only"; + << "Gradient compression is currently supported for fp32 only"; if (req_meta.push) { // there used several WaitToRead, this is because \a recved's memory // could be deallocated when this function returns. so we need to make sure @@ -639,8 +675,8 @@ class KVStoreDistServer { CHECK_EQ(req_data.vals.size(), (size_t)req_data.lens[1]); int original_size = DecodeKey(req_data.keys[0]); - int key = DecodeKey(req_data.keys[1]); - auto& stored = store_[key]; + int key = DecodeKey(req_data.keys[1]); + auto& stored = store_[key]; size_t ds[] = {(size_t)req_data.lens[1] / mshadow::mshadow_sizeof(type.dtype)}; mxnet::TShape dshape(ds, ds + 1); @@ -648,7 +684,7 @@ class KVStoreDistServer { NDArray recved = NDArray(recv_blob, 0); NDArray decomp_buf = decomp_buf_[key]; - dshape = mxnet::TShape{(int64_t) original_size}; + dshape = mxnet::TShape{(int64_t)original_size}; if (decomp_buf.is_none()) { decomp_buf = NDArray(dshape, Context()); @@ -683,7 +719,7 @@ class KVStoreDistServer { server->Response(req_meta); stored.WaitToRead(); } - } else { // pull + } else { // pull CHECK_EQ(req_data.keys.size(), (size_t)1); CHECK_EQ(req_data.lens.size(), (size_t)0); int key = DecodeKey(req_data.keys[0]); @@ -691,8 +727,9 @@ class KVStoreDistServer { } } - void DataHandleDefault(const DataHandleType type, const ps::KVMeta& req_meta, - const ps::KVPairs &req_data, + void DataHandleDefault(const DataHandleType type, + const ps::KVMeta& req_meta, + const ps::KVPairs& req_data, ps::KVServer* server) { // do some check CHECK_EQ(req_data.keys.size(), (size_t)1); @@ -700,13 +737,13 @@ class KVStoreDistServer { CHECK_EQ(req_data.lens.size(), (size_t)1); CHECK_EQ(req_data.vals.size(), (size_t)req_data.lens[0]); } - int key = DecodeKey(req_data.keys[0]); + int key = DecodeKey(req_data.keys[0]); auto& stored = has_multi_precision_copy(type) ? store_realt_[key] : store_[key]; // there used several WaitToRead, this is because \a recved's memory // could be deallocated when this function returns. so we need to make sure // the operators with \a NDArray are actually finished if (req_meta.push) { - size_t ds[] = {(size_t) req_data.lens[0] / mshadow::mshadow_sizeof(type.dtype)}; + size_t ds[] = {(size_t)req_data.lens[0] / mshadow::mshadow_sizeof(type.dtype)}; mxnet::TShape dshape(ds, ds + 1); TBlob recv_blob; MSHADOW_REAL_TYPE_SWITCH(type.dtype, DType, { @@ -715,21 +752,25 @@ class KVStoreDistServer { NDArray recved = NDArray(recv_blob, 0); if (stored.is_none()) { // initialization - stored = NDArray(dshape, Context(), false, + stored = NDArray(dshape, + Context(), + false, has_multi_precision_copy(type) ? mshadow::kFloat32 : type.dtype); CopyFromTo(recved, &stored, 0); server->Response(req_meta); if (has_multi_precision_copy(type)) { auto& stored_dtype = store_[key]; - stored_dtype = NDArray(dshape, Context(), false, type.dtype); + stored_dtype = NDArray(dshape, Context(), false, type.dtype); CopyFromTo(stored, stored_dtype); stored_dtype.WaitToRead(); } stored.WaitToRead(); } else { - auto &updates = update_buf_[key]; + auto& updates = update_buf_[key]; if (sync_mode_ && updates.merged.is_none()) { - updates.merged = NDArray(dshape, Context(), false, + updates.merged = NDArray(dshape, + Context(), + false, has_multi_precision_copy(type) ? mshadow::kFloat32 : type.dtype); } if (has_multi_precision_copy(type) && updates.temp_array.is_none()) { @@ -767,7 +808,6 @@ class KVStoreDistServer { return key - kr.begin(); } - /** * \brief user defined mode for push */ diff --git a/src/kvstore/kvstore_local.h b/src/kvstore/kvstore_local.h index bc4e9337568b..6362a4fc7d90 100644 --- a/src/kvstore/kvstore_local.h +++ b/src/kvstore/kvstore_local.h @@ -48,8 +48,8 @@ namespace kvstore { * \param delim char to split string around * \param result container for tokens extracted after splitting */ -template -void split(const std::string &s, const char delim, Out result) { +template +void split(const std::string& s, const char delim, Out result) { std::stringstream ss; ss.str(s); std::string item; @@ -58,11 +58,7 @@ void split(const std::string &s, const char delim, Out result) { } } -enum KeyType { - kUndefinedKey = -1, - kStringKey, - kIntKey -}; +enum KeyType { kUndefinedKey = -1, kStringKey, kIntKey }; /** * \brief store data in local machine @@ -83,7 +79,7 @@ class KVStoreLocal : public KVStore { } else { comm_ = new CommCPU(); } - pinned_ctx_ = comm_->pinned_ctx(); + pinned_ctx_ = comm_->pinned_ctx(); gradient_compression_ = std::make_shared(); } @@ -92,25 +88,23 @@ class KVStoreLocal : public KVStore { comm_ = nullptr; } - void Init(const std::vector& keys, - const std::vector& values) override { + void Init(const std::vector& keys, const std::vector& values) override { SetKeyType(kIntKey); InitImpl(keys, values); } - void Init(const std::vector& str_keys, - const std::vector& values) override { + void Init(const std::vector& str_keys, const std::vector& values) override { SetKeyType(kStringKey); std::vector keys(str_keys.size()); for (size_t i = 0; i < str_keys.size(); ++i) { - auto &str_key = str_keys[i]; + auto& str_key = str_keys[i]; CHECK(str_key_dict_.find(str_key) == str_key_dict_.end()) - << "duplicate init of key " << str_key; - auto key = next_str_key_++; + << "duplicate init of key " << str_key; + auto key = next_str_key_++; str_key_dict_[str_key] = key; // record reverse mapping from int to string reverse_str_key_dict_[key] = str_key; - keys[i] = key; + keys[i] = key; } InitImpl(keys, values); } @@ -183,14 +177,14 @@ class KVStoreLocal : public KVStore { std::vector vkeys(str_vkeys.size()); std::vector okeys(str_okeys.size()); for (size_t i = 0; i < str_vkeys.size(); ++i) { - auto &str_key = str_vkeys[i]; + auto& str_key = str_vkeys[i]; CHECK(str_key_dict_.find(str_key) == str_key_dict_.end()) - << "duplicate init of key " << str_key; - auto key = next_str_key_++; + << "duplicate init of key " << str_key; + auto key = next_str_key_++; str_key_dict_[str_key] = key; // record reverse mapping from int to string reverse_str_key_dict_[key] = str_key; - vkeys[i] = key; + vkeys[i] = key; } LookupKeys(str_okeys, &okeys); BroadcastImpl(vkeys, okeys, values, outs, priority); @@ -218,14 +212,13 @@ class KVStoreLocal : public KVStore { PullRowSparseImpl(keys, val_rowids, priority); } - void SetGradientCompression(const std::vector > - & kwargs) override { + void SetGradientCompression( + const std::vector>& kwargs) override { gradient_compression_->SetParams(kwargs); } private: - virtual void InitImpl(const std::vector& keys, - const std::vector& values) { + virtual void InitImpl(const std::vector& keys, const std::vector& values) { for (size_t i = 0; i < keys.size(); ++i) { CHECK(local_.find(keys[i]) == local_.end()) << "duplicate init of key " << keys[i] @@ -241,12 +234,12 @@ class KVStoreLocal : public KVStore { const std::vector& values, int priority) { std::vector uniq_keys; - std::vector > grouped_vals; + std::vector> grouped_vals; GroupKVPairsPush(keys, values, &uniq_keys, &grouped_vals, false); for (size_t i = 0; i < uniq_keys.size(); ++i) { - int key = uniq_keys[i]; + int key = uniq_keys[i]; const NDArray& merged = comm_->Reduce(key, grouped_vals[i], priority); - NDArray& local = local_[key]; + NDArray& local = local_[key]; if (key_type_ == kStringKey) { local.AssignStorageInfo( profiler::ProfilerScope::Get()->GetCurrentProfilerScope() + "kvstore:push:", @@ -259,8 +252,7 @@ class KVStoreLocal : public KVStore { if (updater_ != nullptr) { CHECK(!local.is_none()) << "key " << key << " has not been inited"; // if merged is on gpu, we may need copy weight from cpu to gpu - if (merged.ctx().dev_mask() != cpu::kDevMask && - local.ctx().dev_mask() == cpu::kDevMask) { + if (merged.ctx().dev_mask() != cpu::kDevMask && local.ctx().dev_mask() == cpu::kDevMask) { local = local.Copy(merged.ctx()); } // call the updater with string keys @@ -269,11 +261,11 @@ class KVStoreLocal : public KVStore { if (key_type_ == kStringKey && str_updater_ != nullptr) { // TODO(haibin) CHECK(str_updater_ != nullptr) if use_str_key // after all language bindings picks up string interface changes - const std::string &str_key = reverse_str_key_dict_[key]; + const std::string& str_key = reverse_str_key_dict_[key]; // TODO(haibin) avoid reverse key lookup if use_str_key - str_updater_(str_key, merged, &local); + str_updater_(str_key, merged, &local); } else { - updater_(key, merged, &local); + updater_(key, merged, &local); } } else { if (merged.storage_type() != local.storage_type()) { @@ -290,16 +282,17 @@ class KVStoreLocal : public KVStore { int priority, bool ignore_sparse) { std::vector uniq_keys; - std::vector > grouped_vals; + std::vector> grouped_vals; GroupKVPairsPull(keys, values, &uniq_keys, &grouped_vals, ignore_sparse); for (size_t i = 0; i < uniq_keys.size(); ++i) { - int key = uniq_keys[i]; + int key = uniq_keys[i]; const NDArray& local = local_[key]; CHECK(!local.is_none()) << "key " << key << " has not been inited"; comm_->Broadcast(key, local, grouped_vals[i], priority); for (std::vector::iterator iter = grouped_vals[i].begin(); - iter != grouped_vals[i].end(); ++iter) { + iter != grouped_vals[i].end(); + ++iter) { if (key_type_ == kStringKey) { (*iter)->AssignStorageInfo( profiler::ProfilerScope::Get()->GetCurrentProfilerScope() + "kvstore:pull:", @@ -320,15 +313,15 @@ class KVStoreLocal : public KVStore { std::vector>> grouped_val_rowids; GroupKVPairsPullRsp(keys, val_rowids, &uniq_keys, &grouped_val_rowids, false); for (size_t i = 0; i < uniq_keys.size(); ++i) { - int key = uniq_keys[i]; + int key = uniq_keys[i]; const NDArray& local = local_[key]; CHECK(!local.is_none()) << "key " << key << " has not been inited"; CHECK_EQ(local.storage_type(), kRowSparseStorage) - << "PullRowSparse expects row_sparse src NDArray"; - auto &target_val_rowids = grouped_val_rowids[i]; - const size_t num_vals = target_val_rowids.size(); + << "PullRowSparse expects row_sparse src NDArray"; + auto& target_val_rowids = grouped_val_rowids[i]; + const size_t num_vals = target_val_rowids.size(); for (size_t j = 0; j < num_vals; j++) { - auto &row_id = target_val_rowids[j].second; + auto& row_id = target_val_rowids[j].second; target_val_rowids[j].second = Unique(row_id, local.ctx(), 0); } comm_->BroadcastRowSparse(key, local, grouped_val_rowids[i], priority); @@ -342,7 +335,8 @@ class KVStoreLocal : public KVStore { * If the key type is already defined, check if it matches the provided key type */ void SetKeyType(const KeyType key_type) { - if (key_type_ == kUndefinedKey) key_type_ = key_type; + if (key_type_ == kUndefinedKey) + key_type_ = key_type; CHECK_EQ(key_type_, key_type) << "Mixed key types are not allowed"; } @@ -369,15 +363,16 @@ class KVStoreLocal : public KVStore { */ virtual void GroupKVPairsPush(const std::vector& keys, const std::vector& values, - std::vector *uniq_keys, - std::vector> *grouped_vals, + std::vector* uniq_keys, + std::vector>* grouped_vals, bool ignore_sparse) { // check if the storage type of a value is valid auto validator = [](const int key, const NDArray& nd, bool ignore_sparse) -> bool { CHECK(!ignore_sparse) << "Cannot ignore sparse arrays for push"; auto stype = nd.storage_type(); // valid NDArray - if (stype == kDefaultStorage || stype == kRowSparseStorage) return true; + if (stype == kDefaultStorage || stype == kRowSparseStorage) + return true; // invalid NDArray, abort LOG(FATAL) << "Unexpected storage type detected during kvstore push: " << stype; return false; @@ -389,13 +384,14 @@ class KVStoreLocal : public KVStore { */ virtual void GroupKVPairsPull(const std::vector& keys, const std::vector& values, - std::vector *uniq_keys, - std::vector> *grouped_vals, + std::vector* uniq_keys, + std::vector>* grouped_vals, bool ignore_sparse) { // check if the storage type of a value is valid auto validator = [this](const int key, const NDArray* nd, bool ignore_sparse) -> bool { // valid - if (nd->storage_type() == kDefaultStorage || !ignore_sparse) return true; + if (nd->storage_type() == kDefaultStorage || !ignore_sparse) + return true; // invalid, print warning messages once if (this->warnings_printed_.find(key) == this->warnings_printed_.end()) { LOG(INFO) << "Warning: non-default weights detected during kvstore pull. " @@ -414,19 +410,21 @@ class KVStoreLocal : public KVStore { */ virtual void GroupKVPairsPullRsp(const std::vector& keys, const std::vector& values, - std::vector *uniq_keys, - std::vector> *grouped_vals, + std::vector* uniq_keys, + std::vector>* grouped_vals, bool ignore_sparse) { // check if the storage type of a value is valid auto validator = [](const int key, const RSPVal& val_rowid, bool ignore_sparse) -> bool { CHECK(!ignore_sparse) << "Cannot ignore sparse arrays in row_sparse_pull"; - auto val_stype = val_rowid.first->storage_type(); + auto val_stype = val_rowid.first->storage_type(); auto rowid_stype = val_rowid.second.storage_type(); // check storage types - CHECK_EQ(val_stype, kRowSparseStorage) << "Expected row_sparse storage type for " - << "row_sparse_pull values, but detected storage type " << val_stype; - CHECK_EQ(rowid_stype, kDefaultStorage) << "Expected default storage type for " - << "row_sparse_pull rowids, but detected storage type " << rowid_stype; + CHECK_EQ(val_stype, kRowSparseStorage) + << "Expected row_sparse storage type for " + << "row_sparse_pull values, but detected storage type " << val_stype; + CHECK_EQ(rowid_stype, kDefaultStorage) + << "Expected default storage type for " + << "row_sparse_pull rowids, but detected storage type " << rowid_stype; return true; }; GroupKVPairs(keys, values, uniq_keys, grouped_vals, validator, ignore_sparse); @@ -440,7 +438,7 @@ class KVStoreLocal : public KVStore { void GroupKVPairs(const std::vector& keys, const std::vector& values, std::vector* uniq_keys, - std::vector >* grouped_vals, + std::vector>* grouped_vals, const FValidate& is_valid, bool ignore_sparse) { CHECK_EQ(keys.size(), values.size()); @@ -448,11 +446,10 @@ class KVStoreLocal : public KVStore { using Idx = std::pair; std::vector idx(keys.size()); for (size_t i = 0; i < keys.size(); ++i) { - idx[i].first = keys[i]; idx[i].second = i; + idx[i].first = keys[i]; + idx[i].second = i; } - std::sort(idx.begin(), idx.end(), [](const Idx& a, const Idx& b) { - return a.first < b.first; - }); + std::sort(idx.begin(), idx.end(), [](const Idx& a, const Idx& b) { return a.first < b.first; }); int pre_key = idx[0].first - 1; for (auto i : idx) { @@ -468,12 +465,11 @@ class KVStoreLocal : public KVStore { } } - void LookupKeys(const std::vector& str_keys, - std::vector *keys) { + void LookupKeys(const std::vector& str_keys, std::vector* keys) { for (size_t i = 0; i < str_keys.size(); ++i) { - auto &str_key = str_keys[i]; + auto& str_key = str_keys[i]; CHECK(str_key_dict_.find(str_key) != str_key_dict_.end()) - << "key " << str_key << " doesn't exist. Did you init?"; + << "key " << str_key << " doesn't exist. Did you init?"; keys->at(i) = str_key_dict_[str_key]; } } @@ -487,49 +483,50 @@ class KVStoreLocal : public KVStore { * \param ctx the target context * \param priority the priority of the operation */ - NDArray Unique(const NDArray &data, Context ctx, int priority) { + NDArray Unique(const NDArray& data, Context ctx, int priority) { // create kRowSparseStorage output ndarray const size_t num_elements = data.shape().Size(); - NDArray out(kRowSparseStorage, mshadow::Shape2(num_elements, 1), - ctx, true, mshadow::kInt64); - bool diff_ctx = data.ctx() != ctx; + NDArray out(kRowSparseStorage, mshadow::Shape2(num_elements, 1), ctx, true, mshadow::kInt64); + bool diff_ctx = data.ctx() != ctx; NDArray data_in_ctx = diff_ctx ? NDArray(data.shape(), ctx, true, data.dtype()) : data; // if data == data_in_ctx, CopyFromTo is smart enough to skip the copy CopyFromTo(data, &data_in_ctx, priority); // GPU requires temp resources bool is_gpu = out.ctx().dev_mask() == gpu::kDevMask; Engine::Get()->PushAsync( - [=](RunContext rctx, Engine::CallbackOnComplete on_complete) { - // copy data.data() to out.data() - out.CheckAndAlloc({mshadow::Shape1(num_elements)}); - TBlob out_data = out.data(); - NDArray workspace; - switch (out.ctx().dev_mask()) { - case cpu::kDevMask: { - mshadow::Stream *s = rctx.get_stream(); - ndarray::Copy(data_in_ctx.data(), &out_data, - ctx, ctx, rctx); - UniqueImpl(&workspace, s, out); - break; - } - #if MXNET_USE_CUDA - case gpu::kDevMask: { - mshadow::Stream *s = rctx.get_stream(); - ndarray::Copy(data_in_ctx.data(), &out_data, - ctx, ctx, rctx); - UniqueImpl(&workspace, s, out); - // wait for GPU operations to complete - s->Wait(); - break; + [=](RunContext rctx, Engine::CallbackOnComplete on_complete) { + // copy data.data() to out.data() + out.CheckAndAlloc({mshadow::Shape1(num_elements)}); + TBlob out_data = out.data(); + NDArray workspace; + switch (out.ctx().dev_mask()) { + case cpu::kDevMask: { + mshadow::Stream* s = rctx.get_stream(); + ndarray::Copy(data_in_ctx.data(), &out_data, ctx, ctx, rctx); + UniqueImpl(&workspace, s, out); + break; + } +#if MXNET_USE_CUDA + case gpu::kDevMask: { + mshadow::Stream* s = rctx.get_stream(); + ndarray::Copy(data_in_ctx.data(), &out_data, ctx, ctx, rctx); + UniqueImpl(&workspace, s, out); + // wait for GPU operations to complete + s->Wait(); + break; + } +#endif + default: + LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; } - #endif - default: - LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; - } - on_complete(); - }, out.ctx(), {data_in_ctx.var()}, {out.var()}, - is_gpu ? FnProperty::kGPUPrioritized : FnProperty::kCPUPrioritized, - priority, "KVStoreUnique"); + on_complete(); + }, + out.ctx(), + {data_in_ctx.var()}, + {out.var()}, + is_gpu ? FnProperty::kGPUPrioritized : FnProperty::kCPUPrioritized, + priority, + "KVStoreUnique"); return out; } diff --git a/src/kvstore/kvstore_nccl.h b/src/kvstore/kvstore_nccl.h index 09bd880bfd68..736794ab1ddf 100644 --- a/src/kvstore/kvstore_nccl.h +++ b/src/kvstore/kvstore_nccl.h @@ -63,9 +63,9 @@ class KVStoreNCCL : public KVStoreLocal { public: KVStoreNCCL() : KVStoreLocal() { // Due to aggregation, we do not use the Comm interface - comm_ = nullptr; + comm_ = nullptr; pinned_ctx_ = Context::CPUPinned(0); - inited_ = false; + inited_ = false; } virtual ~KVStoreNCCL() { @@ -76,11 +76,9 @@ class KVStoreNCCL : public KVStoreLocal { } private: - void InitImpl(const std::vector& keys, - const std::vector& values) override { + void InitImpl(const std::vector& keys, const std::vector& values) override { for (size_t i = 0; i < keys.size(); ++i) { - CHECK(local_.find(keys[i]) == local_.end()) - << "duplicate init of key " << keys[i]; + CHECK(local_.find(keys[i]) == local_.end()) << "duplicate init of key " << keys[i]; local_[keys[i]] = values[i].Copy(pinned_ctx_); InitKey(keys[i], values[i].storage_type(), values[i].shape(), values[i].dtype()); } @@ -90,7 +88,7 @@ class KVStoreNCCL : public KVStoreLocal { const std::vector& values, int priority) override { std::vector uniq_keys; - std::vector > grouped_vals; + std::vector> grouped_vals; // nccl kvstore doesn't support sparse ndarray GroupKVPairsHelper(keys, values, &uniq_keys, &grouped_vals, true); @@ -106,13 +104,12 @@ class KVStoreNCCL : public KVStoreLocal { // We issued NCCL kernels, need to synchronize nccl_called = true; } - auto& merged = *(merged_ptrs[i]); + auto& merged = *(merged_ptrs[i]); NDArray& local = local_[key]; if (updater_ != nullptr) { CHECK(!local.is_none()) << "key " << key << " has not been inited"; // if merged is on gpu, we may need copy weight from cpu to gpu - if (merged.ctx().dev_mask() != cpu::kDevMask && - local.ctx().dev_mask() == cpu::kDevMask) { + if (merged.ctx().dev_mask() != cpu::kDevMask && local.ctx().dev_mask() == cpu::kDevMask) { local = local.Copy(merged.ctx()); } } @@ -125,8 +122,8 @@ class KVStoreNCCL : public KVStoreLocal { } for (size_t i = 0; i < uniq_keys.size(); ++i) { - int key = uniq_keys[i]; - auto& merged = *(merged_ptrs[i]); + int key = uniq_keys[i]; + auto& merged = *(merged_ptrs[i]); NDArray& local = *(local_ptrs[i]); if (updater_ != nullptr) { // call the updater with string keys @@ -134,10 +131,10 @@ class KVStoreNCCL : public KVStoreLocal { // otherwise fallback to updater_ which uses int key interface if (key_type_ == kStringKey && str_updater_ != nullptr) { // after all language bindings picks up string interface changes - const std::string &str_key = reverse_str_key_dict_[key]; - str_updater_(str_key, merged, &local); + const std::string& str_key = reverse_str_key_dict_[key]; + str_updater_(str_key, merged, &local); } else { - updater_(key, merged, &local); + updater_(key, merged, &local); } } else { local = merged; @@ -147,16 +144,17 @@ class KVStoreNCCL : public KVStoreLocal { void PullImpl(const std::vector& keys, const std::vector& values, - int priority, bool ignore_sparse) override { + int priority, + bool ignore_sparse) override { CHECK(ignore_sparse) << "nccl kvstore pull doesn't support ignore_sparse=False"; std::vector uniq_keys; - std::vector > grouped_vals; + std::vector> grouped_vals; GroupKVPairsHelper(keys, values, &uniq_keys, &grouped_vals, true); std::vector locals; bool nccl_called = false; for (size_t i = 0; i < uniq_keys.size(); ++i) { - int key = uniq_keys[i]; + int key = uniq_keys[i]; const NDArray& local = local_[key]; locals.push_back(local_[key]); CHECK(!local.is_none()) << "key " << key << " has not been inited"; @@ -180,8 +178,8 @@ class KVStoreNCCL : public KVStoreLocal { LOG(FATAL) << "NCCL kvstore does not support sparse storage type"; } - void SetGradientCompression(const std::vector > - & kwargs) override { + void SetGradientCompression( + const std::vector>& kwargs) override { LOG(FATAL) << "NCCL kvstore does not support gradient compression"; } @@ -192,15 +190,16 @@ class KVStoreNCCL : public KVStoreLocal { template void GroupKVPairsHelper(const std::vector& keys, const std::vector& values, - std::vector *uniq_keys, - std::vector> *grouped_vals, + std::vector* uniq_keys, + std::vector>* grouped_vals, bool ignore_sparse) { // check if the storage type of a value is valid auto validator = [this](const int key, const T nd, bool ignore_sparse) -> bool { CHECK(ignore_sparse) << "nccl kvstore pull doesn't support ignore_sparse=False"; auto stype = ptr(nd)->storage_type(); // valid NDArray - if (stype == kDefaultStorage) return true; + if (stype == kDefaultStorage) + return true; // invalid NDArray, abort LOG(FATAL) << "NCCL kvstore does not support sparse storage type"; return false; @@ -221,8 +220,8 @@ class KVStoreNCCL : public KVStoreLocal { std::vector mutate_vars; for (size_t k = 0; k < keys.size(); ++k) { - auto& key = keys[k]; - auto& src = srcs[k]; + auto& key = keys[k]; + auto& src = srcs[k]; auto& root_id = root_ids[k]; // avoid extra copy for single device, but it may bring problems for @@ -250,10 +249,10 @@ class KVStoreNCCL : public KVStoreLocal { CHECK(device_ids_ == dev_ids) << "NCCL KVStore supports only single set of devices"; auto& buf = merge_buf_[key]; - int root = buf.merged.ctx().dev_id; - root_id = FindRootId(src, root); + int root = buf.merged.ctx().dev_id; + root_id = FindRootId(src, root); - auto& reduce = buf.merged; + auto& reduce = buf.merged; (*merged_ptrs)[k] = &reduce; // Need to pass NDArrays by value to the engine reduces[k] = reduce; @@ -264,70 +263,73 @@ class KVStoreNCCL : public KVStoreLocal { mutate_vars.push_back(reduce.var()); } - Engine::Get()->PushSync([srcs, reduces, root_ids, this](RunContext rctx) { - std::lock_guard l(Storage::Get()->GetMutex(Context::kGPU)); + Engine::Get()->PushSync( + [srcs, reduces, root_ids, this](RunContext rctx) { + std::lock_guard l(Storage::Get()->GetMutex(Context::kGPU)); #if (NCCL_MAJOR > 2 || (NCCL_MAJOR == 2 && NCCL_MINOR > 1)) - ncclGroupStart(); -#endif - for (size_t k = 0; k < srcs.size(); ++k) { - auto& src = srcs[k]; - auto& root_id = root_ids[k]; - auto& reduce = reduces[k]; - if (src.size() <= 1) { - continue; - } - int root = nccl_data_[src[root_id].ctx().dev_id].rank; ncclGroupStart(); - for (size_t i = 0; i < src.size(); ++i) { - NCCLEntry cur = nccl_data_[src[i].ctx().dev_id]; - if (i == root_id) { - MSHADOW_TYPE_SWITCH(src[i].dtype(), DType, - ncclReduce(src[i].data().dptr(), - reduce.data().dptr(), - src[i].shape().Size(), - GetNCCLType(src[i].dtype()), - ncclSum, - root, - cur.comm, - cur.stream);); - } else { - MSHADOW_TYPE_SWITCH(src[i].dtype(), DType, - ncclReduce(src[i].data().dptr(), - nullptr, - src[i].shape().Size(), - GetNCCLType(src[i].dtype()), - ncclSum, - root, - cur.comm, - cur.stream);); +#endif + for (size_t k = 0; k < srcs.size(); ++k) { + auto& src = srcs[k]; + auto& root_id = root_ids[k]; + auto& reduce = reduces[k]; + if (src.size() <= 1) { + continue; + } + int root = nccl_data_[src[root_id].ctx().dev_id].rank; + ncclGroupStart(); + for (size_t i = 0; i < src.size(); ++i) { + NCCLEntry cur = nccl_data_[src[i].ctx().dev_id]; + if (i == root_id) { + MSHADOW_TYPE_SWITCH(src[i].dtype(), + DType, + ncclReduce(src[i].data().dptr(), + reduce.data().dptr(), + src[i].shape().Size(), + GetNCCLType(src[i].dtype()), + ncclSum, + root, + cur.comm, + cur.stream);); + } else { + MSHADOW_TYPE_SWITCH(src[i].dtype(), + DType, + ncclReduce(src[i].data().dptr(), + nullptr, + src[i].shape().Size(), + GetNCCLType(src[i].dtype()), + ncclSum, + root, + cur.comm, + cur.stream);); + } } + ncclGroupEnd(); } - ncclGroupEnd(); - } #if (NCCL_MAJOR > 2 || (NCCL_MAJOR == 2 && NCCL_MINOR > 1)) - ncclGroupEnd(); + ncclGroupEnd(); #endif - }, - Context::CPU(), - const_vars, - mutate_vars, - FnProperty::kCPUPrioritized, - priority, - "KVStoreReduce"); + }, + Context::CPU(), + const_vars, + mutate_vars, + FnProperty::kCPUPrioritized, + priority, + "KVStoreReduce"); } virtual void Broadcast(const std::vector keys, - const std::vector& srcs, - const std::vector>& dsts, - int priority) { + const std::vector& srcs, + const std::vector>& dsts, + int priority) { std::vector root_ids(keys.size()); std::vector const_vars; std::vector mutable_vars; for (size_t k = 0; k < keys.size(); ++k) { - auto& key = keys[k]; - auto& src = srcs[k]; - auto& dst = dsts[k]; + auto& key = keys[k]; + auto& src = srcs[k]; + auto& dst = dsts[k]; auto& root_id = root_ids[k]; if (!inited_) { @@ -341,7 +343,7 @@ class KVStoreNCCL : public KVStoreLocal { } } else { auto& buf = merge_buf_[key]; - int root = src.ctx().dev_id; + int root = src.ctx().dev_id; assert(root == buf.merged.ctx().dev_id); root_id = FindRootId(dst, root); @@ -357,7 +359,7 @@ class KVStoreNCCL : public KVStoreLocal { // On root perform simple copy to the output CopyFromTo(src, *dst[root_id], priority); for (size_t i = 0; i < dst.size(); ++i) { - if ( i != root_id) + if (i != root_id) mutable_vars.push_back(dst[i]->var()); } const_vars.push_back(src.var()); @@ -380,44 +382,46 @@ class KVStoreNCCL : public KVStoreLocal { } } - Engine::Get()->PushSync([srcs, broadcasts, root_ids, this](RunContext rctx) { - std::lock_guard l(Storage::Get()->GetMutex(Context::kGPU)); + Engine::Get()->PushSync( + [srcs, broadcasts, root_ids, this](RunContext rctx) { + std::lock_guard l(Storage::Get()->GetMutex(Context::kGPU)); #if (NCCL_MAJOR > 2 || (NCCL_MAJOR == 2 && NCCL_MINOR > 1)) - ncclGroupStart(); + ncclGroupStart(); #endif - for (size_t k = 0; k < srcs.size(); ++k) { - auto& src = srcs[k]; - auto& dst = broadcasts[k]; - auto& root_id = root_ids[k]; - if (dst.size() <= 1) { - continue; - } + for (size_t k = 0; k < srcs.size(); ++k) { + auto& src = srcs[k]; + auto& dst = broadcasts[k]; + auto& root_id = root_ids[k]; + if (dst.size() <= 1) { + continue; + } - int root = nccl_data_[src.ctx().dev_id].rank; - ncclGroupStart(); - for (size_t i = 0; i < dst.size(); ++i) { - auto& bcast = (i == root_id) ? src : dst[i]; - NCCLEntry cur = nccl_data_[bcast.ctx().dev_id]; - MSHADOW_TYPE_SWITCH(bcast.dtype(), DType, - ncclBcast(bcast.data().dptr(), - bcast.shape().Size(), - GetNCCLType(bcast.dtype()), - root, - cur.comm, - cur.stream);); + int root = nccl_data_[src.ctx().dev_id].rank; + ncclGroupStart(); + for (size_t i = 0; i < dst.size(); ++i) { + auto& bcast = (i == root_id) ? src : dst[i]; + NCCLEntry cur = nccl_data_[bcast.ctx().dev_id]; + MSHADOW_TYPE_SWITCH(bcast.dtype(), + DType, + ncclBcast(bcast.data().dptr(), + bcast.shape().Size(), + GetNCCLType(bcast.dtype()), + root, + cur.comm, + cur.stream);); + } + ncclGroupEnd(); } - ncclGroupEnd(); - } #if (NCCL_MAJOR > 2 || (NCCL_MAJOR == 2 && NCCL_MINOR > 1)) - ncclGroupEnd(); + ncclGroupEnd(); #endif - }, - Context::CPU(), - const_vars, - mutable_vars, - FnProperty::kCPUPrioritized, - priority, - "KVStoreBCast"); + }, + Context::CPU(), + const_vars, + mutable_vars, + FnProperty::kCPUPrioritized, + priority, + "KVStoreBCast"); } // Function that waits for NCCL collective to complete @@ -425,26 +429,29 @@ class KVStoreNCCL : public KVStoreLocal { void CommSync(const std::vector& dst, int priority) { std::vector mutate_vars; for (size_t i = 0; i < dst.size(); ++i) { - mutate_vars.push_back(ptr(dst[i])->var()); + mutate_vars.push_back(ptr(dst[i])->var()); } - Engine::Get()->PushSync([this](RunContext rctx) { - mxnet::common::cuda::DeviceStore device_store; - for (auto cur : nccl_data_) { - device_store.SetDevice(cur.second.dev_id); - CUDA_CALL(cudaStreamSynchronize(cur.second.stream)); - } - }, - Context::CPU(), - {}, - mutate_vars, - FnProperty::kCPUPrioritized, - priority, - "KVStoreStreamSync"); + Engine::Get()->PushSync( + [this](RunContext rctx) { + mxnet::common::cuda::DeviceStore device_store; + for (auto cur : nccl_data_) { + device_store.SetDevice(cur.second.dev_id); + CUDA_CALL(cudaStreamSynchronize(cur.second.stream)); + } + }, + Context::CPU(), + {}, + mutate_vars, + FnProperty::kCPUPrioritized, + priority, + "KVStoreStreamSync"); } // Initialize single key - void InitKey(int key, const NDArrayStorageType stype, const mxnet::TShape& shape, - int dtype = mshadow::kFloat32) { + void InitKey(int key, + const NDArrayStorageType stype, + const mxnet::TShape& shape, + int dtype = mshadow::kFloat32) { if (stype == kDefaultStorage) { key_attrs_.push_back(std::make_tuple(key, shape, dtype)); } else { @@ -484,8 +491,8 @@ class KVStoreNCCL : public KVStoreLocal { for (size_t i = 0; i < devs.size(); ++i) { NCCLEntry e; e.dev_id = device_ids_[i]; - e.comm = comms[i]; - e.rank = i; + e.comm = comms[i]; + e.rank = i; device_store.SetDevice(e.dev_id); cudaStreamCreate(&(e.stream)); nccl_data_[device_ids_[i]] = e; @@ -495,10 +502,10 @@ class KVStoreNCCL : public KVStoreLocal { using KeyAttrs = std::tuple; void InitMergeBuffer(const std::vector& devs) { for (size_t i = 0; i < key_attrs_.size(); ++i) { - int key = std::get<0>(key_attrs_[i]); + int key = std::get<0>(key_attrs_[i]); mxnet::TShape s = std::get<1>(key_attrs_[i]); - int type = std::get<2>(key_attrs_[i]); - auto& buf = merge_buf_[key]; + int type = std::get<2>(key_attrs_[i]); + auto& buf = merge_buf_[key]; // always use devs[0] as root buf.merged = NDArray(s, devs[0], false, type); } @@ -507,11 +514,15 @@ class KVStoreNCCL : public KVStoreLocal { // Functions that enable templates to work on both references // and pointers - template - const T * ptr(const T & obj) { return &obj; } + template + const T* ptr(const T& obj) { + return &obj; + } - template - const T * ptr(T * obj) { return obj; } + template + const T* ptr(T* obj) { + return obj; + } // Find which element of the vector // corresponds to root dev_id diff --git a/src/kvstore/kvstore_utils.cc b/src/kvstore/kvstore_utils.cc index b53eca433e97..57cb0b058e2e 100644 --- a/src/kvstore/kvstore_utils.cc +++ b/src/kvstore/kvstore_utils.cc @@ -28,21 +28,19 @@ namespace mxnet { namespace kvstore { -template<> -void UniqueImpl(NDArray* workspace, mshadow::Stream *s, - const NDArray& out) { +template <> +void UniqueImpl(NDArray* workspace, mshadow::Stream* s, const NDArray& out) { const size_t num_elements = out.shape().Size(); CHECK_EQ(out.storage_type(), kRowSparseStorage) << "row_sparse NDArray is expected"; MSHADOW_IDX_TYPE_SWITCH(out.dtype(), IType, { - IType *dptr = out.data().dptr(); - common::ParallelSort(dptr, dptr + num_elements, - engine::OpenMP::Get()->GetRecommendedOMPThreadCount()); + IType* dptr = out.data().dptr(); + common::ParallelSort( + dptr, dptr + num_elements, engine::OpenMP::Get()->GetRecommendedOMPThreadCount()); const size_t num_selected_out = std::unique(dptr, dptr + num_elements) - dptr; // set the shape of data/aux_data according to the number of unique values out.set_aux_shape(rowsparse::kIdx, mshadow::Shape1(num_selected_out)); }); } - } // namespace kvstore } // namespace mxnet diff --git a/src/kvstore/kvstore_utils.cu b/src/kvstore/kvstore_utils.cu index bcecaea75fb6..bca6e7be36d5 100644 --- a/src/kvstore/kvstore_utils.cu +++ b/src/kvstore/kvstore_utils.cu @@ -40,68 +40,69 @@ namespace mxnet { namespace kvstore { -template -size_t UniqueImplGPU(NDArray *workspace, mshadow::Stream *s, - IType *dptr, const size_t size, Context ctx) { +template +size_t UniqueImplGPU(NDArray* workspace, + mshadow::Stream* s, + IType* dptr, + const size_t size, + Context ctx) { // estimate unique temp space. The first byte is reserved to store the number // of unique values selected const size_t num_selected_bytes = sizeof(size_t); - size_t unique_temp_bytes = 0; - size_t *null_ptr = nullptr; - size_t *null_dptr = nullptr; - cudaStream_t stream = mshadow::Stream::GetStream(s); - cub::DeviceSelect::Unique(nullptr, unique_temp_bytes, null_dptr, null_dptr, - null_ptr, size, stream); + size_t unique_temp_bytes = 0; + size_t* null_ptr = nullptr; + size_t* null_dptr = nullptr; + cudaStream_t stream = mshadow::Stream::GetStream(s); + cub::DeviceSelect::Unique( + nullptr, unique_temp_bytes, null_dptr, null_dptr, null_ptr, size, stream); // estimate sort temp space const size_t sort_output_bytes = size * sizeof(IType); - size_t sort_temp_bytes = 0; + size_t sort_temp_bytes = 0; #ifndef SORT_WITH_THRUST // The least-significant bit index (inclusive) needed for key comparison const int begin_bit = 0; // The most-significant bit index (exclusive) needed for key comparison const int end_bit = sizeof(IType) * 8; - cub::DeviceRadixSort::SortKeys(nullptr, sort_temp_bytes, null_dptr, null_dptr, - size, begin_bit, end_bit, stream); + cub::DeviceRadixSort::SortKeys( + nullptr, sort_temp_bytes, null_dptr, null_dptr, size, begin_bit, end_bit, stream); #else // sort_temp_bytes remains 0 because thrust request memory by itself #endif // request temp storage - const size_t total_workspace = num_selected_bytes + sort_output_bytes + - std::max(sort_temp_bytes, unique_temp_bytes); - *workspace = NDArray(mshadow::Shape1((total_workspace + 3) / 4), ctx, false); + const size_t total_workspace = + num_selected_bytes + sort_output_bytes + std::max(sort_temp_bytes, unique_temp_bytes); + *workspace = NDArray(mshadow::Shape1((total_workspace + 3) / 4), ctx, false); char* workspace_dptr = reinterpret_cast(workspace->data().dptr_); // temp space layout: num_selected_ptr, sort_output_bytes, unique/sort_temp_storage size_t* num_selected_ptr = reinterpret_cast(workspace_dptr); - IType* sort_output_ptr = reinterpret_cast(workspace_dptr + num_selected_bytes); - void *temp_storage = static_cast(workspace_dptr + - num_selected_bytes + sort_output_bytes); + IType* sort_output_ptr = reinterpret_cast(workspace_dptr + num_selected_bytes); + void* temp_storage = static_cast(workspace_dptr + num_selected_bytes + sort_output_bytes); // execute the sort kernel #ifndef SORT_WITH_THRUST - cub::DeviceRadixSort::SortKeys(temp_storage, sort_temp_bytes, dptr, sort_output_ptr, - size, begin_bit, end_bit, stream); + cub::DeviceRadixSort::SortKeys( + temp_storage, sort_temp_bytes, dptr, sort_output_ptr, size, begin_bit, end_bit, stream); #else - thrust::sort(thrust::cuda::par.on(stream), - dptr, dptr + size, thrust::greater()); - CUDA_CALL(cudaMemcpyAsync(sort_output_ptr, dptr, sort_output_bytes, - cudaMemcpyDeviceToDevice, stream)); + thrust::sort(thrust::cuda::par.on(stream), dptr, dptr + size, thrust::greater()); + CUDA_CALL( + cudaMemcpyAsync(sort_output_ptr, dptr, sort_output_bytes, cudaMemcpyDeviceToDevice, stream)); #endif // execute unique kernel - cub::DeviceSelect::Unique(temp_storage, unique_temp_bytes, sort_output_ptr, dptr, - num_selected_ptr, size, stream); + cub::DeviceSelect::Unique( + temp_storage, unique_temp_bytes, sort_output_ptr, dptr, num_selected_ptr, size, stream); // retrieve num selected unique values size_t num_selected_out = 0; - CUDA_CALL(cudaMemcpyAsync(&num_selected_out, num_selected_ptr, num_selected_bytes, - cudaMemcpyDeviceToHost, stream)); + CUDA_CALL(cudaMemcpyAsync( + &num_selected_out, num_selected_ptr, num_selected_bytes, cudaMemcpyDeviceToHost, stream)); CUDA_CALL(cudaStreamSynchronize(stream)); return num_selected_out; } -template<> -void UniqueImpl(NDArray *workspace, mshadow::Stream *s, const NDArray &out) { +template <> +void UniqueImpl(NDArray* workspace, mshadow::Stream* s, const NDArray& out) { const size_t num_elements = out.shape().Size(); CHECK_EQ(out.storage_type(), kRowSparseStorage) << "row_sparse NDArray is expected"; MSHADOW_IDX_TYPE_SWITCH(out.dtype(), IType, { - IType *dptr = out.data().dptr(); + IType* dptr = out.data().dptr(); size_t num_selected_out = UniqueImplGPU(workspace, s, dptr, num_elements, out.ctx()); // set the shape of data/aux_data according to the number of unique values out.set_aux_shape(rowsparse::kIdx, mshadow::Shape1(num_selected_out)); diff --git a/src/kvstore/kvstore_utils.h b/src/kvstore/kvstore_utils.h index 2527f7ed0ce2..53b77ee41f4f 100644 --- a/src/kvstore/kvstore_utils.h +++ b/src/kvstore/kvstore_utils.h @@ -33,7 +33,6 @@ namespace mxnet { namespace kvstore { - /*! * \brief compute unique and sorted values in a row_sparse ndarray. * \param workspace Temp workspace for computation. Its a pointer to a @@ -43,8 +42,8 @@ namespace kvstore { * \param out Input and output ndarray. The ndarray stores the * unique elements in out.data(). */ -template -void UniqueImpl(NDArray* workspace, mshadow::Stream *s, const NDArray& out); +template +void UniqueImpl(NDArray* workspace, mshadow::Stream* s, const NDArray& out); } // namespace kvstore } // namespace mxnet diff --git a/src/kvstore/p3store_dist.h b/src/kvstore/p3store_dist.h index 3c99515e27b1..d167a51248a9 100644 --- a/src/kvstore/p3store_dist.h +++ b/src/kvstore/p3store_dist.h @@ -39,8 +39,7 @@ namespace kvstore { */ class P3StoreDist : public KVStoreDist { public: - explicit P3StoreDist(bool use_device_comm) - : KVStoreDist(use_device_comm) { + explicit P3StoreDist(bool use_device_comm) : KVStoreDist(use_device_comm) { slice_threshold_ = dmlc::GetEnv("MXNET_KVSTORE_SLICE_THRESHOLD", 40 * 1000); } @@ -57,12 +56,12 @@ class P3StoreDist : public KVStoreDist { } void set_updater(const Updater& updater) final { - LOG(FATAL) << "NotImplementedError: Update on P3StoreDist is not supported. " - << "Please set MXNET_UPDATE_ON_KVSTORE to false."; + LOG(FATAL) << "NotImplementedError: Update on P3StoreDist is not supported. " + << "Please set MXNET_UPDATE_ON_KVSTORE to false."; } - void SetGradientCompression(const std::vector> - & kwargs) final { + void SetGradientCompression( + const std::vector>& kwargs) final { LOG(FATAL) << "NotImplementedError: Gradient compression not supported in P3StoreDist."; } @@ -75,141 +74,160 @@ class P3StoreDist : public KVStoreDist { EncodeDefaultKey(key, value.shape().Size(), mshadow::mshadow_sizeof(value.dtype())); } - void PushCompressed(int key, const NDArray& comm_buf, const PSKV& pskv, - int priority) final { + void PushCompressed(int key, const NDArray& comm_buf, const PSKV& pskv, int priority) final { LOG(FATAL) << "NotImplementedError: PushCompressed not implemented in P3StoreDist."; } - void PushDefault(int key, const NDArray &send_buf, const PSKV& pskv, - int priority) override { - auto push_to_servers = [this, key, pskv, send_buf, priority] - (RunContext rctx, Engine::CallbackOnComplete cb) { - const int dtype = send_buf.dtype(); - // convert to ps keys - const size_t size = send_buf.shape().Size() * mshadow::mshadow_sizeof(dtype); - char* data = static_cast(send_buf.data().dptr_); - // do push. false means no delete - ps::SArray vals(data, size, false); - int cmd = GetCommandType(RequestType::kDefaultPushPull, dtype); + void PushDefault(int key, const NDArray& send_buf, const PSKV& pskv, int priority) override { + auto push_to_servers = [this, key, pskv, send_buf, priority](RunContext rctx, + Engine::CallbackOnComplete cb) { + const int dtype = send_buf.dtype(); + // convert to ps keys + const size_t size = send_buf.shape().Size() * mshadow::mshadow_sizeof(dtype); + char* data = static_cast(send_buf.data().dptr_); + // do push. false means no delete + ps::SArray vals(data, size, false); + int cmd = GetCommandType(RequestType::kDefaultPushPull, dtype); - size_t off = 0; - auto counter = new std::atomic(pskv.keys.size()); - for (size_t idx = 0; idx < pskv.keys.size(); idx++) { - auto ks = pskv.keys.segment(idx, idx+1); - auto ls = pskv.lens.segment(idx, idx+1); - auto vs = vals.segment(off, off + pskv.lens[idx]); - CHECK_NOTNULL(ps_worker_)->ZPush( - ks, vs, ls, cmd, [counter, cb]() { - if (--(*counter) == 0) { - delete counter; - cb(); - } - }, priority); - off += pskv.lens[idx]; - } - }; - Engine::Get()->PushAsync( - push_to_servers, - pinned_ctx_, - {send_buf.var()}, - {}, - FnProperty::kNormal, - priority, - "P3StoreDistDefaultPush"); + size_t off = 0; + auto counter = new std::atomic(pskv.keys.size()); + for (size_t idx = 0; idx < pskv.keys.size(); idx++) { + auto ks = pskv.keys.segment(idx, idx + 1); + auto ls = pskv.lens.segment(idx, idx + 1); + auto vs = vals.segment(off, off + pskv.lens[idx]); + CHECK_NOTNULL(ps_worker_) + ->ZPush( + ks, + vs, + ls, + cmd, + [counter, cb]() { + if (--(*counter) == 0) { + delete counter; + cb(); + } + }, + priority); + off += pskv.lens[idx]; + } + }; + Engine::Get()->PushAsync(push_to_servers, + pinned_ctx_, + {send_buf.var()}, + {}, + FnProperty::kNormal, + priority, + "P3StoreDistDefaultPush"); } - void PushRowSparse(int key, const NDArray &send_buf, int priority) override { + void PushRowSparse(int key, const NDArray& send_buf, int priority) override { LOG(FATAL) << "NotImplementedError: PushRowSparse not implemented in P3StoreDist."; } - void PullDefault(int key, const NDArray &recv_buf, int priority) override { + void PullDefault(int key, const NDArray& recv_buf, int priority) override { CHECK(gradient_compression_->get_type() == CompressionType::kNone) - << "Gradient compression not supported in P3StoreDist."; - auto pull_from_servers = [this, key, recv_buf, priority]( - RunContext rctx, Engine::CallbackOnComplete cb) { + << "Gradient compression not supported in P3StoreDist."; + auto pull_from_servers = [this, key, recv_buf, priority](RunContext rctx, + Engine::CallbackOnComplete cb) { // convert to ps keys - size_t size = recv_buf.shape().Size(); - const int dtype = recv_buf.dtype(); + size_t size = recv_buf.shape().Size(); + const int dtype = recv_buf.dtype(); const int num_bytes = mshadow::mshadow_sizeof(dtype); - PSKV& pskv = EncodeDefaultKey(key, size, num_bytes); - char* data = static_cast (recv_buf.data().dptr_); + PSKV& pskv = EncodeDefaultKey(key, size, num_bytes); + char* data = static_cast(recv_buf.data().dptr_); // issue pull const int cmd = GetCommandType(RequestType::kDefaultPushPull, dtype); - size_t off = 0; - auto counter = new std::atomic(pskv.keys.size()); + size_t off = 0; + auto counter = new std::atomic(pskv.keys.size()); for (size_t idx = 0; idx < pskv.keys.size(); idx++) { - auto ks = pskv.keys.segment(idx, idx+1); + auto ks = pskv.keys.segment(idx, idx + 1); auto ls = new ps::SArray(1, pskv.lens[idx]); auto vs = new ps::SArray(data + off, pskv.lens[idx], false); - CHECK_NOTNULL(ps_worker_)->ZPull( - ks, vs, ls, cmd, [vs, ls, counter, cb]() { - delete vs; - delete ls; - if (--(*counter) == 0) { - delete counter; - cb(); - } - }, priority); + CHECK_NOTNULL(ps_worker_) + ->ZPull( + ks, + vs, + ls, + cmd, + [vs, ls, counter, cb]() { + delete vs; + delete ls; + if (--(*counter) == 0) { + delete counter; + cb(); + } + }, + priority); off += pskv.lens[idx]; } }; - CHECK_NOTNULL(Engine::Get())->PushAsync( - pull_from_servers, - pinned_ctx_, - {}, - {recv_buf.var()}, - FnProperty::kNormal, - priority, - "P3StoreDistDefaultStoragePull"); + CHECK_NOTNULL(Engine::Get()) + ->PushAsync(pull_from_servers, + pinned_ctx_, + {}, + {recv_buf.var()}, + FnProperty::kNormal, + priority, + "P3StoreDistDefaultStoragePull"); } - void PullRowSparse_(const int key, const NDArray& recv_buf, - const NDArray& indices, int priority) override { + void PullRowSparse_(const int key, + const NDArray& recv_buf, + const NDArray& indices, + int priority) override { LOG(FATAL) << "NotImplementedError: PullRowSparse not implemented in P3StoreDist."; } - void PushPullDefault(int key, const NDArray &comm_buf, int priority) override { + void PushPullDefault(int key, const NDArray& comm_buf, int priority) override { CHECK(gradient_compression_->get_type() == CompressionType::kNone) - << "Compression not supported in P3StoreDist"; - auto pushpull = [this, key, comm_buf, priority]( - RunContext rctx, Engine::CallbackOnComplete cb) { - size_t size = comm_buf.shape().Size(); - const int dtype = comm_buf.dtype(); + << "Compression not supported in P3StoreDist"; + auto pushpull = [this, key, comm_buf, priority](RunContext rctx, + Engine::CallbackOnComplete cb) { + size_t size = comm_buf.shape().Size(); + const int dtype = comm_buf.dtype(); const int num_bytes = mshadow::mshadow_sizeof(dtype); - const int cmd = GetCommandType(RequestType::kDefaultPushPull, dtype); + const int cmd = GetCommandType(RequestType::kDefaultPushPull, dtype); PSKV& pskv = EncodeDefaultKey(key, size, num_bytes); char* data = static_cast(comm_buf.data().dptr_); - size_t off = 0; + size_t off = 0; auto counter = new std::atomic(pskv.keys.size()); for (size_t idx = 0; idx < pskv.keys.size(); idx++) { - auto ks = pskv.keys.segment(idx, idx+1); + auto ks = pskv.keys.segment(idx, idx + 1); auto ls = new ps::SArray(1, pskv.lens[idx]); auto vs = new ps::SArray(data + off, pskv.lens[idx], false); - CHECK_NOTNULL(ps_worker_)->ZPushPull( - ks, *vs, vs, ls, cmd, [vs, ls, counter, cb]() { - delete vs; - delete ls; - if (--(*counter) == 0) { - delete counter; - cb(); - } - }, priority); + CHECK_NOTNULL(ps_worker_) + ->ZPushPull( + ks, + *vs, + vs, + ls, + cmd, + [vs, ls, counter, cb]() { + delete vs; + delete ls; + if (--(*counter) == 0) { + delete counter; + cb(); + } + }, + priority); off += pskv.lens[idx]; } }; - CHECK_NOTNULL(Engine::Get())->PushAsync( - pushpull, - pinned_ctx_, - {}, - {comm_buf.var()}, - FnProperty::kNormal, - priority, - "P3StoreDistDefaultStoragePushPull"); + CHECK_NOTNULL(Engine::Get()) + ->PushAsync(pushpull, + pinned_ctx_, + {}, + {comm_buf.var()}, + FnProperty::kNormal, + priority, + "P3StoreDistDefaultStoragePushPull"); } - inline PSKV& EncodeDefaultKey(const int key, const size_t num_arr_elems, + inline PSKV& EncodeDefaultKey(const int key, + const size_t num_arr_elems, const int num_bytes) override { mu_.lock(); PSKV& pskv = ps_kv_[key]; @@ -217,22 +235,21 @@ class P3StoreDist : public KVStoreDist { size_t pskv_size = num_arr_elems * num_bytes; if (!pskv.keys.empty()) { CHECK_EQ(static_cast(pskv.size), pskv_size) - << "The value size cannot be changed " << pskv_size << ". Key is " << key; + << "The value size cannot be changed " << pskv_size << ". Key is " << key; } else { - auto krs = ps::Postoffice::Get()->GetServerKeyRanges(); + auto krs = ps::Postoffice::Get()->GetServerKeyRanges(); const int num_servers = krs.size(); CHECK_GT(num_servers, 0); - int64_t num_params = num_arr_elems * num_bytes; - int64_t slice_bound = slice_threshold_ * num_bytes; + int64_t num_params = num_arr_elems * num_bytes; + int64_t slice_bound = slice_threshold_ * num_bytes; static size_t server = 0; while (num_params > 0) { - ps::Key ps_key = krs[server%num_servers].begin() - + (ps::Key)(key + server/num_servers); - CHECK_LT(ps_key, krs[server%num_servers].end()); + ps::Key ps_key = krs[server % num_servers].begin() + (ps::Key)(key + server / num_servers); + CHECK_LT(ps_key, krs[server % num_servers].end()); pskv.keys.push_back(ps_key); - const size_t part_size = static_cast((num_params > slice_bound) - ? slice_bound : num_params); + const size_t part_size = + static_cast((num_params > slice_bound) ? slice_bound : num_params); pskv.lens.push_back(part_size); pskv.size += part_size; @@ -252,5 +269,4 @@ class P3StoreDist : public KVStoreDist { } // namespace kvstore } // namespace mxnet - #endif // MXNET_KVSTORE_P3STORE_DIST_H_ diff --git a/src/lib_api.cc b/src/lib_api.cc index c0fea4ab3302..7bfa3a928b54 100644 --- a/src/lib_api.cc +++ b/src/lib_api.cc @@ -32,9 +32,9 @@ #include "mxnet/lib_api.h" mxnet::ext::MXerrorMsgs* mxnet::ext::MXerrorMsgs::get() { - static MXerrorMsgs inst; - return &inst; - } + static MXerrorMsgs inst; + return &inst; +} std::stringstream& mxnet::ext::MXerrorMsgs::add(const char* file, int line) { messages.emplace_back(); @@ -53,21 +53,34 @@ const std::string* mxnet::ext::MXerrorMsgs::get(int idx) { mxnet::ext::MXContext::MXContext() : dev_type("error"), dev_id(-1) {} mxnet::ext::MXContext::MXContext(std::string dev_type_, int dev_id_) - : dev_type(std::move(dev_type_)), dev_id(dev_id_) {} + : dev_type(std::move(dev_type_)), dev_id(dev_id_) {} mxnet::ext::MXContext::MXContext(const char* dev_type_, int dev_id_) - : dev_type(dev_type_), dev_id(dev_id_) {} + : dev_type(dev_type_), dev_id(dev_id_) {} -mxnet::ext::MXContext mxnet::ext::MXContext::CPU() { return MXContext("cpu", 0); } +mxnet::ext::MXContext mxnet::ext::MXContext::CPU() { + return MXContext("cpu", 0); +} -mxnet::ext::MXContext mxnet::ext::MXContext::GPU() { return MXContext("gpu", 0); } +mxnet::ext::MXContext mxnet::ext::MXContext::GPU() { + return MXContext("gpu", 0); +} -mxnet::ext::MXContext mxnet::ext::MXContext::CPU(int dev_id) { return MXContext("cpu", dev_id); } +mxnet::ext::MXContext mxnet::ext::MXContext::CPU(int dev_id) { + return MXContext("cpu", dev_id); +} -mxnet::ext::MXContext mxnet::ext::MXContext::GPU(int dev_id) { return MXContext("gpu", dev_id); } +mxnet::ext::MXContext mxnet::ext::MXContext::GPU(int dev_id) { + return MXContext("gpu", dev_id); +} -void mxnet::ext::MXSparse::set(void *data_ptr, const int64_t* dims, int ndims, void *idx, - int64_t num_idx, void *idx_ptr, int64_t num_idx_ptr) { +void mxnet::ext::MXSparse::set(void* data_ptr, + const int64_t* dims, + int ndims, + void* idx, + int64_t num_idx, + void* idx_ptr, + int64_t num_idx_ptr) { data = data_ptr; // If CSR, num of non-zero elemets is num_idx, // If row sparse, num of elements is num_idx * width. @@ -77,33 +90,54 @@ void mxnet::ext::MXSparse::set(void *data_ptr, const int64_t* dims, int ndims, v data_len *= dims[i]; } - indices = reinterpret_cast(idx); + indices = reinterpret_cast(idx); indices_len = num_idx; if (idx_ptr) { - indptr = reinterpret_cast(idx_ptr); + indptr = reinterpret_cast(idx_ptr); indptr_len = num_idx_ptr; } } -mxnet::ext::MXTensor::MXTensor() : data_ptr(nullptr), dtype(kUNSET), verID(0), - stype(kDefaultStorage) {} -mxnet::ext::MXTensor::MXTensor(const MXTensor& oth) : data_ptr(oth.data_ptr), shape(oth.shape), - dtype(oth.dtype), verID(oth.verID), - ctx(oth.ctx), stype(oth.stype) { +mxnet::ext::MXTensor::MXTensor() + : data_ptr(nullptr), dtype(kUNSET), verID(0), stype(kDefaultStorage) {} +mxnet::ext::MXTensor::MXTensor(const MXTensor& oth) + : data_ptr(oth.data_ptr), + shape(oth.shape), + dtype(oth.dtype), + verID(oth.verID), + ctx(oth.ctx), + stype(oth.stype) { setDLTensor(); } -mxnet::ext::MXTensor::MXTensor(void *data_ptr, std::vector shape, MXDType dtype, - size_t vID, MXContext mx_ctx, MXStorageType stype) - : data_ptr(data_ptr), shape(std::move(shape)), dtype(dtype), verID(vID), ctx(std::move(mx_ctx)), - stype(stype) { +mxnet::ext::MXTensor::MXTensor(void* data_ptr, + std::vector shape, + MXDType dtype, + size_t vID, + MXContext mx_ctx, + MXStorageType stype) + : data_ptr(data_ptr), + shape(std::move(shape)), + dtype(dtype), + verID(vID), + ctx(std::move(mx_ctx)), + stype(stype) { setDLTensor(); } -void mxnet::ext::MXTensor::setTensor(void *dptr, MXDType type, const int64_t* dims, int ndims, - size_t vID, MXContext mx_ctx, MXStorageType storage_type) { - data_ptr = dptr; dtype = type; verID = vID; ctx = mx_ctx; stype = storage_type; +void mxnet::ext::MXTensor::setTensor(void* dptr, + MXDType type, + const int64_t* dims, + int ndims, + size_t vID, + MXContext mx_ctx, + MXStorageType storage_type) { + data_ptr = dptr; + dtype = type; + verID = vID; + ctx = mx_ctx; + stype = storage_type; shape.clear(); for (int j = 0; j < ndims; j++) { shape.push_back(dims[j]); @@ -112,12 +146,12 @@ void mxnet::ext::MXTensor::setTensor(void *dptr, MXDType type, const int64_t* di } void mxnet::ext::MXTensor::setDLTensor() { - dltensor.data = data_ptr; - dltensor.ndim = shape.size(); - dltensor.shape = const_cast(shape.data()); - dltensor.strides = nullptr; - dltensor.byte_offset = 0; - dltensor.dtype.lanes = 1; + dltensor.data = data_ptr; + dltensor.ndim = shape.size(); + dltensor.shape = const_cast(shape.data()); + dltensor.strides = nullptr; + dltensor.byte_offset = 0; + dltensor.dtype.lanes = 1; dltensor.ctx.device_id = ctx.dev_id; if (ctx.dev_type == "cpu") dltensor.ctx.device_type = kDLCPU; @@ -136,72 +170,76 @@ void mxnet::ext::MXTensor::setDLTensor() { else dltensor.ctx.device_type = kDLExtDev; switch (dtype) { - case kFloat32: - dltensor.dtype.code = kDLFloat; - dltensor.dtype.bits = 32; - break; - case kFloat64: - dltensor.dtype.code = kDLFloat; - dltensor.dtype.bits = 64; - break; - case kFloat16: - dltensor.dtype.code = kDLFloat; - dltensor.dtype.bits = 16; - break; - case kUint8: - dltensor.dtype.code = kDLUInt; - dltensor.dtype.bits = 8; - break; - case kInt32: - dltensor.dtype.code = kDLInt; - dltensor.dtype.bits = 32; - break; - case kInt8: - dltensor.dtype.code = kDLInt; - dltensor.dtype.bits = 8; - break; - case kInt64: - dltensor.dtype.code = kDLInt; - dltensor.dtype.bits = 64; - break; - default: - dltensor.dtype.code = 0; - dltensor.dtype.bits = 0; - throw std::runtime_error("Error! Invalid dtype flag: " - + std::to_string(static_cast(dtype)) - + " when constructing MXTensor"); + case kFloat32: + dltensor.dtype.code = kDLFloat; + dltensor.dtype.bits = 32; + break; + case kFloat64: + dltensor.dtype.code = kDLFloat; + dltensor.dtype.bits = 64; + break; + case kFloat16: + dltensor.dtype.code = kDLFloat; + dltensor.dtype.bits = 16; + break; + case kUint8: + dltensor.dtype.code = kDLUInt; + dltensor.dtype.bits = 8; + break; + case kInt32: + dltensor.dtype.code = kDLInt; + dltensor.dtype.bits = 32; + break; + case kInt8: + dltensor.dtype.code = kDLInt; + dltensor.dtype.bits = 8; + break; + case kInt64: + dltensor.dtype.code = kDLInt; + dltensor.dtype.bits = 64; + break; + default: + dltensor.dtype.code = 0; + dltensor.dtype.bits = 0; + throw std::runtime_error( + "Error! Invalid dtype flag: " + std::to_string(static_cast(dtype)) + + " when constructing MXTensor"); } } int64_t mxnet::ext::MXTensor::size() const { int64_t size = 1; - for (auto &s : shape) + for (auto& s : shape) size *= s; return size; } -bool mxnet::ext::MXTensor::isSame(const MXTensor &oth) const { - return data_ptr == oth.data_ptr && - dtype == oth.dtype && - verID == oth.verID && - ctx.dev_type == oth.ctx.dev_type && - ctx.dev_id == oth.ctx.dev_id && - shape == oth.shape && - stype == oth.stype; +bool mxnet::ext::MXTensor::isSame(const MXTensor& oth) const { + return data_ptr == oth.data_ptr && dtype == oth.dtype && verID == oth.verID && + ctx.dev_type == oth.ctx.dev_type && ctx.dev_id == oth.ctx.dev_id && shape == oth.shape && + stype == oth.stype; } mxnet::ext::PassResource::PassResource(std::unordered_map* new_args, std::unordered_map* new_aux, - nd_malloc_t nd_malloc, const void* nd_alloc) - : new_args_(new_args), new_aux_(new_aux), nd_malloc_(nd_malloc), nd_alloc_(nd_alloc) {} + nd_malloc_t nd_malloc, + const void* nd_alloc) + : new_args_(new_args), new_aux_(new_aux), nd_malloc_(nd_malloc), nd_alloc_(nd_alloc) {} mxnet::ext::MXTensor* mxnet::ext::PassResource::alloc_arg(const std::string& name, const std::vector& shapes, - const mxnet::ext::MXContext &ctx, + const mxnet::ext::MXContext& ctx, mxnet::ext::MXDType dtype) const { void* data; - nd_malloc_(nd_alloc_, shapes.data(), shapes.size(), ctx.dev_type.c_str(), ctx.dev_id, - dtype, name.c_str(), 1, &data); + nd_malloc_(nd_alloc_, + shapes.data(), + shapes.size(), + ctx.dev_type.c_str(), + ctx.dev_id, + dtype, + name.c_str(), + 1, + &data); MXTensor tensor(data, shapes, dtype, 0, ctx, kDefaultStorage); (*new_args_)[name] = tensor; return &(new_args_->at(name)); @@ -209,24 +247,41 @@ mxnet::ext::MXTensor* mxnet::ext::PassResource::alloc_arg(const std::string& nam mxnet::ext::MXTensor* mxnet::ext::PassResource::alloc_aux(const std::string& name, const std::vector& shapes, - const mxnet::ext::MXContext &ctx, + const mxnet::ext::MXContext& ctx, mxnet::ext::MXDType dtype) const { void* data; - nd_malloc_(nd_alloc_, shapes.data(), shapes.size(), ctx.dev_type.c_str(), ctx.dev_id, - dtype, name.c_str(), 0, &data); + nd_malloc_(nd_alloc_, + shapes.data(), + shapes.size(), + ctx.dev_type.c_str(), + ctx.dev_id, + dtype, + name.c_str(), + 0, + &data); MXTensor tensor(data, shapes, dtype, 0, ctx, kDefaultStorage); (*new_aux_)[name] = tensor; return &(new_aux_->at(name)); } -mxnet::ext::OpResource::OpResource(xpu_malloc_t cpu_malloc_fp, void* cpu_alloc_fp, - xpu_malloc_t gpu_malloc_fp, void* gpu_alloc_fp, void* stream, - sparse_malloc_t sparse_malloc_fp, void* sparse_alloc_fp, - void* rng_cpu_states, void* rng_gpu_states) - : cpu_malloc(cpu_malloc_fp), gpu_malloc(gpu_malloc_fp), - cpu_alloc(cpu_alloc_fp), gpu_alloc(gpu_alloc_fp), cuda_stream(stream), - sparse_malloc(sparse_malloc_fp), sparse_alloc(sparse_alloc_fp), - rand_cpu_states(rng_cpu_states), rand_gpu_states(rng_gpu_states) {} +mxnet::ext::OpResource::OpResource(xpu_malloc_t cpu_malloc_fp, + void* cpu_alloc_fp, + xpu_malloc_t gpu_malloc_fp, + void* gpu_alloc_fp, + void* stream, + sparse_malloc_t sparse_malloc_fp, + void* sparse_alloc_fp, + void* rng_cpu_states, + void* rng_gpu_states) + : cpu_malloc(cpu_malloc_fp), + gpu_malloc(gpu_malloc_fp), + cpu_alloc(cpu_alloc_fp), + gpu_alloc(gpu_alloc_fp), + cuda_stream(stream), + sparse_malloc(sparse_malloc_fp), + sparse_alloc(sparse_alloc_fp), + rand_cpu_states(rng_cpu_states), + rand_gpu_states(rng_gpu_states) {} void* mxnet::ext::OpResource::alloc_cpu(int size) const { return cpu_malloc(cpu_alloc, size); @@ -236,10 +291,17 @@ void* mxnet::ext::OpResource::alloc_gpu(int size) const { return gpu_malloc(gpu_alloc, size); } -void mxnet::ext::OpResource::alloc_sparse(mxnet::ext::MXSparse* sparse, int index, - int indices_len, int indptr_len) const { - sparse_malloc(sparse_alloc, index, indices_len, indptr_len, - &(sparse->data), &(sparse->indices), &(sparse->indptr)); +void mxnet::ext::OpResource::alloc_sparse(mxnet::ext::MXSparse* sparse, + int index, + int indices_len, + int indptr_len) const { + sparse_malloc(sparse_alloc, + index, + indices_len, + indptr_len, + &(sparse->data), + &(sparse->indices), + &(sparse->indptr)); } mxnet::ext::mx_cpu_rand_t* mxnet::ext::OpResource::get_cpu_rand_states() const { @@ -249,50 +311,57 @@ mxnet::ext::mx_cpu_rand_t* mxnet::ext::OpResource::get_cpu_rand_states() const { std::string mxnet::ext::getShapeAt(const std::string& shape, unsigned index) { int idx = 1; // start at 1 to skip the first square bracket [ // find the beginning of the output shape for the particular output index - for (unsigned x=0; x < index; x++) - idx = shape.find('[', idx+1); + for (unsigned x = 0; x < index; x++) + idx = shape.find('[', idx + 1); int stop = shape.find(']', idx); // find stop index for this output shape // add this shape to the list - return shape.substr(idx, stop-idx+1); + return shape.substr(idx, stop - idx + 1); } std::string mxnet::ext::getDtypeAt(const std::string& dtype, unsigned index) { // find the beginning of the output dtype for the particular output index int idx = 0; - for (unsigned x=0; x < index; x++) - idx = dtype.find(',', idx+1); - int stop = dtype.find(',', idx+1); // find stop index for this output dtype - if (stop == -1) stop = dtype.find(']', idx+1); - return dtype.substr(idx+1, stop-idx-1); + for (unsigned x = 0; x < index; x++) + idx = dtype.find(',', idx + 1); + int stop = dtype.find(',', idx + 1); // find stop index for this output dtype + if (stop == -1) + stop = dtype.find(']', idx + 1); + return dtype.substr(idx + 1, stop - idx - 1); } mxnet::ext::JsonVal::JsonVal() : type(ERR), num(-1), str("") {} mxnet::ext::JsonVal::JsonVal(mxnet::ext::JsonType t) : type(t), num(-1), str("") {} mxnet::ext::JsonVal::JsonVal(std::string s) : type(STR), num(-1), str(std::move(s)) {} mxnet::ext::JsonVal::JsonVal(int n) : type(NUM), num(n), str(std::to_string(n)) {} -mxnet::ext::JsonVal::JsonVal(JsonType t, int n, std::string s) : type(t), num(n), - str(std::move(s)) {} +mxnet::ext::JsonVal::JsonVal(JsonType t, int n, std::string s) + : type(t), num(n), str(std::move(s)) {} -bool mxnet::ext::JsonVal::operator<(const mxnet::ext::JsonVal &o) const { +bool mxnet::ext::JsonVal::operator<(const mxnet::ext::JsonVal& o) const { // for string JSON objects compare the string - if (type == STR) return type == o.type && str < o.str; + if (type == STR) + return type == o.type && str < o.str; // for number JSON objects compare the number - if (type == NUM) return type == o.type && num < o.num; + if (type == NUM) + return type == o.type && num < o.num; // for list JSON objects, compare the size of list, and then each object in the list if (type == LIST) { - if (list.size() != o.list.size()) return false; - for (unsigned int i=0; i< list.size(); i++) + if (list.size() != o.list.size()) + return false; + for (unsigned int i = 0; i < list.size(); i++) if (list[i] < o.list[i]) return false; // if we find an object that doesnt match return - return true; // all objects in lists matched + return true; // all objects in lists matched } // for map JSON objects, compare the size of map, and then each key/value in the maps if (type == MAP) { - if (map.size() != o.map.size()) return false; - for (auto &item : map) { + if (map.size() != o.map.size()) + return false; + for (auto& item : map) { // if one map is missing a key in another return - if (o.map.find(item.first) == o.map.end()) return false; - if (item.second < o.map.at(item.first)) return false; + if (o.map.find(item.first) == o.map.end()) + return false; + if (item.second < o.map.at(item.first)) + return false; } return true; } @@ -302,35 +371,35 @@ bool mxnet::ext::JsonVal::operator<(const mxnet::ext::JsonVal &o) const { std::string mxnet::ext::JsonVal::dump() const { std::string ret; switch (type) { - case ERR: - ret = "json(Error)"; - break; - case STR: - ret = "\"" + str + "\""; - break; - case NUM: - ret = str; - break; - case LIST: - ret = "["; - for (unsigned i=0; i < list.size(); i++) { - auto &item = list[i]; - ret += item.dump(); - if (i < list.size()-1) - ret += ","; - } - ret += "]"; - break; - case MAP: - ret = "{"; - unsigned cnt = 0; - for (auto &item : map) { - ret += item.first.dump() + " : " + item.second.dump(); - if (cnt++ < map.size()-1) - ret += ","; - } - ret += "}"; - break; + case ERR: + ret = "json(Error)"; + break; + case STR: + ret = "\"" + str + "\""; + break; + case NUM: + ret = str; + break; + case LIST: + ret = "["; + for (unsigned i = 0; i < list.size(); i++) { + auto& item = list[i]; + ret += item.dump(); + if (i < list.size() - 1) + ret += ","; + } + ret += "]"; + break; + case MAP: + ret = "{"; + unsigned cnt = 0; + for (auto& item : map) { + ret += item.first.dump() + " : " + item.second.dump(); + if (cnt++ < map.size() - 1) + ret += ","; + } + ret += "}"; + break; } return ret; } @@ -343,8 +412,8 @@ mxnet::ext::JsonVal mxnet::ext::JsonVal::parse(const std::string& json) { mxnet::ext::JsonVal mxnet::ext::JsonVal::parse_string(const std::string& json, unsigned int* idx) { JsonVal ret(STR); while (*idx < json.size()) { - if (json[*idx] == '"' && (ret.str.size() == 0 || - (ret.str.size() > 0 && ret.str.back() != '\\'))) { + if (json[*idx] == '"' && + (ret.str.size() == 0 || (ret.str.size() > 0 && ret.str.back() != '\\'))) { ++(*idx); return ret; } else { @@ -398,7 +467,7 @@ mxnet::ext::JsonVal mxnet::ext::JsonVal::parse_map(const std::string& json, unsi key = item; } else { ret.map[key] = item; - key.type = ERR; + key.type = ERR; } } } @@ -406,7 +475,7 @@ mxnet::ext::JsonVal mxnet::ext::JsonVal::parse_map(const std::string& json, unsi return mxnet::ext::JsonVal(); } -mxnet::ext::JsonVal mxnet::ext::JsonVal::parse(const std::string& json, unsigned int *idx) { +mxnet::ext::JsonVal mxnet::ext::JsonVal::parse(const std::string& json, unsigned int* idx) { JsonVal ret; while (*idx < json.size()) { if (json[*idx] == '"') { @@ -420,8 +489,11 @@ mxnet::ext::JsonVal mxnet::ext::JsonVal::parse(const std::string& json, unsigned } else if (json[*idx] == '{') { ++(*idx); ret = JsonVal::parse_map(json, idx); - } else if (json[*idx] == ']' || json[*idx] == '}') {return ret;} - if (ret.type != ERR) return ret; + } else if (json[*idx] == ']' || json[*idx] == '}') { + return ret; + } + if (ret.type != ERR) + return ret; ++(*idx); } return ret; @@ -430,44 +502,50 @@ mxnet::ext::JsonVal mxnet::ext::JsonVal::parse(const std::string& json, unsigned std::string mxnet::ext::JsonVal::toString() const { std::string ret; switch (type) { - case ERR: - ret = "json(Error)"; - break; - case STR: - ret = "json(STR:" + str + ")"; - break; - case NUM: - ret = "json(INT:" + str + ")"; - break; - case LIST: - ret = "json(LIST:["; - for (auto &item : list) - ret += item.toString() + ","; - ret += "])"; - break; - case MAP: - ret = "json(MAP:{"; - for (auto &item : map) - ret += item.first.toString() + " : " + item.second.toString() + ","; - ret += "})"; - break; + case ERR: + ret = "json(Error)"; + break; + case STR: + ret = "json(STR:" + str + ")"; + break; + case NUM: + ret = "json(INT:" + str + ")"; + break; + case LIST: + ret = "json(LIST:["; + for (auto& item : list) + ret += item.toString() + ","; + ret += "])"; + break; + case MAP: + ret = "json(MAP:{"; + for (auto& item : map) + ret += item.first.toString() + " : " + item.second.toString() + ","; + ret += "})"; + break; } return ret; } -mxnet::ext::Node::Node() {tensor = nullptr;} +mxnet::ext::Node::Node() { + tensor = nullptr; +} -void mxnet::ext::Node::_setPassResource(mxnet::ext::PassResource* res_) {res = res_;} +void mxnet::ext::Node::_setPassResource(mxnet::ext::PassResource* res_) { + res = res_; +} void mxnet::ext::Node::alloc_arg(const std::vector& shapes, - const mxnet::ext::MXContext &ctx, mxnet::ext::MXDType dtype) { + const mxnet::ext::MXContext& ctx, + mxnet::ext::MXDType dtype) { if (!res) throw std::runtime_error("Node not initialized. Cannot use alloc_arg outside of graph passes."); tensor = res->alloc_arg(name, shapes, ctx, dtype); } void mxnet::ext::Node::alloc_aux(const std::vector& shapes, - const mxnet::ext::MXContext &ctx, mxnet::ext::MXDType dtype) { + const mxnet::ext::MXContext& ctx, + mxnet::ext::MXDType dtype) { if (!res) throw std::runtime_error("Node not initialized. Cannot use alloc_aux outside of graph passes."); tensor = res->alloc_aux(name, shapes, ctx, dtype); @@ -476,7 +554,7 @@ void mxnet::ext::Node::alloc_aux(const std::vector& shapes, mxnet::ext::Graph::Graph() : res(nullptr) {} mxnet::ext::Graph::~Graph() { - for (auto &node : nodes) + for (auto& node : nodes) delete node; } @@ -488,7 +566,7 @@ mxnet::ext::Graph* mxnet::ext::Graph::fromString(const std::string& json) { mxnet::ext::Graph* mxnet::ext::Graph::fromJson(mxnet::ext::JsonVal val) { // get nodes list JsonVal nodes = val.map[JsonVal("nodes")]; - Graph *g = new Graph(); + Graph* g = new Graph(); std::map nodeMap; // loop over nodes @@ -498,7 +576,7 @@ mxnet::ext::Graph* mxnet::ext::Graph::fromJson(mxnet::ext::JsonVal val) { JsonVal node = nodes.list[i]; // set the op info - n->op = node.map[JsonVal("op")].str; + n->op = node.map[JsonVal("op")].str; n->name = node.map[JsonVal("name")].str; // if op is null it is an input to the graph @@ -514,7 +592,7 @@ mxnet::ext::Graph* mxnet::ext::Graph::fromJson(mxnet::ext::JsonVal val) { // set subgraphs, parsing each into a graph if (node.map.count(JsonVal("subgraphs")) > 0) { JsonVal subgraphs = node.map[JsonVal("subgraphs")]; - for (auto &subgraph : subgraphs.list) { + for (auto& subgraph : subgraphs.list) { n->subgraphs.push_back(fromJson(subgraph)); } } @@ -523,7 +601,7 @@ mxnet::ext::Graph* mxnet::ext::Graph::fromJson(mxnet::ext::JsonVal val) { JsonVal node_inputs = node.map[JsonVal("inputs")]; n->inputs.resize(node_inputs.list.size()); for (int j = 0; j < node_inputs.list.size(); j++) { - JsonVal input = node_inputs.list[j]; + JsonVal input = node_inputs.list[j]; NodeEntry& entry = n->inputs[j]; // get pointer to other node entry.node = nodeMap[input.list[0].num]; @@ -539,17 +617,15 @@ mxnet::ext::Graph* mxnet::ext::Graph::fromJson(mxnet::ext::JsonVal val) { JsonVal& heads = val.map[JsonVal("heads")]; g->outputs.resize(heads.list.size()); for (int i = 0; i < heads.list.size(); i++) { - JsonVal head = heads.list[i]; - g->outputs[i].node = nodeMap[head.list[0].num]; + JsonVal head = heads.list[i]; + g->outputs[i].node = nodeMap[head.list[0].num]; g->outputs[i].entry = head.list[1].num; } // add all attributes to the graph for (auto& kv : val.map) { - if (kv.first.str.compare("nodes") != 0 && - kv.first.str.compare("heads") != 0 && - kv.first.str.compare("node_row_ptr") != 0 && - kv.first.str.compare("arg_nodes") != 0) { + if (kv.first.str.compare("nodes") != 0 && kv.first.str.compare("heads") != 0 && + kv.first.str.compare("node_row_ptr") != 0 && kv.first.str.compare("arg_nodes") != 0) { g->attrs[kv.first.str] = kv.second; } } @@ -571,25 +647,25 @@ mxnet::ext::JsonVal mxnet::ext::Graph::toJson() const { std::vector sorted = topological_sort(); // nodes are in reverse topological order in the vector (back is first) // so loop from end to front over the vector 'sorted' - for (int i = sorted.size()-1; i >= 0; i--) { - nodeMap[sorted[i]] = sorted.size()-1-i; + for (int i = sorted.size() - 1; i >= 0; i--) { + nodeMap[sorted[i]] = sorted.size() - 1 - i; } // create node_row_ptr entry val.map[JsonVal("node_row_ptr")] = JsonVal(LIST); - JsonVal& node_row_ptr = val.map[JsonVal("node_row_ptr")]; + JsonVal& node_row_ptr = val.map[JsonVal("node_row_ptr")]; for (int i = 0; i < nodes.size(); i++) node_row_ptr.list.emplace_back(i); // add all input nodes val.map[JsonVal("arg_nodes")] = JsonVal(LIST); - JsonVal& arg_nodes = val.map[JsonVal("arg_nodes")]; - for (auto &input : inputs) + JsonVal& arg_nodes = val.map[JsonVal("arg_nodes")]; + for (auto& input : inputs) arg_nodes.list.emplace_back(nodeMap[input]); // add all output nodes val.map[JsonVal("heads")] = JsonVal(LIST); - JsonVal& heads = val.map[JsonVal("heads")]; + JsonVal& heads = val.map[JsonVal("heads")]; for (int i = 0; i < outputs.size(); i++) { heads.list.emplace_back(LIST); JsonVal& out = heads.list[i]; @@ -600,15 +676,15 @@ mxnet::ext::JsonVal mxnet::ext::Graph::toJson() const { // add all graph nodes val.map[JsonVal("nodes")] = JsonVal(LIST); - JsonVal& nodes_ = val.map[JsonVal("nodes")]; - for (int i = sorted.size()-1; i >= 0; i--) { + JsonVal& nodes_ = val.map[JsonVal("nodes")]; + for (int i = sorted.size() - 1; i >= 0; i--) { // each node is a map nodes_.list.emplace_back(MAP); - Node* n = sorted[i]; - JsonVal& n_ = nodes_.list[nodes_.list.size()-1]; + Node* n = sorted[i]; + JsonVal& n_ = nodes_.list[nodes_.list.size() - 1]; - n_.map[JsonVal("op")] = JsonVal(n->op); - n_.map[JsonVal("name")] = JsonVal(n->name); + n_.map[JsonVal("op")] = JsonVal(n->op); + n_.map[JsonVal("name")] = JsonVal(n->name); n_.map[JsonVal("inputs")] = JsonVal(LIST); // add inputs for this node @@ -616,7 +692,7 @@ mxnet::ext::JsonVal mxnet::ext::Graph::toJson() const { for (int j = 0; j < n->inputs.size(); j++) { inputs_.list.emplace_back(LIST); NodeEntry& entry = n->inputs[j]; - JsonVal& in = inputs_.list[j]; + JsonVal& in = inputs_.list[j]; in.list.emplace_back(nodeMap[entry.node]); in.list.emplace_back(entry.entry); in.list.emplace_back(0); @@ -625,15 +701,15 @@ mxnet::ext::JsonVal mxnet::ext::Graph::toJson() const { // add subgraphs for this node, convert each back to JSON if (n->subgraphs.size() > 0) { n_.map[JsonVal("subgraphs")] = JsonVal(LIST); - JsonVal &subgraphs_ = n_.map[JsonVal("subgraphs")]; - for (Graph *subgraph : n->subgraphs) { + JsonVal& subgraphs_ = n_.map[JsonVal("subgraphs")]; + for (Graph* subgraph : n->subgraphs) { subgraphs_.list.push_back(subgraph->toJson()); } } // add attributes for this node n_.map[JsonVal("attrs")] = JsonVal(MAP); - JsonVal& attrs_ = n_.map[JsonVal("attrs")]; + JsonVal& attrs_ = n_.map[JsonVal("attrs")]; for (auto& kv : n->attrs) { attrs_.map[JsonVal(kv.first)] = JsonVal(kv.second); } @@ -646,8 +722,9 @@ std::string mxnet::ext::Graph::toString() const { return toJson().dump(); } - /* \brief visits a node "n" */ -void mxnet::ext::Graph::_dfs_util(Node* n, std::unordered_set* to_visit, +/* \brief visits a node "n" */ +void mxnet::ext::Graph::_dfs_util(Node* n, + std::unordered_set* to_visit, std::function handler) const { to_visit->erase(n); // remove node now that we're visiting it for (NodeEntry& e : n->outputs) { @@ -687,31 +764,30 @@ std::vector mxnet::ext::Graph::topological_sort() const { /* \brief print out graph details */ void mxnet::ext::Graph::print(int indent) const { std::string space = ""; - for (int i = 0; i < indent; i++) space+=" "; + for (int i = 0; i < indent; i++) + space += " "; std::cout << space << "########### Graph #############" << std::endl; std::cout << space << "attributes: " << std::endl; - for (auto &kv : attrs) + for (auto& kv : attrs) std::cout << space << "\t" << kv.first << " : " << kv.second.str << std::endl; std::cout << space << "inputs: " << inputs.size() << std::endl; std::cout << space << "outputs: " << outputs.size() << std::endl; std::cout << space << "nodes: " << nodes.size() << std::endl; std::vector sorted = topological_sort(); // loop over each node and print out its inputs/outputs - for (int i = sorted.size()-1; i >= 0; i--) { + for (int i = sorted.size() - 1; i >= 0; i--) { std::cout << space << "Node: " << sorted[i]->name << std::endl; - for (auto &input : sorted[i]->inputs) { - std::cout << space << "\tInput: " << input.node->name << " " - << input.entry << std::endl; + for (auto& input : sorted[i]->inputs) { + std::cout << space << "\tInput: " << input.node->name << " " << input.entry << std::endl; } - for (auto &output : sorted[i]->outputs) { - std::cout << space << "\tOutput: " << output.node->name << " " - << output.entry << std::endl; + for (auto& output : sorted[i]->outputs) { + std::cout << space << "\tOutput: " << output.node->name << " " << output.entry << std::endl; } if (sorted[i]->subgraphs.size() > 0) { - for (auto &subgraph : sorted[i]->subgraphs) { + for (auto& subgraph : sorted[i]->subgraphs) { std::cout << space << "\tSubgraph:" << std::endl; - subgraph->print(indent+2); + subgraph->print(indent + 2); } } } @@ -723,7 +799,7 @@ mxnet::ext::Node* mxnet::ext::Graph::addNode(const std::string& name, const std: Node* n = new Node(); nodes.push_back(n); n->name = name; - n->op = op; + n->op = op; if (res) n->_setPassResource(res); return n; @@ -775,8 +851,13 @@ void mxnet::ext::Graph::_setParams(std::unordered_map 0) @@ -849,8 +930,8 @@ void mxnet::ext::CustomOp::mapToVector() { void mxnet::ext::CustomOp::raiseDuplicateContextError() { std::string op_name_str(name); throw std::runtime_error( - "Error! Error! Cannot register multiple functions under same context for operator '" - + op_name_str + "'"); + "Error! Error! Cannot register multiple functions under same context for operator '" + + op_name_str + "'"); } mxnet::ext::CustomStatefulOp::CustomStatefulOp() : ignore_warn(false), created(false) {} @@ -861,16 +942,14 @@ mxnet::ext::CustomStatefulOpWrapper::~CustomStatefulOpWrapper() { } mxnet::ext::CustomPass::CustomPass() : name("ERROR") {} -mxnet::ext::CustomPass::CustomPass(const char* pass_name) - : name(pass_name) {} +mxnet::ext::CustomPass::CustomPass(const char* pass_name) : name(pass_name) {} mxnet::ext::CustomPass& mxnet::ext::CustomPass::setBody(graphPass_t fn) { pass = fn; return *this; } mxnet::ext::CustomPartitioner::CustomPartitioner() : name("ERROR") {} -mxnet::ext::CustomPartitioner::CustomPartitioner(const char* backend_name) : - name(backend_name) {} +mxnet::ext::CustomPartitioner::CustomPartitioner(const char* backend_name) : name(backend_name) {} mxnet::ext::CustomPartitioner& mxnet::ext::CustomPartitioner::addStrategy(const char* prop_name, const char* sg_name) { @@ -879,20 +958,23 @@ mxnet::ext::CustomPartitioner& mxnet::ext::CustomPartitioner::addStrategy(const return *this; } -mxnet::ext::CustomPartitioner& mxnet::ext::CustomPartitioner::setSupportedOps(const char* prop_name, - mxnet::ext::supportedOps_t fn) { +mxnet::ext::CustomPartitioner& mxnet::ext::CustomPartitioner::setSupportedOps( + const char* prop_name, + mxnet::ext::supportedOps_t fn) { supported_map[std::string(prop_name)] = fn; return *this; } mxnet::ext::CustomPartitioner& mxnet::ext::CustomPartitioner::setCreateSelector( - const char* prop_name, mxnet::ext::createSelector_t fn) { + const char* prop_name, + mxnet::ext::createSelector_t fn) { selector_map[std::string(prop_name)] = fn; return *this; } mxnet::ext::CustomPartitioner& mxnet::ext::CustomPartitioner::setReviewSubgraph( - const char* prop_name, mxnet::ext::reviewSubgraph_t fn) { + const char* prop_name, + mxnet::ext::reviewSubgraph_t fn) { review_map[std::string(prop_name)] = fn; return *this; } @@ -932,31 +1014,40 @@ MX_INT_RET _opRegSize() { } /*! \brief returns operator registration at specified index */ -MX_VOID_RET _opRegGet(int idx, const char** name, int *isSGop, - const char*** forward_ctx, mxnet::ext::fcomp_t** forward_fp, - int* forward_count, const char*** backward_ctx, - mxnet::ext::fcomp_t** backward_fp, int* backward_count, - const char*** create_op_ctx, mxnet::ext::createOpState_t** create_op_fp, - int* create_op_count, mxnet::ext::parseAttrs_t* parse, - mxnet::ext::inferType_t* type, mxnet::ext::inferSType_t* stype, - mxnet::ext::inferShape_t* shape, mxnet::ext::mutateInputs_t* mutate) { - mxnet::ext::CustomOp &op = mxnet::ext::Registry::get()->get(idx); - *name = op.name; - *parse = op.parse_attrs; - *type = op.infer_type; - *stype = op.infer_storage_type; - *shape = op.infer_shape; - *mutate = op.mutate_inputs; - *isSGop = op.isSGop; +MX_VOID_RET _opRegGet(int idx, + const char** name, + int* isSGop, + const char*** forward_ctx, + mxnet::ext::fcomp_t** forward_fp, + int* forward_count, + const char*** backward_ctx, + mxnet::ext::fcomp_t** backward_fp, + int* backward_count, + const char*** create_op_ctx, + mxnet::ext::createOpState_t** create_op_fp, + int* create_op_count, + mxnet::ext::parseAttrs_t* parse, + mxnet::ext::inferType_t* type, + mxnet::ext::inferSType_t* stype, + mxnet::ext::inferShape_t* shape, + mxnet::ext::mutateInputs_t* mutate) { + mxnet::ext::CustomOp& op = mxnet::ext::Registry::get()->get(idx); + *name = op.name; + *parse = op.parse_attrs; + *type = op.infer_type; + *stype = op.infer_storage_type; + *shape = op.infer_shape; + *mutate = op.mutate_inputs; + *isSGop = op.isSGop; op.mapToVector(); - *forward_ctx = op.forward_ctx_cstr.data(); - *forward_fp = op.forward_fp.data(); - *forward_count = op.forward_fp.size(); - *backward_ctx = op.backward_ctx_cstr.data(); - *backward_fp = op.backward_fp.data(); - *backward_count = op.backward_fp.size(); - *create_op_ctx = op.create_op_ctx_cstr.data(); - *create_op_fp = op.create_op_fp.data(); + *forward_ctx = op.forward_ctx_cstr.data(); + *forward_fp = op.forward_fp.data(); + *forward_count = op.forward_fp.size(); + *backward_ctx = op.backward_ctx_cstr.data(); + *backward_fp = op.backward_fp.data(); + *backward_count = op.backward_fp.size(); + *create_op_ctx = op.create_op_ctx_cstr.data(); + *create_op_fp = op.create_op_fp.data(); *create_op_count = op.create_op_fp.size(); } @@ -966,9 +1057,12 @@ MX_VOID_RET _opCallFree(void* ptr) { } /*! \brief returns status of calling parse attributes function for operator from library */ -MX_INT_RET _opCallParseAttrs(mxnet::ext::parseAttrs_t parseAttrs, const char* const* keys, - const char* const* vals, int num, - int* num_in, int* num_out) { +MX_INT_RET _opCallParseAttrs(mxnet::ext::parseAttrs_t parseAttrs, + const char* const* keys, + const char* const* vals, + int num, + int* num_in, + int* num_out) { // create map of attributes from list std::unordered_map attrs; for (int i = 0; i < num; i++) { @@ -978,11 +1072,18 @@ MX_INT_RET _opCallParseAttrs(mxnet::ext::parseAttrs_t parseAttrs, const char* co } /*! \brief returns status of calling inferShape function for operator from library */ -MX_INT_RET _opCallInferShape(mxnet::ext::inferShape_t inferShape, const char* const* keys, - const char* const* vals, int num, - unsigned int** inshapes, int* indims, int num_in, - unsigned int*** mod_inshapes, int** mod_indims, - unsigned int*** outshapes, int** outdims, int num_out) { +MX_INT_RET _opCallInferShape(mxnet::ext::inferShape_t inferShape, + const char* const* keys, + const char* const* vals, + int num, + unsigned int** inshapes, + int* indims, + int num_in, + unsigned int*** mod_inshapes, + int** mod_indims, + unsigned int*** outshapes, + int** outdims, + int num_out) { // create map of attributes from list std::unordered_map attrs; for (int i = 0; i < num; i++) { @@ -1001,29 +1102,30 @@ MX_INT_RET _opCallInferShape(mxnet::ext::inferShape_t inferShape, const char* co std::vector > out_shapes(num_out); int retval = inferShape(attrs, &in_shapes, &out_shapes); - if (!retval) return retval; + if (!retval) + return retval; // allocate space for modified input dims, shape - *mod_indims = static_cast(malloc (num_in * sizeof(int))); - *mod_inshapes = static_cast(malloc (num_in * sizeof(unsigned*))); + *mod_indims = static_cast(malloc(num_in * sizeof(int))); + *mod_inshapes = static_cast(malloc(num_in * sizeof(unsigned*))); // copy modified input shapes for (int i = 0; i < num_in; i++) { - (*mod_indims)[i] = in_shapes[i].size(); - (*mod_inshapes)[i] = static_cast(malloc ((*mod_indims)[i] * sizeof(unsigned))); + (*mod_indims)[i] = in_shapes[i].size(); + (*mod_inshapes)[i] = static_cast(malloc((*mod_indims)[i] * sizeof(unsigned))); for (int j = 0; j < (*mod_indims)[i]; j++) { (*mod_inshapes)[i][j] = in_shapes[i][j]; } } // allocate space for output dims, shape - *outdims = static_cast(malloc (num_out * sizeof(int))); - *outshapes = static_cast(malloc (num_out * sizeof(unsigned*))); + *outdims = static_cast(malloc(num_out * sizeof(int))); + *outshapes = static_cast(malloc(num_out * sizeof(unsigned*))); // copy output shapes for (int i = 0; i < num_out; i++) { - (*outdims)[i] = out_shapes[i].size(); - (*outshapes)[i] = static_cast(malloc ((*outdims)[i] * sizeof(unsigned))); + (*outdims)[i] = out_shapes[i].size(); + (*outshapes)[i] = static_cast(malloc((*outdims)[i] * sizeof(unsigned))); for (int j = 0; j < (*outdims)[i]; j++) { (*outshapes)[i][j] = out_shapes[i][j]; } @@ -1032,9 +1134,14 @@ MX_INT_RET _opCallInferShape(mxnet::ext::inferShape_t inferShape, const char* co } /*! \brief returns status of calling inferType function for operator from library */ -MX_INT_RET _opCallInferType(mxnet::ext::inferType_t inferType, const char* const* keys, - const char* const* vals, int num, - int* intypes, int num_in, int* outtypes, int num_out) { +MX_INT_RET _opCallInferType(mxnet::ext::inferType_t inferType, + const char* const* keys, + const char* const* vals, + int num, + int* intypes, + int num_in, + int* outtypes, + int num_out) { // create map of attributes from list std::unordered_map attrs; for (int i = 0; i < num; i++) { @@ -1067,9 +1174,14 @@ MX_INT_RET _opCallInferType(mxnet::ext::inferType_t inferType, const char* const } /*! \brief returns status of calling inferSType function for operator from library */ -MX_INT_RET _opCallInferSType(mxnet::ext::inferSType_t inferSType, const char* const* keys, - const char* const* vals, int num, - int* instypes, int num_in, int* outstypes, int num_out) { +MX_INT_RET _opCallInferSType(mxnet::ext::inferSType_t inferSType, + const char* const* keys, + const char* const* vals, + int num, + int* instypes, + int num_in, + int* outstypes, + int num_out) { // create map of attributes from list std::unordered_map attrs; for (int i = 0; i < num; i++) { @@ -1103,22 +1215,45 @@ MX_INT_RET _opCallInferSType(mxnet::ext::inferSType_t inferSType, const char* co } /*! \brief returns status of calling Forward/Backward function for operator from library */ -MX_INT_RET _opCallFCompute(mxnet::ext::fcomp_t fcomp, const char* const* keys, +MX_INT_RET _opCallFCompute(mxnet::ext::fcomp_t fcomp, + const char* const* keys, const char* const* vals, - int num, const int64_t** inshapes, int* indims, void** indata, - int* intypes, size_t* inIDs, const char** indev_type, int* indev_id, - int num_in, const int64_t** outshapes, int* outdims, void** outdata, - int* outtypes, size_t* outIDs, const char** outdev_type, - int* outdev_id, int num_out, mxnet::ext::xpu_malloc_t cpu_malloc, + int num, + const int64_t** inshapes, + int* indims, + void** indata, + int* intypes, + size_t* inIDs, + const char** indev_type, + int* indev_id, + int num_in, + const int64_t** outshapes, + int* outdims, + void** outdata, + int* outtypes, + size_t* outIDs, + const char** outdev_type, + int* outdev_id, + int num_out, + mxnet::ext::xpu_malloc_t cpu_malloc, void* cpu_alloc, - mxnet::ext::xpu_malloc_t gpu_malloc, void* gpu_alloc, + mxnet::ext::xpu_malloc_t gpu_malloc, + void* gpu_alloc, void* cuda_stream, - mxnet::ext::sparse_malloc_t sparse_malloc, void* sparse_alloc, - int* instypes, int* outstypes, void** in_indices, void** out_indices, - void** in_indptr, void** out_indptr, - int64_t* in_indices_shapes, int64_t* out_indices_shapes, - int64_t* in_indptr_shapes, int64_t* out_indptr_shapes, - void* rng_cpu_states, void* rng_gpu_states) { + mxnet::ext::sparse_malloc_t sparse_malloc, + void* sparse_alloc, + int* instypes, + int* outstypes, + void** in_indices, + void** out_indices, + void** in_indptr, + void** out_indptr, + int64_t* in_indices_shapes, + int64_t* out_indices_shapes, + int64_t* in_indptr_shapes, + int64_t* out_indptr_shapes, + void* rng_cpu_states, + void* rng_gpu_states) { // create map of attributes from list std::unordered_map attrs; for (int i = 0; i < num; i++) { @@ -1133,8 +1268,12 @@ MX_INT_RET _opCallFCompute(mxnet::ext::fcomp_t fcomp, const char* const* keys, for (int i = 0; i < num_in; i++) { // Dense representation. if (instypes[i] == 0) { - inputs[i].setTensor(indata[i], (mxnet::ext::MXDType)intypes[i], inshapes[i], indims[i], - inIDs[i], mxnet::ext::MXContext(indev_type[i], indev_id[i]), + inputs[i].setTensor(indata[i], + (mxnet::ext::MXDType)intypes[i], + inshapes[i], + indims[i], + inIDs[i], + mxnet::ext::MXContext(indev_type[i], indev_id[i]), mxnet::ext::kDefaultStorage); } else { // Sparse representation. @@ -1144,12 +1283,21 @@ MX_INT_RET _opCallFCompute(mxnet::ext::fcomp_t fcomp, const char* const* keys, in_sparse[i].set(indata[i], inshapes[i], indims[i], in_indices[i], in_indices_shapes[i]); } else { type = mxnet::ext::kCSRStorage; - in_sparse[i].set(indata[i], inshapes[i], indims[i], in_indices[i], - in_indices_shapes[i], in_indptr[i], in_indptr_shapes[i]); + in_sparse[i].set(indata[i], + inshapes[i], + indims[i], + in_indices[i], + in_indices_shapes[i], + in_indptr[i], + in_indptr_shapes[i]); } - inputs[i].setTensor(reinterpret_cast(&in_sparse[i]), (mxnet::ext::MXDType)intypes[i], - inshapes[i], indims[i], inIDs[i], - mxnet::ext::MXContext(indev_type[i], indev_id[i]), type); + inputs[i].setTensor(reinterpret_cast(&in_sparse[i]), + (mxnet::ext::MXDType)intypes[i], + inshapes[i], + indims[i], + inIDs[i], + mxnet::ext::MXContext(indev_type[i], indev_id[i]), + type); } } @@ -1160,38 +1308,59 @@ MX_INT_RET _opCallFCompute(mxnet::ext::fcomp_t fcomp, const char* const* keys, for (int i = 0; i < num_out; i++) { // Dense representation. if (outstypes[i] == 0) { - outputs[i].setTensor(outdata[i], (mxnet::ext::MXDType)outtypes[i], outshapes[i], outdims[i], - outIDs[i], mxnet::ext::MXContext(outdev_type[i], outdev_id[i]), + outputs[i].setTensor(outdata[i], + (mxnet::ext::MXDType)outtypes[i], + outshapes[i], + outdims[i], + outIDs[i], + mxnet::ext::MXContext(outdev_type[i], outdev_id[i]), mxnet::ext::kDefaultStorage); } else { // Sparse representation. mxnet::ext::MXStorageType type; if (outstypes[i] == 1) { type = mxnet::ext::kRowSparseStorage; - out_sparse[i].set(outdata[i], outshapes[i], outdims[i], - out_indices[i], out_indices_shapes[i]); + out_sparse[i].set( + outdata[i], outshapes[i], outdims[i], out_indices[i], out_indices_shapes[i]); } else { type = mxnet::ext::kCSRStorage; - out_sparse[i].set(outdata[i], outshapes[i], outdims[i], out_indices[i], - out_indices_shapes[i], out_indptr[i], out_indptr_shapes[i]); + out_sparse[i].set(outdata[i], + outshapes[i], + outdims[i], + out_indices[i], + out_indices_shapes[i], + out_indptr[i], + out_indptr_shapes[i]); } outputs[i].setTensor(reinterpret_cast(&out_sparse[i]), (mxnet::ext::MXDType)outtypes[i], - outshapes[i], outdims[i], outIDs[i], - mxnet::ext::MXContext(outdev_type[i], outdev_id[i]), type); + outshapes[i], + outdims[i], + outIDs[i], + mxnet::ext::MXContext(outdev_type[i], outdev_id[i]), + type); } } - mxnet::ext::OpResource res(cpu_malloc, cpu_alloc, gpu_malloc, gpu_alloc, - cuda_stream, sparse_malloc, sparse_alloc, - rng_cpu_states, rng_gpu_states); + mxnet::ext::OpResource res(cpu_malloc, + cpu_alloc, + gpu_malloc, + gpu_alloc, + cuda_stream, + sparse_malloc, + sparse_alloc, + rng_cpu_states, + rng_gpu_states); return fcomp(attrs, &inputs, &outputs, res); } /*! \brief returns status of calling mutateInputs function for operator from library */ -MX_INT_RET _opCallMutateInputs(mxnet::ext::mutateInputs_t mutate, const char* const* keys, - const char* const* vals, int num, - int** mutate_indices, int* indices_size) { +MX_INT_RET _opCallMutateInputs(mxnet::ext::mutateInputs_t mutate, + const char* const* keys, + const char* const* vals, + int num, + int** mutate_indices, + int* indices_size) { // create map of attributes from list std::unordered_map attrs; for (int i = 0; i < num; i++) { @@ -1206,8 +1375,8 @@ MX_INT_RET _opCallMutateInputs(mxnet::ext::mutateInputs_t mutate, const char* co return retval; // output the input indices - *indices_size = mut_ind.size(); - *mutate_indices = static_cast(malloc (*indices_size * sizeof(int))); + *indices_size = mut_ind.size(); + *mutate_indices = static_cast(malloc(*indices_size * sizeof(int))); for (int i = 0; i < *indices_size; i++) { (*mutate_indices)[i] = mut_ind[i]; } @@ -1216,10 +1385,17 @@ MX_INT_RET _opCallMutateInputs(mxnet::ext::mutateInputs_t mutate, const char* co } /*! \brief returns status of calling createStatefulOp function for operator from library */ -MX_INT_RET _opCallCreateOpState(mxnet::ext::createOpState_t create_op, const char* const* keys, - const char* const* vals, int num, const char* dev_type, - int dev_id, unsigned int** inshapes, int* indims, - int num_in, const int* intypes, void** state_op) { +MX_INT_RET _opCallCreateOpState(mxnet::ext::createOpState_t create_op, + const char* const* keys, + const char* const* vals, + int num, + const char* dev_type, + int dev_id, + unsigned int** inshapes, + int* indims, + int num_in, + const int* intypes, + void** state_op) { // create map of attributes from list std::unordered_map attrs; for (int i = 0; i < num; i++) { @@ -1245,34 +1421,54 @@ MX_INT_RET _opCallCreateOpState(mxnet::ext::createOpState_t create_op, const cha // void pointer to hold custom state op instance created in custom library // eventually state_op pointer is populated by instance from custom library mxnet::ext::CustomStatefulOp** op_ptr = - reinterpret_cast(state_op); + reinterpret_cast(state_op); return create_op(attrs, ctx, in_shapes, in_types, op_ptr); } /*! \brief calls StatefulOp destructor for operator from library */ MX_VOID_RET _opCallDestroyOpState(void* state_op) { - mxnet::ext::CustomStatefulOp* op_ptr = - reinterpret_cast(state_op); + mxnet::ext::CustomStatefulOp* op_ptr = reinterpret_cast(state_op); delete op_ptr; } /*! \brief returns status of calling Stateful Forward/Backward for operator from library */ -MX_INT_RET _opCallFStatefulCompute(int is_forward, void* state_op, const int64_t** inshapes, - int* indims, void** indata, int* intypes, size_t* inIDs, - const char** indev_type, int* indev_id, int num_in, - const int64_t** outshapes, int* outdims, void** outdata, - int* outtypes, size_t* outIDs, const char** outdev_type, - int* outdev_id, int num_out, +MX_INT_RET _opCallFStatefulCompute(int is_forward, + void* state_op, + const int64_t** inshapes, + int* indims, + void** indata, + int* intypes, + size_t* inIDs, + const char** indev_type, + int* indev_id, + int num_in, + const int64_t** outshapes, + int* outdims, + void** outdata, + int* outtypes, + size_t* outIDs, + const char** outdev_type, + int* outdev_id, + int num_out, mxnet::ext::xpu_malloc_t cpu_malloc, - void* cpu_alloc, mxnet::ext::xpu_malloc_t gpu_malloc, + void* cpu_alloc, + mxnet::ext::xpu_malloc_t gpu_malloc, void* gpu_alloc, - void* stream, mxnet::ext::sparse_malloc_t sparse_malloc, - void* sparse_alloc, int* instypes, int* outstypes, - void** in_indices, void** out_indices, void** in_indptr, - void** out_indptr, int64_t* in_indices_shapes, - int64_t* out_indices_shapes, int64_t* in_indptr_shapes, + void* stream, + mxnet::ext::sparse_malloc_t sparse_malloc, + void* sparse_alloc, + int* instypes, + int* outstypes, + void** in_indices, + void** out_indices, + void** in_indptr, + void** out_indptr, + int64_t* in_indices_shapes, + int64_t* out_indices_shapes, + int64_t* in_indptr_shapes, int64_t* out_indptr_shapes, - void* rng_cpu_states, void* rng_gpu_states) { + void* rng_cpu_states, + void* rng_gpu_states) { // create a vector of tensors for inputs std::vector inputs(num_in); // create a vector for sparse inputs @@ -1281,8 +1477,12 @@ MX_INT_RET _opCallFStatefulCompute(int is_forward, void* state_op, const int64_t for (int i = 0; i < num_in; i++) { if (instypes[i] == 0) { // Dense representation. - inputs[i].setTensor(indata[i], (mxnet::ext::MXDType)intypes[i], inshapes[i], indims[i], - inIDs[i], mxnet::ext::MXContext(indev_type[i], indev_id[i]), + inputs[i].setTensor(indata[i], + (mxnet::ext::MXDType)intypes[i], + inshapes[i], + indims[i], + inIDs[i], + mxnet::ext::MXContext(indev_type[i], indev_id[i]), mxnet::ext::kDefaultStorage); } else { // Sparse representation. @@ -1292,12 +1492,21 @@ MX_INT_RET _opCallFStatefulCompute(int is_forward, void* state_op, const int64_t in_sparse[i].set(indata[i], inshapes[i], indims[i], in_indices[i], in_indices_shapes[i]); } else { type = mxnet::ext::kCSRStorage; - in_sparse[i].set(indata[i], inshapes[i], indims[i], in_indices[i], - in_indices_shapes[i], in_indptr[i], in_indptr_shapes[i]); + in_sparse[i].set(indata[i], + inshapes[i], + indims[i], + in_indices[i], + in_indices_shapes[i], + in_indptr[i], + in_indptr_shapes[i]); } - inputs[i].setTensor(reinterpret_cast(&in_sparse[i]), (mxnet::ext::MXDType)intypes[i], - inshapes[i], indims[i], inIDs[i], - mxnet::ext::MXContext(indev_type[i], indev_id[i]), type); + inputs[i].setTensor(reinterpret_cast(&in_sparse[i]), + (mxnet::ext::MXDType)intypes[i], + inshapes[i], + indims[i], + inIDs[i], + mxnet::ext::MXContext(indev_type[i], indev_id[i]), + type); } } @@ -1309,33 +1518,51 @@ MX_INT_RET _opCallFStatefulCompute(int is_forward, void* state_op, const int64_t for (int i = 0; i < num_out; i++) { if (outstypes[i] == 0) { // Dense representation. - outputs[i].setTensor(outdata[i], (mxnet::ext::MXDType)outtypes[i], outshapes[i], outdims[i], - outIDs[i], mxnet::ext::MXContext(outdev_type[i], outdev_id[i]), + outputs[i].setTensor(outdata[i], + (mxnet::ext::MXDType)outtypes[i], + outshapes[i], + outdims[i], + outIDs[i], + mxnet::ext::MXContext(outdev_type[i], outdev_id[i]), mxnet::ext::kDefaultStorage); } else { // Sparse representation. mxnet::ext::MXStorageType type; if (outstypes[i] == 1) { type = mxnet::ext::kRowSparseStorage; - out_sparse[i].set(outdata[i], outshapes[i], outdims[i], out_indices[i], - out_indices_shapes[i]); + out_sparse[i].set( + outdata[i], outshapes[i], outdims[i], out_indices[i], out_indices_shapes[i]); } else { type = mxnet::ext::kCSRStorage; - out_sparse[i].set(outdata[i], outshapes[i], outdims[i], out_indices[i], - out_indices_shapes[i], out_indptr[i], out_indptr_shapes[i]); + out_sparse[i].set(outdata[i], + outshapes[i], + outdims[i], + out_indices[i], + out_indices_shapes[i], + out_indptr[i], + out_indptr_shapes[i]); } outputs[i].setTensor(reinterpret_cast(&out_sparse[i]), (mxnet::ext::MXDType)outtypes[i], - outshapes[i], outdims[i], outIDs[i], - mxnet::ext::MXContext(outdev_type[i], outdev_id[i]), type); + outshapes[i], + outdims[i], + outIDs[i], + mxnet::ext::MXContext(outdev_type[i], outdev_id[i]), + type); } } - mxnet::ext::OpResource res(cpu_malloc, cpu_alloc, gpu_malloc, gpu_alloc, - stream, sparse_malloc, sparse_alloc, rng_cpu_states, rng_gpu_states); - - mxnet::ext::CustomStatefulOp* op_ptr = - reinterpret_cast(state_op); + mxnet::ext::OpResource res(cpu_malloc, + cpu_alloc, + gpu_malloc, + gpu_alloc, + stream, + sparse_malloc, + sparse_alloc, + rng_cpu_states, + rng_gpu_states); + + mxnet::ext::CustomStatefulOp* op_ptr = reinterpret_cast(state_op); if (is_forward) { return op_ptr->Forward(&inputs, &outputs, res); } @@ -1351,30 +1578,37 @@ MX_INT_RET _partRegSize() { * at specified index */ MX_INT_RET _partRegGetCount(int idx, const char** name) { mxnet::ext::CustomPartitioner part = - mxnet::ext::Registry::get()->get(idx); + mxnet::ext::Registry::get()->get(idx); *name = part.name; return part.strategies.size(); } /*! \brief returns partitioner registration at specified index */ -MX_VOID_RET _partRegGet(int part_idx, int stg_idx, const char** strategy, +MX_VOID_RET _partRegGet(int part_idx, + int stg_idx, + const char** strategy, mxnet::ext::supportedOps_t* supportedOps, mxnet::ext::createSelector_t* createSelector, - mxnet::ext::reviewSubgraph_t* reviewSubgraph, const char** op_name) { + mxnet::ext::reviewSubgraph_t* reviewSubgraph, + const char** op_name) { mxnet::ext::CustomPartitioner part = - mxnet::ext::Registry::get()->get(part_idx); - *strategy = part.strategies[stg_idx]; - *op_name = part.op_names[stg_idx]; - *supportedOps = part.getSupportedOps(stg_idx); + mxnet::ext::Registry::get()->get(part_idx); + *strategy = part.strategies[stg_idx]; + *op_name = part.op_names[stg_idx]; + *supportedOps = part.getSupportedOps(stg_idx); *createSelector = part.getCreateSelector(stg_idx); *reviewSubgraph = part.getReviewSubgraph(stg_idx); } /*! \brief returns status of calling supported ops function from library */ -MX_INT_RET _partCallSupportedOps(mxnet::ext::supportedOps_t supportedOps, const char *json, - int num_ids, int *ids, const char* const* opt_keys, - const char* const* opt_vals, int num_opts) { - mxnet::ext::Graph *graph = mxnet::ext::Graph::fromString(json); +MX_INT_RET _partCallSupportedOps(mxnet::ext::supportedOps_t supportedOps, + const char* json, + int num_ids, + int* ids, + const char* const* opt_keys, + const char* const* opt_vals, + int num_opts) { + mxnet::ext::Graph* graph = mxnet::ext::Graph::fromString(json); // create map of options from list std::unordered_map opts; for (int i = 0; i < num_opts; i++) @@ -1384,7 +1618,8 @@ MX_INT_RET _partCallSupportedOps(mxnet::ext::supportedOps_t supportedOps, const std::vector _ids(num_ids, -2); // call user's supportedOps function mxnet::ext::MXReturnValue retval = supportedOps(graph, &_ids, opts); - if (!retval) return retval; + if (!retval) + return retval; // copy bools in ids to ints for (int i = 0; i < num_ids; i++) @@ -1394,10 +1629,13 @@ MX_INT_RET _partCallSupportedOps(mxnet::ext::supportedOps_t supportedOps, const } /*! \brief returns status of calling create selector function from library */ -MX_INT_RET _partCallCreateSelector(mxnet::ext::createSelector_t createSelector, const char *json, - void** selector, const char* const* opt_keys, - const char* const* opt_vals, int num_opts) { - mxnet::ext::Graph *graph = mxnet::ext::Graph::fromString(json); +MX_INT_RET _partCallCreateSelector(mxnet::ext::createSelector_t createSelector, + const char* json, + void** selector, + const char* const* opt_keys, + const char* const* opt_vals, + int num_opts) { + mxnet::ext::Graph* graph = mxnet::ext::Graph::fromString(json); // create map of options from list std::unordered_map opts; for (int i = 0; i < num_opts; i++) @@ -1406,7 +1644,7 @@ MX_INT_RET _partCallCreateSelector(mxnet::ext::createSelector_t createSelector, // void pointer to hold selector instance created in custom library // eventually pointer is populated by instance from custom library mxnet::ext::CustomOpSelector** sel_ptr = - reinterpret_cast(selector); + reinterpret_cast(selector); // call user's createSelector function return createSelector(graph, sel_ptr, opts); @@ -1414,34 +1652,31 @@ MX_INT_RET _partCallCreateSelector(mxnet::ext::createSelector_t createSelector, /*! \brief returns status of calling select function from library */ MX_VOID_RET _partCallSelect(void* sel_inst, int nodeID, int* selected) { - mxnet::ext::CustomOpSelector* sel_ptr = - reinterpret_cast(sel_inst); - *selected = sel_ptr->Select(nodeID); + mxnet::ext::CustomOpSelector* sel_ptr = reinterpret_cast(sel_inst); + *selected = sel_ptr->Select(nodeID); } /*! \brief returns status of calling select input function from library */ -MX_VOID_RET _partCallSelectInput(void* sel_inst, int nodeID, - int input_nodeID, int* selected) { - mxnet::ext::CustomOpSelector* sel_ptr = - reinterpret_cast(sel_inst); - *selected = sel_ptr->SelectInput(nodeID, input_nodeID); +MX_VOID_RET _partCallSelectInput(void* sel_inst, int nodeID, int input_nodeID, int* selected) { + mxnet::ext::CustomOpSelector* sel_ptr = reinterpret_cast(sel_inst); + *selected = sel_ptr->SelectInput(nodeID, input_nodeID); } /*! \brief returns status of calling select output function from library */ -MX_VOID_RET _partCallSelectOutput(void* sel_inst, int nodeID, - int output_nodeID, int* selected) { - mxnet::ext::CustomOpSelector* sel_ptr = - reinterpret_cast(sel_inst); - *selected = sel_ptr->SelectOutput(nodeID, output_nodeID); +MX_VOID_RET _partCallSelectOutput(void* sel_inst, int nodeID, int output_nodeID, int* selected) { + mxnet::ext::CustomOpSelector* sel_ptr = reinterpret_cast(sel_inst); + *selected = sel_ptr->SelectOutput(nodeID, output_nodeID); } /*! \brief returns status of calling filter function from library */ -MX_VOID_RET _partCallFilter(void* sel_inst, int* candidates, int num_candidates, - int** keep, int* num_keep) { - mxnet::ext::CustomOpSelector* sel_ptr = - reinterpret_cast(sel_inst); +MX_VOID_RET _partCallFilter(void* sel_inst, + int* candidates, + int num_candidates, + int** keep, + int* num_keep) { + mxnet::ext::CustomOpSelector* sel_ptr = reinterpret_cast(sel_inst); std::vector candidates_(num_candidates); - for (int i=0; i < num_candidates; i++) { + for (int i = 0; i < num_candidates; i++) { candidates_[i] = candidates[i]; } std::vector keep_; @@ -1449,35 +1684,48 @@ MX_VOID_RET _partCallFilter(void* sel_inst, int* candidates, int num_candidates, sel_ptr->Filter(candidates_, &keep_); *num_keep = keep_.size(); - *keep = static_cast(malloc(keep_.size() * sizeof(int))); - for (unsigned i=0; i < keep_.size(); i++) + *keep = static_cast(malloc(keep_.size() * sizeof(int))); + for (unsigned i = 0; i < keep_.size(); i++) (*keep)[i] = keep_[i]; } /*! \brief returns status of calling reset selector function from library */ MX_VOID_RET _partCallReset(void* sel_inst) { - mxnet::ext::CustomOpSelector* sel_ptr = - reinterpret_cast(sel_inst); - sel_ptr->Reset(); + mxnet::ext::CustomOpSelector* sel_ptr = reinterpret_cast(sel_inst); + sel_ptr->Reset(); } /*! \brief returns status of calling review subgraph function from library */ -MX_INT_RET _partCallReviewSubgraph(mxnet::ext::reviewSubgraph_t reviewSubgraph, const char *json, - int subgraph_id, int *accept, const char* const* opt_keys, - const char* const* opt_vals, int num_opts, - char*** attr_keys, char*** attr_vals, int *num_attrs, - const char* const* arg_names, int num_args, - void* const* arg_data, const int64_t* const* arg_shapes, - const int* arg_dims, const int* arg_types, - const size_t* arg_IDs, const char* const* arg_dev_type, +MX_INT_RET _partCallReviewSubgraph(mxnet::ext::reviewSubgraph_t reviewSubgraph, + const char* json, + int subgraph_id, + int* accept, + const char* const* opt_keys, + const char* const* opt_vals, + int num_opts, + char*** attr_keys, + char*** attr_vals, + int* num_attrs, + const char* const* arg_names, + int num_args, + void* const* arg_data, + const int64_t* const* arg_shapes, + const int* arg_dims, + const int* arg_types, + const size_t* arg_IDs, + const char* const* arg_dev_type, const int* arg_dev_id, - const char* const* aux_names, int num_aux, - void* const* aux_data, const int64_t* const* aux_shapes, - const int* aux_dims, const int* aux_types, - const size_t* aux_IDs, const char* const* aux_dev_type, + const char* const* aux_names, + int num_aux, + void* const* aux_data, + const int64_t* const* aux_shapes, + const int* aux_dims, + const int* aux_types, + const size_t* aux_IDs, + const char* const* aux_dev_type, const int* aux_dev_id) { - mxnet::ext::Graph *subgraph = mxnet::ext::Graph::fromString(json); - bool accept_bool = false; + mxnet::ext::Graph* subgraph = mxnet::ext::Graph::fromString(json); + bool accept_bool = false; // create map of attributes from list std::unordered_map opts; for (int i = 0; i < num_opts; i++) @@ -1491,8 +1739,11 @@ MX_INT_RET _partCallReviewSubgraph(mxnet::ext::reviewSubgraph_t reviewSubgraph, for (int j = 0; j < arg_dims[i]; j++) shapes.push_back(arg_shapes[i][j]); - mxnet::ext::MXTensor tensor(arg_data[i], shapes, (mxnet::ext::MXDType)arg_types[i], - arg_IDs[i], mxnet::ext::MXContext(arg_dev_type[i], arg_dev_id[i])); + mxnet::ext::MXTensor tensor(arg_data[i], + shapes, + (mxnet::ext::MXDType)arg_types[i], + arg_IDs[i], + mxnet::ext::MXContext(arg_dev_type[i], arg_dev_id[i])); args[arg_names[i]] = tensor; } // create a map of named tensors for aux @@ -1503,34 +1754,38 @@ MX_INT_RET _partCallReviewSubgraph(mxnet::ext::reviewSubgraph_t reviewSubgraph, for (int j = 0; j < aux_dims[i]; j++) shapes.push_back(aux_shapes[i][j]); - mxnet::ext::MXTensor tensor(aux_data[i], shapes, (mxnet::ext::MXDType)aux_types[i], - aux_IDs[i], mxnet::ext::MXContext(aux_dev_type[i], - aux_dev_id[i])); + mxnet::ext::MXTensor tensor(aux_data[i], + shapes, + (mxnet::ext::MXDType)aux_types[i], + aux_IDs[i], + mxnet::ext::MXContext(aux_dev_type[i], aux_dev_id[i])); aux[aux_names[i]] = tensor; } subgraph->_setParams(&args, &aux); std::unordered_map attrs; - mxnet::ext::MXReturnValue retval = reviewSubgraph(subgraph, subgraph_id, &accept_bool, - opts, &attrs); - if (!retval) return retval; + mxnet::ext::MXReturnValue retval = + reviewSubgraph(subgraph, subgraph_id, &accept_bool, opts, &attrs); + if (!retval) + return retval; *accept = accept_bool; if (attrs.size() > 0) { *num_attrs = attrs.size(); // allocate space for attributes - *attr_keys = static_cast(malloc (*num_attrs * sizeof(char*))); - *attr_vals = static_cast(malloc (*num_attrs * sizeof(char*))); + *attr_keys = static_cast(malloc(*num_attrs * sizeof(char*))); + *attr_vals = static_cast(malloc(*num_attrs * sizeof(char*))); // copy attributes int i = 0; for (auto kv : attrs) { - (*attr_keys)[i] = static_cast(malloc ((kv.first.size()+1) * sizeof(char))); // NOLINT - (*attr_vals)[i] = static_cast(malloc ((kv.second.size()+1) * sizeof(char))); // NOLINT - snprintf((*attr_keys)[i], kv.first.size()+1, "%s", kv.first.c_str()); - snprintf((*attr_vals)[i], kv.second.size()+1, "%s", kv.second.c_str()); + (*attr_keys)[i] = static_cast(malloc((kv.first.size() + 1) * sizeof(char))); // NOLINT + (*attr_vals)[i] = + static_cast(malloc((kv.second.size() + 1) * sizeof(char))); // NOLINT + snprintf((*attr_keys)[i], kv.first.size() + 1, "%s", kv.first.c_str()); + snprintf((*attr_vals)[i], kv.second.size() + 1, "%s", kv.second.c_str()); i++; } } @@ -1544,29 +1799,41 @@ MX_INT_RET _passRegSize() { } /*! \brief returns pass registration at specified index */ -MX_VOID_RET _passRegGet(int pass_idx, mxnet::ext::graphPass_t* graphPass, - const char** pass_name) { - mxnet::ext::CustomPass pass = - mxnet::ext::Registry::get()->get(pass_idx); - *graphPass = pass.pass; - *pass_name = pass.name; +MX_VOID_RET _passRegGet(int pass_idx, mxnet::ext::graphPass_t* graphPass, const char** pass_name) { + mxnet::ext::CustomPass pass = mxnet::ext::Registry::get()->get(pass_idx); + *graphPass = pass.pass; + *pass_name = pass.name; } /*! \brief returns status of calling graph pass function from library */ -MX_INT_RET _passCallGraphPass(mxnet::ext::graphPass_t graphPass, const char *json, - char** out_graph, const char* const* opt_keys, - const char* const* opt_vals, int num_opts, - const char* pass_name, const char* const* arg_names, int num_args, - void* const* arg_data, const int64_t* const* arg_shapes, - const int* arg_dims, const int* arg_types, - const size_t* arg_IDs, const char* const* arg_dev_type, - const int* arg_dev_id, const char* const* aux_names, int num_aux, - void* const* aux_data, const int64_t* const* aux_shapes, - const int* aux_dims, const int* aux_types, - const size_t* aux_IDs, const char* const* aux_dev_type, - const int* aux_dev_id, mxnet::ext::nd_malloc_t nd_malloc, +MX_INT_RET _passCallGraphPass(mxnet::ext::graphPass_t graphPass, + const char* json, + char** out_graph, + const char* const* opt_keys, + const char* const* opt_vals, + int num_opts, + const char* pass_name, + const char* const* arg_names, + int num_args, + void* const* arg_data, + const int64_t* const* arg_shapes, + const int* arg_dims, + const int* arg_types, + const size_t* arg_IDs, + const char* const* arg_dev_type, + const int* arg_dev_id, + const char* const* aux_names, + int num_aux, + void* const* aux_data, + const int64_t* const* aux_shapes, + const int* aux_dims, + const int* aux_types, + const size_t* aux_IDs, + const char* const* aux_dev_type, + const int* aux_dev_id, + mxnet::ext::nd_malloc_t nd_malloc, const void* nd_alloc) { - mxnet::ext::Graph *graph = mxnet::ext::Graph::fromString(json); + mxnet::ext::Graph* graph = mxnet::ext::Graph::fromString(json); // create map of attributes from list std::unordered_map opts; for (int i = 0; i < num_opts; i++) @@ -1580,9 +1847,11 @@ MX_INT_RET _passCallGraphPass(mxnet::ext::graphPass_t graphPass, const char *jso for (int j = 0; j < arg_dims[i]; j++) shapes.push_back(arg_shapes[i][j]); - mxnet::ext::MXTensor tensor(arg_data[i], shapes, (mxnet::ext::MXDType)arg_types[i], - arg_IDs[i], mxnet::ext::MXContext(arg_dev_type[i], - arg_dev_id[i])); + mxnet::ext::MXTensor tensor(arg_data[i], + shapes, + (mxnet::ext::MXDType)arg_types[i], + arg_IDs[i], + mxnet::ext::MXContext(arg_dev_type[i], arg_dev_id[i])); args[arg_names[i]] = tensor; } // create a map of named tensors for aux @@ -1593,9 +1862,11 @@ MX_INT_RET _passCallGraphPass(mxnet::ext::graphPass_t graphPass, const char *jso for (int j = 0; j < aux_dims[i]; j++) shapes.push_back(aux_shapes[i][j]); - mxnet::ext::MXTensor tensor(aux_data[i], shapes, (mxnet::ext::MXDType)aux_types[i], - aux_IDs[i], mxnet::ext::MXContext(aux_dev_type[i], - aux_dev_id[i])); + mxnet::ext::MXTensor tensor(aux_data[i], + shapes, + (mxnet::ext::MXDType)aux_types[i], + aux_IDs[i], + mxnet::ext::MXContext(aux_dev_type[i], aux_dev_id[i])); aux[aux_names[i]] = tensor; } @@ -1604,11 +1875,12 @@ MX_INT_RET _passCallGraphPass(mxnet::ext::graphPass_t graphPass, const char *jso graph->_setParams(&args, &aux); graph->_setPassResource(&res); mxnet::ext::MXReturnValue retval = graphPass(graph, opts); - if (!retval) return retval; + if (!retval) + return retval; std::string tmp = graph->toString(); - *out_graph = static_cast(malloc ((tmp.size()+1) * sizeof(char))); // NOLINT - snprintf((*out_graph), tmp.size()+1, "%s", tmp.c_str()); + *out_graph = static_cast(malloc((tmp.size() + 1) * sizeof(char))); // NOLINT + snprintf((*out_graph), tmp.size() + 1, "%s", tmp.c_str()); return retval; } @@ -1624,7 +1896,7 @@ __declspec(dllexport) mxnet::ext::MXReturnValue __cdecl #else mxnet::ext::MXReturnValue #endif -initialize(int version); + initialize(int version); MX_INT_RET _msgSize() { return mxnet::ext::MXerrorMsgs::get()->size(); diff --git a/src/libinfo.cc b/src/libinfo.cc index 27898adc96d5..292fbe615d09 100644 --- a/src/libinfo.cc +++ b/src/libinfo.cc @@ -33,8 +33,7 @@ namespace features { class FeatureSet { public: - FeatureSet() : - feature_bits() { + FeatureSet() : feature_bits() { // GPU feature_bits.set(CUDA, MXNET_USE_CUDA); feature_bits.set(CUDNN, MXNET_USE_CUDNN); @@ -115,50 +114,50 @@ bool is_enabled(const unsigned feat) { } LibInfo::LibInfo() { - for (size_t i = 0; i < MAX_FEATURES; ++i) { - m_lib_features[i].name = EnumNames::names[i].c_str(); - m_lib_features[i].enabled = is_enabled(i); - } + for (size_t i = 0; i < MAX_FEATURES; ++i) { + m_lib_features[i].name = EnumNames::names[i].c_str(); + m_lib_features[i].enabled = is_enabled(i); + } } -LibInfo *LibInfo::getInstance() { - if (!m_inst) - m_inst = std::make_unique(); - return m_inst.get(); +LibInfo* LibInfo::getInstance() { + if (!m_inst) + m_inst = std::make_unique(); + return m_inst.get(); } std::unique_ptr LibInfo::m_inst = nullptr; const std::vector EnumNames::names = { - "CUDA", - "CUDNN", - "NCCL", - "TENSORRT", - "CUTENSOR", - "CPU_SSE", - "CPU_SSE2", - "CPU_SSE3", - "CPU_SSE4_1", - "CPU_SSE4_2", - "CPU_SSE4A", - "CPU_AVX", - "CPU_AVX2", - "OPENMP", - "SSE", - "F16C", - "JEMALLOC", - "BLAS_OPEN", - "BLAS_ATLAS", - "BLAS_MKL", - "BLAS_APPLE", - "LAPACK", - "ONEDNN", - "OPENCV", - "DIST_KVSTORE", - "INT64_TENSOR_SIZE", - "SIGNAL_HANDLER", - "DEBUG", - "TVM_OP", + "CUDA", + "CUDNN", + "NCCL", + "TENSORRT", + "CUTENSOR", + "CPU_SSE", + "CPU_SSE2", + "CPU_SSE3", + "CPU_SSE4_1", + "CPU_SSE4_2", + "CPU_SSE4A", + "CPU_AVX", + "CPU_AVX2", + "OPENMP", + "SSE", + "F16C", + "JEMALLOC", + "BLAS_OPEN", + "BLAS_ATLAS", + "BLAS_MKL", + "BLAS_APPLE", + "LAPACK", + "ONEDNN", + "OPENCV", + "DIST_KVSTORE", + "INT64_TENSOR_SIZE", + "SIGNAL_HANDLER", + "DEBUG", + "TVM_OP", }; } // namespace features diff --git a/src/ndarray/ndarray_function-inl.h b/src/ndarray/ndarray_function-inl.h index beb40c73b8ad..fc20fb1c8841 100644 --- a/src/ndarray/ndarray_function-inl.h +++ b/src/ndarray/ndarray_function-inl.h @@ -31,37 +31,36 @@ // macro to help specialize evaluation function #ifndef DECL_TERNARY -#define DECL_TERNARY(XPU, OP, FUN) \ - template<> \ - void Eval(const TBlob &lhs, const TBlob &mhs, \ - const TBlob &rhs, TBlob *ret, RunContext ctx) { \ - FUN(lhs, mhs, rhs, ret, ctx); \ +#define DECL_TERNARY(XPU, OP, FUN) \ + template <> \ + void Eval( \ + const TBlob& lhs, const TBlob& mhs, const TBlob& rhs, TBlob* ret, RunContext ctx) { \ + FUN(lhs, mhs, rhs, ret, ctx); \ } #endif #ifndef DECL_BINARY #define DECL_BINARY(XPU, OP, FUN) \ - template<> \ - void Eval(const TBlob &lhs, const TBlob &rhs, TBlob *ret, RunContext ctx) { \ + template <> \ + void Eval(const TBlob& lhs, const TBlob& rhs, TBlob* ret, RunContext ctx) { \ FUN(lhs, rhs, ret, ctx); \ } #endif #ifndef DECL_BINARY_LAUNCH #define DECL_BINARY_LAUNCH(XPU, OP) \ - template <> \ - void BinaryOpKernelImpl(mshadow::Stream *s, \ - const TBlob& lhs, const TBlob& rhs, TBlob *out) { \ - BinaryOpKernelLaunch(s, lhs, rhs, out); \ + template <> \ + void BinaryOpKernelImpl( \ + mshadow::Stream * s, const TBlob& lhs, const TBlob& rhs, TBlob* out) { \ + BinaryOpKernelLaunch(s, lhs, rhs, out); \ } #endif #ifndef DECL_SCALAR -#define DECL_SCALAR(XPU, OP, FUN, REVERSE) \ - template<> \ - void Eval(const TBlob &lhs, const real_t &rhs, \ - TBlob *ret, RunContext ctx) { \ - FUN(lhs, rhs, ret, ctx); \ +#define DECL_SCALAR(XPU, OP, FUN, REVERSE) \ + template <> \ + void Eval(const TBlob& lhs, const real_t& rhs, TBlob* ret, RunContext ctx) { \ + FUN(lhs, rhs, ret, ctx); \ } #endif @@ -75,381 +74,348 @@ namespace mxnet { namespace ndarray { // true implementation -template -void EvalBinary_(const TBlob &lhs, const TBlob &rhs, - TBlob *ret, RunContext ctx) { +template +void EvalBinary_(const TBlob& lhs, const TBlob& rhs, TBlob* ret, RunContext ctx) { using namespace mshadow::expr; - mshadow::Stream *s = ctx.get_stream(); - CHECK_EQ(ret->type_flag_, lhs.type_flag_) - << "Only support input/output with the same data type"; - CHECK_EQ(ret->type_flag_, rhs.type_flag_) - << "Only support input/output with the same data type"; + mshadow::Stream* s = ctx.get_stream(); + CHECK_EQ(ret->type_flag_, lhs.type_flag_) << "Only support input/output with the same data type"; + CHECK_EQ(ret->type_flag_, rhs.type_flag_) << "Only support input/output with the same data type"; MSHADOW_TYPE_SWITCH(ret->type_flag_, DType, { - ret->FlatTo2D(s) - = F(lhs.FlatTo2D(s), - rhs.FlatTo2D(s)); + ret->FlatTo2D(s) = + F(lhs.FlatTo2D(s), rhs.FlatTo2D(s)); }); } -template -void EvalOneHot_(const TBlob &index, const TBlob &rhs, - TBlob *ret, RunContext ctx) { +template +void EvalOneHot_(const TBlob& index, const TBlob& rhs, TBlob* ret, RunContext ctx) { LOG(INFO) << "The operator onehot_encode is deprecated; use one_hot instead."; using namespace mshadow::expr; - mshadow::Stream *s = ctx.get_stream(); + mshadow::Stream* s = ctx.get_stream(); // TODO(eric): support mixed type encoding, i.e. int index and float rhs. CHECK_EQ(ret->type_flag_, mshadow::default_type_flag) - << "one_hot_encode only support float32 as input/output"; + << "one_hot_encode only support float32 as input/output"; CHECK_EQ(rhs.type_flag_, mshadow::default_type_flag) - << "one_hot_encode only support float32 as input/output"; + << "one_hot_encode only support float32 as input/output"; CHECK_EQ(index.type_flag_, mshadow::default_type_flag) - << "one_hot_encode only support float32 as input/output"; - ret->get(s) = - one_hot_encode(index.get(s), - rhs.shape_[1]); + << "one_hot_encode only support float32 as input/output"; + ret->get(s) = one_hot_encode(index.get(s), rhs.shape_[1]); } -template -void EvalMatChooseRowElem_(const TBlob &lhs, const TBlob &rhs, - TBlob *ret, RunContext ctx) { +template +void EvalMatChooseRowElem_(const TBlob& lhs, const TBlob& rhs, TBlob* ret, RunContext ctx) { using namespace mshadow::expr; - mshadow::Stream *s = ctx.get_stream(); + mshadow::Stream* s = ctx.get_stream(); // TODO(eric): support mixed type choose, i.e. int index and float rhs. CHECK_EQ(ret->type_flag_, mshadow::default_type_flag) - << "mat_choose_row_element only support float32 as input/output"; + << "mat_choose_row_element only support float32 as input/output"; CHECK_EQ(rhs.type_flag_, mshadow::default_type_flag) - << "mat_choose_row_element only support float32 as input/output"; + << "mat_choose_row_element only support float32 as input/output"; CHECK_EQ(lhs.type_flag_, mshadow::default_type_flag) - << "mat_choose_row_element only support float32 as input/output"; - ret->get(s) - = mat_choose_row_element(lhs.get(s), - rhs.get(s)); + << "mat_choose_row_element only support float32 as input/output"; + ret->get(s) = + mat_choose_row_element(lhs.get(s), rhs.get(s)); } -template -void EvalMatFillRowElem_(const TBlob &lhs, const TBlob &mhs, const TBlob &rhs, - TBlob *ret, RunContext ctx) { +template +void EvalMatFillRowElem_(const TBlob& lhs, + const TBlob& mhs, + const TBlob& rhs, + TBlob* ret, + RunContext ctx) { using namespace mshadow::expr; - mshadow::Stream *s = ctx.get_stream(); - ret->get(s) - = mat_fill_row_element(lhs.get(s), - mhs.get(s), - rhs.get(s)); + mshadow::Stream* s = ctx.get_stream(); + ret->get(s) = mat_fill_row_element( + lhs.get(s), mhs.get(s), rhs.get(s)); } -template -void EvalScalar_(const TBlob &lhs, const real_t &rhs, - TBlob *ret, RunContext ctx) { +template +void EvalScalar_(const TBlob& lhs, const real_t& rhs, TBlob* ret, RunContext ctx) { using namespace mshadow::expr; - mshadow::Stream *s = ctx.get_stream(); - CHECK_EQ(ret->type_flag_, lhs.type_flag_) - << "Only support input/output with the same data type"; + mshadow::Stream* s = ctx.get_stream(); + CHECK_EQ(ret->type_flag_, lhs.type_flag_) << "Only support input/output with the same data type"; if (reverse) { MSHADOW_TYPE_SWITCH(ret->type_flag_, DType, { - ret->FlatTo2D(s) - = F(scalar(DType(rhs)), lhs.FlatTo2D(s)); + ret->FlatTo2D(s) = + F(scalar(DType(rhs)), lhs.FlatTo2D(s)); }); } else { MSHADOW_TYPE_SWITCH(ret->type_flag_, DType, { - ret->FlatTo2D(s) - = F(lhs.FlatTo2D(s), scalar(DType(rhs))); + ret->FlatTo2D(s) = + F(lhs.FlatTo2D(s), scalar(DType(rhs))); }); } } -template<> -void EvalClip(const TBlob &src, const real_t &a_min, const real_t &a_max, - TBlob *ret, RunContext ctx) { +template <> +void EvalClip(const TBlob& src, + const real_t& a_min, + const real_t& a_max, + TBlob* ret, + RunContext ctx) { typedef DEVICE xpu; using namespace mshadow::expr; - mshadow::Stream *s = ctx.get_stream(); - CHECK_EQ(ret->type_flag_, src.type_flag_) - << "Only support input/output with the same data type"; + mshadow::Stream* s = ctx.get_stream(); + CHECK_EQ(ret->type_flag_, src.type_flag_) << "Only support input/output with the same data type"; MSHADOW_TYPE_SWITCH(ret->type_flag_, DType, { - ret->FlatTo2D(s) - = F( - F(src.FlatTo2D(s), scalar(DType(a_min))), - scalar(DType(a_max))); + ret->FlatTo2D(s) = F( + F(src.FlatTo2D(s), scalar(DType(a_min))), + scalar(DType(a_max))); }); } -template<> -void EvalRandom(const real_t &a, - const real_t &b, - const Resource &resource, - TBlob *ret, +template <> +void EvalRandom(const real_t& a, + const real_t& b, + const Resource& resource, + TBlob* ret, RunContext ctx) { typedef DEVICE xpu; - mshadow::Stream *s = ctx.get_stream(); + mshadow::Stream* s = ctx.get_stream(); switch (ret->type_flag_) { - case mshadow::kFloat32: - { - mshadow::Random *prnd = resource.get_random(s); + case mshadow::kFloat32: { + mshadow::Random* prnd = resource.get_random(s); mshadow::Tensor tmp = ret->FlatTo2D(s); prnd->SampleUniform(&tmp, float(a), float(b)); // NOLINT(*) break; } - case mshadow::kFloat64: - { - mshadow::Random *prnd = resource.get_random(s); + case mshadow::kFloat64: { + mshadow::Random* prnd = resource.get_random(s); mshadow::Tensor tmp = ret->FlatTo2D(s); prnd->SampleUniform(&tmp, double(a), double(b)); // NOLINT(*) break; } - default: - LOG(FATAL) << "Random only support float32 and float64"; + default: + LOG(FATAL) << "Random only support float32 and float64"; } } -template<> -void EvalRandom( - const real_t &mu, - const real_t &sigma, - const Resource &resource, - TBlob *ret, - RunContext ctx) { +template <> +void EvalRandom(const real_t& mu, + const real_t& sigma, + const Resource& resource, + TBlob* ret, + RunContext ctx) { typedef DEVICE xpu; - mshadow::Stream *s = ctx.get_stream(); + mshadow::Stream* s = ctx.get_stream(); switch (ret->type_flag_) { - case mshadow::kFloat32: - { - mshadow::Random *prnd = resource.get_random(s); + case mshadow::kFloat32: { + mshadow::Random* prnd = resource.get_random(s); mshadow::Tensor tmp = ret->FlatTo2D(s); prnd->SampleGaussian(&tmp, float(mu), float(sigma)); // NOLINT(*) break; } - case mshadow::kFloat64: - { - mshadow::Random *prnd = resource.get_random(s); + case mshadow::kFloat64: { + mshadow::Random* prnd = resource.get_random(s); mshadow::Tensor tmp = ret->FlatTo2D(s); prnd->SampleGaussian(&tmp, double(mu), double(sigma)); // NOLINT(*) break; } - default: - LOG(FATAL) << "Random only support float32 and float64"; + default: + LOG(FATAL) << "Random only support float32 and float64"; } } #if defined(__CUDACC__) -template<> -void EvalRandom( - const real_t &alpha, - const real_t &beta, - const Resource &resource, - TBlob *ret, - RunContext ctx) { +template <> +void EvalRandom(const real_t& alpha, + const real_t& beta, + const Resource& resource, + TBlob* ret, + RunContext ctx) { EvalRandom(alpha, beta, resource, ret, ctx); } -template<> +template <> void EvalRandom( - const real_t &lambda, - const real_t &dummy, // this is to satisfy the SampleOp lambda signature - const Resource &resource, - TBlob *ret, + const real_t& lambda, + const real_t& dummy, // this is to satisfy the SampleOp lambda signature + const Resource& resource, + TBlob* ret, RunContext ctx) { EvalRandom(lambda, dummy, resource, ret, ctx); } -template<> +template <> void EvalRandom( - const real_t &lambda, - const real_t &dummy, // this is to satisfy the SampleOp lambda signature - const Resource &resource, - TBlob *ret, + const real_t& lambda, + const real_t& dummy, // this is to satisfy the SampleOp lambda signature + const Resource& resource, + TBlob* ret, RunContext ctx) { EvalRandom(lambda, dummy, resource, ret, ctx); } -template<> -void EvalRandom( - const real_t &k, - const real_t &p, - const Resource &resource, - TBlob *ret, - RunContext ctx) { +template <> +void EvalRandom(const real_t& k, + const real_t& p, + const Resource& resource, + TBlob* ret, + RunContext ctx) { EvalRandom(k, p, resource, ret, ctx); } -template<> -void EvalRandom( - const real_t &mu, - const real_t &alpha, - const Resource &resource, - TBlob *ret, - RunContext ctx) { +template <> +void EvalRandom(const real_t& mu, + const real_t& alpha, + const Resource& resource, + TBlob* ret, + RunContext ctx) { EvalRandom(mu, alpha, resource, ret, ctx); } #else -template<> -void EvalRandom( - const real_t &alpha, - const real_t &beta, - const Resource &resource, - TBlob *ret, - RunContext ctx) { +template <> +void EvalRandom(const real_t& alpha, + const real_t& beta, + const Resource& resource, + TBlob* ret, + RunContext ctx) { typedef cpu xpu; // No support for gpu for this distribution. - mshadow::Stream *s = ctx.get_stream(); + mshadow::Stream* s = ctx.get_stream(); switch (ret->type_flag_) { - case mshadow::kFloat32: - { - mshadow::Random *prnd = resource.get_random(s); + case mshadow::kFloat32: { + mshadow::Random* prnd = resource.get_random(s); mshadow::Tensor tmp = ret->FlatTo2D(s); prnd->SampleGamma(&tmp, float(alpha), float(beta)); // NOLINT(*) break; } - case mshadow::kFloat64: - { - mshadow::Random *prnd = resource.get_random(s); + case mshadow::kFloat64: { + mshadow::Random* prnd = resource.get_random(s); mshadow::Tensor tmp = ret->FlatTo2D(s); prnd->SampleGamma(&tmp, double(alpha), double(beta)); // NOLINT(*) break; } - default: - LOG(FATAL) << "Random only support float32 and float64"; + default: + LOG(FATAL) << "Random only support float32 and float64"; } } -template<> +template <> void EvalRandom( - const real_t &lambda, - const real_t &dummy, // this is to satisfy the SampleOp lambda signature - const Resource &resource, - TBlob *ret, + const real_t& lambda, + const real_t& dummy, // this is to satisfy the SampleOp lambda signature + const Resource& resource, + TBlob* ret, RunContext ctx) { typedef cpu xpu; // No support for gpu for this distribution. - mshadow::Stream *s = ctx.get_stream(); + mshadow::Stream* s = ctx.get_stream(); switch (ret->type_flag_) { - case mshadow::kFloat32: - { - mshadow::Random *prnd = resource.get_random(s); + case mshadow::kFloat32: { + mshadow::Random* prnd = resource.get_random(s); mshadow::Tensor tmp = ret->FlatTo2D(s); prnd->SampleExponential(&tmp, float(lambda)); // NOLINT(*) break; } - case mshadow::kFloat64: - { - mshadow::Random *prnd = resource.get_random(s); + case mshadow::kFloat64: { + mshadow::Random* prnd = resource.get_random(s); mshadow::Tensor tmp = ret->FlatTo2D(s); prnd->SampleExponential(&tmp, double(lambda)); // NOLINT(*) break; } - default: - LOG(FATAL) << "Random only support float32 and float64"; + default: + LOG(FATAL) << "Random only support float32 and float64"; } } -template<> +template <> void EvalRandom( - const real_t &lambda, - const real_t &dummy, // this is to satisfy the SampleOp lambda signature - const Resource &resource, - TBlob *ret, + const real_t& lambda, + const real_t& dummy, // this is to satisfy the SampleOp lambda signature + const Resource& resource, + TBlob* ret, RunContext ctx) { typedef cpu xpu; // No support for gpu for this distribution. - mshadow::Stream *s = ctx.get_stream(); + mshadow::Stream* s = ctx.get_stream(); switch (ret->type_flag_) { - case mshadow::kFloat32: - { - mshadow::Random *prnd = resource.get_random(s); + case mshadow::kFloat32: { + mshadow::Random* prnd = resource.get_random(s); mshadow::Tensor tmp = ret->FlatTo2D(s); prnd->SamplePoisson(&tmp, float(lambda)); // NOLINT(*) break; } - case mshadow::kFloat64: - { - mshadow::Random *prnd = resource.get_random(s); + case mshadow::kFloat64: { + mshadow::Random* prnd = resource.get_random(s); mshadow::Tensor tmp = ret->FlatTo2D(s); prnd->SamplePoisson(&tmp, double(lambda)); // NOLINT(*) break; } - default: - LOG(FATAL) << "Random only support float32 and float64"; + default: + LOG(FATAL) << "Random only support float32 and float64"; } } -template<> -void EvalRandom( - const real_t &k, - const real_t &p, - const Resource &resource, - TBlob *ret, - RunContext ctx) { +template <> +void EvalRandom(const real_t& k, + const real_t& p, + const Resource& resource, + TBlob* ret, + RunContext ctx) { typedef cpu xpu; // No support for gpu for this distribution. - mshadow::Stream *s = ctx.get_stream(); + mshadow::Stream* s = ctx.get_stream(); switch (ret->type_flag_) { - case mshadow::kFloat32: - { - mshadow::Random *prnd = resource.get_random(s); + case mshadow::kFloat32: { + mshadow::Random* prnd = resource.get_random(s); mshadow::Tensor tmp = ret->FlatTo2D(s); prnd->SampleNegativeBinomial(&tmp, float(k), float(p)); // NOLINT(*) break; } - case mshadow::kFloat64: - { - mshadow::Random *prnd = resource.get_random(s); + case mshadow::kFloat64: { + mshadow::Random* prnd = resource.get_random(s); mshadow::Tensor tmp = ret->FlatTo2D(s); prnd->SampleNegativeBinomial(&tmp, double(k), double(p)); // NOLINT(*) break; } - default: - LOG(FATAL) << "Random only support float32 and float64"; + default: + LOG(FATAL) << "Random only support float32 and float64"; } } -template<> -void EvalRandom( - const real_t &mu, - const real_t &alpha, - const Resource &resource, - TBlob *ret, - RunContext ctx) { +template <> +void EvalRandom(const real_t& mu, + const real_t& alpha, + const Resource& resource, + TBlob* ret, + RunContext ctx) { typedef cpu xpu; // No support for gpu for this distribution. - mshadow::Stream *s = ctx.get_stream(); + mshadow::Stream* s = ctx.get_stream(); switch (ret->type_flag_) { - case mshadow::kFloat32: - { - mshadow::Random *prnd = resource.get_random(s); + case mshadow::kFloat32: { + mshadow::Random* prnd = resource.get_random(s); mshadow::Tensor tmp = ret->FlatTo2D(s); prnd->SampleGeneralizedNegativeBinomial(&tmp, float(mu), float(alpha)); // NOLINT(*) break; } - case mshadow::kFloat64: - { - mshadow::Random *prnd = resource.get_random(s); + case mshadow::kFloat64: { + mshadow::Random* prnd = resource.get_random(s); mshadow::Tensor tmp = ret->FlatTo2D(s); prnd->SampleGeneralizedNegativeBinomial(&tmp, double(mu), double(alpha)); // NOLINT(*) break; } - default: - LOG(FATAL) << "Random only support float32 and float64"; + default: + LOG(FATAL) << "Random only support float32 and float64"; } } #endif // #ifndef __CUDACC__ -template<> -void Eval(const real_t &rhs, TBlob *ret, RunContext ctx) { - mshadow::Stream *s = ctx.get_stream(); - MSHADOW_TYPE_SWITCH_WITH_BOOL(ret->type_flag_, DType, { - ret->FlatTo2D(s) = DType(rhs); - }); +template <> +void Eval(const real_t& rhs, TBlob* ret, RunContext ctx) { + mshadow::Stream* s = ctx.get_stream(); + MSHADOW_TYPE_SWITCH_WITH_BOOL( + ret->type_flag_, DType, { ret->FlatTo2D(s) = DType(rhs); }); } -template<> -void ElementwiseSum(const std::vector source, - TBlob *dst, - RunContext ctx) { +template <> +void ElementwiseSum(const std::vector source, TBlob* dst, RunContext ctx) { typedef DEVICE xpu; using namespace mshadow; using namespace mshadow::expr; - Stream *s = ctx.get_stream(); + Stream* s = ctx.get_stream(); for (size_t i = 1; i < source.size(); ++i) { CHECK_EQ(source[i].type_flag_, dst->type_flag_) - << "Only support input/output with the same data type"; + << "Only support input/output with the same data type"; } MSHADOW_TYPE_SWITCH(dst->type_flag_, DType, { Tensor out = dst->FlatTo2D(s); @@ -458,14 +424,14 @@ void ElementwiseSum(const std::vector source, case 2: { Tensor in_0 = source[0].FlatTo2D(s); Tensor in_1 = source[1].FlatTo2D(s); - out = in_0 + in_1; + out = in_0 + in_1; break; } case 3: { Tensor in_0 = source[0].FlatTo2D(s); Tensor in_1 = source[1].FlatTo2D(s); Tensor in_2 = source[2].FlatTo2D(s); - out = in_0 + in_1 + in_2; + out = in_0 + in_1 + in_2; break; } case 4: { @@ -473,12 +439,12 @@ void ElementwiseSum(const std::vector source, Tensor in_1 = source[1].FlatTo2D(s); Tensor in_2 = source[2].FlatTo2D(s); Tensor in_3 = source[3].FlatTo2D(s); - out = in_0 + in_1 + in_2 + in_3; + out = in_0 + in_1 + in_2 + in_3; break; } default: { Tensor in_0 = source[0].FlatTo2D(s); - out = F(in_0); + out = F(in_0); for (size_t i = 1; i < source.size(); ++i) { out += source[i].FlatTo2D(s); } @@ -491,23 +457,19 @@ void ElementwiseSum(const std::vector source, template <> void EvalBroadcast(TBlob const& src, TBlob* ret, int size, RunContext ctx) { typedef DEVICE xpu; - mshadow::Stream* s = ctx.get_stream(); + mshadow::Stream* s = ctx.get_stream(); mshadow::Tensor out = ret->get(s); - mshadow::Tensor in = src.get(s); - out = mshadow::expr::broadcast_with_axis(in, 0, size); + mshadow::Tensor in = src.get(s); + out = mshadow::expr::broadcast_with_axis(in, 0, size); } -template -void BinaryOpKernelLaunch(mshadow::Stream* s, const TBlob& lhs, const TBlob& rhs, TBlob *out) { +template +void BinaryOpKernelLaunch(mshadow::Stream* s, const TBlob& lhs, const TBlob& rhs, TBlob* out) { using namespace op::mxnet_op; using namespace mshadow; MSHADOW_TYPE_SWITCH(out->type_flag_, DType, { - Kernel, xpu >:: - Launch(s, - lhs.Size(), - out->dptr(), - lhs.dptr(), - rhs.dptr()); + Kernel, xpu>::Launch( + s, lhs.Size(), out->dptr(), lhs.dptr(), rhs.dptr()); }); } // declarations diff --git a/src/ndarray/ndarray_function.cc b/src/ndarray/ndarray_function.cc index e0a445814314..c1a9a54ac05b 100644 --- a/src/ndarray/ndarray_function.cc +++ b/src/ndarray/ndarray_function.cc @@ -32,31 +32,32 @@ namespace mxnet { namespace ndarray { -template<> -void Copy(const TBlob &from, TBlob *to, - Context from_ctx, Context to_ctx, +template <> +void Copy(const TBlob& from, + TBlob* to, + Context from_ctx, + Context to_ctx, RunContext ctx) { MSHADOW_TYPE_SWITCH_EXT_WITH_BOOL(to->type_flag_, DType, { if (to->type_flag_ == from.type_flag_) { if (!features::is_enabled(features::INT64_TENSOR_SIZE)) { - CHECK_LT(from.Size(), (int64_t{1} << 31) - 1) << - "Size of tensor you are trying to allocate is larger than " - "2^31 elements. Please build with flag USE_INT64_TENSOR_SIZE=1"; + CHECK_LT(from.Size(), (int64_t{1} << 31) - 1) + << "Size of tensor you are trying to allocate is larger than " + "2^31 elements. Please build with flag USE_INT64_TENSOR_SIZE=1"; } const index_t size = static_cast(from.Size()); CHECK_EQ(size, to->Size()) << "copying size mismatch, from: " << size * sizeof(DType) - << " bytes, to: " << to->Size() * sizeof(DType) << " bytes."; + << " bytes, to: " << to->Size() * sizeof(DType) << " bytes."; common::ParallelCopy(to->dptr(), from.dptr(), size); } else { MSHADOW_TYPE_SWITCH_EXT_WITH_BOOL(from.type_flag_, SrcDType, { - to->FlatTo1D() = - mshadow::expr::tcast(from.FlatTo1D()); + to->FlatTo1D() = mshadow::expr::tcast(from.FlatTo1D()); }) } }) } -template +template void ElementwiseSumRspImpl(mshadow::Stream* s, const std::vector& nds, const std::vector& uniq_row_idx, @@ -64,15 +65,15 @@ void ElementwiseSumRspImpl(mshadow::Stream* s, const int nthreads = 4) { #pragma omp parallel num_threads(nthreads) { - const size_t nnr = uniq_row_idx.size(); - const int num_threads = omp_get_num_threads(); - size_t row_block_len = (nnr + num_threads - 1) / num_threads; + const size_t nnr = uniq_row_idx.size(); + const int num_threads = omp_get_num_threads(); + size_t row_block_len = (nnr + num_threads - 1) / num_threads; const size_t row_block_start = omp_get_thread_num() * row_block_len; if (row_block_start < nnr) { - const size_t row_block_end = std::min(row_block_start+row_block_len, nnr); + const size_t row_block_end = std::min(row_block_start + row_block_len, nnr); const size_t row_length = out->data().shape_.ProdShape(1, out->data().shape_.ndim()); - auto out_values = out->data().get_with_shape( + auto out_values = out->data().get_with_shape( mshadow::Shape2(out->storage_shape()[0], row_length), s); auto out_indices = out->aux_data(rowsparse::kIdx).FlatTo1D(); for (size_t i = row_block_start; i < row_block_end; ++i) { @@ -81,24 +82,24 @@ void ElementwiseSumRspImpl(mshadow::Stream* s, for (const auto& nd : nds) { if (nd.storage_initialized()) { const auto nd_indices = nd.aux_data(rowsparse::kIdx).FlatTo1D(); - const auto nd_values = nd.data().get_with_shape( + const auto nd_values = nd.data().get_with_shape( mshadow::Shape2(nd.storage_shape()[0], row_length), s); - const auto nd_num_rows = nd.aux_shape(rowsparse::kIdx).Size(); + const auto nd_num_rows = nd.aux_shape(rowsparse::kIdx).Size(); const IType* nd_indices_start = &nd_indices[0]; - const IType* nd_indices_end = nd_indices_start + nd_num_rows; - const IType* row_idx_ptr = std::lower_bound(nd_indices_start, nd_indices_end, - out_indices[row_block_start]); + const IType* nd_indices_end = nd_indices_start + nd_num_rows; + const IType* row_idx_ptr = + std::lower_bound(nd_indices_start, nd_indices_end, out_indices[row_block_start]); // skip this nd if all of its row indices are smaller than out_indices[row_block_start] // or current row block is not covered by [*row_idx_ptr, nd_indices_end). - if (nd_indices_end == row_idx_ptr || *row_idx_ptr > out_indices[row_block_end-1]) { + if (nd_indices_end == row_idx_ptr || *row_idx_ptr > out_indices[row_block_end - 1]) { continue; } for (size_t irow = row_block_start; irow < row_block_end && row_idx_ptr != nd_indices_end;) { if (out_indices[irow] == *row_idx_ptr) { auto out_value_cur_row = out_values[irow]; - const auto offset = row_idx_ptr - nd_indices_start; - auto nd_value_cur_row = nd_values[offset]; + const auto offset = row_idx_ptr - nd_indices_start; + auto nd_value_cur_row = nd_values[offset]; for (index_t j = 0; j < nd_value_cur_row.shape_[0]; ++j) { out_value_cur_row[j] += nd_value_cur_row[j]; } @@ -120,9 +121,8 @@ void ElementwiseSumRspImpl(mshadow::Stream* s, * \brief Given a vector of ndarrays, generate a index vector containing * all the unique row indices of the ndarrays. */ -template -void GetUniqueRspRowIdx(const std::vector& nds, - std::vector* uniq_row_idx) { +template +void GetUniqueRspRowIdx(const std::vector& nds, std::vector* uniq_row_idx) { using namespace rowsparse; size_t total_num_rows = 0; for (const auto& nd : nds) { @@ -134,14 +134,14 @@ void GetUniqueRspRowIdx(const std::vector& nds, uniq_row_idx->resize(total_num_rows); int nthreads = omp_get_max_threads(); - int offset = 0; + int offset = 0; for (const auto& nd : nds) { if (nd.storage_initialized()) { const IType* nd_row_idx = nd.aux_data(kIdx).dptr(); - const int num_rows = nd.aux_shape(kIdx).Size(); + const int num_rows = nd.aux_shape(kIdx).Size(); #pragma omp parallel for num_threads(nthreads) for (int i = 0; i < num_rows; ++i) { - (*uniq_row_idx)[offset+i] = nd_row_idx[i]; + (*uniq_row_idx)[offset + i] = nd_row_idx[i]; } offset += num_rows; } @@ -156,11 +156,11 @@ void ElementwiseSumRsp(mshadow::Stream* s, const Resource& rsc, const std::vector& nds, NDArray* out) { - if (nds.empty()) return; + if (nds.empty()) + return; using namespace rowsparse; CHECK_EQ(out->storage_type(), kRowSparseStorage) - << "Expected row sparse storage type (" - << out->storage_type() << " given)"; + << "Expected row sparse storage type (" << out->storage_type() << " given)"; MSHADOW_TYPE_SWITCH(out->dtype(), DType, { MSHADOW_IDX_TYPE_SWITCH(out->aux_type(kIdx), IType, { @@ -183,21 +183,30 @@ void ElementwiseSumDnsCsrDnsImpl(mshadow::Stream* s, using namespace mxnet::op::mxnet_op; const TBlob& out_data = out->data(); MSHADOW_TYPE_SWITCH(out->dtype(), DType, { // data type - Kernel::Launch( - s, out_data.Size(), out_data.dptr(), kWriteTo, nds[0].data().dptr(), - nds[2].data().dptr()); - const TBlob& csr_data = nds[1].data(); - const TBlob& csr_indices = nds[1].aux_data(csr::kIdx); - const TBlob& csr_indptr = nds[1].aux_data(csr::kIndPtr); + Kernel::Launch(s, + out_data.Size(), + out_data.dptr(), + kWriteTo, + nds[0].data().dptr(), + nds[2].data().dptr()); + const TBlob& csr_data = nds[1].data(); + const TBlob& csr_indices = nds[1].aux_data(csr::kIdx); + const TBlob& csr_indptr = nds[1].aux_data(csr::kIndPtr); const nnvm::dim_t num_rows = nds[1].shape()[0]; const nnvm::dim_t num_cols = nds[1].shape()[1]; - MSHADOW_IDX_TYPE_SWITCH(csr_indices.type_flag_, IType, { // indices type + MSHADOW_IDX_TYPE_SWITCH(csr_indices.type_flag_, IType, { // indices type MSHADOW_IDX_TYPE_SWITCH(csr_indptr.type_flag_, CType, { // indptr type if (nds[1].storage_initialized()) { Kernel, cpu>::Launch( - s, num_rows, out_data.dptr(), out_data.dptr(), - csr_data.dptr(), csr_indices.dptr(), - csr_indptr.dptr(), num_rows, num_cols); + s, + num_rows, + out_data.dptr(), + out_data.dptr(), + csr_data.dptr(), + csr_indices.dptr(), + csr_indptr.dptr(), + num_rows, + num_cols); } }); }); @@ -218,13 +227,13 @@ void ElementwiseSumContainsDnsImpl(mshadow::Stream* s, Kernel::Launch(s, out_data.Size(), out_data.dptr()); } for (size_t i = 0; i < nds.size(); ++i) { - const NDArray& nd = nds[i]; + const NDArray& nd = nds[i]; const TBlob& nd_data = nd.data(); if (i == 0) { if (nd.storage_type() == kDefaultStorage) { Kernel, cpu>::Launch( - s, out_data.Size(), out_data.dptr(), nd_data.dptr()); + s, out_data.Size(), out_data.dptr(), nd_data.dptr()); continue; } else { Kernel::Launch(s, out_data.Size(), out_data.dptr()); @@ -233,39 +242,53 @@ void ElementwiseSumContainsDnsImpl(mshadow::Stream* s, switch (nd.storage_type()) { case kDefaultStorage: { - Kernel, cpu>::Launch( - s, out_data.Size(), out_data.dptr(), out_data.dptr(), - nd_data.dptr()); + Kernel, cpu>::Launch(s, + out_data.Size(), + out_data.dptr(), + out_data.dptr(), + nd_data.dptr()); break; } case kCSRStorage: { - const TBlob& nd_indices = nd.aux_data(csr::kIdx); - const TBlob& nd_indptr = nd.aux_data(csr::kIndPtr); + const TBlob& nd_indices = nd.aux_data(csr::kIdx); + const TBlob& nd_indptr = nd.aux_data(csr::kIndPtr); const nnvm::dim_t num_rows = nd.shape()[0]; const nnvm::dim_t num_cols = nd.shape()[1]; - MSHADOW_IDX_TYPE_SWITCH(nd_indices.type_flag_, IType, { // indices type + MSHADOW_IDX_TYPE_SWITCH(nd_indices.type_flag_, IType, { // indices type MSHADOW_IDX_TYPE_SWITCH(nd_indptr.type_flag_, CType, { // indptr type if (nd.storage_initialized()) { Kernel, cpu>::Launch( - s, num_rows, out_data.dptr(), out_data.dptr(), - nd_data.dptr(), nd_indices.dptr(), - nd_indptr.dptr(), num_rows, num_cols); + s, + num_rows, + out_data.dptr(), + out_data.dptr(), + nd_data.dptr(), + nd_indices.dptr(), + nd_indptr.dptr(), + num_rows, + num_cols); } }); }); break; } case kRowSparseStorage: { - const TBlob& nd_indices = nd.aux_data(rowsparse::kIdx); + const TBlob& nd_indices = nd.aux_data(rowsparse::kIdx); const nnvm::dim_t num_rows = nd.shape()[0]; const nnvm::dim_t num_cols = nd.shape()[1]; MSHADOW_IDX_TYPE_SWITCH(nd_indices.type_flag_, IType, { // indices type if (nd.storage_initialized()) { const nnvm::dim_t nz_rows = nd_indices.Size(); Kernel, cpu>::Launch( - s, nz_rows * num_cols, out_data.dptr(), - out_data.dptr(), nd_data.dptr(), nd_indices.dptr(), - num_rows, nz_rows, num_cols); + s, + nz_rows * num_cols, + out_data.dptr(), + out_data.dptr(), + nd_data.dptr(), + nd_indices.dptr(), + num_rows, + nz_rows, + num_cols); } }); break; @@ -281,12 +304,13 @@ void ElementwiseSumContainsDnsImpl(mshadow::Stream* s, * \brief Parallel cpu impl of elemwise sum for sparse tensors. * Currently only support row sparse sum. */ -template<> +template <> void ElementwiseSum(mshadow::Stream* s, const Resource& rsc, const std::vector& nds, NDArray* out) { - if (nds.empty()) return; + if (nds.empty()) + return; if (common::ContainsOnlyStorage(nds, kRowSparseStorage)) { ElementwiseSumRsp(s, rsc, nds, out); } else if (nds.size() == 3U && nds[0].storage_type() == kDefaultStorage && @@ -302,11 +326,9 @@ void ElementwiseSum(mshadow::Stream* s, } } - -template<> -void Eval(mshadow::Stream *s, - const real_t val, const NDArray& dst) { - NDArray temp = dst; +template <> +void Eval(mshadow::Stream* s, const real_t val, const NDArray& dst) { + NDArray temp = dst; const NDArrayStorageType stype = temp.storage_type(); if (stype == kRowSparseStorage) { SetValueRspImpl(s, val, &temp); diff --git a/src/ndarray/ndarray_function.cu b/src/ndarray/ndarray_function.cu index e00b4c3f948e..f6189f939131 100644 --- a/src/ndarray/ndarray_function.cu +++ b/src/ndarray/ndarray_function.cu @@ -38,56 +38,56 @@ namespace mxnet { namespace ndarray { -template<> -void Copy(const TBlob &from, TBlob *to, - Context from_ctx, Context to_ctx, +template <> +void Copy(const TBlob& from, + TBlob* to, + Context from_ctx, + Context to_ctx, RunContext ctx) { CHECK_EQ(to->type_flag_, from.type_flag_) - << "Source and target must have the same data type when copying across devices."; + << "Source and target must have the same data type when copying across devices."; MSHADOW_TYPE_SWITCH_WITH_BOOL(to->type_flag_, DType, { - mshadow::Copy(to->FlatTo1D(), - from.FlatTo1D(), - ctx.get_stream()); + mshadow::Copy(to->FlatTo1D(), from.FlatTo1D(), ctx.get_stream()); }); } -template<> -void Copy(const TBlob &from, TBlob *to, - Context from_ctx, Context to_ctx, +template <> +void Copy(const TBlob& from, + TBlob* to, + Context from_ctx, + Context to_ctx, RunContext ctx) { CHECK_EQ(to->type_flag_, from.type_flag_) - << "Source and target must have the same data type when copying across devices."; + << "Source and target must have the same data type when copying across devices."; MSHADOW_TYPE_SWITCH_WITH_BOOL(to->type_flag_, DType, { - mshadow::Copy(to->FlatTo1D(), - from.FlatTo1D(), - ctx.get_stream()); + mshadow::Copy(to->FlatTo1D(), from.FlatTo1D(), ctx.get_stream()); }); } -template<> -void Copy(const TBlob &from, TBlob *to, - Context from_ctx, Context to_ctx, +template <> +void Copy(const TBlob& from, + TBlob* to, + Context from_ctx, + Context to_ctx, RunContext ctx) { if (from_ctx.dev_id == to_ctx.dev_id) { mshadow::Stream* s = ctx.get_stream(); MSHADOW_TYPE_SWITCH_WITH_BOOL(to->type_flag_, DType, { if (to->type_flag_ == from.type_flag_) { - mshadow::Copy(to->FlatTo1D(s), - from.FlatTo1D(s), - s); + mshadow::Copy(to->FlatTo1D(s), from.FlatTo1D(s), s); } else { MSHADOW_TYPE_SWITCH_WITH_BOOL(from.type_flag_, SrcDType, { to->FlatTo1D(s) = - mshadow::expr::tcast(from.FlatTo1D(s)); + mshadow::expr::tcast(from.FlatTo1D(s)); }) } }) } else { CHECK(from.CheckContiguous() && to->CheckContiguous()) - << "copy across only support contiguous memory"; + << "copy across only support contiguous memory"; CHECK_EQ(to->type_flag_, from.type_flag_) - << "Source and target must have the same data type when copying across devices."; - mshadow::Stream *s = ctx.get_stream(); + << "Source and target must have the same data type when copying across devices."; + mshadow::Stream* s = ctx.get_stream(); CHECK(s != nullptr) << "need stream in GPU context"; cudaMemcpyPeerAsync(to->dptr_, to_ctx.dev_id, @@ -109,7 +109,7 @@ void ElementwiseSumRspImpl(mshadow::Stream* s, using namespace rowsparse; using nnvm::dim_t; CHECK_EQ(out->storage_type(), kRowSparseStorage) - << "Expected rowsparse storage_type (" << out->storage_type() << " given)"; + << "Expected rowsparse storage_type (" << out->storage_type() << " given)"; int init = 0; for (const auto& nd : nds) { if (nd.storage_initialized()) { @@ -121,68 +121,59 @@ void ElementwiseSumRspImpl(mshadow::Stream* s, FillZerosRspImpl(s, *out); return; } - const dim_t num_rows = out->shape()[0]; + const dim_t num_rows = out->shape()[0]; const dim_t row_length = out->shape().ProdShape(1, out->shape().ndim()); - MSHADOW_TYPE_SWITCH(out->dtype(), DType, { // data type + MSHADOW_TYPE_SWITCH(out->dtype(), DType, { // data type MSHADOW_IDX_TYPE_SWITCH(out->aux_type(kIdx), IType, { // row_idx type // Allocate temporary storage for row_flg array and cub's prefix sum operation - IType* row_flg = nullptr; - void* d_temp_storage = nullptr; + IType* row_flg = nullptr; + void* d_temp_storage = nullptr; size_t temp_storage_bytes = 0; - cudaStream_t stream = mshadow::Stream::GetStream(s); - cub::DeviceScan::InclusiveSum(d_temp_storage, - temp_storage_bytes, - row_flg, - row_flg, - num_rows, - stream); - mshadow::Tensor workspace = rsc - .get_space_typed(mshadow::Shape1(num_rows * sizeof(IType) + - temp_storage_bytes), s); - row_flg = reinterpret_cast(workspace.dptr_); - d_temp_storage = workspace.dptr_ + num_rows*sizeof(IType); + cudaStream_t stream = mshadow::Stream::GetStream(s); + cub::DeviceScan::InclusiveSum( + d_temp_storage, temp_storage_bytes, row_flg, row_flg, num_rows, stream); + mshadow::Tensor workspace = rsc.get_space_typed( + mshadow::Shape1(num_rows * sizeof(IType) + temp_storage_bytes), s); + row_flg = reinterpret_cast(workspace.dptr_); + d_temp_storage = workspace.dptr_ + num_rows * sizeof(IType); // Mark row_flg array with 0 for zero rows and 1 for non-zero rows dim_t num_threads = num_rows; mxnet_op::Kernel::Launch(s, num_threads, row_flg); for (const auto& nd : nds) { if (nd.storage_initialized()) { const IType* nd_row_idx = nd.aux_data(kIdx).dptr(); - const dim_t nd_nnr = nd.storage_shape()[0]; - num_threads = nd_nnr; - mxnet_op::Kernel::Launch(s, num_threads, - row_flg, nd_row_idx, nd_nnr); + const dim_t nd_nnr = nd.storage_shape()[0]; + num_threads = nd_nnr; + mxnet_op::Kernel::Launch( + s, num_threads, row_flg, nd_row_idx, nd_nnr); } } // Compute inclusive prefix sum over row_flg - cub::DeviceScan::InclusiveSum(d_temp_storage, - temp_storage_bytes, - row_flg, - row_flg, - num_rows, - stream); + cub::DeviceScan::InclusiveSum( + d_temp_storage, temp_storage_bytes, row_flg, row_flg, num_rows, stream); // Get total number of output non-zero rows from GPU and allocate out data and row_idx dim_t nnr_out = 0; - CUDA_CALL(cudaMemcpyAsync(&nnr_out, &row_flg[num_rows-1], sizeof(dim_t), - cudaMemcpyDeviceToHost, stream)); + CUDA_CALL(cudaMemcpyAsync( + &nnr_out, &row_flg[num_rows - 1], sizeof(dim_t), cudaMemcpyDeviceToHost, stream)); CUDA_CALL(cudaStreamSynchronize(stream)); out->CheckAndAlloc({mshadow::Shape1(nnr_out)}); IType* out_row_idx = out->aux_data(kIdx).dptr(); - DType* out_data = out->data().dptr(); + DType* out_data = out->data().dptr(); // Fill row_idx array of output using row_flg num_threads = num_rows; - mxnet_op::Kernel::Launch(s, num_threads, - out_row_idx, row_flg, num_rows); + mxnet_op::Kernel::Launch( + s, num_threads, out_row_idx, row_flg, num_rows); // Perform elementwise addition, writing to output data num_threads = nnr_out * row_length; mxnet_op::Kernel::Launch(s, num_threads, out_data); for (const auto& nd : nds) { if (nd.storage_initialized()) { const IType* nd_row_idx = nd.aux_data(kIdx).dptr(); - const DType* nd_data = nd.data().dptr(); - const dim_t nd_nnr = nd.storage_shape()[0]; - num_threads = nd_nnr * row_length; - mxnet_op::Kernel::Launch(s, num_threads, - out_data, row_flg, nd_row_idx, nd_data, nd_nnr, row_length); + const DType* nd_data = nd.data().dptr(); + const dim_t nd_nnr = nd.storage_shape()[0]; + num_threads = nd_nnr * row_length; + mxnet_op::Kernel::Launch( + s, num_threads, out_data, row_flg, nd_row_idx, nd_data, nd_nnr, row_length); } } }); @@ -197,21 +188,30 @@ void ElementwiseSumDnsCsrDnsImpl(mshadow::Stream* s, using namespace mxnet::op::mxnet_op; const TBlob& out_data = out->data(); MSHADOW_TYPE_SWITCH(out->dtype(), DType, { // data type - Kernel::Launch( - s, out_data.Size(), out_data.dptr(), kWriteTo, nds[0].data().dptr(), - nds[2].data().dptr()); - const TBlob& csr_data = nds[1].data(); - const TBlob& csr_indices = nds[1].aux_data(csr::kIdx); - const TBlob& csr_indptr = nds[1].aux_data(csr::kIndPtr); + Kernel::Launch(s, + out_data.Size(), + out_data.dptr(), + kWriteTo, + nds[0].data().dptr(), + nds[2].data().dptr()); + const TBlob& csr_data = nds[1].data(); + const TBlob& csr_indices = nds[1].aux_data(csr::kIdx); + const TBlob& csr_indptr = nds[1].aux_data(csr::kIndPtr); const nnvm::dim_t num_rows = nds[1].shape()[0]; const nnvm::dim_t num_cols = nds[1].shape()[1]; - MSHADOW_IDX_TYPE_SWITCH(csr_indices.type_flag_, IType, { // indices type + MSHADOW_IDX_TYPE_SWITCH(csr_indices.type_flag_, IType, { // indices type MSHADOW_IDX_TYPE_SWITCH(csr_indptr.type_flag_, CType, { // indptr type if (nds[1].storage_initialized()) { Kernel, gpu>::Launch( - s, kWarpSize * num_rows, out_data.dptr(), out_data.dptr(), - csr_data.dptr(), csr_indices.dptr(), - csr_indptr.dptr(), num_rows, num_cols); + s, + kWarpSize * num_rows, + out_data.dptr(), + out_data.dptr(), + csr_data.dptr(), + csr_indices.dptr(), + csr_indptr.dptr(), + num_rows, + num_cols); } }); }); @@ -219,23 +219,23 @@ void ElementwiseSumDnsCsrDnsImpl(mshadow::Stream* s, } void ElementwiseSumContainsDnsImpl(mshadow::Stream* s, - const Resource& rsc, - const std::vector& nds, - NDArray* out) { + const Resource& rsc, + const std::vector& nds, + NDArray* out) { using namespace mxnet::op; using namespace mxnet::op::mxnet_op; const TBlob& out_data = out->data(); MSHADOW_TYPE_SWITCH(out->dtype(), DType, { // data type for (size_t i = 0; i < nds.size(); ++i) { - const NDArray& nd = nds[i]; + const NDArray& nd = nds[i]; const nnvm::dim_t num_rows = nd.shape()[0]; const nnvm::dim_t num_cols = nd.shape()[1]; - const TBlob& nd_data = nd.data(); + const TBlob& nd_data = nd.data(); if (i == 0) { if (nd.storage_type() == kDefaultStorage) { Kernel, gpu>::Launch( - s, out_data.Size(), out_data.dptr(), nd_data.dptr()); + s, out_data.Size(), out_data.dptr(), nd_data.dptr()); continue; } else { Kernel::Launch(s, out_data.Size(), out_data.dptr()); @@ -244,21 +244,29 @@ void ElementwiseSumContainsDnsImpl(mshadow::Stream* s, switch (nd.storage_type()) { case kDefaultStorage: { - Kernel, gpu>::Launch( - s, out_data.Size(), out_data.dptr(), out_data.dptr(), - nd_data.dptr()); + Kernel, gpu>::Launch(s, + out_data.Size(), + out_data.dptr(), + out_data.dptr(), + nd_data.dptr()); break; } case kCSRStorage: { const TBlob& nd_indices = nd.aux_data(csr::kIdx); - const TBlob& nd_indptr = nd.aux_data(csr::kIndPtr); - MSHADOW_IDX_TYPE_SWITCH(nd_indices.type_flag_, IType, { // indices type + const TBlob& nd_indptr = nd.aux_data(csr::kIndPtr); + MSHADOW_IDX_TYPE_SWITCH(nd_indices.type_flag_, IType, { // indices type MSHADOW_IDX_TYPE_SWITCH(nd_indptr.type_flag_, CType, { // indptr type if (nd.storage_initialized()) { Kernel, gpu>::Launch( - s, kWarpSize * num_rows, out_data.dptr(), out_data.dptr(), - nd_data.dptr(), nd_indices.dptr(), - nd_indptr.dptr(), num_rows, num_cols); + s, + kWarpSize * num_rows, + out_data.dptr(), + out_data.dptr(), + nd_data.dptr(), + nd_indices.dptr(), + nd_indptr.dptr(), + num_rows, + num_cols); } }); }); @@ -270,9 +278,15 @@ void ElementwiseSumContainsDnsImpl(mshadow::Stream* s, if (nd.storage_initialized()) { const nnvm::dim_t nz_rows = nd_indices.Size(); Kernel, gpu>::Launch( - s, nz_rows * num_cols, out_data.dptr(), - out_data.dptr(), nd_data.dptr(), nd_indices.dptr(), - num_rows, nz_rows, num_cols); + s, + nz_rows * num_cols, + out_data.dptr(), + out_data.dptr(), + nd_data.dptr(), + nd_indices.dptr(), + num_rows, + nz_rows, + num_cols); } }); break; @@ -288,12 +302,13 @@ void ElementwiseSumContainsDnsImpl(mshadow::Stream* s, * \brief Parallel gpu impl of elemwise sum for sparse tensors. * Currently only support row sparse sum. */ -template<> +template <> void ElementwiseSum(mshadow::Stream* s, const Resource& rsc, const std::vector& nds, NDArray* out) { - if (nds.empty()) return; + if (nds.empty()) + return; if (common::ContainsOnlyStorage(nds, kRowSparseStorage)) { ElementwiseSumRspImpl(s, rsc, nds, out); } else if (nds.size() == 3U && nds[0].storage_type() == kDefaultStorage && @@ -305,14 +320,13 @@ void ElementwiseSum(mshadow::Stream* s, ElementwiseSumContainsDnsImpl(s, rsc, nds, out); } else { LOG(FATAL) << "ElementwiseSum has not been implemented for storage_type = << " - << nds[0].storage_type(); + << nds[0].storage_type(); } } -template<> -void Eval(mshadow::Stream *s, - const real_t val, const NDArray& dst) { - NDArray temp = dst; +template <> +void Eval(mshadow::Stream* s, const real_t val, const NDArray& dst) { + NDArray temp = dst; const NDArrayStorageType stype = temp.storage_type(); if (stype == kRowSparseStorage) { SetValueRspImpl(s, val, &temp); diff --git a/src/ndarray/ndarray_function.h b/src/ndarray/ndarray_function.h index 505bd205a8d5..003eda951905 100644 --- a/src/ndarray/ndarray_function.h +++ b/src/ndarray/ndarray_function.h @@ -38,7 +38,7 @@ namespace mxnet { /*! \brief namespace to support all possible Ndarray operator */ namespace ndarray { struct BinaryBase { - inline static mxnet::TShape GetShape(const mxnet::TShape &lshape, const mxnet::TShape &rshape) { + inline static mxnet::TShape GetShape(const mxnet::TShape& lshape, const mxnet::TShape& rshape) { CHECK(lshape == rshape) << "operands shape mismatch"; CHECK(!mxnet::op::shape_is_none(lshape)) << "source operand have zero dimension shape"; return lshape; @@ -68,7 +68,7 @@ struct Mod : public BinaryBase { struct ClipMin : public BinaryBase { struct mshadow_op { - template + template MSHADOW_XINLINE static DType Map(DType a, DType b) { if (a < b) { return b; @@ -81,7 +81,7 @@ struct ClipMin : public BinaryBase { struct ClipMax : public BinaryBase { struct mshadow_op { - template + template MSHADOW_XINLINE static DType Map(DType a, DType b) { if (a > b) { return b; @@ -92,9 +92,8 @@ struct ClipMax : public BinaryBase { }; }; - struct OneHotEncode { - inline static mxnet::TShape GetShape(const mxnet::TShape &index, const mxnet::TShape &proptype) { + inline static mxnet::TShape GetShape(const mxnet::TShape& index, const mxnet::TShape& proptype) { CHECK(index.ndim() == 1 && proptype.ndim() == 2) << "OneHotEncode only support 1d index."; CHECK_EQ(index[0], proptype[0]) << "OneHotEncode shape inconsistent"; return proptype; @@ -102,7 +101,7 @@ struct OneHotEncode { }; struct MatChooseRowElem { - inline static mxnet::TShape GetShape(const mxnet::TShape &lshape, const mxnet::TShape &rshape) { + inline static mxnet::TShape GetShape(const mxnet::TShape& lshape, const mxnet::TShape& rshape) { CHECK(lshape.ndim() == 2 && rshape.ndim() == 1) << "choose_row_element only support 2D Matrix and 1D index"; CHECK_EQ(lshape[0], rshape[0]) << "choose_row_element index and matrix shape mismatch"; @@ -111,9 +110,9 @@ struct MatChooseRowElem { }; struct MatFillRowElem { - inline static mxnet::TShape GetShape(const mxnet::TShape &lshape, - const mxnet::TShape &mshape, - const mxnet::TShape &rshape) { + inline static mxnet::TShape GetShape(const mxnet::TShape& lshape, + const mxnet::TShape& mshape, + const mxnet::TShape& rshape) { CHECK(lshape.ndim() == 2 && mshape.ndim() == 1 && rshape.ndim() == 1) << "fill_row_element only support 2D Matrix, 1D value and 1D index"; CHECK((lshape[0] == mshape[0]) && (mshape[0] == rshape[0])) @@ -137,46 +136,46 @@ struct NegBinomialDistribution {}; struct GenNegBinomialDistribution {}; -template -void EvalClip(const TBlob &src, const real_t &a_min, const real_t &a_max, - TBlob *ret, RunContext ctx); +template +void EvalClip(const TBlob& src, + const real_t& a_min, + const real_t& a_max, + TBlob* ret, + RunContext ctx); -template -void Eval(const TBlob &lhs, const TBlob &mhs, const TBlob &rhs, TBlob *ret, RunContext ctx); +template +void Eval(const TBlob& lhs, const TBlob& mhs, const TBlob& rhs, TBlob* ret, RunContext ctx); -template -void Eval(const TBlob &lhs, const TBlob &rhs, TBlob *ret, RunContext ctx); +template +void Eval(const TBlob& lhs, const TBlob& rhs, TBlob* ret, RunContext ctx); -template -void Eval(const TBlob &src, TBlob *ret, RunContext ctx); +template +void Eval(const TBlob& src, TBlob* ret, RunContext ctx); -template -void Eval(const TBlob &lhs, const real_t &rhs, TBlob *ret, RunContext ctx); +template +void Eval(const TBlob& lhs, const real_t& rhs, TBlob* ret, RunContext ctx); -template -void Eval(const real_t &rhs, TBlob *ret, RunContext ctx); +template +void Eval(const real_t& rhs, TBlob* ret, RunContext ctx); -template -void EvalRandom(const real_t &a, - const real_t &b, - const Resource &resource, - TBlob *ret, RunContext ctx); +template +void EvalRandom(const real_t& a, + const real_t& b, + const Resource& resource, + TBlob* ret, + RunContext ctx); // copy function when only cpu is involved -template -void Copy(const TBlob &from, TBlob *to, - Context from_ctx, Context to_ctx, - RunContext ctx); +template +void Copy(const TBlob& from, TBlob* to, Context from_ctx, Context to_ctx, RunContext ctx); -template -void ElementwiseSum(const std::vector source, - TBlob *out, - RunContext ctx); +template +void ElementwiseSum(const std::vector source, TBlob* out, RunContext ctx); /*! * \brief Interface for parallel impl of elemwise sum for sparse matrices */ -template +template void ElementwiseSum(mshadow::Stream* s, const Resource& rsc, const std::vector& nds, @@ -188,9 +187,8 @@ void ElementwiseSum(mshadow::Stream* s, * \param val - The value to be set * \param dst - NDArray which is to be set to val */ -template -void SetValueRspImpl(mshadow::Stream *s, - const real_t val, NDArray *dst) { +template +void SetValueRspImpl(mshadow::Stream* s, const real_t val, NDArray* dst) { CHECK_EQ(dst->storage_type(), kRowSparseStorage); using namespace mxnet::op; nnvm::dim_t nnr = dst->shape()[0]; @@ -202,17 +200,15 @@ void SetValueRspImpl(mshadow::Stream *s, Fill(s, dst->data(), kWriteTo, val); } -template -void Eval(mshadow::Stream *s, - const real_t val, const NDArray& dst); +template +void Eval(mshadow::Stream* s, const real_t val, const NDArray& dst); // broadcasting template void EvalBroadcast(TBlob const& src, TBlob* ret, int size, RunContext ctx); template -void BinaryOpKernelImpl(mshadow::Stream *s, const TBlob& lhs, - const TBlob& rhs, TBlob *out); +void BinaryOpKernelImpl(mshadow::Stream* s, const TBlob& lhs, const TBlob& rhs, TBlob* out); } // namespace ndarray } // namespace mxnet diff --git a/src/nnvm/amp_infer_unknown.cc b/src/nnvm/amp_infer_unknown.cc index c457905a3b68..3353a28f4b43 100644 --- a/src/nnvm/amp_infer_unknown.cc +++ b/src/nnvm/amp_infer_unknown.cc @@ -36,24 +36,25 @@ #include "../operator/tensor/amp_cast.h" namespace mxnet { -using nnvm::Graph; -using nnvm::ObjectPtr; -using nnvm::NodeEntry; using dmlc::any; using mxnet::op::AMPCastParam; +using nnvm::Graph; +using nnvm::NodeEntry; +using nnvm::ObjectPtr; // If a var node is not visited, visit it and set inferred_dtype_result as result_dtype, // If already visited compare the result_dtype with existing inferred_dtype_result static void CheckAndUpdateInferredDtypes( - const nnvm::DTypeVector &inferred_dtypes, const nnvm::IndexedGraph &idx, - const NodeEntry &node_entry, + const nnvm::DTypeVector& inferred_dtypes, + const nnvm::IndexedGraph& idx, + const NodeEntry& node_entry, mshadow::TypeFlag result_dtype, - std::unordered_map *visited_vars, - nnvm::DTypeVector *inferred_dtype_result) { - const ObjectPtr &input_node = node_entry.node; + std::unordered_map* visited_vars, + nnvm::DTypeVector* inferred_dtype_result) { + const ObjectPtr& input_node = node_entry.node; if (!visited_vars->count(input_node->attrs.name)) { if ((*inferred_dtype_result)[idx.entry_id(node_entry)] == -1) { - (*visited_vars)[input_node->attrs.name] = result_dtype; + (*visited_vars)[input_node->attrs.name] = result_dtype; (*inferred_dtype_result)[idx.entry_id(node_entry)] = result_dtype; } } else { @@ -68,15 +69,14 @@ static void CheckAndUpdateInferredDtypes( // Graph pass to infer unknown nodes which are input nodes // as LP16 if possible -Graph AMPInferUnknown(Graph &&src) { - const nnvm::DTypeVector &inferred_dtypes = - src.GetAttr("inferred_dtypes"); - const int target_dtype = src.GetAttr("target_dtype"); +Graph AMPInferUnknown(Graph&& src) { + const nnvm::DTypeVector& inferred_dtypes = src.GetAttr("inferred_dtypes"); + const int target_dtype = src.GetAttr("target_dtype"); CHECK(target_dtype == mshadow::kFloat16 || target_dtype == mshadow::kBfloat16) << "Only float16 and bfloat16 target_dtypes are supported yet"; nnvm::DTypeVector inferred_dtype_result(inferred_dtypes); - const nnvm::IndexedGraph &idx = src.indexed_graph(); + const nnvm::IndexedGraph& idx = src.indexed_graph(); std::unordered_map visited_vars; @@ -84,7 +84,7 @@ Graph AMPInferUnknown(Graph &&src) { // and check if inputs to these nodes are variables. // If input nodes are variables, set dtype for these inputs // and check for conflicts if an input node goes to two cast nodes - DFSVisit(src.outputs, [&](const ObjectPtr &node) { + DFSVisit(src.outputs, [&](const ObjectPtr& node) { if (!node->is_variable()) { std::string op_name = node->op()->name; @@ -93,17 +93,18 @@ Graph AMPInferUnknown(Graph &&src) { // to visited_vars, if a var is being visited second time // and already has dtype set, make sure the dtype inferred again // is same, otherwise reset dtype to original dtype - for (const NodeEntry &node_entry : node->inputs) { - const ObjectPtr &input_node = node_entry.node; + for (const NodeEntry& node_entry : node->inputs) { + const ObjectPtr& input_node = node_entry.node; if (input_node->is_variable() && (node->attrs.dict.find("dtype") != node->attrs.dict.end())) { - const AMPCastParam ¶m = - nnvm::get(node->attrs.parsed); - CHECK(param.dtype != -1) - << "amp_cast node shouldn't have unknown dtype"; - CheckAndUpdateInferredDtypes(inferred_dtypes, idx, node_entry, + const AMPCastParam& param = nnvm::get(node->attrs.parsed); + CHECK(param.dtype != -1) << "amp_cast node shouldn't have unknown dtype"; + CheckAndUpdateInferredDtypes(inferred_dtypes, + idx, + node_entry, static_cast(param.dtype), - &visited_vars, &inferred_dtype_result); + &visited_vars, + &inferred_dtype_result); } } } else if (op_name == "amp_multicast") { @@ -120,11 +121,14 @@ Graph AMPInferUnknown(Graph &&src) { } } if (max_dtype == target_dtype) { - for (const NodeEntry &node_entry : node->inputs) { - const ObjectPtr &input_node = node_entry.node; + for (const NodeEntry& node_entry : node->inputs) { + const ObjectPtr& input_node = node_entry.node; if (input_node->is_variable()) { - CheckAndUpdateInferredDtypes(inferred_dtypes, idx, node_entry, - max_dtype, &visited_vars, + CheckAndUpdateInferredDtypes(inferred_dtypes, + idx, + node_entry, + max_dtype, + &visited_vars, &inferred_dtype_result); } } diff --git a/src/nnvm/error.h b/src/nnvm/error.h index 863513964d91..c98016d0f71c 100644 --- a/src/nnvm/error.h +++ b/src/nnvm/error.h @@ -28,11 +28,12 @@ namespace pass { class InvalidGraphError : public std::exception { public: - explicit InvalidGraphError(const std::string& msg = "invalid graph error"): msg_(msg) { } + explicit InvalidGraphError(const std::string& msg = "invalid graph error") : msg_(msg) {} ~InvalidGraphError() throw() {} virtual const char* what() const throw() { return msg_.c_str(); } + private: std::string msg_; }; diff --git a/src/nnvm/gradient.cc b/src/nnvm/gradient.cc index 447f658890d8..2609f78d7579 100644 --- a/src/nnvm/gradient.cc +++ b/src/nnvm/gradient.cc @@ -53,18 +53,16 @@ struct GradEntry { std::vector grads; }; - /*! * \brief Build the backward graph from the mirror map. This function will be * invoked twice if backward mirroring has been enabled. */ -Graph BuildGradientGraph( - const Graph& src, - const std::vector& xs, - const std::vector& topo_order, - std::unordered_map > output_grads, - std::function mirror_fun, - const std::unordered_map& mirror_map); +Graph BuildGradientGraph(const Graph& src, + const std::vector& xs, + const std::vector& topo_order, + std::unordered_map > output_grads, + std::function mirror_fun, + const std::unordered_map& mirror_map); /*! * \brief Auxiliary function that maps the forward node of the source graph to @@ -74,25 +72,19 @@ inline const ObjectPtr& MapFwdNodeToMirrorPath( const ObjectPtr& n, const std::unordered_map& mirror_map) { auto iter = mirror_map.find(n.get()); - if (iter == mirror_map.end() || - iter->second == nullptr) { + if (iter == mirror_map.end() || iter->second == nullptr) { return n; } return iter->second; } - Graph Gradient(Graph src) { - CHECK_NE(src.attrs.count("grad_ys"), 0U) - << "Gradient require grad_ys to be presented."; - CHECK_NE(src.attrs.count("grad_xs"), 0U) - << "Gradient require grad_xs to be presented."; + CHECK_NE(src.attrs.count("grad_ys"), 0U) << "Gradient require grad_ys to be presented."; + CHECK_NE(src.attrs.count("grad_xs"), 0U) << "Gradient require grad_xs to be presented."; CHECK_NE(src.attrs.count("grad_ys_out_grad"), 0U) << "Gradient require grad_ys_out_grad to be presented."; - const std::vector& xs = - src.GetAttr >("grad_xs"); - const std::vector& ys = - src.GetAttr >("grad_ys"); + const std::vector& xs = src.GetAttr >("grad_xs"); + const std::vector& ys = src.GetAttr >("grad_ys"); const std::vector& ys_out_grad = src.GetAttr >("grad_ys_out_grad"); CHECK_EQ(ys.size(), ys_out_grad.size()); @@ -102,13 +94,12 @@ Graph Gradient(Graph src) { std::vector topo_order; std::unordered_map > output_grads; - DFSVisit(ys, - [&](const ObjectPtr& node) { - if (output_grads.count(node.get()) == 0) { - output_grads[node.get()].resize(node->num_outputs()); - } - topo_order.push_back(node); - }); + DFSVisit(ys, [&](const ObjectPtr& node) { + if (output_grads.count(node.get()) == 0) { + output_grads[node.get()].resize(node->num_outputs()); + } + topo_order.push_back(node); + }); for (size_t i = 0; i < ys.size(); ++i) { output_grads[ys[i].node.get()][ys[i].index].grads = {ys_out_grad[i]}; @@ -117,12 +108,11 @@ Graph Gradient(Graph src) { // check that all xs are reachable from ys for (size_t i = 0; i < xs.size(); ++i) { CHECK(output_grads.find(xs[i].node.get()) != output_grads.end()) - << "Cannot differentiate with respect to the " - << (i + 1) << "-th variable " + << "Cannot differentiate with respect to the " << (i + 1) << "-th variable " << "because it is unreachable from the outputs."; } - using MirrorFun = std::function; + using MirrorFun = std::function; MirrorFun mirror_fun = nullptr; if (src.attrs.count("mirror_fun") != 0) { mirror_fun = src.GetAttr("mirror_fun"); @@ -130,22 +120,18 @@ Graph Gradient(Graph src) { std::unordered_map mirror_map; // complete the backward graph of the src, but without backward mirroring - nnvm::Graph gsrc = BuildGradientGraph(src, xs, topo_order, - output_grads, - nullptr, mirror_map); + nnvm::Graph gsrc = BuildGradientGraph(src, xs, topo_order, output_grads, nullptr, mirror_map); if (mirror_fun == nullptr) { return gsrc; // Gradient pass without mirroring ends here. } - const IndexedGraph& idx = src.indexed_graph(), - & gidx = gsrc.indexed_graph(); + const IndexedGraph &idx = src.indexed_graph(), &gidx = gsrc.indexed_graph(); // =========================================================================== // ----- Gradient Pass w/ Backward Mirroring ----- // =========================================================================== // Record, for each node entry ∈ gsrc, the nodes that reference it as inputs. // It is important to note that since the node entry reference mapping is // constructed from gradient graph, it can only be indexed using gidx entry ID. - std::vector > node_entry_ref_map( - gidx.num_node_entries()); + std::vector > node_entry_ref_map(gidx.num_node_entries()); static const auto& fignore_inputs = Op::GetAttr("FIgnoreInputs"); for (uint32_t gnid = 0; gnid < gidx.num_nodes(); ++gnid) { const IndexedGraph::Node& inode = gidx[gnid]; @@ -156,8 +142,7 @@ Graph Gradient(Graph src) { if (fignore_inputs.count(inode.source->op()) != 0) { std::vector ignore_inputs = fignore_inputs[inode.source->op()](inode.source->attrs); - if (std::find(ignore_inputs.begin(), ignore_inputs.end(), i) - != ignore_inputs.end()) { + if (std::find(ignore_inputs.begin(), ignore_inputs.end(), i) != ignore_inputs.end()) { continue; } } @@ -208,9 +193,8 @@ Graph Gradient(Graph src) { // B // After invoking this function. `subgraph` will become {A, B}. // Note that this function will be invoked multiple times. - auto subworklist_backprop = [&subworklist, &subgraph, - &subgraph_topo_order, - &mirror_fun, &worklist]() { + auto subworklist_backprop = + [&subworklist, &subgraph, &subgraph_topo_order, &mirror_fun, &worklist]() { std::deque subworklist_topo_order; for (; !subworklist.empty(); subworklist.pop()) { const Node* const subworkitem = subworklist.front(); @@ -266,8 +250,7 @@ Graph Gradient(Graph src) { do { has_subgraph_converged = true; for (const Node* const subgraph_node : subgraph_topo_order) { - for (const NodeEntry& subgraph_node_entry : - subgraph_node->inputs) { + for (const NodeEntry& subgraph_node_entry : subgraph_node->inputs) { const std::unordered_set ref_nodes = node_entry_ref_map[gidx.entry_id(subgraph_node_entry)]; @@ -289,7 +272,7 @@ Graph Gradient(Graph src) { ref_node_heads.push(ref_node); for (; !ref_node_heads.empty(); ref_node_heads.pop()) { const Node* const ref_node_head = ref_node_heads.front(); - bool is_ref_node_head_output = false; + bool is_ref_node_head_output = false; for (const NodeEntry& y : ys) { if (ref_node_head == y.node.get()) { is_ref_node_head_output = true; @@ -309,7 +292,7 @@ Graph Gradient(Graph src) { } } } // for (oid ∈ [0, ref_node_head->num_outputs())) - } // for (ref_node_head ∈ ref_node_heads) + } // for (ref_node_head ∈ ref_node_heads) // Do the backpropagation again. The topological order of the // subworklist can be directly appended to the end of the existing // order. E,g, in our previous example, we expect to have @@ -320,7 +303,7 @@ Graph Gradient(Graph src) { break; } // if (ref_node != subgraph_node && idx.exist(ref_node) && // subgraph.find(ref_node) == subgraph.end() - } // for (ref_node ∈ ref_nodes) + } // for (ref_node ∈ ref_nodes) if (!has_subgraph_converged) { break; } @@ -365,15 +348,13 @@ Graph Gradient(Graph src) { bool is_frontier = true; for (const NodeEntry& e : subgraph_node->inputs) { auto iter = mirror_map.find(e.node.get()); - if (mirror_fun(*(e.node)) && - !(iter != mirror_map.end() && iter->second == nullptr)) { + if (mirror_fun(*(e.node)) && !(iter != mirror_map.end() && iter->second == nullptr)) { is_frontier = false; } } for (const ObjectPtr& n : subgraph_node->control_deps) { auto iter = mirror_map.find(n.get()); - if (mirror_fun(*n) && - !(iter != mirror_map.end() && iter->second == nullptr)) { + if (mirror_fun(*n) && !(iter != mirror_map.end() && iter->second == nullptr)) { is_frontier = false; } } @@ -404,7 +385,7 @@ Graph Gradient(Graph src) { has_forward_candidates_converged = true; for (const Node* const candidate : forward_candidates) { for (const NodeEntry& candidate_input : candidate->inputs) { - uint32_t geid = gidx.entry_id(candidate_input); + uint32_t geid = gidx.entry_id(candidate_input); const std::unordered_set& ref_nodes = node_entry_ref_map[geid]; for (const Node* const ref_node : ref_nodes) { if (ref_node != frontier_node.first && @@ -427,23 +408,19 @@ Graph Gradient(Graph src) { // Record the node entries that are newly allocated and those that are // released. A node entry can be released if all its referencing nodes // are part of the subgraph frontier. Otherwise, it is in the allocated set. - std::unordered_set newly_allocated_node_entries, - released_node_entries; + std::unordered_set newly_allocated_node_entries, released_node_entries; for (const Node* const candidate : forward_candidates) { - uint32_t nid = idx.node_id(candidate), - gnid = gidx.node_id(candidate); + uint32_t nid = idx.node_id(candidate), gnid = gidx.node_id(candidate); for (uint32_t oid = 0; oid < candidate->num_outputs(); ++oid) { - uint32_t eid = idx.entry_id(nid, oid), - geid = gidx.entry_id(gnid, oid); + uint32_t eid = idx.entry_id(nid, oid), geid = gidx.entry_id(gnid, oid); if (node_entry_ref_map[geid].size() != 0) { newly_allocated_node_entries.insert(eid); } } for (const NodeEntry& candidate_input : candidate->inputs) { - uint32_t eid = idx.entry_id(candidate_input), - geid = gidx.entry_id(candidate_input); + uint32_t eid = idx.entry_id(candidate_input), geid = gidx.entry_id(candidate_input); const std::unordered_set& ref_nodes = node_entry_ref_map[geid]; - bool can_be_released = true; + bool can_be_released = true; for (const Node* const ref_node : ref_nodes) { if (subgraph_frontier.find(ref_node) == subgraph_frontier.end()) { newly_allocated_node_entries.insert(eid); @@ -454,7 +431,7 @@ Graph Gradient(Graph src) { released_node_entries.insert(eid); } } // for (candidate_input ∈ candidate->input) - } // for (candidate ∈ forward_candidates) + } // for (candidate ∈ forward_candidates) // Now, compare the total amount of newly allocated storage versus the // released storage, if the latter is greater or equal to the former, @@ -462,8 +439,7 @@ Graph Gradient(Graph src) { // forward candidate nodes are marked as on the mirror path. size_t newly_allocated_storage = 0, released_storage = 0; for (const uint32_t eid : newly_allocated_node_entries) { - newly_allocated_storage += src_shapes[eid].Size() * - MXGetDTypeSize(src_dtypes[eid]); + newly_allocated_storage += src_shapes[eid].Size() * MXGetDTypeSize(src_dtypes[eid]); } for (const uint32_t eid : released_node_entries) { released_storage += src_shapes[eid].Size() * MXGetDTypeSize(src_dtypes[eid]); @@ -477,7 +453,7 @@ Graph Gradient(Graph src) { has_subgraph_converged = false; break; } // if (released_storage >= newly_allocated_storage) - } // for (frontier_node ∈ subgraph_frontier) + } // for (frontier_node ∈ subgraph_frontier) } while (!has_subgraph_converged); // Finally, mark all the remaining nodes of the subgraph as on the mirror path. @@ -486,7 +462,7 @@ Graph Gradient(Graph src) { continue; } ObjectPtr subgraph_node_mirror = Node::Create(); - *subgraph_node_mirror = *subgraph_node; + *subgraph_node_mirror = *subgraph_node; subgraph_node_mirror->attrs.name += "_mirror"; for (NodeEntry& e : subgraph_node_mirror->inputs) { e.node = MapFwdNodeToMirrorPath(e.node, mirror_map); @@ -497,26 +473,23 @@ Graph Gradient(Graph src) { mirror_map[subgraph_node] = subgraph_node_mirror; } } // for (workitem ∈ worklist) - DFSVisit(ys, - [&](const ObjectPtr& node) { - if (mirror_map.at(node.get()) != nullptr) { - node->attrs.dict["__mirror_stage__"] = "1"; - } else { - node->attrs.dict["__mirror_stage__"] = "0"; - } - }); - return BuildGradientGraph(src, xs, topo_order, - output_grads, - mirror_fun, mirror_map); + DFSVisit(ys, [&](const ObjectPtr& node) { + if (mirror_map.at(node.get()) != nullptr) { + node->attrs.dict["__mirror_stage__"] = "1"; + } else { + node->attrs.dict["__mirror_stage__"] = "0"; + } + }); + return BuildGradientGraph(src, xs, topo_order, output_grads, mirror_fun, mirror_map); } - /*! * \brief Auxiliary function that checks whether all the gradients are zero or not. */ inline bool CheckGradAllZero(const std::vector& grads, const std::vector& zero_ops) { - if (!grads.size() || !zero_ops.size()) return false; + if (!grads.size() || !zero_ops.size()) + return false; for (const auto& g : grads) { bool found = false; for (const auto& op : zero_ops) { @@ -525,43 +498,41 @@ inline bool CheckGradAllZero(const std::vector& grads, break; } } - if (!found) return false; + if (!found) + return false; } return true; } - -Graph BuildGradientGraph( - const Graph& src, - const std::vector& xs, - const std::vector& topo_order, - std::unordered_map > output_grads, - std::function mirror_fun, - const std::unordered_map& mirror_map) { +Graph BuildGradientGraph(const Graph& src, + const std::vector& xs, + const std::vector& topo_order, + std::unordered_map > output_grads, + std::function mirror_fun, + const std::unordered_map& mirror_map) { static auto& grad_fun_map = Op::GetAttr("FGradient"); // gradient aggregation function - using AggFun = std::function&&)>; - AggFun agg_fun = [](std::vector&& v)->NodeEntry { - if (v.size() == 1) { - return std::move(v[0]); - } else if (v.size() == 0) { - ObjectPtr zero_grad_node = Node::Create(); - zero_grad_node->attrs.op = Op::Get("zeros"); - zero_grad_node->attrs.name = "zero_grad"; - zero_grad_node->attrs.op->attr_parser(&(zero_grad_node->attrs)); - return NodeEntry{zero_grad_node, 0, 0}; - } else { - ObjectPtr grad_sum_node = Node::Create(); - grad_sum_node->attrs.op = Op::Get("elemwise_sum"); - grad_sum_node->inputs = std::move(v); - grad_sum_node->attrs.name = "grad_sum"; - grad_sum_node->attrs.dict["num_args"] = - std::to_string(grad_sum_node->inputs.size()); - grad_sum_node->attrs.op->attr_parser(&(grad_sum_node->attrs)); - return NodeEntry{grad_sum_node, 0, 0}; - } - }; + using AggFun = std::function &&)>; + AggFun agg_fun = [](std::vector&& v) -> NodeEntry { + if (v.size() == 1) { + return std::move(v[0]); + } else if (v.size() == 0) { + ObjectPtr zero_grad_node = Node::Create(); + zero_grad_node->attrs.op = Op::Get("zeros"); + zero_grad_node->attrs.name = "zero_grad"; + zero_grad_node->attrs.op->attr_parser(&(zero_grad_node->attrs)); + return NodeEntry{zero_grad_node, 0, 0}; + } else { + ObjectPtr grad_sum_node = Node::Create(); + grad_sum_node->attrs.op = Op::Get("elemwise_sum"); + grad_sum_node->inputs = std::move(v); + grad_sum_node->attrs.name = "grad_sum"; + grad_sum_node->attrs.dict["num_args"] = std::to_string(grad_sum_node->inputs.size()); + grad_sum_node->attrs.op->attr_parser(&(grad_sum_node->attrs)); + return NodeEntry{grad_sum_node, 0, 0}; + } + }; if (src.attrs.count("grad_aggregate_fun") != 0) { agg_fun = src.GetAttr("grad_aggregate_fun"); } @@ -571,19 +542,21 @@ Graph BuildGradientGraph( if (src.attrs.count("zero_ops") != 0) { zero_ops = src.GetAttr >("zero_ops"); } - const Op* copy_op = (src.attrs.count("copy_op_str") != 0) ? - Op::Get(src.GetAttr("copy_op_str")) : nullptr; + const Op* copy_op = (src.attrs.count("copy_op_str") != 0) + ? Op::Get(src.GetAttr("copy_op_str")) + : nullptr; std::vector out_agg_grads; - for (auto topo_order_rit = topo_order.rbegin(); - topo_order_rit != topo_order.rend(); ++topo_order_rit) { + for (auto topo_order_rit = topo_order.rbegin(); topo_order_rit != topo_order.rend(); + ++topo_order_rit) { const ObjectPtr& src_fwd_node = *topo_order_rit; - if (src_fwd_node->is_variable()) continue; + if (src_fwd_node->is_variable()) + continue; // gather all the output gradient entries and apply the aggregation function out_agg_grads.clear(); auto& out_grad_vec = output_grads.at(src_fwd_node.get()); - for (auto & e : out_grad_vec) { + for (auto& e : out_grad_vec) { e.sum = agg_fun(std::move(e.grads)); out_agg_grads.push_back(e.sum); } @@ -606,15 +579,13 @@ Graph BuildGradientGraph( if (mirror_fun != nullptr && !mirror_fun(*fwd_node)) { for (NodeEntry& input_grad : input_grads) { for (NodeEntry& grad_input : input_grad.node->inputs) { - const ObjectPtr& grad_input_node_mirrored = MapFwdNodeToMirrorPath( - grad_input.node, mirror_map); - grad_input = NodeEntry( - grad_input_node_mirrored, - grad_input.index, - grad_input.version); + const ObjectPtr& grad_input_node_mirrored = + MapFwdNodeToMirrorPath(grad_input.node, mirror_map); + grad_input = + NodeEntry(grad_input_node_mirrored, grad_input.index, grad_input.version); } // for (grad_input ∈ input_grad.node->inputs) - } // for (input_grad ∈ input_grads) - } // if (mirror_fun != nullptr && !mirror_fun(*fwd_node)) + } // for (input_grad ∈ input_grads) + } // if (mirror_fun != nullptr && !mirror_fun(*fwd_node)) } else if (CheckGradAllZero(out_agg_grads, zero_ops)) { for (size_t i = 0; i < src_fwd_node->num_inputs(); ++i) { std::ostringstream os; @@ -623,8 +594,8 @@ Graph BuildGradientGraph( } else { os << fwd_node->attrs.name << "_in" << i << "_backward"; } - auto p = Node::Create(); - p->attrs.op = zero_ops[0]; + auto p = Node::Create(); + p->attrs.op = zero_ops[0]; p->attrs.name = os.str(); p->inputs.push_back(fwd_node->inputs[i]); p->control_deps.emplace_back(fwd_node); @@ -634,8 +605,9 @@ Graph BuildGradientGraph( input_grads.emplace_back(p, 0, 0); } // for (i ∈ src_fwd_node->num_inputs()) } else { - std::string message = "Operator " + std::string(src_fwd_node->op()->name) - + "is non-differentiable because it didn't register FGradient attribute."; + std::string message = + "Operator " + std::string(src_fwd_node->op()->name) + + "is non-differentiable because it didn't register FGradient attribute."; throw nnvm::pass::InvalidGraphError(message); } for (const auto& e : input_grads) { @@ -643,15 +615,14 @@ Graph BuildGradientGraph( } auto input_grad_iter = input_grads.begin(); CHECK(src_fwd_node->inputs.size() <= input_grads.size()); - for (auto input_iter = src_fwd_node->inputs.begin(); - input_iter != src_fwd_node->inputs.end(); + for (auto input_iter = src_fwd_node->inputs.begin(); input_iter != src_fwd_node->inputs.end(); ++input_iter, ++input_grad_iter) { // propagate the input gradients to the output gradients of the input nodes - output_grads[input_iter->node.get()][input_iter->index] - .grads.emplace_back(std::move(*input_grad_iter)); + output_grads[input_iter->node.get()][input_iter->index].grads.emplace_back( + std::move(*input_grad_iter)); } } // if (src_fwd_node->inputs.size() != 0) - } // for (topo_order_rit ∈ reverse(topo_order)) + } // for (topo_order_rit ∈ reverse(topo_order)) // take out the xs' grads Graph ret; ret.outputs.resize(xs.size()); @@ -672,14 +643,13 @@ Graph BuildGradientGraph( std::ostringstream os; os << entry.sum.node->attrs.name << "_" << kv->second.first << "_copy"; kv->second.first++; - copy_node->attrs.op = copy_op; + copy_node->attrs.op = copy_op; copy_node->attrs.name = os.str(); copy_node->inputs.emplace_back(entry.sum); if (copy_node->attrs.op->attr_parser != nullptr) { copy_node->attrs.op->attr_parser(&(copy_node->attrs)); } - unique_grads.emplace(NodeEntry{std::move(copy_node), 0, 0}, - std::make_pair(1, counter)); + unique_grads.emplace(NodeEntry{std::move(copy_node), 0, 0}, std::make_pair(1, counter)); } } else { ret.outputs[counter] = entry.sum; @@ -694,17 +664,16 @@ Graph BuildGradientGraph( return ret; } - // register pass NNVM_REGISTER_PASS(MXGradient) -.describe(R"(Return a gradient graph of src.attrs["ys"] wrt src.attrs["xs"])") -.set_body(Gradient) -.set_change_graph(true) -.depend_graph_attr("grad_ys") -.depend_graph_attr("grad_xs") -.depend_graph_attr("in_arg_shapes") -.depend_graph_attr("in_arg_dtypes") -.depend_graph_attr("grad_ys_out_grad"); + .describe(R"(Return a gradient graph of src.attrs["ys"] wrt src.attrs["xs"])") + .set_body(Gradient) + .set_change_graph(true) + .depend_graph_attr("grad_ys") + .depend_graph_attr("grad_xs") + .depend_graph_attr("in_arg_shapes") + .depend_graph_attr("in_arg_dtypes") + .depend_graph_attr("grad_ys_out_grad"); } // namespace diff --git a/src/nnvm/graph_algorithm.h b/src/nnvm/graph_algorithm.h index 031e62254db7..49c6420392af 100644 --- a/src/nnvm/graph_algorithm.h +++ b/src/nnvm/graph_algorithm.h @@ -23,7 +23,7 @@ * \brief This header contains graph algorithms on StaticGraph. * It is used compute informations such as whether two * operations can run in parallel, and helps allocation. -*/ + */ #ifndef MXNET_NNVM_GRAPH_ALGORITHM_H_ #define MXNET_NNVM_GRAPH_ALGORITHM_H_ @@ -42,10 +42,9 @@ namespace pass { * \param path the output path of nodes. * \return the total reward of best path. */ -inline uint32_t MXFindBestPath( - const IndexedGraph& graph, - const std::vector& node_reward, - std::vector* path) { +inline uint32_t MXFindBestPath(const IndexedGraph& graph, + const std::vector& node_reward, + std::vector* path) { const uint32_t num_nodes = static_cast(graph.num_nodes()); CHECK_EQ(num_nodes, node_reward.size()); @@ -58,21 +57,22 @@ inline uint32_t MXFindBestPath( const uint32_t nid = i - 1; best_reward[nid] += node_reward[nid]; if (best_reward[nid] > best_solution) { - best_solution = best_reward[nid]; + best_solution = best_reward[nid]; best_start_node = nid; } for (const auto& e : graph[nid].inputs) { const uint32_t prev = e.node_id; if (best_reward[nid] > best_reward[prev]) { best_reward[prev] = best_reward[nid]; - next_node[prev] = nid; + next_node[prev] = nid; } } } path->clear(); uint32_t reward = 0; for (uint32_t nid = best_start_node; nid < num_nodes; nid = next_node[nid]) { - path->push_back(nid); reward += node_reward[nid]; + path->push_back(nid); + reward += node_reward[nid]; } CHECK_EQ(reward, best_solution); return best_solution; @@ -89,11 +89,10 @@ inline uint32_t MXFindBestPath( * \param color the color index of each of the node. * \return the total number of colors. */ -inline uint32_t MXColorNodeGroup( - const IndexedGraph &graph, - std::vector node_importance, - uint32_t max_ncolor, - std::vector *color) { +inline uint32_t MXColorNodeGroup(const IndexedGraph& graph, + std::vector node_importance, + uint32_t max_ncolor, + std::vector* color) { CHECK_NE(max_ncolor, 0U); CHECK_EQ(graph.num_nodes(), node_importance.size()); @@ -106,7 +105,8 @@ inline uint32_t MXColorNodeGroup( for (cindex = 0; cindex < max_ncolor - 1; ++cindex) { std::vector path; uint32_t reward = MXFindBestPath(graph, node_importance, &path); - if (reward == 0) break; + if (reward == 0) + break; for (uint32_t nid : path) { if (node_importance[nid] != 0) { CHECK_EQ(color->at(nid), max_ncolor); diff --git a/src/nnvm/graph_editor.cc b/src/nnvm/graph_editor.cc index 44a807eda174..eef96ddd1676 100644 --- a/src/nnvm/graph_editor.cc +++ b/src/nnvm/graph_editor.cc @@ -39,18 +39,18 @@ namespace mxnet { * Given a computation graph, this function finds the input nodes of the graph * and create symbols for the input nodes. It returns the input symbols. */ -std::vector GetInputSymbols(const nnvm::Symbol &sym) { +std::vector GetInputSymbols(const nnvm::Symbol& sym) { nnvm::Graph g; - std::vector input_syms; - g.outputs = sym.outputs; + std::vector input_syms; + g.outputs = sym.outputs; const nnvm::IndexedGraph& idx = g.indexed_graph(); // Go through all nodes and return the ones representing variables. for (size_t i = 0; i < idx.num_nodes(); i++) { - const nnvm::Node &n = *idx[i].source; - for (const nnvm::NodeEntry &e : n.inputs) { + const nnvm::Node& n = *idx[i].source; + for (const nnvm::NodeEntry& e : n.inputs) { auto p = e.node; if (p->is_variable()) { - nnvm::Symbol *s = new nnvm::Symbol(); + nnvm::Symbol* s = new nnvm::Symbol(); s->outputs.push_back(e); input_syms.push_back(s); } @@ -65,11 +65,12 @@ std::vector GetInputSymbols(const nnvm::Symbol &sym) { * subgraph. It returns the nodes that connect to the subgraph directly and * the names of the new variable nodes. */ -bool CutGraphInputs(const std::vector &input_entries, - bool skip_var, std::vector *orig_entries) { +bool CutGraphInputs(const std::vector& input_entries, + bool skip_var, + std::vector* orig_entries) { struct pred_entry { nnvm::NodeEntry e; - explicit pred_entry(nnvm::NodeEntry _e): e(std::move(_e)) {} + explicit pred_entry(nnvm::NodeEntry _e) : e(std::move(_e)) {} bool operator()(const nnvm::NodeEntry e1) { return e.node == e1.node && e.index == e1.index; } @@ -83,8 +84,7 @@ bool CutGraphInputs(const std::vector &input_entries, if (input_entry->node->is_variable() && skip_var) continue; - auto it = std::find_if(orig_entries->begin(), orig_entries->end(), - pred_entry(*input_entry)); + auto it = std::find_if(orig_entries->begin(), orig_entries->end(), pred_entry(*input_entry)); bool exist = (it != orig_entries->end()); orig_entries->push_back(*input_entry); nnvm::ObjectPtr n; diff --git a/src/nnvm/legacy_json_util.cc b/src/nnvm/legacy_json_util.cc index 64e1228b37f0..e0bc07dd64cb 100644 --- a/src/nnvm/legacy_json_util.cc +++ b/src/nnvm/legacy_json_util.cc @@ -36,94 +36,95 @@ #include "../c_api/c_api_common.h" namespace mxnet { +using nnvm::FListInputNames; using nnvm::Graph; -using nnvm::Op; using nnvm::Node; -using nnvm::ObjectPtr; using nnvm::NodeAttrs; using nnvm::NodeEntry; +using nnvm::ObjectPtr; +using nnvm::Op; using nnvm::Symbol; -using nnvm::FListInputNames; // First fix things that prevent attr_parser success. Graph UpgradeJSON_FixParsing(Graph g) { nnvm::DFSVisit(g.outputs, [](const std::shared_ptr& n) { - static auto& flist_inputs = Op::GetAttr("FListInputNames"); - - // hold keys that should be converted to hidden keys - std::vector > hidden_keys; - - // remove attrs that prevent parsing - for (auto it = n->attrs.dict.begin(); it != n->attrs.dict.end();) { - bool erase = false; - // remove hidden keys - for (const auto& key : kHiddenKeys) { - size_t pos = it->first.rfind(key); - if (pos == 0 || (pos != std::string::npos && pos == it->first.length() - key.length())) { - hidden_keys.emplace_back(*it); - erase = true; - break; - } + static auto& flist_inputs = Op::GetAttr("FListInputNames"); + + // hold keys that should be converted to hidden keys + std::vector > hidden_keys; + + // remove attrs that prevent parsing + for (auto it = n->attrs.dict.begin(); it != n->attrs.dict.end();) { + bool erase = false; + // remove hidden keys + for (const auto& key : kHiddenKeys) { + size_t pos = it->first.rfind(key); + if (pos == 0 || (pos != std::string::npos && pos == it->first.length() - key.length())) { + hidden_keys.emplace_back(*it); + erase = true; + break; } - - auto tmp = it; - ++it; - if (erase) n->attrs.dict.erase(tmp); } - // parse - if (n->op() != nullptr && n->op()->attr_parser != nullptr) - n->op()->attr_parser(&(n->attrs)); + auto tmp = it; + ++it; + if (erase) + n->attrs.dict.erase(tmp); + } + + // parse + if (n->op() != nullptr && n->op()->attr_parser != nullptr) + n->op()->attr_parser(&(n->attrs)); - // add back removed hidden keys - for (const auto& kv : hidden_keys) { - bool flag = false; - for (const auto& key : kHiddenKeys) { - size_t pos = kv.first.rfind(key); - if (pos == 0 && key.length() == kv.first.length()) { - n->attrs.dict["__"+key+"__"] = kv.second; - flag = true; + // add back removed hidden keys + for (const auto& kv : hidden_keys) { + bool flag = false; + for (const auto& key : kHiddenKeys) { + size_t pos = kv.first.rfind(key); + if (pos == 0 && key.length() == kv.first.length()) { + n->attrs.dict["__" + key + "__"] = kv.second; + flag = true; + break; + } else if (pos != std::string::npos && pos > 1 && pos == kv.first.length() - key.length()) { + if (n->is_variable()) break; - } else if (pos != std::string::npos && pos > 1 - && pos == kv.first.length() - key.length()) { - if (n->is_variable()) break; - FListInputNames fn = flist_inputs.get(n->op(), nullptr); - if (fn == nullptr) break; - auto arg_names = fn(n->attrs); - auto name = kv.first.substr(0, pos-1); - auto it = std::find(arg_names.begin(), arg_names.end(), name); - if (it != arg_names.end()) { - int idx = it - arg_names.begin(); - if (n->inputs[idx].node->is_variable()) { - n->inputs[idx].node->attrs.dict["__"+key+"__"] = kv.second; - flag = true; - } - } + FListInputNames fn = flist_inputs.get(n->op(), nullptr); + if (fn == nullptr) break; + auto arg_names = fn(n->attrs); + auto name = kv.first.substr(0, pos - 1); + auto it = std::find(arg_names.begin(), arg_names.end(), name); + if (it != arg_names.end()) { + int idx = it - arg_names.begin(); + if (n->inputs[idx].node->is_variable()) { + n->inputs[idx].node->attrs.dict["__" + key + "__"] = kv.second; + flag = true; + } } + break; } - if (!flag) n->attrs.dict[kv.first] = kv.second; } - }); + if (!flag) + n->attrs.dict[kv.first] = kv.second; + } + }); return g; } Graph UpgradeJSON_Parse(Graph g) { nnvm::DFSVisit(g.outputs, [](const std::shared_ptr& n) { - if (n->op() != nullptr) { - if (n->op()->attr_parser != nullptr) - n->op()->attr_parser(&(n->attrs)); - } else { - // ugly workaround due to VariableParam is not exposed. - n->attrs.parsed = - nnvm::Symbol::CreateVariable(n->attrs.name).outputs[0].node->attrs.parsed; - } - }); + if (n->op() != nullptr) { + if (n->op()->attr_parser != nullptr) + n->op()->attr_parser(&(n->attrs)); + } else { + // ugly workaround due to VariableParam is not exposed. + n->attrs.parsed = nnvm::Symbol::CreateVariable(n->attrs.name).outputs[0].node->attrs.parsed; + } + }); return g; } -inline std::string DefaultVarName(const std::string &op_name, - const std::string &arg_name) { +inline std::string DefaultVarName(const std::string& op_name, const std::string& arg_name) { if (op_name.length() == 0) { return arg_name; } else { @@ -134,68 +135,70 @@ inline std::string DefaultVarName(const std::string &op_name, // aux variables are not stored in json before 0.9.0. Add them here. Graph UpgradeJSON_000800_000900(Graph g) { nnvm::DFSVisit(g.outputs, [](const std::shared_ptr& n) { - static auto& flist_inputs = Op::GetAttr("FListInputNames"); - if (n->inputs.size() < n->num_inputs()) { - FListInputNames fn = flist_inputs.get(n->op(), nullptr); - if (fn == nullptr) return; - - auto arg_names = fn(n->attrs); - for (size_t i = n->inputs.size(); i < n->num_inputs(); ++i) { - auto var = Symbol::CreateVariable( - DefaultVarName(n->attrs.name, arg_names[i])).outputs[0]; - var.node->attrs.dict = n->attrs.dict; - n->inputs.push_back(var); - } + static auto& flist_inputs = Op::GetAttr("FListInputNames"); + if (n->inputs.size() < n->num_inputs()) { + FListInputNames fn = flist_inputs.get(n->op(), nullptr); + if (fn == nullptr) + return; + + auto arg_names = fn(n->attrs); + for (size_t i = n->inputs.size(); i < n->num_inputs(); ++i) { + auto var = Symbol::CreateVariable(DefaultVarName(n->attrs.name, arg_names[i])).outputs[0]; + var.node->attrs.dict = n->attrs.dict; + n->inputs.push_back(var); } - }); + } + }); return g; } // Refactor initializer in v0.9.2 Graph UpgradeJSON_000903_000904(Graph g) { nnvm::DFSVisit(g.outputs, [](const std::shared_ptr& n) { - static auto& fset_attrs = + static auto& fset_attrs = Op::GetAttr("FSetInputVarAttrOnCompose"); - if (n->op() != nullptr) { - nnvm::FSetInputVarAttrOnCompose fn = fset_attrs.get(n->op(), nullptr); - if (fn != nullptr) { - for (size_t i = 0; i < n->inputs.size(); ++i) { - if (n->inputs[i].node->is_variable()) { - fn(n->attrs, n->inputs[i].node, i); - } + if (n->op() != nullptr) { + nnvm::FSetInputVarAttrOnCompose fn = fset_attrs.get(n->op(), nullptr); + if (fn != nullptr) { + for (size_t i = 0; i < n->inputs.size(); ++i) { + if (n->inputs[i].node->is_variable()) { + fn(n->attrs, n->inputs[i].node, i); } } } - }); + } + }); return g; } // ReduceAxisParam: int axis -> optional axis Graph UpgradeJSON_000904_000905(Graph g) { nnvm::DFSVisit(g.outputs, [](const std::shared_ptr& n) { - if (n->op() == nullptr) return; - if (n->op()->name != "argmin" && n->op()->name != "argmax") return; - if (n->attrs.dict.find("axis") == n->attrs.dict.end() || n->attrs.dict["axis"] != "-1") - return; - n->attrs.dict.erase("axis"); - n->op()->attr_parser(&(n->attrs)); - }); + if (n->op() == nullptr) + return; + if (n->op()->name != "argmin" && n->op()->name != "argmax") + return; + if (n->attrs.dict.find("axis") == n->attrs.dict.end() || n->attrs.dict["axis"] != "-1") + return; + n->attrs.dict.erase("axis"); + n->op()->attr_parser(&(n->attrs)); + }); return g; } static std::vector > > upgrader_list = { - {MXNET_VERSION, UpgradeJSON_FixParsing}, - {MXNET_MAKE_VERSION(100, 0, 0), UpgradeJSON_Parse}, - {MXNET_MAKE_VERSION(0, 9, 0), UpgradeJSON_000800_000900}, - {MXNET_MAKE_VERSION(0, 9, 4), UpgradeJSON_000903_000904}, - {MXNET_MAKE_VERSION(0, 9, 5), UpgradeJSON_000904_000905}, + {MXNET_VERSION, UpgradeJSON_FixParsing}, + {MXNET_MAKE_VERSION(100, 0, 0), UpgradeJSON_Parse}, + {MXNET_MAKE_VERSION(0, 9, 0), UpgradeJSON_000800_000900}, + {MXNET_MAKE_VERSION(0, 9, 4), UpgradeJSON_000903_000904}, + {MXNET_MAKE_VERSION(0, 9, 5), UpgradeJSON_000904_000905}, }; Graph LoadLegacyJSONPass(Graph g) { g.attrs["load_json_no_parse"] = std::make_shared(true); - Graph load = nnvm::ApplyPass(g, "LoadJSON"); - int version = MXNET_MAKE_VERSION(0, 8, 0); + Graph load = nnvm::ApplyPass(g, "LoadJSON"); + int version = MXNET_MAKE_VERSION(0, 8, 0); if (load.attrs.find("mxnet_version") != load.attrs.end()) { version = nnvm::get(*load.attrs["mxnet_version"]); } @@ -206,23 +209,24 @@ Graph LoadLegacyJSONPass(Graph g) { << ". May cause undefined behavior. " << "Please update MXNet if you encounter any issue"; } else if (version < MXNET_VERSION) { - LOG(INFO) << "Loading symbol saved by previous version v" - << version/10000 << "." << (version/100)%100 << "." << version%100 - << ". Attempting to upgrade..."; + LOG(INFO) << "Loading symbol saved by previous version v" << version / 10000 << "." + << (version / 100) % 100 << "." << version % 100 << ". Attempting to upgrade..."; upgrading = true; } for (auto& upgrader : upgrader_list) { - if (upgrader.first > version) load = upgrader.second(load); + if (upgrader.first > version) + load = upgrader.second(load); } - if (upgrading) LOG(INFO) << "Symbol successfully upgraded!"; + if (upgrading) + LOG(INFO) << "Symbol successfully upgraded!"; return load; } // register pass NNVM_REGISTER_PASS(LoadLegacyJSON) -.describe("Return a new Graph, loaded from src.attrs[\"json\"] and upgraded to current version") -.set_body(LoadLegacyJSONPass) -.set_change_graph(true) -.depend_graph_attr("json"); + .describe("Return a new Graph, loaded from src.attrs[\"json\"] and upgraded to current version") + .set_body(LoadLegacyJSONPass) + .set_change_graph(true) + .depend_graph_attr("json"); } // namespace mxnet diff --git a/src/nnvm/legacy_op_util.cc b/src/nnvm/legacy_op_util.cc index 851552a56016..53dd5725679d 100644 --- a/src/nnvm/legacy_op_util.cc +++ b/src/nnvm/legacy_op_util.cc @@ -34,11 +34,11 @@ namespace mxnet { namespace op { -using nnvm::Op; using nnvm::Node; -using nnvm::ObjectPtr; using nnvm::NodeAttrs; using nnvm::NodeEntry; +using nnvm::ObjectPtr; +using nnvm::Op; class ParsedOpProp { public: @@ -50,7 +50,7 @@ class ParsedOpProp { // initializer void Init(const NodeAttrs& attrs) { // For performance, do a reserve first and then copy attrs.dict - std::vector > kwargs; + std::vector> kwargs; kwargs.reserve(attrs.dict.size()); kwargs.insert(kwargs.end(), attrs.dict.begin(), attrs.dict.end()); try { @@ -66,18 +66,17 @@ class ParsedOpProp { os << ")"; throw dmlc::ParamError(os.str()); } - arguments = ptr->ListArguments(); + arguments = ptr->ListArguments(); aux_states = ptr->ListAuxiliaryStates(); - outputs = ptr->ListOutputs(); - inputs = arguments; - inputs.insert( - inputs.end(), aux_states.begin(), aux_states.end()); + outputs = ptr->ListOutputs(); + inputs = arguments; + inputs.insert(inputs.end(), aux_states.begin(), aux_states.end()); } }; class OperatorState { public: - OperatorState(Operator *opr, const OperatorProperty *prop) { + OperatorState(Operator* opr, const OperatorProperty* prop) { opr_ = opr; in_data_fwd_.resize(prop->ListArguments().size()); @@ -99,13 +98,14 @@ class OperatorState { for (size_t i = 0; i < out_data_.size(); ++i) { out_data_ptr[i] = &out_data_[i]; } - arg_data_ptr_ = prop->BackwardInputs( - out_grad_ptr, in_data_ptr, out_data_ptr); + arg_data_ptr_ = prop->BackwardInputs(out_grad_ptr, in_data_ptr, out_data_ptr); } - ~OperatorState() { delete opr_; } + ~OperatorState() { + delete opr_; + } - void Forward(const OpContext &ctx, + void Forward(const OpContext& ctx, const std::vector& inputs, const std::vector& req, const std::vector& outputs) { @@ -113,16 +113,19 @@ class OperatorState { CHECK_EQ(outputs.size(), out_data_.size()); // in_data_bwd_ has the same tblobs as the ones in in_data_fwd_, except that the ones // referred by arg_data_ptr_ will be overriden - for (size_t i = 0; i < in_data_fwd_.size(); ++i) in_data_fwd_[i] = inputs[i]; - for (size_t i = 0; i < in_data_fwd_.size(); ++i) in_data_bwd_[i] = inputs[i]; + for (size_t i = 0; i < in_data_fwd_.size(); ++i) + in_data_fwd_[i] = inputs[i]; + for (size_t i = 0; i < in_data_fwd_.size(); ++i) + in_data_bwd_[i] = inputs[i]; for (size_t i = 0; i < aux_data_.size(); ++i) { aux_data_[i] = inputs[i + in_data_fwd_.size()]; } - for (size_t i = 0; i < out_data_.size(); ++i) out_data_[i] = outputs[i]; + for (size_t i = 0; i < out_data_.size(); ++i) + out_data_[i] = outputs[i]; opr_->Forward(ctx, in_data_fwd_, req, out_data_, aux_data_); } - void Backward(const OpContext &ctx, + void Backward(const OpContext& ctx, const std::vector& inputs, const std::vector& req, const std::vector& outputs) { @@ -136,12 +139,13 @@ class OperatorState { aux_data_[i] = inputs[inputs.size() - aux_data_.size() + i]; } CHECK_EQ(outputs.size(), in_grad_.size()); - for (size_t i = 0; i < outputs.size(); ++i) in_grad_[i] = outputs[i]; + for (size_t i = 0; i < outputs.size(); ++i) + in_grad_[i] = outputs[i]; opr_->Backward(ctx, out_grad_, in_data_bwd_, out_data_, req, in_grad_, aux_data_); } private: - Operator *opr_; + Operator* opr_; // input data blobs for forward and backward // in_data_fwd_ and in_data_bwd_ will hold different tblobs when StorageFallbackOpExecutor // performs storage fallback on a non-default input NDArray. The one in in_data_fwd_ is @@ -176,17 +180,15 @@ const OperatorProperty* OpPropGetOpProperty(const NodeAttrs& attrs) { return nnvm::get(attrs.parsed).ptr.get(); } -template +template bool OpPropInferAttr(const NodeAttrs& attrs, - std::vector *iattr, - std::vector *oattr, + std::vector* iattr, + std::vector* oattr, FInfer finfer) { auto& prop = nnvm::get(attrs.parsed); CHECK_EQ(prop.inputs.size(), iattr->size()) - << "op=" << attrs.op->name - << ", inputs.size=" << prop.inputs.size() - << ", iattr.size=" << iattr->size() - << ", arg.size=" << prop.arguments.size(); + << "op=" << attrs.op->name << ", inputs.size=" << prop.inputs.size() + << ", iattr.size=" << iattr->size() << ", arg.size=" << prop.arguments.size(); std::vector in_attr(prop.arguments.size()); std::vector aux_attr(prop.aux_states.size()); @@ -196,7 +198,8 @@ bool OpPropInferAttr(const NodeAttrs& attrs, for (size_t i = 0; i < prop.aux_states.size(); ++i) { aux_attr[i] = (*iattr)[i + prop.arguments.size()]; } - if (!finfer(prop.ptr.get(), &in_attr, oattr, &aux_attr)) return false; + if (!finfer(prop.ptr.get(), &in_attr, oattr, &aux_attr)) + return false; for (size_t i = 0; i < prop.arguments.size(); ++i) { (*iattr)[i] = in_attr[i]; @@ -208,26 +211,20 @@ bool OpPropInferAttr(const NodeAttrs& attrs, } bool OpPropInferShape(const NodeAttrs& attrs, - mxnet::ShapeVector *iattr, - mxnet::ShapeVector *oattr) { + mxnet::ShapeVector* iattr, + mxnet::ShapeVector* oattr) { auto finfer = [](const OperatorProperty* op, - mxnet::ShapeVector *in, - mxnet::ShapeVector *out, - mxnet::ShapeVector *aux) { - return op->InferShape(in, out, aux); - }; + mxnet::ShapeVector* in, + mxnet::ShapeVector* out, + mxnet::ShapeVector* aux) { return op->InferShape(in, out, aux); }; return OpPropInferAttr(attrs, iattr, oattr, finfer); } -bool OpPropInferType(const NodeAttrs& attrs, - std::vector *iattr, - std::vector *oattr) { +bool OpPropInferType(const NodeAttrs& attrs, std::vector* iattr, std::vector* oattr) { auto finfer = [](const OperatorProperty* op, - std::vector *in, - std::vector *out, - std::vector *aux) { - return op->InferType(in, out, aux); - }; + std::vector* in, + std::vector* out, + std::vector* aux) { return op->InferType(in, out, aux); }; return OpPropInferAttr(attrs, iattr, oattr, finfer); } @@ -265,7 +262,7 @@ std::vector OpPropMutateInputs(const NodeAttrs& attrs) { return ret; } -std::vector > OpPropInplaceOption(const NodeAttrs& attrs) { +std::vector> OpPropInplaceOption(const NodeAttrs& attrs) { auto& prop = nnvm::get(attrs.parsed); std::vector in_data(prop.arguments.size()); std::vector out_data(prop.outputs.size()); @@ -277,7 +274,7 @@ std::vector > OpPropInplaceOption(const NodeAttrs& attrs) { out_data[i] = static_cast(i); out_addr[i] = &out_data[i]; } - std::vector > forward_inplace; + std::vector> forward_inplace; for (auto& kv : prop.ptr->ForwardInplaceOption(in_data, out_addr)) { forward_inplace.emplace_back(kv.first, *static_cast(kv.second)); } @@ -307,30 +304,28 @@ OpStatePtr OpPropCreateLayerOp(const NodeAttrs& attrs, prop.ptr.get()); } -inline std::vector OpPropGradient( - const Op* back_op, - const ObjectPtr& ptr, - const std::vector& out_grads) { +inline std::vector OpPropGradient(const Op* back_op, + const ObjectPtr& ptr, + const std::vector& out_grads) { auto& prop = nnvm::get(ptr->attrs.parsed); std::vector out_data; out_data.reserve(prop.outputs.size()); for (size_t i = 0; i < prop.outputs.size(); ++i) out_data.emplace_back(ptr, i, 0); - std::vector in_data( - ptr->inputs.begin(), ptr->inputs.begin() + prop.arguments.size()); - std::vector ograd( - out_grads.begin(), out_grads.begin() + prop.ptr->NumVisibleOutputs()); + std::vector in_data(ptr->inputs.begin(), ptr->inputs.begin() + prop.arguments.size()); + std::vector ograd(out_grads.begin(), + out_grads.begin() + prop.ptr->NumVisibleOutputs()); auto inputs = prop.ptr->BackwardInputs(ograd, in_data, out_data); // add all the auxiliary data for (size_t i = 0; i < prop.aux_states.size(); ++i) { inputs.emplace_back(ptr->inputs[i + prop.arguments.size()]); } ObjectPtr gnode = Node::Create(); - gnode->inputs = std::move(inputs); + gnode->inputs = std::move(inputs); gnode->control_deps.emplace_back(ptr); - gnode->attrs = ptr->attrs; - gnode->attrs.op = back_op; + gnode->attrs = ptr->attrs; + gnode->attrs.op = back_op; gnode->attrs.name = ptr->attrs.name + "_backward"; std::vector in_grad; in_grad.reserve(prop.arguments.size() + prop.aux_states.size()); @@ -358,12 +353,13 @@ std::vector OpBackListOutputNames(const NodeAttrs& attrs) { std::vector OpBackMutateInputs(const NodeAttrs& attrs) { auto& prop = nnvm::get(attrs.parsed); - if (prop.aux_states.size() == 0) return std::vector{}; + if (prop.aux_states.size() == 0) + return std::vector{}; std::vector out_grad_index(prop.ptr->NumVisibleOutputs()); std::vector in_data_index(prop.arguments.size()); std::vector out_data_index(prop.outputs.size()); - size_t arg_size = prop.ptr->DeclareBackwardDependency( - out_grad_index, in_data_index, out_data_index).size(); + size_t arg_size = + prop.ptr->DeclareBackwardDependency(out_grad_index, in_data_index, out_data_index).size(); std::vector ret; for (uint32_t i = 0; i < prop.aux_states.size(); ++i) { ret.push_back(static_cast(i + arg_size)); @@ -371,7 +367,7 @@ std::vector OpBackMutateInputs(const NodeAttrs& attrs) { return ret; } -std::vector > OpBackInplaceOption(const NodeAttrs& attrs) { +std::vector> OpBackInplaceOption(const NodeAttrs& attrs) { auto& prop = nnvm::get(attrs.parsed); std::vector out_grad_index(prop.ptr->NumVisibleOutputs()); std::vector in_data_index(prop.arguments.size()); @@ -388,8 +384,8 @@ std::vector > OpBackInplaceOption(const NodeAttrs& attrs) { out_data_index[i] = counter++; } - auto args_index = prop.ptr->DeclareBackwardDependency( - out_grad_index, in_data_index, out_data_index); + auto args_index = + prop.ptr->DeclareBackwardDependency(out_grad_index, in_data_index, out_data_index); std::vector args_array(counter, -1); for (size_t i = 0; i < args_index.size(); ++i) { args_array[args_index[i]] = static_cast(i); @@ -401,14 +397,14 @@ std::vector > OpBackInplaceOption(const NodeAttrs& attrs) { in_grad_ptr[i] = (void*)&in_data_index[i]; // NOLINT(*) } - auto remap_index = prop.ptr->BackwardInplaceOption( - out_grad_index, in_data_index, out_data_index, in_grad_ptr); - std::vector > remap(remap_index.size()); + auto remap_index = + prop.ptr->BackwardInplaceOption(out_grad_index, in_data_index, out_data_index, in_grad_ptr); + std::vector> remap(remap_index.size()); for (size_t i = 0; i < remap_index.size(); ++i) { if (args_array[remap_index[i].first] == -1) { LOG(FATAL) << "BackwardInplaceOption not consistent with DeclareBackwardDependency"; } - remap[i].first = args_array[remap_index[i].first]; + remap[i].first = args_array[remap_index[i].first]; remap[i].second = *static_cast(remap_index[i].second); } return remap; @@ -423,8 +419,9 @@ inline ExecType OpExecType(const NodeAttrs& attrs) { void RegisterLegacyOpProp() { for (auto reg : dmlc::Registry::List()) { Op& op = ::dmlc::Registry<::nnvm::Op>::Get()->__REGISTER_OR_GET__(reg->name); - if (op.attr_parser != nullptr) continue; - auto creator = reg->body; + if (op.attr_parser != nullptr) + continue; + auto creator = reg->body; auto attr_parser = [creator](NodeAttrs* attrs) { if (attrs->parsed.empty()) { ParsedOpProp op; @@ -457,18 +454,17 @@ void RegisterLegacyOpProp() { // register BackwardOps std::string back_op_name = "_backward_" + reg->name; - Op& back_op = ::dmlc::Registry<::nnvm::Op>::Get()->__REGISTER__(back_op_name); - op.set_attr("FGradient", std::bind( - OpPropGradient, &back_op, - std::placeholders::_1, std::placeholders::_2)); + Op& back_op = ::dmlc::Registry<::nnvm::Op>::Get()->__REGISTER__(back_op_name); + op.set_attr( + "FGradient", + std::bind(OpPropGradient, &back_op, std::placeholders::_1, std::placeholders::_2)); back_op.set_attr_parser(attr_parser); back_op.set_num_inputs(nnvm::kVarg); back_op.set_num_outputs(OpBackNumOutputs); back_op.set_attr("FListOutputNames", OpBackListOutputNames); back_op.set_attr("FMutateInputs", OpBackMutateInputs); back_op.set_attr("FInplaceOption", OpBackInplaceOption); - back_op.set_attr( - "FResourceRequest", OpBackResourceRequest); + back_op.set_attr("FResourceRequest", OpBackResourceRequest); back_op.set_attr("TIsLayerOpBackward", true); back_op.set_attr("TIsBackward", true); back_op.set_attr("FExecType", OpExecType); @@ -479,64 +475,67 @@ void RegisterLegacyOpProp() { // no gradient operator NNVM_REGISTER_OP(_NoGradient) -.set_num_inputs(0) -.set_num_outputs(1) -.describe("Place holder for variable who cannot perform gradient"); + .set_num_inputs(0) + .set_num_outputs(1) + .describe("Place holder for variable who cannot perform gradient"); void RegisterLegacyNDFunc() { for (auto reg : dmlc::Registry::List()) { - if (reg->type_mask & kScalarArgBeforeNDArray) continue; + if (reg->type_mask & kScalarArgBeforeNDArray) + continue; Op& op = ::dmlc::Registry<::nnvm::Op>::Get()->__REGISTER_OR_GET__(reg->name); - if (op.attr_parser != nullptr) continue; + if (op.attr_parser != nullptr) + continue; - CHECK_LE(reg->num_scalars + reg->num_use_vars, reg->arguments.size()) - << reg->name; + CHECK_LE(reg->num_scalars + reg->num_use_vars, reg->arguments.size()) << reg->name; auto func = reg->body; op.describe(reg->description); op.add_arguments(reg->arguments); op.set_num_inputs(reg->num_use_vars); op.set_num_outputs(reg->num_mutate_vars); - op.set_attr_parser([](NodeAttrs* attrs){}); - op.set_attr("FNDArrayFunction", [reg](const nnvm::NodeAttrs& attrs, - const std::vector& inputs, - std::vector* outputs) { - CHECK_EQ(inputs.size(), reg->num_use_vars); - CHECK_EQ(outputs->size(), reg->num_mutate_vars); - - int n_scalars = reg->num_scalars; - std::vector scalars; - scalars.reserve(n_scalars); - auto dict = attrs.dict; - for (int i = 0; i < n_scalars; ++i) { - const std::string& name = reg->arguments[i+reg->num_use_vars].name; - auto s = dict.find(name); - CHECK(s != dict.end()) << "Missing scalar param " << name; - scalars.push_back(std::stof(s->second)); - dict.erase(s); - } - - int n_params = dict.size(); - std::vector keys, vals; - keys.reserve(n_params); - vals.reserve(n_params); - for (auto& i : dict) { - keys.push_back(dmlc::BeginPtr(i.first)); - vals.push_back(dmlc::BeginPtr(i.second)); - } - std::vector input_ptrs, output_ptrs; - for (auto& i : inputs) { - input_ptrs.push_back(const_cast(&i)); - } - for (auto& i : *outputs) { - output_ptrs.push_back(&i); - } - reg->body(input_ptrs.data(), - scalars.data(), - output_ptrs.data(), - n_params, - const_cast(keys.data()), - const_cast(vals.data())); - }); + op.set_attr_parser([](NodeAttrs* attrs) {}); + op.set_attr("FNDArrayFunction", + [reg](const nnvm::NodeAttrs& attrs, + const std::vector& inputs, + std::vector* outputs) { + CHECK_EQ(inputs.size(), reg->num_use_vars); + CHECK_EQ(outputs->size(), reg->num_mutate_vars); + + int n_scalars = reg->num_scalars; + std::vector scalars; + scalars.reserve(n_scalars); + auto dict = attrs.dict; + for (int i = 0; i < n_scalars; ++i) { + const std::string& name = + reg->arguments[i + reg->num_use_vars].name; + auto s = dict.find(name); + CHECK(s != dict.end()) << "Missing scalar param " << name; + scalars.push_back(std::stof(s->second)); + dict.erase(s); + } + + int n_params = dict.size(); + std::vector keys, vals; + keys.reserve(n_params); + vals.reserve(n_params); + for (auto& i : dict) { + keys.push_back(dmlc::BeginPtr(i.first)); + vals.push_back(dmlc::BeginPtr(i.second)); + } + std::vector input_ptrs, output_ptrs; + for (auto& i : inputs) { + input_ptrs.push_back(const_cast(&i)); + } + for (auto& i : *outputs) { + output_ptrs.push_back(&i); + } + reg->body(input_ptrs.data(), + scalars.data(), + output_ptrs.data(), + n_params, + const_cast(keys.data()), + const_cast(vals.data())); + }); } } diff --git a/src/nnvm/low_precision_pass.cc b/src/nnvm/low_precision_pass.cc index a13344dfccf5..48575c00bccd 100644 --- a/src/nnvm/low_precision_pass.cc +++ b/src/nnvm/low_precision_pass.cc @@ -32,63 +32,65 @@ #include namespace mxnet { -using nnvm::Symbol; +using nnvm::Graph; using nnvm::Node; -using nnvm::ObjectPtr; using nnvm::NodeEntry; -using nnvm::Graph; +using nnvm::ObjectPtr; +using nnvm::Symbol; // create a node for operator : op_name with name : node_name static ObjectPtr CreateNode(std::string op_name, std::string node_name) { - ObjectPtr node = Node::Create(); + ObjectPtr node = Node::Create(); node->attrs.name = node_name; if (op_name == "nullptr") { node->attrs.op = nullptr; // ugly workaround because VariableParam is not exposed - node->attrs.parsed = nnvm::Symbol::CreateVariable(node->attrs.name) - .outputs[0] - .node->attrs.parsed; + node->attrs.parsed = + nnvm::Symbol::CreateVariable(node->attrs.name).outputs[0].node->attrs.parsed; } else { node->attrs.op = Op::Get(op_name); } return node; } -static ObjectPtr InsertNode(std::string op_name, std::string node_name, ObjectPtr current, - NodeEntry previous) { - ObjectPtr node = CreateNode(op_name, node_name); - node->inputs.emplace_back(previous); - if (current) current->inputs.emplace_back(NodeEntry{node, 0, 0}); - return node; +static ObjectPtr InsertNode(std::string op_name, + std::string node_name, + ObjectPtr current, + NodeEntry previous) { + ObjectPtr node = CreateNode(op_name, node_name); + node->inputs.emplace_back(previous); + if (current) + current->inputs.emplace_back(NodeEntry{node, 0, 0}); + return node; } // get suffix for a node entry so that it can be used for amp_cast/amp_multicast node name -static std::string GetSuffix(const nnvm::NodeEntry &node_entry, - const std::unordered_map &mirror_map) { - static const auto &flist_outputs = - nnvm::Op::GetAttr("FListOutputNames"); - std::string suffix = ""; - ObjectPtr mirror_node = mirror_map.at(node_entry.node.get()); +static std::string GetSuffix(const nnvm::NodeEntry& node_entry, + const std::unordered_map& mirror_map) { + static const auto& flist_outputs = nnvm::Op::GetAttr("FListOutputNames"); + std::string suffix = ""; + ObjectPtr mirror_node = mirror_map.at(node_entry.node.get()); if (mirror_node->op() != nullptr) { - auto list_output_names_func = flist_outputs.get(node_entry.node->op(), nullptr); - if (list_output_names_func != nullptr) { - std::vector names = list_output_names_func(node_entry.node->attrs); - suffix = "_" + names[node_entry.index]; - } else { - suffix = "_" + std::to_string(node_entry.index); - } + auto list_output_names_func = flist_outputs.get(node_entry.node->op(), nullptr); + if (list_output_names_func != nullptr) { + std::vector names = list_output_names_func(node_entry.node->attrs); + suffix = "_" + names[node_entry.index]; + } else { + suffix = "_" + std::to_string(node_entry.index); + } } return suffix; } // add amp_cast node between curr_node and input -static void AddCastNode(const nnvm::NodeEntry &e, const std::string &suffix, - const nnvm::NodeEntry &input, const std::string dtype, - nnvm::NodeEntryMap *mirror_entry_map, +static void AddCastNode(const nnvm::NodeEntry& e, + const std::string& suffix, + const nnvm::NodeEntry& input, + const std::string dtype, + nnvm::NodeEntryMap* mirror_entry_map, ObjectPtr curr_node) { ObjectPtr cast_node = - InsertNode("amp_cast", e.node->attrs.name + suffix + "_amp_cast_" + dtype, - curr_node, input); + InsertNode("amp_cast", e.node->attrs.name + suffix + "_amp_cast_" + dtype, curr_node, input); cast_node->attrs.dict["dtype"] = dtype; cast_node->op()->attr_parser(&(cast_node->attrs)); (*mirror_entry_map)[e] = NodeEntry{std::move(cast_node), 0, e.version}; @@ -96,34 +98,33 @@ static void AddCastNode(const nnvm::NodeEntry &e, const std::string &suffix, } // add amp_multicast node between curr_node and inputs -static void AddMultiCastNode(const std::vector &inputs, - const std::string &node_name, - const std::unordered_map &mirror_map, +static void AddMultiCastNode(const std::vector& inputs, + const std::string& node_name, + const std::unordered_map& mirror_map, ObjectPtr curr_node) { ObjectPtr node = - CreateNode("amp_multicast", - inputs[0].node->attrs.name + node_name + "_amp_multicast"); - for (const auto &node_entry : inputs) { + CreateNode("amp_multicast", inputs[0].node->attrs.name + node_name + "_amp_multicast"); + for (const auto& node_entry : inputs) { ObjectPtr mirror_node = mirror_map.at(node_entry.node.get()); - NodeEntry mirror_entry = NodeEntry{std::move(mirror_node), node_entry.index, - node_entry.version}; + NodeEntry mirror_entry = + NodeEntry{std::move(mirror_node), node_entry.index, node_entry.version}; node->inputs.emplace_back(mirror_entry); } node->attrs.dict["num_outputs"] = std::to_string(inputs.size()); node->op()->attr_parser(&(node->attrs)); for (uint32_t i = 0; i < inputs.size(); ++i) { - const auto &e = inputs[i]; - curr_node->inputs.emplace_back( - NodeEntry{node, static_cast(i), e.version}); + const auto& e = inputs[i]; + curr_node->inputs.emplace_back(NodeEntry{node, static_cast(i), e.version}); } return; } static bool CheckConditionalFP32( - const std::unordered_map< - std::string, std::unordered_map>> - &conditional_fp32_ops, - const std::unordered_set &excluded_syms, ObjectPtr node) { + const std::unordered_map>>& + conditional_fp32_ops, + const std::unordered_set& excluded_syms, + ObjectPtr node) { if (node->is_variable() || (excluded_syms.count(node->attrs.name) > 0) || conditional_fp32_ops.count(node->op()->name) == 0) { return false; @@ -134,12 +135,12 @@ static bool CheckConditionalFP32( auto it_params = it->second; // For each param name, iterate through param values to check // if the provided param name is equal to any of the values - for (auto & it_param : it_params) { + for (auto& it_param : it_params) { auto param_key = node->attrs.dict.find(it_param.first); if (param_key != node->attrs.dict.end()) { auto it_param_vals = it_param.second; - if (std::find(it_param_vals.begin(), it_param_vals.end(), - param_key->second) != it_param_vals.end()) { + if (std::find(it_param_vals.begin(), it_param_vals.end(), param_key->second) != + it_param_vals.end()) { return true; } } @@ -149,19 +150,16 @@ static bool CheckConditionalFP32( } } -Graph ReducePrecision(Graph &&src) { +Graph ReducePrecision(Graph&& src) { static auto& fmutate_inputs = Op::GetAttr("FMutateInputs"); - static auto& infertype = nnvm::Op::GetAttr("FInferType"); - const auto target_dtype_ops = - src.GetAttr>("target_dtype_ops"); - const auto fp32_ops = - src.GetAttr>("fp32_ops"); - const auto widest_dtype_ops = - src.GetAttr>("widest_dtype_ops"); - const auto target_dtype = src.GetAttr("target_dtype"); - const auto excluded_syms = src.GetAttr>("excluded_syms"); - const auto conditional_fp32_ops = src.GetAttr>>>( + static auto& infertype = nnvm::Op::GetAttr("FInferType"); + const auto target_dtype_ops = src.GetAttr>("target_dtype_ops"); + const auto fp32_ops = src.GetAttr>("fp32_ops"); + const auto widest_dtype_ops = src.GetAttr>("widest_dtype_ops"); + const auto target_dtype = src.GetAttr("target_dtype"); + const auto excluded_syms = src.GetAttr>("excluded_syms"); + const auto conditional_fp32_ops = src.GetAttr< + std::unordered_map>>>( "conditional_fp32_ops"); const auto data_name_types = src.GetAttr>("data_name_types"); const auto cast_optional_params = src.GetAttr("cast_optional_params"); @@ -177,12 +175,12 @@ Graph ReducePrecision(Graph &&src) { } // Additional data structures to share common cast node inputs among different nodes - std::unordered_map mirror_map; + std::unordered_map mirror_map; nnvm::NodeEntryMap mirror_fp32_map; nnvm::NodeEntryMap mirror_target_dtype_map; // Visit nodes in a topologically sorted order - DFSVisit(src.outputs, [&](const ObjectPtr &node) { + DFSVisit(src.outputs, [&](const ObjectPtr& node) { ObjectPtr new_node = Node::Create(*node); new_node->inputs.clear(); std::vector mutable_inputs; @@ -203,11 +201,11 @@ Graph ReducePrecision(Graph &&src) { (excluded_syms.count(node->attrs.name) > 0)) { // Add output entry to fp32_map for (size_t i = 0; i < node->num_outputs(); ++i) { - const auto out_entry = NodeEntry(node, i, 0); + const auto out_entry = NodeEntry(node, i, 0); mirror_fp32_map[out_entry] = NodeEntry(new_node, i, 0); } for (size_t i = 0; i < node->inputs.size(); ++i) { - const auto &node_entry = node->inputs[i]; + const auto& node_entry = node->inputs[i]; if (mirror_fp32_map.count(node_entry)) { new_node->inputs.emplace_back(mirror_fp32_map[node_entry]); } else if (node_entry.node->is_variable()) { @@ -215,9 +213,9 @@ Graph ReducePrecision(Graph &&src) { ObjectPtr mirror_node = mirror_map.at(node_entry.node.get()); new_node->inputs.emplace_back(mirror_node, node_entry.index, node_entry.version); } else { - ObjectPtr mirror_node = mirror_map.at(node_entry.node.get()); + ObjectPtr mirror_node = mirror_map.at(node_entry.node.get()); NodeEntry mirror_entry = NodeEntry{mirror_node, node_entry.index, node_entry.version}; - std::string suffix = GetSuffix(node_entry, mirror_map); + std::string suffix = GetSuffix(node_entry, mirror_map); AddCastNode(node_entry, suffix, mirror_entry, "float32", &mirror_fp32_map, new_node); } } @@ -228,7 +226,7 @@ Graph ReducePrecision(Graph &&src) { if (infertype.count(node->op())) { // Try to infertype with target dtype. And add output entry to mirror_target_dtype_map or // mirror_fp32_map based on infered result. - in_types[0] = target_dtype; + in_types[0] = target_dtype; bool infer_type_success = infertype[node->op()](node->attrs, &in_types, &out_types); CHECK(infer_type_success == true); for (size_t i = 0; i < node->num_outputs(); ++i) { @@ -241,7 +239,7 @@ Graph ReducePrecision(Graph &&src) { } } for (size_t i = 0; i < node->inputs.size(); ++i) { - const auto &node_entry = node->inputs[i]; + const auto& node_entry = node->inputs[i]; if (mirror_target_dtype_map.count(node_entry)) { new_node->inputs.emplace_back(mirror_target_dtype_map[node_entry]); } else if ((cast_optional_params && node_entry.node->is_variable() && @@ -255,7 +253,7 @@ Graph ReducePrecision(Graph &&src) { // 2. Mutable inputs. // 3. Even the input[0] is target dtype, some operations still require float32 for other // inputs. For example, Batchnorm. - ObjectPtr mirror_node = mirror_map.at(node_entry.node.get()); + ObjectPtr mirror_node = mirror_map.at(node_entry.node.get()); const auto mirror_entry = NodeEntry(mirror_node, node_entry.index, node_entry.version); new_node->inputs.push_back(mirror_entry); if ((cast_optional_params && node_entry.node->is_variable())) { @@ -263,28 +261,30 @@ Graph ReducePrecision(Graph &&src) { mirror_target_dtype_map[node_entry] = mirror_entry; } } else { - ObjectPtr mirror_node = mirror_map.at(node_entry.node.get()); + ObjectPtr mirror_node = mirror_map.at(node_entry.node.get()); NodeEntry mirror_entry = NodeEntry{mirror_node, node_entry.index, node_entry.version}; - std::string suffix = GetSuffix(node_entry, mirror_map); - AddCastNode(node_entry, suffix, mirror_entry, target_dtype_str, &mirror_target_dtype_map, + std::string suffix = GetSuffix(node_entry, mirror_map); + AddCastNode(node_entry, + suffix, + mirror_entry, + target_dtype_str, + &mirror_target_dtype_map, new_node); } } - } else if (!node->is_variable() && - widest_dtype_ops.count(node->op()->name) > 0 && + } else if (!node->is_variable() && widest_dtype_ops.count(node->op()->name) > 0 && excluded_syms.count(node->attrs.name) == 0) { CHECK(node->inputs.size() > 0) - << "Please check the symbol. node name: " << node->attrs.name - << "op name " << node->op()->name << " has no inputs." + << "Please check the symbol. node name: " << node->attrs.name << "op name " + << node->op()->name << " has no inputs." << "It is likely that something went wrong during symbolic construction."; CHECK_EQ(mutable_inputs.size(), 0) << "can't handle the widest_dtype_ops with mutable inputs."; - int out_dtype = target_dtype; + int out_dtype = target_dtype; bool have_unknown_dtype = false; - for (auto & input : node->inputs) { + for (auto& input : node->inputs) { // Try to infer output dtype based on input dtype - if (!mirror_target_dtype_map.count(input) - && !mirror_fp32_map.count(input)) { + if (!mirror_target_dtype_map.count(input) && !mirror_fp32_map.count(input)) { have_unknown_dtype = true; break; } else if (mirror_fp32_map.count(input)) { @@ -293,7 +293,7 @@ Graph ReducePrecision(Graph &&src) { } if (have_unknown_dtype) { // We can't infer all dtype for inputs, so we need to add AddMultiCastNode here. - const auto &e = node->inputs[0]; + const auto& e = node->inputs[0]; std::string suffix = GetSuffix(e, mirror_map); AddMultiCastNode(node->inputs, suffix, mirror_map, new_node); } else { @@ -307,24 +307,28 @@ Graph ReducePrecision(Graph &&src) { } // we know all dtype from inputs, then we can use amp_cast instead. for (size_t i = 0; i < node->inputs.size(); ++i) { - const auto &node_entry = node->inputs[i]; + const auto& node_entry = node->inputs[i]; if (out_dtype == target_dtype) { if (mirror_target_dtype_map.count(node_entry)) { new_node->inputs.emplace_back(mirror_target_dtype_map[node_entry]); } else { - ObjectPtr mirror_node = mirror_map.at(node_entry.node.get()); + ObjectPtr mirror_node = mirror_map.at(node_entry.node.get()); NodeEntry mirror_entry = NodeEntry{mirror_node, node_entry.index, node_entry.version}; - std::string suffix = GetSuffix(node_entry, mirror_map); - AddCastNode(node_entry, suffix, mirror_entry, target_dtype_str, - &mirror_target_dtype_map, new_node); + std::string suffix = GetSuffix(node_entry, mirror_map); + AddCastNode(node_entry, + suffix, + mirror_entry, + target_dtype_str, + &mirror_target_dtype_map, + new_node); } } else { if (mirror_fp32_map.count(node_entry)) { new_node->inputs.emplace_back(mirror_fp32_map[node_entry]); } else { - ObjectPtr mirror_node = mirror_map.at(node_entry.node.get()); + ObjectPtr mirror_node = mirror_map.at(node_entry.node.get()); NodeEntry mirror_entry = NodeEntry{mirror_node, node_entry.index, node_entry.version}; - std::string suffix = GetSuffix(node_entry, mirror_map); + std::string suffix = GetSuffix(node_entry, mirror_map); AddCastNode(node_entry, suffix, mirror_entry, "float32", &mirror_fp32_map, new_node); } } @@ -332,11 +336,11 @@ Graph ReducePrecision(Graph &&src) { } } else if (CheckConditionalFP32(conditional_fp32_ops, excluded_syms, node)) { for (size_t i = 0; i < node->num_outputs(); ++i) { - const auto out_entry = NodeEntry(node, i, 0); + const auto out_entry = NodeEntry(node, i, 0); mirror_fp32_map[out_entry] = NodeEntry(new_node, i, 0); } for (size_t i = 0; i < node->inputs.size(); ++i) { - const auto &node_entry = node->inputs[i]; + const auto& node_entry = node->inputs[i]; if (mirror_fp32_map.count(node_entry)) { new_node->inputs.emplace_back(mirror_fp32_map[node_entry]); } else if (std::find(mutable_inputs.begin(), mutable_inputs.end(), i) != @@ -345,9 +349,9 @@ Graph ReducePrecision(Graph &&src) { ObjectPtr mirror_node = mirror_map.at(node_entry.node.get()); new_node->inputs.emplace_back(mirror_node, node_entry.index, node_entry.version); } else { - ObjectPtr mirror_node = mirror_map.at(node_entry.node.get()); + ObjectPtr mirror_node = mirror_map.at(node_entry.node.get()); NodeEntry mirror_entry = NodeEntry{mirror_node, node_entry.index, node_entry.version}; - std::string suffix = GetSuffix(node_entry, mirror_map); + std::string suffix = GetSuffix(node_entry, mirror_map); AddCastNode(node_entry, suffix, mirror_entry, "float32", &mirror_fp32_map, new_node); } } @@ -361,7 +365,7 @@ Graph ReducePrecision(Graph &&src) { std::vector in_types(node->inputs.size(), -1); std::vector out_types(node->num_outputs(), -1); if (infertype.count(node->op())) { - in_types[0] = in_type; + in_types[0] = in_type; bool infer_type_success = infertype[node->op()](node->attrs, &in_types, &out_types); if (infer_type_success) { for (size_t i = 0; i < node->num_outputs(); ++i) { @@ -384,13 +388,13 @@ Graph ReducePrecision(Graph &&src) { }); std::vector outputs; - for (const auto &e : src.outputs) { + for (const auto& e : src.outputs) { if (mirror_fp32_map.count(e)) { outputs.emplace_back(mirror_fp32_map[e]); } else { - ObjectPtr mirror_node = mirror_map.at(e.node.get()); + ObjectPtr mirror_node = mirror_map.at(e.node.get()); NodeEntry mirror_entry = NodeEntry{mirror_node, e.index, e.version}; - std::string suffix = GetSuffix(e, mirror_map); + std::string suffix = GetSuffix(e, mirror_map); AddCastNode(e, suffix, mirror_entry, "float32", &mirror_fp32_map, nullptr); outputs.emplace_back(mirror_fp32_map[e]); } diff --git a/src/nnvm/node_op_util.h b/src/nnvm/node_op_util.h index cba6fb8286af..c3a277aa4e22 100644 --- a/src/nnvm/node_op_util.h +++ b/src/nnvm/node_op_util.h @@ -34,58 +34,54 @@ namespace util { class NodeOpGen { private: - const nnvm::ObjectPtr &dependent_node; + const nnvm::ObjectPtr& dependent_node; public: - explicit NodeOpGen(const nnvm::ObjectPtr &dependent_node) : dependent_node{dependent_node} {} + explicit NodeOpGen(const nnvm::ObjectPtr& dependent_node) : dependent_node{dependent_node} {} - nnvm::NodeEntry mul(const nnvm::NodeEntry &lhs, const nnvm::NodeEntry &rhs) { - return nnvm::NodeEntry{mxnet::op::MakeNode("elemwise_mul", - dependent_node->attrs.name + "_mul", - {lhs, rhs}, nullptr, &dependent_node)}; - } + nnvm::NodeEntry mul(const nnvm::NodeEntry& lhs, const nnvm::NodeEntry& rhs) { + return nnvm::NodeEntry{mxnet::op::MakeNode( + "elemwise_mul", dependent_node->attrs.name + "_mul", {lhs, rhs}, nullptr, &dependent_node)}; + } - nnvm::NodeEntry mul(const nnvm::NodeEntry &x, double scalar) { - const std::unordered_map scalar_dict = - {{"scalar", std::to_string(scalar)}}; - return nnvm::NodeEntry{mxnet::op::MakeNode("_mul_scalar", - dependent_node->attrs.name + "_mul_scalar", - {x}, &scalar_dict, &dependent_node)}; - } + nnvm::NodeEntry mul(const nnvm::NodeEntry& x, double scalar) { + const std::unordered_map scalar_dict = { + {"scalar", std::to_string(scalar)}}; + return nnvm::NodeEntry{mxnet::op::MakeNode("_mul_scalar", + dependent_node->attrs.name + "_mul_scalar", + {x}, + &scalar_dict, + &dependent_node)}; + } - nnvm::NodeEntry mul(double scalar, const nnvm::NodeEntry &x) { - return NodeOpGen::mul(x, scalar); - } + nnvm::NodeEntry mul(double scalar, const nnvm::NodeEntry& x) { + return NodeOpGen::mul(x, scalar); + } - nnvm::NodeEntry div(const nnvm::NodeEntry &lhs, const nnvm::NodeEntry &rhs) { - return nnvm::NodeEntry{mxnet::op::MakeNode("elemwise_div", - dependent_node->attrs.name + "_div", - {lhs, rhs}, nullptr, &dependent_node)}; - } + nnvm::NodeEntry div(const nnvm::NodeEntry& lhs, const nnvm::NodeEntry& rhs) { + return nnvm::NodeEntry{mxnet::op::MakeNode( + "elemwise_div", dependent_node->attrs.name + "_div", {lhs, rhs}, nullptr, &dependent_node)}; + } - nnvm::NodeEntry square(const nnvm::NodeEntry &x) { - return nnvm::NodeEntry{mxnet::op::MakeNode("square", - dependent_node->attrs.name + "_square", - {x}, nullptr, &dependent_node)}; - } + nnvm::NodeEntry square(const nnvm::NodeEntry& x) { + return nnvm::NodeEntry{mxnet::op::MakeNode( + "square", dependent_node->attrs.name + "_square", {x}, nullptr, &dependent_node)}; + } - nnvm::NodeEntry reciprocal(const nnvm::NodeEntry &x) { - return nnvm::NodeEntry{mxnet::op::MakeNode("reciprocal", - dependent_node->attrs.name + "_reciprocal", - {x}, nullptr, &dependent_node)}; - } + nnvm::NodeEntry reciprocal(const nnvm::NodeEntry& x) { + return nnvm::NodeEntry{mxnet::op::MakeNode( + "reciprocal", dependent_node->attrs.name + "_reciprocal", {x}, nullptr, &dependent_node)}; + } - nnvm::NodeEntry negative(const nnvm::NodeEntry &x) { - return nnvm::NodeEntry{mxnet::op::MakeNode("negative", - dependent_node->attrs.name + "_negative", - {x}, nullptr, &dependent_node)}; - } + nnvm::NodeEntry negative(const nnvm::NodeEntry& x) { + return nnvm::NodeEntry{mxnet::op::MakeNode( + "negative", dependent_node->attrs.name + "_negative", {x}, nullptr, &dependent_node)}; + } - nnvm::NodeEntry ones_like(const nnvm::NodeEntry &x) { - return nnvm::NodeEntry{mxnet::op::MakeNode("ones_like", - dependent_node->attrs.name + "_oneslike", - {x}, nullptr, &dependent_node)}; - } + nnvm::NodeEntry ones_like(const nnvm::NodeEntry& x) { + return nnvm::NodeEntry{mxnet::op::MakeNode( + "ones_like", dependent_node->attrs.name + "_oneslike", {x}, nullptr, &dependent_node)}; + } }; } // namespace util diff --git a/src/nnvm/plan_memory.cc b/src/nnvm/plan_memory.cc index f6c13e247e6a..4c9ce06d3362 100644 --- a/src/nnvm/plan_memory.cc +++ b/src/nnvm/plan_memory.cc @@ -79,19 +79,22 @@ class MXGraphAllocator { // request a free storage StorageID Request(int dev_id, int dtype, mxnet::TShape shape, uint32_t node_id) { - if (!mxnet::shape_is_known(shape)) return kBadStorageID; + if (!mxnet::shape_is_known(shape)) + return kBadStorageID; // search memory block in [size / match_range_, size * match_range_) size_t size = shape.Size() * MXGetDTypeSize(dtype); - if (match_range_ == 0) return this->Alloc(dev_id, size); + if (match_range_ == 0) + return this->Alloc(dev_id, size); auto begin = free_.lower_bound(size / match_range_); - auto mid = free_.lower_bound(size); - auto end = free_.upper_bound(size * match_range_); + auto mid = free_.lower_bound(size); + auto end = free_.upper_bound(size * match_range_); // search for memory blocks larger than requested for (auto it = mid; it != end; ++it) { - StorageEntry *e = it->second; - if (e->device_id != dev_id) continue; - if (node_color_.size() != 0 && - node_color_[e->released_by_node] != node_color_[node_id]) continue; + StorageEntry* e = it->second; + if (e->device_id != dev_id) + continue; + if (node_color_.size() != 0 && node_color_[e->released_by_node] != node_color_[node_id]) + continue; // Use exect matching strategy e->max_bytes = std::max(size, e->max_bytes); // find a exact match, erase from map and return @@ -101,10 +104,11 @@ class MXGraphAllocator { // then search for memory blocks smaller than requested space for (auto it = mid; it != begin;) { --it; - StorageEntry *e = it->second; - if (e->device_id != dev_id) continue; - if (node_color_.size() != 0 && - node_color_[e->released_by_node] != node_color_[node_id]) continue; + StorageEntry* e = it->second; + if (e->device_id != dev_id) + continue; + if (node_color_.size() != 0 && node_color_[e->released_by_node] != node_color_[node_id]) + continue; // Use exect matching strategy e->max_bytes = std::max(size, e->max_bytes); // erase from map and return @@ -117,8 +121,9 @@ class MXGraphAllocator { // release a memory space. void Release(StorageID id, uint32_t node_id) { CHECK_NE(id, kBadStorageID); - if (id == kExternalStorageID || id == kDynamicStorageID) return; - StorageEntry *e = data_[id].get(); + if (id == kExternalStorageID || id == kDynamicStorageID) + return; + StorageEntry* e = data_[id].get(); e->released_by_node = node_id; free_.insert({e->max_bytes, e}); } @@ -126,7 +131,7 @@ class MXGraphAllocator { // totoal number of bytes allocated size_t TotalAllocBytes() const { size_t total = 0; - for (auto &p : data_) { + for (auto& p : data_) { total += p->max_bytes; } return total; @@ -140,23 +145,23 @@ class MXGraphAllocator { private: // initialize the graph allocator void Init(const size_t match_range, const uint32_t num_match_color) { - match_range_ = match_range; + match_range_ = match_range; num_match_color_ = num_match_color; if (num_match_color_ > 1) { std::vector importance(idx_->num_nodes(), 0); for (uint32_t nid = 0; nid < idx_->num_nodes(); ++nid) { - if ((*idx_)[nid].source->is_variable()) continue; + if ((*idx_)[nid].source->is_variable()) + continue; importance[nid] = 1; } - num_match_color_ = pass::MXColorNodeGroup( - *idx_, importance, num_match_color_, &node_color_); + num_match_color_ = pass::MXColorNodeGroup(*idx_, importance, num_match_color_, &node_color_); } } StorageID Alloc(int dev_id, size_t size) { StorageID id = static_cast(data_.size()); std::unique_ptr ptr(new StorageEntry()); - ptr->id = id; + ptr->id = id; ptr->device_id = dev_id; ptr->max_bytes = size; data_.emplace_back(std::move(ptr)); @@ -192,24 +197,25 @@ class MXGraphAllocator { /* * Internal method to perform the memory allocation for a graph * */ -size_t MXAllocMemory(const Graph& ret, const IndexedGraph& idx, - const std::pair& node_range, - StorageVector* storage_ptr, - std::vector* storage_inplace_index_ptr, - const std::vector& entry_ref_count, - MXGraphAllocator* allocator) { - static auto& finplace_option = Op::GetAttr("FInplaceOption"); +size_t MXAllocMemory(const Graph& ret, + const IndexedGraph& idx, + const std::pair& node_range, + StorageVector* storage_ptr, + std::vector* storage_inplace_index_ptr, + const std::vector& entry_ref_count, + MXGraphAllocator* allocator) { + static auto& finplace_option = Op::GetAttr("FInplaceOption"); static auto& finplace_identity = Op::GetAttr("FInplaceIdentity"); - static auto& fignore_inputs = Op::GetAttr("FIgnoreInputs"); + static auto& fignore_inputs = Op::GetAttr("FIgnoreInputs"); // Get reference - auto &storage = *storage_ptr; - auto &storage_inplace_index = *storage_inplace_index_ptr; + auto& storage = *storage_ptr; + auto& storage_inplace_index = *storage_inplace_index_ptr; // Get attributes from the graph const mxnet::ShapeVector& shape_vec = ret.GetAttr("shape"); - const DTypeVector& dtype_vec = ret.GetAttr("dtype"); - const DeviceVector* device_vec = nullptr; + const DTypeVector& dtype_vec = ret.GetAttr("dtype"); + const DeviceVector* device_vec = nullptr; if (ret.attrs.count("device") != 0) { device_vec = &(ret.GetAttr("device")); @@ -219,7 +225,8 @@ size_t MXAllocMemory(const Graph& ret, const IndexedGraph& idx, for (uint32_t nid = node_range.first; nid < node_range.second; ++nid) { const auto& inode = idx[nid]; - if (inode.source->is_variable()) continue; + if (inode.source->is_variable()) + continue; // check inplace option if (finplace_option.count(inode.source->op()) != 0) { auto inplace_pairs = finplace_option[inode.source->op()](inode.source->attrs); @@ -234,30 +241,26 @@ size_t MXAllocMemory(const Graph& ret, const IndexedGraph& idx, } std::vector taken(inode.inputs.size(), false); for (size_t ipair = 0; ipair < inplace_pairs.size(); ++ipair) { - const auto& kv = inplace_pairs[ipair]; - uint32_t eid_out = idx.entry_id(nid, kv.second); - uint32_t eid_in = idx.entry_id(inode.inputs[kv.first]); - auto sid_out = storage[eid_out]; - auto sid_in = storage[eid_in]; + const auto& kv = inplace_pairs[ipair]; + uint32_t eid_out = idx.entry_id(nid, kv.second); + uint32_t eid_in = idx.entry_id(inode.inputs[kv.first]); + auto sid_out = storage[eid_out]; + auto sid_in = storage[eid_in]; bool ignore_all_inputs = (fignore_inputs.count(inode.source->op()) != 0 && - fignore_inputs[inode.source->op()]( - inode.source->attrs).size() == inode.source->num_inputs()); + fignore_inputs[inode.source->op()](inode.source->attrs).size() == + inode.source->num_inputs()); // Identity should only be true if shape.Size() and types match - bool real_identity = identity[ipair] && - ndim_is_known(shape_vec[eid_out]) && + bool real_identity = identity[ipair] && ndim_is_known(shape_vec[eid_out]) && ndim_is_known(shape_vec[eid_in]) && shape_vec[eid_out].Size() == shape_vec[eid_in].Size() && dtype_vec[eid_out] == dtype_vec[eid_in]; - if (taken[kv.first] == false && - sid_out == MXGraphAllocator::kBadStorageID && - sid_in >= 0 && + if (taken[kv.first] == false && sid_out == MXGraphAllocator::kBadStorageID && sid_in >= 0 && ((storage_ref_count[sid_in] == 1 && !ignore_all_inputs) || real_identity) && - entry_ref_count[eid_out] > 0 && - shape_vec[eid_out].Size() == shape_vec[eid_in].Size() && - (dtype_vec[eid_out] == dtype_vec[eid_in] || + entry_ref_count[eid_out] > 0 && shape_vec[eid_out].Size() == shape_vec[eid_in].Size() && + (dtype_vec[eid_out] == dtype_vec[eid_in] || MXGetDTypeSize(dtype_vec[eid_out]) == MXGetDTypeSize(dtype_vec[eid_in]))) { // inplace optimization - taken[kv.first] = true; + taken[kv.first] = true; storage[eid_out] = sid_in; // Reuse storage for output and add ref count of output // to storage. This will get substracted later in free @@ -275,18 +278,18 @@ size_t MXAllocMemory(const Graph& ret, const IndexedGraph& idx, uint32_t eid = idx.entry_id(nid, index); // only request memory for kBadStorageID if (storage[eid] == MXGraphAllocator::kBadStorageID) { - auto &eshape = shape_vec[eid]; + auto& eshape = shape_vec[eid]; size_t esize = ndim_is_known(shape_vec[eid]) ? eshape.Size() : 0; eids.insert(std::make_pair(esize, eid)); } } for (auto rit = eids.rbegin(); rit != eids.rend(); ++rit) { - uint32_t eid = rit->second; - auto sid = allocator->Request(dev_id, dtype_vec[eid], shape_vec[eid], nid); - if (sid >= 0) { - storage_ref_count[sid] = entry_ref_count[eid]; - } - storage[eid] = sid; + uint32_t eid = rit->second; + auto sid = allocator->Request(dev_id, dtype_vec[eid], shape_vec[eid], nid); + if (sid >= 0) { + storage_ref_count[sid] = entry_ref_count[eid]; + } + storage[eid] = sid; } // check if certain inputs is ignored. std::vector ignore_inputs; @@ -297,12 +300,14 @@ size_t MXAllocMemory(const Graph& ret, const IndexedGraph& idx, // then free inputs for (size_t i = 0; i < inode.inputs.size(); ++i) { // ref counter of ignored input is already decreased. - if (std::binary_search(ignore_inputs.begin(), ignore_inputs.end(), i)) continue; + if (std::binary_search(ignore_inputs.begin(), ignore_inputs.end(), i)) + continue; const auto& e = inode.inputs[i]; - uint32_t eid = idx.entry_id(e); - auto sid = storage[eid]; + uint32_t eid = idx.entry_id(e); + auto sid = storage[eid]; // storage_ref_count == 0 means it is taken by inplace op - if (sid < 0) continue; + if (sid < 0) + continue; // if we decrease it to zero, means we are ready to release --storage_ref_count[sid]; if (storage_ref_count[sid] == 0) { @@ -313,7 +318,7 @@ size_t MXAllocMemory(const Graph& ret, const IndexedGraph& idx, // these output are not referenced by any operator. for (uint32_t index = 0; index < inode.source->num_outputs(); ++index) { uint32_t eid = idx.entry_id(nid, index); - auto sid = storage[eid]; + auto sid = storage[eid]; if (sid >= 0 && storage_ref_count[sid] == 0) { allocator->Release(sid, nid); // use -2 to indicate that the node was never touched. @@ -327,12 +332,11 @@ size_t MXAllocMemory(const Graph& ret, const IndexedGraph& idx, return num_not_allocated; } - // function to plan memory Graph MXPlanMemory(Graph ret) { // setup ref counter - const IndexedGraph& idx = ret.indexed_graph(); - static auto& fignore_inputs = Op::GetAttr("FIgnoreInputs"); + const IndexedGraph& idx = ret.indexed_graph(); + static auto& fignore_inputs = Op::GetAttr("FIgnoreInputs"); std::pair node_range = {0, idx.num_nodes()}; if (ret.attrs.count("node_range")) { node_range = ret.MoveCopyAttr >("node_range"); @@ -346,7 +350,8 @@ Graph MXPlanMemory(Graph ret) { ref_count.resize(idx.num_node_entries(), 0); for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { const auto& inode = idx[nid]; - if (inode.source->is_variable()) continue; + if (inode.source->is_variable()) + continue; for (const auto& e : inode.inputs) { ++ref_count[idx.entry_id(e)]; } @@ -373,10 +378,11 @@ Graph MXPlanMemory(Graph ret) { // Search the best NNVM_EXEC_MATCH_RANGE parameter. This is turned off by default size_t min_allocated_bytes = -1; - size_t max_match_range = dmlc::GetEnv("NNVM_EXEC_MATCH_RANGE", 16); + size_t max_match_range = dmlc::GetEnv("NNVM_EXEC_MATCH_RANGE", 16); size_t min_match_range = - dmlc::GetEnv("MXNET_MEMORY_OPT", 0) || - dmlc::GetEnv("NNVM_AUTO_SEARCH_MATCH_RANGE", false) ? 1 : max_match_range; + dmlc::GetEnv("MXNET_MEMORY_OPT", 0) || dmlc::GetEnv("NNVM_AUTO_SEARCH_MATCH_RANGE", false) + ? 1 + : max_match_range; for (size_t match_range = min_match_range; match_range <= max_match_range; match_range *= 2) { // Make a copy of related fields StorageVector storage_vec(storage); @@ -386,18 +392,17 @@ Graph MXPlanMemory(Graph ret) { MXGraphAllocator allocator(&idx, match_range); // number of entries that are not statically allocated. - size_t storage_num_not_allocated = - MXAllocMemory(ret, idx, node_range, &storage_vec, &storage_inplace_index, - ref_count, &allocator); + size_t storage_num_not_allocated = MXAllocMemory( + ret, idx, node_range, &storage_vec, &storage_inplace_index, ref_count, &allocator); size_t storage_allocated_bytes = allocator.TotalAllocBytes(); // Choose the plan which leads to minimal memory usage if (min_allocated_bytes > storage_allocated_bytes) { - ret.attrs["storage_id"] = std::make_shared(std::move(storage_vec)); + ret.attrs["storage_id"] = std::make_shared(std::move(storage_vec)); ret.attrs["storage_inplace_index"] = std::make_shared(std::move(storage_inplace_index)); - ret.attrs["storage_allocated_bytes"] = std::make_shared(storage_allocated_bytes); + ret.attrs["storage_allocated_bytes"] = std::make_shared(storage_allocated_bytes); ret.attrs["storage_num_not_allocated"] = std::make_shared(storage_num_not_allocated); - min_allocated_bytes = storage_allocated_bytes; + min_allocated_bytes = storage_allocated_bytes; } if (max_match_range == 0) { @@ -408,13 +413,13 @@ Graph MXPlanMemory(Graph ret) { } NNVM_REGISTER_PASS(MXPlanMemory) -.describe("Plan the memory allocation of each node entries.") -.set_body(MXPlanMemory) -.set_change_graph(false) -.depend_graph_attr("dtype") -.depend_graph_attr("shape") -.provide_graph_attr("storage_id") -.provide_graph_attr("storage_inplace_index"); + .describe("Plan the memory allocation of each node entries.") + .set_body(MXPlanMemory) + .set_change_graph(false) + .depend_graph_attr("dtype") + .depend_graph_attr("shape") + .provide_graph_attr("storage_id") + .provide_graph_attr("storage_inplace_index"); } // namespace } // namespace pass diff --git a/src/nnvm/tvm_bridge.cc b/src/nnvm/tvm_bridge.cc index 20700f60054e..27c42aaabbb6 100644 --- a/src/nnvm/tvm_bridge.cc +++ b/src/nnvm/tvm_bridge.cc @@ -65,30 +65,26 @@ class TVMFunctor { values_.clear(); type_codes_.clear(); values_.insert(values_.end(), args.values, args.values + args.size()); - type_codes_.insert( - type_codes_.end(), args.type_codes, args.type_codes + args.size()); + type_codes_.insert(type_codes_.end(), args.type_codes, args.type_codes + args.size()); size_t const_loc_ptr = 0; for (int i = 0; i < args.size(); ++i) { if (args.type_codes[i] == kTVMNDArrayTypeCode) { - const NDArray& nd = - static_cast(args.values[i].v_handle)[0]; + const NDArray& nd = static_cast(args.values[i].v_handle)[0]; // We cannot set the value until type_codes_[i] = kTVMDLTensorHandle; array_data_.push_back(nd); array_loc_.push_back(i); // check if there is read or mutate // by default assume we mutate the array. - if (const_loc_ptr < const_loc.size() && - i == const_loc[const_loc_ptr]) { + if (const_loc_ptr < const_loc.size() && i == const_loc[const_loc_ptr]) { const_vars->push_back(nd.var()); ++const_loc_ptr; } else { mutate_vars->push_back(nd.var()); } } else { - CHECK_LT(args.type_codes[i], kTVMDataType) - << "Only allow POD type in mxnet async call"; + CHECK_LT(args.type_codes[i], kTVMDataType) << "Only allow POD type in mxnet async call"; } } } @@ -100,8 +96,7 @@ class TVMFunctor { void Run(const RunContext& rctx) { // setup DLTensor for (size_t i = 0; i < array_loc_.size(); ++i) { - values_[array_loc_[i]].v_handle = - const_cast(&(array_data_[i].data().dltensor())); + values_[array_loc_[i]].v_handle = const_cast(&(array_data_[i].data().dltensor())); } // run the packed function TVMRetValue rv; @@ -109,7 +104,7 @@ class TVMFunctor { if (ctx().dev_type == Context::kGPU) { #if MXNET_USE_CUDA // pass stream via last argument. - void* strm = static_cast(rctx.get_stream()->stream_); + void* strm = static_cast(rctx.get_stream()->stream_); int dev_type = kDLGPU; fset_stream_(dev_type, rctx.ctx.dev_id, strm); func_.CallPacked(args, &rv); @@ -137,14 +132,13 @@ class TVMFunctor { std::vector array_loc_; }; - // Wrap a TVM function to a function that invokes MXNet's Engine // It does two things: call the engine properly // set up the NDArray to DLTensor during invocation. void WrapAsyncCall(TVMArgs wrap_args, TVMRetValue* wrap_rv) { - PackedFunc f = wrap_args[0]; - PackedFunc fset_stream = wrap_args[1]; - int num_const = wrap_args[2]; + PackedFunc f = wrap_args[0]; + PackedFunc fset_stream = wrap_args[1]; + int num_const = wrap_args[2]; // sorted position of constant arguments std::vector const_loc; @@ -156,15 +150,13 @@ void WrapAsyncCall(TVMArgs wrap_args, TVMRetValue* wrap_rv) { // wrapped function // This is the function that called by the user. auto wrapped = [f, fset_stream, const_loc](TVMArgs args, TVMRetValue* rv) { - std::shared_ptr func = - std::make_shared(f, fset_stream); + std::shared_ptr func = std::make_shared(f, fset_stream); std::vector const_vars, mutate_vars; func->Init(args, const_loc, &const_vars, &mutate_vars); - Engine *engine = Engine::Get(); + Engine* engine = Engine::Get(); engine->DeduplicateVarHandle(&const_vars, &mutate_vars); - engine->PushSync([func](RunContext ctx) { - func->Run(ctx); - }, func->ctx(), const_vars, mutate_vars); + engine->PushSync( + [func](RunContext ctx) { func->Run(ctx); }, func->ctx(), const_vars, mutate_vars); }; *wrap_rv = PackedFunc(wrapped); } @@ -175,8 +167,7 @@ void WrapAsyncCall(TVMArgs wrap_args, TVMRetValue* wrap_rv) { // the WrapAsyncCall function. extern "C" MXNET_DLL int MXTVMBridge(TVMFunctionHandle pregister) { using tvm::runtime::PackedFunc; - const PackedFunc& fregister = - *static_cast(pregister); + const PackedFunc& fregister = *static_cast(pregister); fregister("WrapAsyncCall", PackedFunc(mxnet::WrapAsyncCall)); return 0; } diff --git a/src/operator/all_finite-inl.h b/src/operator/all_finite-inl.h index d646d5b19336..68ecf0648b0a 100644 --- a/src/operator/all_finite-inl.h +++ b/src/operator/all_finite-inl.h @@ -44,12 +44,10 @@ namespace mxnet { namespace op { -struct AllFiniteParam: public dmlc::Parameter { +struct AllFiniteParam : public dmlc::Parameter { bool init_output; DMLC_DECLARE_PARAMETER(AllFiniteParam) { - DMLC_DECLARE_FIELD(init_output) - .set_default(true) - .describe("Initialize output to 1."); + DMLC_DECLARE_FIELD(init_output).set_default(true).describe("Initialize output to 1."); } }; @@ -57,32 +55,28 @@ struct MultiAllFiniteParam : public dmlc::Parameter { int num_arrays; bool init_output; DMLC_DECLARE_PARAMETER(MultiAllFiniteParam) { - DMLC_DECLARE_FIELD(num_arrays) - .set_default(1) - .describe("Number of arrays."); - DMLC_DECLARE_FIELD(init_output) - .set_default(true) - .describe("Initialize output to 1."); + DMLC_DECLARE_FIELD(num_arrays).set_default(1).describe("Number of arrays."); + DMLC_DECLARE_FIELD(init_output).set_default(true).describe("Initialize output to 1."); } }; -template +template struct MultiAllFiniteKernelParam { static const int N = 200; int count; size_t max_size; size_t sizes[N]; - DType *arrays[N]; + DType* arrays[N]; }; -template +template MultiAllFiniteKernelParam FillMultiAllFiniteParam(const MultiAllFiniteParam& op_param, - const OpContext &ctx, - const std::vector &inputs) { + const OpContext& ctx, + const std::vector& inputs) { MultiAllFiniteKernelParam param; using namespace mxnet_op; Stream* s = ctx.get_stream(); - param.count = op_param.num_arrays; + param.count = op_param.num_arrays; param.max_size = 0; for (int i = 0; i < param.count; ++i) { param.sizes[i] = inputs[i].shape_.Size(); diff --git a/src/operator/all_finite.cc b/src/operator/all_finite.cc index 8ef36711e888..3bb4336f74e7 100644 --- a/src/operator/all_finite.cc +++ b/src/operator/all_finite.cc @@ -19,7 +19,7 @@ /*! * Copyright (c) 2019 by Contributors - * \file all_finite.cc + * \file all_finite.cc * \brief operator for checking if a group of array is all finite * \author Clement Fuji Tsang */ @@ -29,11 +29,11 @@ namespace mxnet { namespace op { -template +template struct AllFiniteCPUKernel { MSHADOW_XINLINE static void Map(int i, const DType* in, float* out) { bool is_finite = true; - is_finite = std::isfinite(static_cast(in[i])) ? is_finite : false; + is_finite = std::isfinite(static_cast(in[i])) ? is_finite : false; if (!is_finite) { out[0] = 0.; } @@ -41,28 +41,27 @@ struct AllFiniteCPUKernel { }; inline void AllFiniteCPU(const nnvm::NodeAttrs& attrs, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { using namespace mxnet_op; - Stream* s = ctx.get_stream(); + Stream* s = ctx.get_stream(); const AllFiniteParam& op_param = nnvm::get(attrs.parsed); - Tensor out = outputs[0].FlatTo2D(s); + Tensor out = outputs[0].FlatTo2D(s); if (op_param.init_output) { out = 1.; } MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { Tensor in = inputs[0].FlatTo2D(s); - const int n = in.shape_.Size(); + const int n = in.shape_.Size(); Kernel, cpu>::Launch(s, n, in.dptr_, out.dptr_); }); } -template +template struct MultiAllFiniteCPUKernel { - MSHADOW_XINLINE static void Map(int i, const MultiAllFiniteKernelParam param, - float* out) { + MSHADOW_XINLINE static void Map(int i, const MultiAllFiniteKernelParam param, float* out) { bool is_finite = true; for (int index = 0; index < param.count; ++index) { if ((size_t)i < param.sizes[index]) { @@ -76,95 +75,95 @@ struct MultiAllFiniteCPUKernel { }; inline void MultiAllFiniteCPU(const nnvm::NodeAttrs& attrs, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { using namespace mxnet_op; - Stream* s = ctx.get_stream(); + Stream* s = ctx.get_stream(); const MultiAllFiniteParam& op_param = nnvm::get(attrs.parsed); - Tensor out = outputs[0].FlatTo2D(s); + Tensor out = outputs[0].FlatTo2D(s); if (op_param.init_output) out = 1.; MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { MultiAllFiniteKernelParam param = - FillMultiAllFiniteParam(op_param, ctx, inputs); - Kernel, cpu>::Launch(s, param.max_size, - param, out.dptr_); + FillMultiAllFiniteParam(op_param, ctx, inputs); + Kernel, cpu>::Launch(s, param.max_size, param, out.dptr_); }); } DMLC_REGISTER_PARAMETER(AllFiniteParam); NNVM_REGISTER_OP(all_finite) -.add_alias("_npi_all_finite") -.describe(R"code(Check if all the float numbers in the array are finite (used for AMP) + .add_alias("_npi_all_finite") + .describe(R"code(Check if all the float numbers in the array are finite (used for AMP) )code" ADD_FILELINE) -.set_num_inputs(1) -.set_num_outputs(1) -.set_attr_parser(ParamParser) -.set_attr("FInferShape", - [](const nnvm::NodeAttrs& attrs, - std::vector *in_attrs, - std::vector *out_attrs){ - (*out_attrs)[0] = TShape({1}); - return true; - }) -.set_attr("FInferType", - [](const nnvm::NodeAttrs& attrs, - std::vector *in_attrs, - std::vector *out_attrs){ - (*out_attrs)[0] = mshadow::kFloat32; - return true; - }) -.set_attr("FListInputNames", - [](const NodeAttrs& attrs) { - std::vector ret; - ret.emplace_back("data"); - return ret; - }) -.add_argument("data", "NDArray", "Array") -.add_arguments(AllFiniteParam::__FIELDS__()) -.set_attr("FCompute", AllFiniteCPU); + .set_num_inputs(1) + .set_num_outputs(1) + .set_attr_parser(ParamParser) + .set_attr("FInferShape", + [](const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + (*out_attrs)[0] = TShape({1}); + return true; + }) + .set_attr("FInferType", + [](const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + (*out_attrs)[0] = mshadow::kFloat32; + return true; + }) + .set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + std::vector ret; + ret.emplace_back("data"); + return ret; + }) + .add_argument("data", "NDArray", "Array") + .add_arguments(AllFiniteParam::__FIELDS__()) + .set_attr("FCompute", AllFiniteCPU); DMLC_REGISTER_PARAMETER(MultiAllFiniteParam); NNVM_REGISTER_OP(multi_all_finite) -.add_alias("_npi_multi_all_finite") -.describe(R"code(Check if all the float numbers in all the arrays are finite (used for AMP) + .add_alias("_npi_multi_all_finite") + .describe(R"code(Check if all the float numbers in all the arrays are finite (used for AMP) )code" ADD_FILELINE) -.set_num_inputs([](const nnvm::NodeAttrs& attrs) { - const MultiAllFiniteParam& param = dmlc::get(attrs.parsed); - return static_cast(param.num_arrays); - }) -.set_num_outputs(1) -.set_attr_parser(ParamParser) -.set_attr("FInferShape", - [](const nnvm::NodeAttrs& attrs, - std::vector *in_attrs, - std::vector *out_attrs) { - (*out_attrs)[0] = TShape({1}); - return true; - }) -.set_attr("FInferType", - [](const nnvm::NodeAttrs& attrs, - std::vector *in_attrs, - std::vector *out_attrs) { - (*out_attrs)[0] = mshadow::kFloat32; - return true; - }) -.set_attr("FListInputNames", - [](const NodeAttrs& attrs) { - uint32_t num_args = dmlc::get(attrs.parsed).num_arrays; - std::vector ret; - for (uint32_t i = 0; i < num_args; ++i) { - ret.push_back(std::string("array_") + std::to_string(i)); - } - return ret; - }) -.add_argument("data", "NDArray-or-Symbol[]", "Arrays") -.add_arguments(MultiAllFiniteParam::__FIELDS__()) -.set_attr("FCompute", MultiAllFiniteCPU); + .set_num_inputs([](const nnvm::NodeAttrs& attrs) { + const MultiAllFiniteParam& param = dmlc::get(attrs.parsed); + return static_cast(param.num_arrays); + }) + .set_num_outputs(1) + .set_attr_parser(ParamParser) + .set_attr("FInferShape", + [](const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + (*out_attrs)[0] = TShape({1}); + return true; + }) + .set_attr("FInferType", + [](const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + (*out_attrs)[0] = mshadow::kFloat32; + return true; + }) + .set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + uint32_t num_args = + dmlc::get(attrs.parsed).num_arrays; + std::vector ret; + for (uint32_t i = 0; i < num_args; ++i) { + ret.push_back(std::string("array_") + std::to_string(i)); + } + return ret; + }) + .add_argument("data", "NDArray-or-Symbol[]", "Arrays") + .add_arguments(MultiAllFiniteParam::__FIELDS__()) + .set_attr("FCompute", MultiAllFiniteCPU); } // namespace op } // namespace mxnet diff --git a/src/operator/all_finite.cu b/src/operator/all_finite.cu index 69ba35f0844a..ad48c93b1b26 100644 --- a/src/operator/all_finite.cu +++ b/src/operator/all_finite.cu @@ -42,21 +42,22 @@ __global__ void AllFiniteGPUKernel(const int size, const DType* in, float* out) } inline void AllFiniteGPU(const nnvm::NodeAttrs& attrs, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { using namespace mxnet_op; - Stream* s = ctx.get_stream(); + Stream* s = ctx.get_stream(); const AllFiniteParam& op_param = nnvm::get(attrs.parsed); - Tensor out = outputs[0].FlatTo2D(s); + Tensor out = outputs[0].FlatTo2D(s); if (op_param.init_output) out = 1.; MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { Tensor in = inputs[0].FlatTo2D(s); - const int n = in.shape_.Size(); + const int n = in.shape_.Size(); AllFiniteGPUKernel<<::GetStream(s)>>>(n, in.dptr_, out.dptr_); MSHADOW_CUDA_POST_KERNEL_CHECK(AllFiniteGPUKernel); }); @@ -77,31 +78,30 @@ __global__ void MultiAllFiniteGPUKernel(const MultiAllFiniteKernelParam p } inline void MultiAllFiniteGPU(const nnvm::NodeAttrs& attrs, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { using namespace mxnet_op; - Stream* s = ctx.get_stream(); + Stream* s = ctx.get_stream(); const MultiAllFiniteParam& op_param = nnvm::get(attrs.parsed); - Tensor out = outputs[0].FlatTo2D(s); + Tensor out = outputs[0].FlatTo2D(s); if (op_param.init_output) out = 1.; MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { MultiAllFiniteKernelParam param = - FillMultiAllFiniteParam(op_param, ctx, inputs); + FillMultiAllFiniteParam(op_param, ctx, inputs); MultiAllFiniteGPUKernel<<::GetStream(s)>>>(param, out.dptr_); MSHADOW_CUDA_POST_KERNEL_CHECK(MultiAllFiniteGPUKernel); }); } -NNVM_REGISTER_OP(all_finite) -.set_attr("FCompute", AllFiniteGPU); +NNVM_REGISTER_OP(all_finite).set_attr("FCompute", AllFiniteGPU); -NNVM_REGISTER_OP(multi_all_finite) -.set_attr("FCompute", MultiAllFiniteGPU); +NNVM_REGISTER_OP(multi_all_finite).set_attr("FCompute", MultiAllFiniteGPU); } // namespace op } // namespace mxnet diff --git a/src/operator/amp_graph_pass.cc b/src/operator/amp_graph_pass.cc index 5b397e746c03..66d1546915fd 100644 --- a/src/operator/amp_graph_pass.cc +++ b/src/operator/amp_graph_pass.cc @@ -30,10 +30,9 @@ namespace mxnet { namespace op { +using nnvm::Graph; using nnvm::Node; using nnvm::ObjectPtr; -using nnvm::Graph; - /* * \brief Remove amp_cast and amp_multicast and replug the fp32 weights @@ -52,10 +51,7 @@ Graph RemoveAmpCast(Graph&& g) { return std::move(g); } -NNVM_REGISTER_PASS(RemoveAmpCast) -.describe("") -.set_body(RemoveAmpCast) -.set_change_graph(true); +NNVM_REGISTER_PASS(RemoveAmpCast).describe("").set_body(RemoveAmpCast).set_change_graph(true); } // namespace op } // namespace mxnet diff --git a/src/operator/bilinear_sampler-inl.h b/src/operator/bilinear_sampler-inl.h index cec27faa1968..585bc8ca66ec 100644 --- a/src/operator/bilinear_sampler-inl.h +++ b/src/operator/bilinear_sampler-inl.h @@ -22,7 +22,7 @@ * \file bilinear_Sampler-inl.h * \brief * \author Xu Dong -*/ + */ #ifndef MXNET_OPERATOR_BILINEAR_SAMPLER_INL_H_ #define MXNET_OPERATOR_BILINEAR_SAMPLER_INL_H_ @@ -39,62 +39,63 @@ namespace mxnet { namespace op { namespace bs { -enum BilinearSamplerOpInputs {kData, kGrid}; -enum BilinearSamplerOpOutputs {kOut, kTmp}; -} +enum BilinearSamplerOpInputs { kData, kGrid }; +enum BilinearSamplerOpOutputs { kOut, kTmp }; +} // namespace bs struct BilinearSamplerParam : public dmlc::Parameter { dmlc::optional cudnn_off; DMLC_DECLARE_PARAMETER(BilinearSamplerParam) { - DMLC_DECLARE_FIELD(cudnn_off).set_default(dmlc::optional()) + DMLC_DECLARE_FIELD(cudnn_off) + .set_default(dmlc::optional()) .describe("whether to turn cudnn off"); } }; -template +template class BilinearSamplerOp : public Operator { public: explicit BilinearSamplerOp(BilinearSamplerParam p) { this->param_ = p; } - virtual void Forward(const OpContext &ctx, - const std::vector &in_data, - const std::vector &req, - const std::vector &out_data, - const std::vector &aux_args) { + virtual void Forward(const OpContext& ctx, + const std::vector& in_data, + const std::vector& req, + const std::vector& out_data, + const std::vector& aux_args) { using namespace mshadow; using namespace mshadow::expr; CHECK_EQ(req[bs::kOut], kWriteTo); CHECK_EQ(in_data.size(), 2U); - Stream *s = ctx.get_stream(); + Stream* s = ctx.get_stream(); Tensor data = in_data[bs::kData].get(s); Tensor grid = in_data[bs::kGrid].get(s); - Tensor out = out_data[bs::kOut].get(s); + Tensor out = out_data[bs::kOut].get(s); BilinearSamplerForward(out, data, grid); } - virtual void Backward(const OpContext &ctx, - const std::vector &out_grad, - const std::vector &in_data, - const std::vector &out_data, - const std::vector &req, - const std::vector &in_grad, - const std::vector &aux_args) { + virtual void Backward(const OpContext& ctx, + const std::vector& out_grad, + const std::vector& in_data, + const std::vector& out_data, + const std::vector& req, + const std::vector& in_grad, + const std::vector& aux_args) { using namespace mshadow; using namespace mshadow::expr; CHECK_EQ(in_data.size(), 2U); CHECK_NE(req[bs::kData], kWriteInplace); CHECK_NE(req[bs::kGrid], kWriteInplace); - Stream *s = ctx.get_stream(); + Stream* s = ctx.get_stream(); - Tensor data = in_data[bs::kData].get(s); - Tensor grid = in_data[bs::kGrid].get(s); + Tensor data = in_data[bs::kData].get(s); + Tensor grid = in_data[bs::kGrid].get(s); Tensor gdata = in_grad[bs::kData].get(s); Tensor ggrid = in_grad[bs::kGrid].get(s); - Tensor grad = out_grad[bs::kOut].get(s); + Tensor grad = out_grad[bs::kOut].get(s); if (req[bs::kData] == kNullOp && req[bs::kGrid] == kNullOp) { return; } else { @@ -112,7 +113,7 @@ class BilinearSamplerOp : public Operator { BilinearSamplerParam param_; }; // class BilinearSamplerOp -template +template Operator* CreateOp(BilinearSamplerParam param, int dtype); #if DMLC_USE_CXX11 @@ -142,27 +143,25 @@ class BilinearSamplerProp : public OperatorProperty { return param_.__DICT__(); } - bool InferShape(mxnet::ShapeVector *in_shape, - mxnet::ShapeVector *out_shape, - mxnet::ShapeVector *aux_shape) const override { + bool InferShape(mxnet::ShapeVector* in_shape, + mxnet::ShapeVector* out_shape, + mxnet::ShapeVector* aux_shape) const override { using namespace mshadow; CHECK_EQ(in_shape->size(), 2U) << "Input:[data, grid]"; - const mxnet::TShape &dshape = (*in_shape)[bs::kData]; - const mxnet::TShape &lshape = (*in_shape)[bs::kGrid]; - if (!shape_is_known(dshape)) return false; - CHECK_EQ(dshape.ndim(), 4U) \ - << "input data should be 4D in batch-num_filter-y-x"; - if (!shape_is_known(lshape)) return false; - CHECK_EQ(lshape.ndim(), 4U) \ - << "Sampler grid should be 4D in batch-2-y-x"; + const mxnet::TShape& dshape = (*in_shape)[bs::kData]; + const mxnet::TShape& lshape = (*in_shape)[bs::kGrid]; + if (!shape_is_known(dshape)) + return false; + CHECK_EQ(dshape.ndim(), 4U) << "input data should be 4D in batch-num_filter-y-x"; + if (!shape_is_known(lshape)) + return false; + CHECK_EQ(lshape.ndim(), 4U) << "Sampler grid should be 4D in batch-2-y-x"; CHECK_EQ(dshape[0], lshape[0]); CHECK_EQ(lshape[1], 2U) << "incorrect grid shape[1], should be 2"; // target height - CHECK_GT(lshape[2], 0U) \ - << "incorrect grid_shape: " << lshape[2]; + CHECK_GT(lshape[2], 0U) << "incorrect grid_shape: " << lshape[2]; // target width - CHECK_GT(lshape[3], 0U) \ - << "incorrect grid_shape: " << lshape[3]; + CHECK_GT(lshape[3], 0U) << "incorrect grid_shape: " << lshape[3]; out_shape->clear(); // output_shape : (data.shape[0], data.shape[1], grid.shape[2], grid.shape[3]) out_shape->push_back(dshape); @@ -172,37 +171,38 @@ class BilinearSamplerProp : public OperatorProperty { return true; } - bool InferType(std::vector *in_type, - std::vector *out_type, - std::vector *aux_type) const override { - int dtype = -1; - for (int type : *in_type) { - if (dtype == -1) { - dtype = type; - } else { - CHECK(type == dtype || - type == -1) << - "Non-uniform data type in BilinearSampler"; - } - } + bool InferType(std::vector* in_type, + std::vector* out_type, + std::vector* aux_type) const override { + int dtype = -1; + for (int type : *in_type) { if (dtype == -1) { - LOG(FATAL) << "Not enough information to infer type in BilinearSampler."; - return false; + dtype = type; + } else { + CHECK(type == dtype || type == -1) << "Non-uniform data type in BilinearSampler"; } - size_t nin = this->ListArguments().size(); - in_type->clear(); - for (size_t i = 0; i < nin; ++i) in_type->push_back(dtype); - size_t naux = this->ListAuxiliaryStates().size(); - aux_type->clear(); - for (size_t i = 0; i < naux; ++i) aux_type->push_back(dtype); - size_t nout = this->ListOutputs().size(); - out_type->clear(); - for (size_t i = 0; i < nout; ++i) out_type->push_back(dtype); - return true; } + if (dtype == -1) { + LOG(FATAL) << "Not enough information to infer type in BilinearSampler."; + return false; + } + size_t nin = this->ListArguments().size(); + in_type->clear(); + for (size_t i = 0; i < nin; ++i) + in_type->push_back(dtype); + size_t naux = this->ListAuxiliaryStates().size(); + aux_type->clear(); + for (size_t i = 0; i < naux; ++i) + aux_type->push_back(dtype); + size_t nout = this->ListOutputs().size(); + out_type->clear(); + for (size_t i = 0; i < nout; ++i) + out_type->push_back(dtype); + return true; + } OperatorProperty* Copy() const override { - auto ptr = new BilinearSamplerProp(); + auto ptr = new BilinearSamplerProp(); ptr->param_ = param_; return ptr; } @@ -211,14 +211,10 @@ class BilinearSamplerProp : public OperatorProperty { return "BilinearSampler"; } - std::vector DeclareBackwardDependency( - const std::vector &out_grad, - const std::vector &in_data, - const std::vector &out_data) const override { - return {out_grad[bs::kOut], - in_data[bs::kData], - out_data[bs::kTmp], - in_data[bs::kGrid]}; + std::vector DeclareBackwardDependency(const std::vector& out_grad, + const std::vector& in_data, + const std::vector& out_data) const override { + return {out_grad[bs::kOut], in_data[bs::kData], out_data[bs::kTmp], in_data[bs::kGrid]}; } Operator* CreateOperator(Context ctx) const override { @@ -226,12 +222,13 @@ class BilinearSamplerProp : public OperatorProperty { return nullptr; } - Operator* CreateOperatorEx(Context ctx, mxnet::ShapeVector *in_shape, - std::vector *in_type) const override; + Operator* CreateOperatorEx(Context ctx, + mxnet::ShapeVector* in_shape, + std::vector* in_type) const override; private: BilinearSamplerParam param_; -}; // class BilinearSamplerProp +}; // class BilinearSamplerProp #endif // DMLC_USE_CXX11 } // namespace op } // namespace mxnet diff --git a/src/operator/bilinear_sampler.cc b/src/operator/bilinear_sampler.cc index d8ce15486d8c..7148db5bc35c 100644 --- a/src/operator/bilinear_sampler.cc +++ b/src/operator/bilinear_sampler.cc @@ -22,173 +22,175 @@ * \file bilinear_sampler.cc * \brief * \author Xu Dong -*/ + */ #include #include "./bilinear_sampler-inl.h" namespace mshadow { -template +template bool between(DType value, int lowerBound, int upperBound) { return (value >= lowerBound && value <= upperBound); } -template -inline void BilinearSamplerForward(const Tensor &output, - const Tensor &input, - const Tensor &grid_src) { - DType *out = output.dptr_; - const DType *data = input.dptr_; - const DType *grid = grid_src.dptr_; +template +inline void BilinearSamplerForward(const Tensor& output, + const Tensor& input, + const Tensor& grid_src) { + DType* out = output.dptr_; + const DType* data = input.dptr_; + const DType* grid = grid_src.dptr_; int o_n = output.size(0), o_c = output.size(1), o_h = output.size(2), o_w = output.size(3); int i_c = input.size(1), i_h = input.size(2), i_w = input.size(3); for (index_t n = 0; n < static_cast(o_n); ++n) { for (index_t c = 0; c < static_cast(o_c); ++c) { for (index_t h = 0; h < static_cast(o_h); ++h) { for (index_t w = 0; w < static_cast(o_w); ++w) { - index_t out_index = n * o_c * o_h * o_w + c * o_h * o_w + h * o_w + w; + index_t out_index = n * o_c * o_h * o_w + c * o_h * o_w + h * o_w + w; index_t grid_index = n * o_h * o_w * 2 + h * o_w + w; - DType y_real = (*(grid + grid_index + o_h * o_w) + 1) * (i_h - 1) / 2; - DType x_real = (*(grid + grid_index) + 1) * (i_w - 1) / 2; + DType y_real = (*(grid + grid_index + o_h * o_w) + 1) * (i_h - 1) / 2; + DType x_real = (*(grid + grid_index) + 1) * (i_w - 1) / 2; // NOLINTNEXTLINE int top_left_y = static_cast(std::floor(y_real)); // NOLINTNEXTLINE - int top_left_x = static_cast(std::floor(x_real)); - DType top_left_y_w = 1.0 - (y_real - top_left_y); - DType top_left_x_w = 1.0 - (x_real - top_left_x); - int data_index = n * i_c * i_h * i_w + c * i_h * i_w + - top_left_y * i_w + top_left_x; - DType top_left_v = 0; - DType top_right_v = 0; + int top_left_x = static_cast(std::floor(x_real)); + DType top_left_y_w = 1.0 - (y_real - top_left_y); + DType top_left_x_w = 1.0 - (x_real - top_left_x); + int data_index = n * i_c * i_h * i_w + c * i_h * i_w + top_left_y * i_w + top_left_x; + DType top_left_v = 0; + DType top_right_v = 0; DType bottom_left_v = 0; DType bottom_right_v = 0; - if (between(top_left_x, 0, i_w-1) && between(top_left_y, 0, i_h-1)) + if (between(top_left_x, 0, i_w - 1) && between(top_left_y, 0, i_h - 1)) top_left_v = *(data + data_index); - if (between(top_left_x + 1, 0, i_w-1) && between(top_left_y, 0, i_h-1)) + if (between(top_left_x + 1, 0, i_w - 1) && between(top_left_y, 0, i_h - 1)) top_right_v = *(data + data_index + 1); - if (between(top_left_x, 0, i_w-1) && between(top_left_y + 1, 0, i_h-1)) + if (between(top_left_x, 0, i_w - 1) && between(top_left_y + 1, 0, i_h - 1)) bottom_left_v = *(data + data_index + i_w); - if (between(top_left_x+1, 0, i_w-1) && between(top_left_y + 1, 0, i_h-1)) + if (between(top_left_x + 1, 0, i_w - 1) && between(top_left_y + 1, 0, i_h - 1)) bottom_right_v = *(data + data_index + i_w + 1); - *(out+out_index) = top_left_v * top_left_y_w * top_left_x_w + - top_right_v * top_left_y_w * (1.0 - top_left_x_w) + - bottom_left_v * (1.0 - top_left_y_w) * top_left_x_w + - bottom_right_v * (1.0 - top_left_y_w) * (1.0 - top_left_x_w); + *(out + out_index) = top_left_v * top_left_y_w * top_left_x_w + + top_right_v * top_left_y_w * (1.0 - top_left_x_w) + + bottom_left_v * (1.0 - top_left_y_w) * top_left_x_w + + bottom_right_v * (1.0 - top_left_y_w) * (1.0 - top_left_x_w); } } } } } -template -inline void BilinearSamplerBackward(const Tensor &gdata, - const Tensor &ggrid, - const Tensor &output_grad, - const Tensor &input_data, - const Tensor &grid, +template +inline void BilinearSamplerBackward(const Tensor& gdata, + const Tensor& ggrid, + const Tensor& output_grad, + const Tensor& input_data, + const Tensor& grid, const mxnet::OpReqType data_req, const mxnet::OpReqType grid_req) { - DType *g_input = gdata.dptr_; - DType *grad_grid = ggrid.dptr_; - const DType *grid_src = grid.dptr_; - const DType *grad = output_grad.dptr_; - const DType *data = input_data.dptr_; - int o_n = output_grad.size(0), o_c = output_grad.size(1), - o_h = output_grad.size(2), o_w = output_grad.size(3); + DType* g_input = gdata.dptr_; + DType* grad_grid = ggrid.dptr_; + const DType* grid_src = grid.dptr_; + const DType* grad = output_grad.dptr_; + const DType* data = input_data.dptr_; + int o_n = output_grad.size(0), o_c = output_grad.size(1), o_h = output_grad.size(2), + o_w = output_grad.size(3); int i_c = input_data.size(1), i_h = input_data.size(2), i_w = input_data.size(3); for (index_t n = 0; n < static_cast(o_n); ++n) { - for (index_t h = 0; h < static_cast(o_h); ++h) { - for (index_t w = 0; w < static_cast(o_w); ++w) { - DType top_left_y_gw = 0.0; - DType top_left_x_gw = 0.0; - index_t grid_src_index = n * o_h * o_w * 2 + h * o_w + w; - DType y_real = (*(grid_src + grid_src_index + o_h * o_w) + 1) * (i_h - 1) / 2; - DType x_real = (*(grid_src + grid_src_index) + 1) * (i_w - 1) / 2; - // NOLINTNEXTLINE - int top_left_y = static_cast(std::floor(y_real)); - // NOLINTNEXTLINE - int top_left_x = static_cast(std::floor(x_real)); - DType top_left_y_w = 1.0 - (y_real - top_left_y); - DType top_left_x_w = 1.0 - (x_real - top_left_x); - for (index_t c = 0; c < static_cast(o_c); ++c) { - index_t grad_index = n * o_c * o_h * o_w + c * o_h * o_w + h * o_w + w; - int data_index = n * i_c * i_h * i_w + c * i_h * i_w + top_left_y * i_w + top_left_x; - // calc 4 vertex value in input data - DType top_left_v = 0; - DType top_right_v = 0; - DType bottom_left_v = 0; - DType bottom_right_v = 0; - // calc input grad - if (between(top_left_x, 0, i_w-1) && between(top_left_y, 0, i_h-1)) { - if (data_req != mxnet::kNullOp) { - *(g_input + data_index) += *(grad + grad_index) * top_left_y_w * top_left_x_w; - } - top_left_v = *(data + data_index); + for (index_t h = 0; h < static_cast(o_h); ++h) { + for (index_t w = 0; w < static_cast(o_w); ++w) { + DType top_left_y_gw = 0.0; + DType top_left_x_gw = 0.0; + index_t grid_src_index = n * o_h * o_w * 2 + h * o_w + w; + DType y_real = (*(grid_src + grid_src_index + o_h * o_w) + 1) * (i_h - 1) / 2; + DType x_real = (*(grid_src + grid_src_index) + 1) * (i_w - 1) / 2; + // NOLINTNEXTLINE + int top_left_y = static_cast(std::floor(y_real)); + // NOLINTNEXTLINE + int top_left_x = static_cast(std::floor(x_real)); + DType top_left_y_w = 1.0 - (y_real - top_left_y); + DType top_left_x_w = 1.0 - (x_real - top_left_x); + for (index_t c = 0; c < static_cast(o_c); ++c) { + index_t grad_index = n * o_c * o_h * o_w + c * o_h * o_w + h * o_w + w; + int data_index = n * i_c * i_h * i_w + c * i_h * i_w + top_left_y * i_w + top_left_x; + // calc 4 vertex value in input data + DType top_left_v = 0; + DType top_right_v = 0; + DType bottom_left_v = 0; + DType bottom_right_v = 0; + // calc input grad + if (between(top_left_x, 0, i_w - 1) && between(top_left_y, 0, i_h - 1)) { + if (data_req != mxnet::kNullOp) { + *(g_input + data_index) += *(grad + grad_index) * top_left_y_w * top_left_x_w; } - if (between(top_left_x+1, 0, i_w-1) && between(top_left_y, 0, i_h-1)) { - if (data_req != mxnet::kNullOp) { - *(g_input + data_index + 1) += + top_left_v = *(data + data_index); + } + if (between(top_left_x + 1, 0, i_w - 1) && between(top_left_y, 0, i_h - 1)) { + if (data_req != mxnet::kNullOp) { + *(g_input + data_index + 1) += *(grad + grad_index) * top_left_y_w * (1.0 - top_left_x_w); - } - top_right_v = *(data + data_index + 1); } - if (between(top_left_x, 0, i_w-1) && between(top_left_y+1, 0, i_h-1)) { - if (data_req != mxnet::kNullOp) { - *(g_input + data_index+ i_w) += + top_right_v = *(data + data_index + 1); + } + if (between(top_left_x, 0, i_w - 1) && between(top_left_y + 1, 0, i_h - 1)) { + if (data_req != mxnet::kNullOp) { + *(g_input + data_index + i_w) += *(grad + grad_index) * (1.0 - top_left_y_w) * top_left_x_w; - } - bottom_left_v = *(data + data_index + i_w); } - if (between(top_left_x+1, 0, i_w-1) && between(top_left_y+1, 0, i_h-1)) { - if (data_req != mxnet::kNullOp) { - *(g_input + data_index+ i_w + 1) += + bottom_left_v = *(data + data_index + i_w); + } + if (between(top_left_x + 1, 0, i_w - 1) && between(top_left_y + 1, 0, i_h - 1)) { + if (data_req != mxnet::kNullOp) { + *(g_input + data_index + i_w + 1) += *(grad + grad_index) * (1.0 - top_left_y_w) * (1.0 - top_left_x_w); - } - bottom_right_v = *(data + data_index + i_w + 1); } - // calc weight grad of top_left_w, then multiple -1 is the grad of grid_src - top_left_y_gw -= *(grad + grad_index) * (top_right_v - bottom_right_v + - (top_left_v - top_right_v - bottom_left_v + bottom_right_v) - * top_left_x_w); - top_left_x_gw -= *(grad + grad_index) * (bottom_left_v - bottom_right_v + - (top_left_v - top_right_v - bottom_left_v + bottom_right_v) - * top_left_y_w); - } - if (grid_req != mxnet::kNullOp) { - // calc grad of grid - *(grad_grid + grid_src_index + o_h * o_w) += top_left_y_gw * (i_h - 1) / 2; - *(grad_grid + grid_src_index) += top_left_x_gw * (i_w - 1) / 2; + bottom_right_v = *(data + data_index + i_w + 1); } + // calc weight grad of top_left_w, then multiple -1 is the grad of grid_src + top_left_y_gw -= + *(grad + grad_index) * + (top_right_v - bottom_right_v + + (top_left_v - top_right_v - bottom_left_v + bottom_right_v) * top_left_x_w); + top_left_x_gw -= + *(grad + grad_index) * + (bottom_left_v - bottom_right_v + + (top_left_v - top_right_v - bottom_left_v + bottom_right_v) * top_left_y_w); + } + if (grid_req != mxnet::kNullOp) { + // calc grad of grid + *(grad_grid + grid_src_index + o_h * o_w) += top_left_y_gw * (i_h - 1) / 2; + *(grad_grid + grid_src_index) += top_left_x_gw * (i_w - 1) / 2; } } } } +} } // namespace mshadow namespace mxnet { namespace op { -template<> +template <> Operator* CreateOp(BilinearSamplerParam param, int dtype) { - Operator *op = nullptr; - MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { - op = new BilinearSamplerOp(param); - }) + Operator* op = nullptr; + MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { op = new BilinearSamplerOp(param); }) return op; } -Operator *BilinearSamplerProp::CreateOperatorEx(Context ctx, mxnet::ShapeVector *in_shape, - std::vector *in_type) const { +Operator* BilinearSamplerProp::CreateOperatorEx(Context ctx, + mxnet::ShapeVector* in_shape, + std::vector* in_type) const { DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0]); } DMLC_REGISTER_PARAMETER(BilinearSamplerParam); MXNET_REGISTER_OP_PROPERTY(BilinearSampler, BilinearSamplerProp) -.add_argument("data", "NDArray-or-Symbol", "Input data to the BilinearsamplerOp.") -.add_argument("grid", "NDArray-or-Symbol", "Input grid to the BilinearsamplerOp." - "grid has two channels: x_src, y_src") -.add_arguments(BilinearSamplerParam::__FIELDS__()) -.describe(R"code(Applies bilinear sampling to input feature map. + .add_argument("data", "NDArray-or-Symbol", "Input data to the BilinearsamplerOp.") + .add_argument("grid", + "NDArray-or-Symbol", + "Input grid to the BilinearsamplerOp." + "grid has two channels: x_src, y_src") + .add_arguments(BilinearSamplerParam::__FIELDS__()) + .describe(R"code(Applies bilinear sampling to input feature map. Bilinear Sampling is the key of [NIPS2015] \"Spatial Transformer Networks\". The usage of the operator is very similar to remap function in OpenCV, except that the operator has the backward pass. diff --git a/src/operator/bilinear_sampler.cu b/src/operator/bilinear_sampler.cu index dae14a645fd8..5b66a86f22ef 100644 --- a/src/operator/bilinear_sampler.cu +++ b/src/operator/bilinear_sampler.cu @@ -22,7 +22,7 @@ * \file bilinear_sampler.cu * \brief * \author Xu Dong -*/ + */ #include "./bilinear_sampler-inl.h" #include @@ -33,120 +33,130 @@ namespace mshadow { namespace cuda { -template +template __device__ bool between(DType value, int lowerBound, int upperBound) { return (value >= lowerBound && value <= upperBound); } -template -__global__ void BilinearSamplerForwardKernel(const int i_c, const int i_h, - const int i_w, const DType* data, - const DType* grid, const int o_n, - const int o_c, const int o_h, - const int o_w, DType* out) { +template +__global__ void BilinearSamplerForwardKernel(const int i_c, + const int i_h, + const int i_w, + const DType* data, + const DType* grid, + const int o_n, + const int o_c, + const int o_h, + const int o_w, + DType* out) { for (int index = (blockIdx.x + blockIdx.y * gridDim.x) * blockDim.x + threadIdx.x; index < o_n * o_c * o_h * o_w; index += blockDim.x * gridDim.x * gridDim.y) { // (n, c, h, w) is the element in out - int w = index % o_w; - int h = (index / o_w) % o_h; - int c = (index / o_w / o_h) % o_c; - int n = index / o_w / o_h / o_c; - int out_index = n * o_c * o_h * o_w + c * o_h * o_w + h * o_w + w; - int grid_index = n * o_h * o_w * 2 + h * o_w + w; - DType y_real = (*(grid + grid_index + o_h * o_w) + 1) * (i_h - 1) / 2; - DType x_real = (*(grid + grid_index) + 1) * (i_w - 1) / 2; - int top_left_y = static_cast(floor(y_real)); - int top_left_x = static_cast(floor(x_real)); - DType top_left_y_w = 1.0 - (y_real - top_left_y); - DType top_left_x_w = 1.0 - (x_real - top_left_x); - int data_index = n * i_c * i_h * i_w + c * i_h * i_w + top_left_y * i_w + top_left_x; - DType top_left_v = 0; - DType top_right_v = 0; - DType bottom_left_v = 0; + int w = index % o_w; + int h = (index / o_w) % o_h; + int c = (index / o_w / o_h) % o_c; + int n = index / o_w / o_h / o_c; + int out_index = n * o_c * o_h * o_w + c * o_h * o_w + h * o_w + w; + int grid_index = n * o_h * o_w * 2 + h * o_w + w; + DType y_real = (*(grid + grid_index + o_h * o_w) + 1) * (i_h - 1) / 2; + DType x_real = (*(grid + grid_index) + 1) * (i_w - 1) / 2; + int top_left_y = static_cast(floor(y_real)); + int top_left_x = static_cast(floor(x_real)); + DType top_left_y_w = 1.0 - (y_real - top_left_y); + DType top_left_x_w = 1.0 - (x_real - top_left_x); + int data_index = n * i_c * i_h * i_w + c * i_h * i_w + top_left_y * i_w + top_left_x; + DType top_left_v = 0; + DType top_right_v = 0; + DType bottom_left_v = 0; DType bottom_right_v = 0; - if (between(top_left_x, 0, i_w-1) && between(top_left_y, 0, i_h-1)) + if (between(top_left_x, 0, i_w - 1) && between(top_left_y, 0, i_h - 1)) top_left_v = *(data + data_index); - if (between(top_left_x + 1, 0, i_w-1) && between(top_left_y, 0, i_h-1)) + if (between(top_left_x + 1, 0, i_w - 1) && between(top_left_y, 0, i_h - 1)) top_right_v = *(data + data_index + 1); - if (between(top_left_x, 0, i_w-1) && between(top_left_y + 1, 0, i_h-1)) + if (between(top_left_x, 0, i_w - 1) && between(top_left_y + 1, 0, i_h - 1)) bottom_left_v = *(data + data_index + i_w); - if (between(top_left_x+1, 0, i_w-1) && between(top_left_y + 1, 0, i_h-1)) + if (between(top_left_x + 1, 0, i_w - 1) && between(top_left_y + 1, 0, i_h - 1)) bottom_right_v = *(data + data_index + i_w + 1); - *(out+out_index) = top_left_v * top_left_y_w * top_left_x_w + - top_right_v * top_left_y_w * (1.0 - top_left_x_w) + - bottom_left_v * (1.0 - top_left_y_w) * top_left_x_w + - bottom_right_v * (1.0 - top_left_y_w) * (1.0 - top_left_x_w); + *(out + out_index) = top_left_v * top_left_y_w * top_left_x_w + + top_right_v * top_left_y_w * (1.0 - top_left_x_w) + + bottom_left_v * (1.0 - top_left_y_w) * top_left_x_w + + bottom_right_v * (1.0 - top_left_y_w) * (1.0 - top_left_x_w); } } -template -__global__ void BilinearSamplerBackwardKernel(const int i_c, const int i_h, - const int i_w, const DType* grad, - const DType* data, const int o_n, - const int o_c, const int o_h, - const int o_w, DType* g_input, +template +__global__ void BilinearSamplerBackwardKernel(const int i_c, + const int i_h, + const int i_w, + const DType* grad, + const DType* data, + const int o_n, + const int o_c, + const int o_h, + const int o_w, + DType* g_input, const DType* grid_src, DType* grad_grid) { for (int index = (blockIdx.x + blockIdx.y * gridDim.x) * blockDim.x + threadIdx.x; index < o_n * o_h * o_w; index += blockDim.x * gridDim.x * gridDim.y) { // (n, c, h, w) is the element in grad - int w = index % o_w; - int h = (index / o_w) % o_h; - int n = index / o_w / o_h; + int w = index % o_w; + int h = (index / o_w) % o_h; + int n = index / o_w / o_h; DType top_left_y_gw = 0.0; DType top_left_x_gw = 0.0; - int grid_src_index = n * o_h * o_w * 2 + h * o_w + w; - DType y_real = (*(grid_src + grid_src_index + o_h * o_w) + 1) * (i_h - 1) / 2; - DType x_real = (*(grid_src + grid_src_index) + 1) * (i_w - 1) / 2; + int grid_src_index = n * o_h * o_w * 2 + h * o_w + w; + DType y_real = (*(grid_src + grid_src_index + o_h * o_w) + 1) * (i_h - 1) / 2; + DType x_real = (*(grid_src + grid_src_index) + 1) * (i_w - 1) / 2; - int top_left_y = static_cast(floor(y_real)); - int top_left_x = static_cast(floor(x_real)); + int top_left_y = static_cast(floor(y_real)); + int top_left_x = static_cast(floor(x_real)); DType top_left_y_w = 1.0 - (y_real - top_left_y); DType top_left_x_w = 1.0 - (x_real - top_left_x); for (int c = 0; c < o_c; ++c) { int grad_index = n * o_c * o_h * o_w + c * o_h * o_w + h * o_w + w; int data_index = n * i_c * i_h * i_w + c * i_h * i_w + top_left_y * i_w + top_left_x; // calc 4 vertex value in input data - DType top_left_v = 0; - DType top_right_v = 0; - DType bottom_left_v = 0; + DType top_left_v = 0; + DType top_right_v = 0; + DType bottom_left_v = 0; DType bottom_right_v = 0; // calc input grad - if (between(top_left_x, 0, i_w-1) && between(top_left_y, 0, i_h-1)) { + if (between(top_left_x, 0, i_w - 1) && between(top_left_y, 0, i_h - 1)) { if (Req1 != mxnet::kNullOp) { atomicAdd(&g_input[data_index], *(grad + grad_index) * top_left_y_w * top_left_x_w); } top_left_v = *(data + data_index); } - if (between(top_left_x+1, 0, i_w-1) && between(top_left_y, 0, i_h-1)) { + if (between(top_left_x + 1, 0, i_w - 1) && between(top_left_y, 0, i_h - 1)) { if (Req1 != mxnet::kNullOp) { atomicAdd(&g_input[data_index + 1], *(grad + grad_index) * top_left_y_w * (1.0 - top_left_x_w)); } top_right_v = *(data + data_index + 1); } - if (between(top_left_x, 0, i_w-1) && between(top_left_y+1, 0, i_h-1)) { + if (between(top_left_x, 0, i_w - 1) && between(top_left_y + 1, 0, i_h - 1)) { if (Req1 != mxnet::kNullOp) { - atomicAdd(&g_input[data_index+ i_w], + atomicAdd(&g_input[data_index + i_w], *(grad + grad_index) * (1.0 - top_left_y_w) * top_left_x_w); } bottom_left_v = *(data + data_index + i_w); } - if (between(top_left_x+1, 0, i_w-1) && between(top_left_y+1, 0, i_h-1)) { + if (between(top_left_x + 1, 0, i_w - 1) && between(top_left_y + 1, 0, i_h - 1)) { if (Req1 != mxnet::kNullOp) { - atomicAdd(&g_input[data_index+ i_w + 1], + atomicAdd(&g_input[data_index + i_w + 1], *(grad + grad_index) * (1.0 - top_left_y_w) * (1.0 - top_left_x_w)); } bottom_right_v = *(data + data_index + i_w + 1); } // calc weight grad of top_left_w, then multiple -1 is the grad of grid_src - top_left_y_gw -= *(grad + grad_index) * (top_right_v - bottom_right_v + - (top_left_v - top_right_v - bottom_left_v + bottom_right_v) - * top_left_x_w); - top_left_x_gw -= *(grad + grad_index) * (bottom_left_v - bottom_right_v + - (top_left_v - top_right_v - bottom_left_v + bottom_right_v) - * top_left_y_w); + top_left_y_gw -= *(grad + grad_index) * + (top_right_v - bottom_right_v + + (top_left_v - top_right_v - bottom_left_v + bottom_right_v) * top_left_x_w); + top_left_x_gw -= *(grad + grad_index) * + (bottom_left_v - bottom_right_v + + (top_left_v - top_right_v - bottom_left_v + bottom_right_v) * top_left_y_w); } if (Req2 != mxnet::kNullOp) { // calc grad of grid @@ -157,54 +167,54 @@ __global__ void BilinearSamplerBackwardKernel(const int i_c, const int i_h, } } // namespace cuda -template -inline void BilinearSamplerForward(const Tensor &output, - const Tensor &input, - const Tensor &grid_src) { - DType *out = output.dptr_; - const DType *data = input.dptr_; - const DType *grid = grid_src.dptr_; - int o_n = output.size(0), o_c = output.size(1), o_h = output.size(2), o_w = output.size(3); - int i_c = input.size(1), i_h = input.size(2), i_w = input.size(3); - using namespace cuda; - const int max_block = (output.shape_.Size() + kMaxThreadsPerBlock - 1) / kMaxThreadsPerBlock; - const int grid_dim_x = (max_block > kMaxGridDim) ? kMaxGridDim : max_block; - const int grid_dim_y = +template +inline void BilinearSamplerForward(const Tensor& output, + const Tensor& input, + const Tensor& grid_src) { + DType* out = output.dptr_; + const DType* data = input.dptr_; + const DType* grid = grid_src.dptr_; + int o_n = output.size(0), o_c = output.size(1), o_h = output.size(2), o_w = output.size(3); + int i_c = input.size(1), i_h = input.size(2), i_w = input.size(3); + using namespace cuda; + const int max_block = (output.shape_.Size() + kMaxThreadsPerBlock - 1) / kMaxThreadsPerBlock; + const int grid_dim_x = (max_block > kMaxGridDim) ? kMaxGridDim : max_block; + const int grid_dim_y = (max_block > kMaxGridDim) ? (max_block + kMaxGridDim - 1) / kMaxGridDim : 1; - dim3 num_blocks(grid_dim_x, grid_dim_y); - dim3 threads_per_block(kMaxThreadsPerBlock); - CheckLaunchParam(num_blocks, threads_per_block, "bilinear sampler forward"); - cudaStream_t stream = Stream::GetStream(output.stream_); - cuda::BilinearSamplerForwardKernel << > >( + dim3 num_blocks(grid_dim_x, grid_dim_y); + dim3 threads_per_block(kMaxThreadsPerBlock); + CheckLaunchParam(num_blocks, threads_per_block, "bilinear sampler forward"); + cudaStream_t stream = Stream::GetStream(output.stream_); + cuda::BilinearSamplerForwardKernel<<>>( i_c, i_h, i_w, data, grid, o_n, o_c, o_h, o_w, out); - // post kernel check - cudaError err = cudaGetLastError(); - CHECK_EQ(err, cudaSuccess) << cudaGetErrorString(err); + // post kernel check + cudaError err = cudaGetLastError(); + CHECK_EQ(err, cudaSuccess) << cudaGetErrorString(err); } -template -inline void BilinearSamplerBackward(const Tensor &input_grad, - const Tensor &ggrid, - const Tensor &output_grad, - const Tensor &input_data, - const Tensor &grid, +template +inline void BilinearSamplerBackward(const Tensor& input_grad, + const Tensor& ggrid, + const Tensor& output_grad, + const Tensor& input_data, + const Tensor& grid, const mxnet::OpReqType data_req, const mxnet::OpReqType grid_req) { using namespace mxnet; - DType *g_input = input_grad.dptr_; - DType *grad_grid = ggrid.dptr_; - const DType *grid_src = grid.dptr_; - const DType *grad = output_grad.dptr_; - const DType *data = input_data.dptr_; - int o_n = output_grad.size(0), o_c = output_grad.size(1), - o_h = output_grad.size(2), o_w = output_grad.size(3); + DType* g_input = input_grad.dptr_; + DType* grad_grid = ggrid.dptr_; + const DType* grid_src = grid.dptr_; + const DType* grad = output_grad.dptr_; + const DType* data = input_data.dptr_; + int o_n = output_grad.size(0), o_c = output_grad.size(1), o_h = output_grad.size(2), + o_w = output_grad.size(3); int i_c = input_data.size(1), i_h = input_data.size(2), i_w = input_data.size(3); using namespace cuda; - const int max_block = (output_grad.shape_.Size() / o_c + kMaxThreadsPerBlock - 1) - / kMaxThreadsPerBlock; + const int max_block = + (output_grad.shape_.Size() / o_c + kMaxThreadsPerBlock - 1) / kMaxThreadsPerBlock; const int grid_dim_x = (max_block > kMaxGridDim) ? kMaxGridDim : max_block; const int grid_dim_y = - (max_block > kMaxGridDim) ? (max_block + kMaxGridDim - 1) / kMaxGridDim : 1; + (max_block > kMaxGridDim) ? (max_block + kMaxGridDim - 1) / kMaxGridDim : 1; dim3 num_blocks(grid_dim_x, grid_dim_y); dim3 threads_per_block(kMaxThreadsPerBlock); CheckLaunchParam(num_blocks, threads_per_block, "bilinear sampler backward"); @@ -212,8 +222,8 @@ inline void BilinearSamplerBackward(const Tensor &input_grad, MXNET_REQ_TYPE_SWITCH(data_req, Req1, { MXNET_REQ_TYPE_SWITCH(grid_req, Req2, { cuda::BilinearSamplerBackwardKernel - <<>>( - i_c, i_h, i_w, grad, data, o_n, o_c, o_h, o_w, g_input, grid_src, grad_grid); + <<>>( + i_c, i_h, i_w, grad, data, o_n, o_c, o_h, o_w, g_input, grid_src, grad_grid); }); }); // post kernel check @@ -225,9 +235,9 @@ inline void BilinearSamplerBackward(const Tensor &input_grad, namespace mxnet { namespace op { -template<> +template <> Operator* CreateOp(BilinearSamplerParam param, int dtype) { - Operator *op = nullptr; + Operator* op = nullptr; #if MXNET_USE_CUDNN == 1 MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { if (param.cudnn_off.has_value() && param.cudnn_off.value()) { @@ -237,9 +247,7 @@ Operator* CreateOp(BilinearSamplerParam param, int dtype) { } }) #else - MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { - op = new BilinearSamplerOp(param); - }) + MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { op = new BilinearSamplerOp(param); }) #endif // MXNET_USE_CUDNN return op; } diff --git a/src/operator/c_lapack_api.cc b/src/operator/c_lapack_api.cc index 2d2109d05427..982e0353220e 100644 --- a/src/operator/c_lapack_api.cc +++ b/src/operator/c_lapack_api.cc @@ -22,151 +22,199 @@ #if (MXNET_USE_LAPACK && (MSHADOW_USE_MKL || MXNET_USE_LAPACKE_INTERFACE)) #elif MXNET_USE_LAPACK #else - // use pragma message instead of warning - #pragma message("Warning: lapack usage not enabled, linalg-operators will not be available." \ - " Ensure that lapack library is installed and build with USE_LAPACK=1 to get lapack" \ - " functionalities.") - - // Define compilable stubs. - #define MXNET_LAPACK_CWRAPPER1(func, dtype) \ - int MXNET_LAPACK_##func(int matrix_layout, char uplo, int n, dtype* a, int lda) { \ +// use pragma message instead of warning +#pragma message( \ + "Warning: lapack usage not enabled, linalg-operators will not be available." \ + " Ensure that lapack library is installed and build with USE_LAPACK=1 to get lapack" \ + " functionalities.") + +// Define compilable stubs. +#define MXNET_LAPACK_CWRAPPER1(func, dtype) \ + int MXNET_LAPACK_##func(int matrix_layout, char uplo, int n, dtype* a, int lda) { \ LOG(FATAL) << "MXNet build without lapack. Function " << #func << " is not available."; \ - return 1; \ + return 1; \ } - #define MXNET_LAPACK_CWRAPPER2(func, dtype) \ - int MXNET_LAPACK_##func(int matrix_layout, int m, int n, dtype* a, \ - int lda, dtype* tau, dtype* work, int lwork) { \ - LOG(FATAL) << "MXNet build without lapack. Function " << #func << " is not available."; \ - return 1; \ +#define MXNET_LAPACK_CWRAPPER2(func, dtype) \ + int MXNET_LAPACK_##func( \ + int matrix_layout, int m, int n, dtype* a, int lda, dtype* tau, dtype* work, int lwork) { \ + LOG(FATAL) << "MXNet build without lapack. Function " << #func << " is not available."; \ + return 1; \ } - #define MXNET_LAPACK_CWRAPPER3(func, dtype) \ - int MXNET_LAPACK_##func(int matrix_layout, char uplo, int n, dtype *a, \ - int lda, dtype *w, dtype *work, int lwork, \ - int *iwork, int liwork) { \ +#define MXNET_LAPACK_CWRAPPER3(func, dtype) \ + int MXNET_LAPACK_##func(int matrix_layout, \ + char uplo, \ + int n, \ + dtype* a, \ + int lda, \ + dtype* w, \ + dtype* work, \ + int lwork, \ + int* iwork, \ + int liwork) { \ LOG(FATAL) << "MXNet build without lapack. Function " << #func << " is not available."; \ - return 1; \ + return 1; \ } - #define MXNET_LAPACK_CWRAPPER4(func, dtype) \ - int MXNET_LAPACK_##func(int matrix_layout, int m, int n, \ - dtype *a, int lda, int *ipiv) { \ +#define MXNET_LAPACK_CWRAPPER4(func, dtype) \ + int MXNET_LAPACK_##func(int matrix_layout, int m, int n, dtype* a, int lda, int* ipiv) { \ LOG(FATAL) << "MXNet build without lapack. Function " << #func << " is not available."; \ - return 1; \ + return 1; \ } - #define MXNET_LAPACK_CWRAPPER5(func, dtype) \ - int MXNET_LAPACK_##func(int matrix_layout, int n, dtype *a, int lda, \ - int *ipiv, dtype *work, int lwork) { \ +#define MXNET_LAPACK_CWRAPPER5(func, dtype) \ + int MXNET_LAPACK_##func( \ + int matrix_layout, int n, dtype* a, int lda, int* ipiv, dtype* work, int lwork) { \ LOG(FATAL) << "MXNet build without lapack. Function " << #func << " is not available."; \ - return 1; \ + return 1; \ } - #define MXNET_LAPACK_CWRAPPER6(func, dtype) \ - int MXNET_LAPACK_##func(int matrix_layout, int m, int n, dtype* ut, \ - int ldut, dtype* s, dtype* v, int ldv, \ - dtype* work, int lwork) { \ +#define MXNET_LAPACK_CWRAPPER6(func, dtype) \ + int MXNET_LAPACK_##func(int matrix_layout, \ + int m, \ + int n, \ + dtype* ut, \ + int ldut, \ + dtype* s, \ + dtype* v, \ + int ldv, \ + dtype* work, \ + int lwork) { \ LOG(FATAL) << "MXNet build without lapack. Function " << #func << " is not available."; \ - return 1; \ + return 1; \ } - #define MXNET_LAPACK_CWRAPPER7(func, dtype) \ - int MXNET_LAPACK_##func(int matrix_order, int n, int nrhs, dtype *a, \ - int lda, int *ipiv, dtype *b, int ldb) { \ +#define MXNET_LAPACK_CWRAPPER7(func, dtype) \ + int MXNET_LAPACK_##func( \ + int matrix_order, int n, int nrhs, dtype* a, int lda, int* ipiv, dtype* b, int ldb) { \ LOG(FATAL) << "MXNet build without lapack. Function " << #func << " is not available."; \ - return 1; \ + return 1; \ } - #define MXNET_LAPACK_CWRAPPER8(func, dtype) \ - int MXNET_LAPACK_##func(int matrix_layout, char jobvl, char jobvr, \ - int n, dtype *a, int lda, \ - dtype *wr, dtype *wi, \ - dtype *vl, int ldvl, dtype *vr, int ldvr, \ - dtype *work, int lwork) { \ +#define MXNET_LAPACK_CWRAPPER8(func, dtype) \ + int MXNET_LAPACK_##func(int matrix_layout, \ + char jobvl, \ + char jobvr, \ + int n, \ + dtype* a, \ + int lda, \ + dtype* wr, \ + dtype* wi, \ + dtype* vl, \ + int ldvl, \ + dtype* vr, \ + int ldvr, \ + dtype* work, \ + int lwork) { \ LOG(FATAL) << "MXNet build without lapack. Function " << #func << " is not available."; \ - return 1; \ + return 1; \ } - #define MXNET_LAPACK_CWRAPPER9(func, dtype) \ - int MXNET_LAPACK_##func(int matrix_layout, int m, int n, \ - dtype *a, int lda, dtype *s, \ - dtype *u, int ldu, \ - dtype *vt, int ldvt, \ - dtype *work, int lwork, int *iwork) { \ +#define MXNET_LAPACK_CWRAPPER9(func, dtype) \ + int MXNET_LAPACK_##func(int matrix_layout, \ + int m, \ + int n, \ + dtype* a, \ + int lda, \ + dtype* s, \ + dtype* u, \ + int ldu, \ + dtype* vt, \ + int ldvt, \ + dtype* work, \ + int lwork, \ + int* iwork) { \ LOG(FATAL) << "MXNet build without lapack. Function " << #func << " is not available."; \ - return 1; \ + return 1; \ } - #define MXNET_LAPACK_CWRAPPER10(func, dtype) \ - int MXNET_LAPACK_##func(int matrix_layout, int m, int n, dtype* a, \ - int lda, dtype* tau, dtype* work, int lwork) { \ - LOG(FATAL) << "MXNet build without lapack. Function " << #func << " is not available."; \ - return 1; \ +#define MXNET_LAPACK_CWRAPPER10(func, dtype) \ + int MXNET_LAPACK_##func( \ + int matrix_layout, int m, int n, dtype* a, int lda, dtype* tau, dtype* work, int lwork) { \ + LOG(FATAL) << "MXNet build without lapack. Function " << #func << " is not available."; \ + return 1; \ } - #define MXNET_LAPACK_CWRAPPER11(func, dtype) \ - int MXNET_LAPACK_##func(int matrix_layout, int m, int n, int nrhs, \ - dtype *a, int lda, dtype *b, int ldb, \ - dtype *s, dtype rcond, int *rank, \ - dtype *work, int lwork, int *iwork) { \ +#define MXNET_LAPACK_CWRAPPER11(func, dtype) \ + int MXNET_LAPACK_##func(int matrix_layout, \ + int m, \ + int n, \ + int nrhs, \ + dtype* a, \ + int lda, \ + dtype* b, \ + int ldb, \ + dtype* s, \ + dtype rcond, \ + int* rank, \ + dtype* work, \ + int lwork, \ + int* iwork) { \ LOG(FATAL) << "MXNet build without lapack. Function " << #func << " is not available."; \ - return 1; \ + return 1; \ } - #define MXNET_LAPACK_CWRAPPER12(func, dtype) \ - int MXNET_LAPACK_##func(int matrix_layout, int m, int n, int k, dtype* a, \ - int lda, dtype* tau, dtype* work, int lwork) { \ +#define MXNET_LAPACK_CWRAPPER12(func, dtype) \ + int MXNET_LAPACK_##func(int matrix_layout, \ + int m, \ + int n, \ + int k, \ + dtype* a, \ + int lda, \ + dtype* tau, \ + dtype* work, \ + int lwork) { \ LOG(FATAL) << "MXNet build without lapack. Function " << #func << " is not available."; \ - return 1; \ + return 1; \ } - #define MXNET_LAPACK_UNAVAILABLE(func) \ - int mxnet_lapack_##func(...) { \ +#define MXNET_LAPACK_UNAVAILABLE(func) \ + int mxnet_lapack_##func(...) { \ LOG(FATAL) << "MXNet build without lapack. Function " << #func << " is not available."; \ - return 1; \ + return 1; \ } - MXNET_LAPACK_CWRAPPER1(spotrf, float) - MXNET_LAPACK_CWRAPPER1(dpotrf, double) - MXNET_LAPACK_CWRAPPER1(spotri, float) - MXNET_LAPACK_CWRAPPER1(dpotri, double) +MXNET_LAPACK_CWRAPPER1(spotrf, float) // NOLINT +MXNET_LAPACK_CWRAPPER1(dpotrf, double) // NOLINT +MXNET_LAPACK_CWRAPPER1(spotri, float) // NOLINT +MXNET_LAPACK_CWRAPPER1(dpotri, double) // NOLINT - MXNET_LAPACK_UNAVAILABLE(sposv) - MXNET_LAPACK_UNAVAILABLE(dposv) +MXNET_LAPACK_UNAVAILABLE(sposv) // NOLINT +MXNET_LAPACK_UNAVAILABLE(dposv) // NOLINT - MXNET_LAPACK_CWRAPPER2(sgelqf, float) - MXNET_LAPACK_CWRAPPER2(dgelqf, double) - MXNET_LAPACK_CWRAPPER2(sorglq, float) - MXNET_LAPACK_CWRAPPER2(dorglq, double) +MXNET_LAPACK_CWRAPPER2(sgelqf, float) // NOLINT +MXNET_LAPACK_CWRAPPER2(dgelqf, double) // NOLINT +MXNET_LAPACK_CWRAPPER2(sorglq, float) // NOLINT +MXNET_LAPACK_CWRAPPER2(dorglq, double) // NOLINT - MXNET_LAPACK_CWRAPPER3(ssyevd, float) - MXNET_LAPACK_CWRAPPER3(dsyevd, double) +MXNET_LAPACK_CWRAPPER3(ssyevd, float) // NOLINT +MXNET_LAPACK_CWRAPPER3(dsyevd, double) // NOLINT - MXNET_LAPACK_CWRAPPER4(sgetrf, float) - MXNET_LAPACK_CWRAPPER4(dgetrf, double) +MXNET_LAPACK_CWRAPPER4(sgetrf, float) // NOLINT +MXNET_LAPACK_CWRAPPER4(dgetrf, double) // NOLINT - MXNET_LAPACK_CWRAPPER5(sgetri, float) - MXNET_LAPACK_CWRAPPER5(dgetri, double) +MXNET_LAPACK_CWRAPPER5(sgetri, float) // NOLINT +MXNET_LAPACK_CWRAPPER5(dgetri, double) // NOLINT - MXNET_LAPACK_CWRAPPER6(sgesvd, float) - MXNET_LAPACK_CWRAPPER6(dgesvd, double) +MXNET_LAPACK_CWRAPPER6(sgesvd, float) // NOLINT +MXNET_LAPACK_CWRAPPER6(dgesvd, double) // NOLINT - MXNET_LAPACK_CWRAPPER7(sgesv, float) - MXNET_LAPACK_CWRAPPER7(dgesv, double) +MXNET_LAPACK_CWRAPPER7(sgesv, float) // NOLINT +MXNET_LAPACK_CWRAPPER7(dgesv, double) // NOLINT - MXNET_LAPACK_CWRAPPER8(sgeev, float) - MXNET_LAPACK_CWRAPPER8(dgeev, double) +MXNET_LAPACK_CWRAPPER8(sgeev, float) // NOLINT +MXNET_LAPACK_CWRAPPER8(dgeev, double) // NOLINT - MXNET_LAPACK_CWRAPPER9(sgesdd, float) - MXNET_LAPACK_CWRAPPER9(dgesdd, double) +MXNET_LAPACK_CWRAPPER9(sgesdd, float) // NOLINT +MXNET_LAPACK_CWRAPPER9(dgesdd, double) // NOLINT - MXNET_LAPACK_CWRAPPER10(sgeqrf, float) - MXNET_LAPACK_CWRAPPER10(dgeqrf, double) +MXNET_LAPACK_CWRAPPER10(sgeqrf, float) // NOLINT +MXNET_LAPACK_CWRAPPER10(dgeqrf, double) // NOLINT - MXNET_LAPACK_CWRAPPER11(sgelsd, float) - MXNET_LAPACK_CWRAPPER11(dgelsd, double) +MXNET_LAPACK_CWRAPPER11(sgelsd, float) // NOLINT +MXNET_LAPACK_CWRAPPER11(dgelsd, double) // NOLINT - MXNET_LAPACK_CWRAPPER12(sorgqr, float) - MXNET_LAPACK_CWRAPPER12(dorgqr, double) +MXNET_LAPACK_CWRAPPER12(sorgqr, float) // NOLINT +MXNET_LAPACK_CWRAPPER12(dorgqr, double) // NOLINT #endif // MSHADOW_USE_MKL == 0 diff --git a/src/operator/c_lapack_api.h b/src/operator/c_lapack_api.h index 71161f108ce6..ee750013d80f 100644 --- a/src/operator/c_lapack_api.h +++ b/src/operator/c_lapack_api.h @@ -75,152 +75,229 @@ using namespace mshadow; extern "C" { - // Fortran signatures - #ifdef __ANDROID__ - #define MXNET_LAPACK_FSIGNATURE1(func, dtype) \ - int func##_(char* uplo, int* n, dtype* a, int* lda, int *info); - #else - #define MXNET_LAPACK_FSIGNATURE1(func, dtype) \ - void func##_(char* uplo, int* n, dtype* a, int* lda, int *info); - #endif - - MXNET_LAPACK_FSIGNATURE1(spotrf, float) - MXNET_LAPACK_FSIGNATURE1(dpotrf, double) - MXNET_LAPACK_FSIGNATURE1(spotri, float) - MXNET_LAPACK_FSIGNATURE1(dpotri, double) - - void dposv_(char *uplo, int *n, int *nrhs, - double *a, int *lda, double *b, int *ldb, int *info); - - void sposv_(char *uplo, int *n, int *nrhs, - float *a, int *lda, float *b, int *ldb, int *info); - - // Note: GELQF in row-major (MXNet) becomes GEQRF in column-major (LAPACK). - // Also, m and n are flipped, compared to the row-major version - #define MXNET_LAPACK_FSIG_GEQRF(func, dtype) \ - void func##_(int *m, int *n, dtype *a, int *lda, dtype *tau, dtype *work, \ - int *lwork, int *info); - - MXNET_LAPACK_FSIG_GEQRF(sgeqrf, float) - MXNET_LAPACK_FSIG_GEQRF(dgeqrf, double) - - // Note: ORGLQ in row-major (MXNet) becomes ORGQR in column-major (LAPACK) - // Also, m and n are flipped, compared to the row-major version - #define MXNET_LAPACK_FSIG_ORGQR(func, dtype) \ - void func##_(int *m, int *n, int *k, dtype *a, int *lda, dtype *tau, \ - dtype *work, int *lwork, int *info); - - MXNET_LAPACK_FSIG_ORGQR(sorgqr, float) - MXNET_LAPACK_FSIG_ORGQR(dorgqr, double) - - #define MXNET_LAPACK_FSIG_SYEVD(func, dtype) \ - void func##_(char *jobz, char *uplo, int *n, dtype *a, int *lda, dtype *w, \ - dtype *work, int *lwork, int *iwork, int *liwork, int *info); - - MXNET_LAPACK_FSIG_SYEVD(ssyevd, float) - MXNET_LAPACK_FSIG_SYEVD(dsyevd, double) - - #define MXNET_LAPACK_FSIG_GESVD(func, dtype) \ - void func##_(char *jobu, char *jobvt, int *m, int *n, dtype *a, int *lda, dtype *s, \ - dtype* u, int *ldu, dtype *vt, int *ldvt, dtype *work, int *lwork, int *info); - - MXNET_LAPACK_FSIG_GESVD(sgesvd, float) - MXNET_LAPACK_FSIG_GESVD(dgesvd, double) - - #ifdef __ANDROID__ - #define MXNET_LAPACK_FSIG_GETRF(func, dtype) \ - int func##_(int *m, int *n, dtype *a, int *lda, int *ipiv, int *info); - #else - #define MXNET_LAPACK_FSIG_GETRF(func, dtype) \ - void func##_(int *m, int *n, dtype *a, int *lda, int *ipiv, int *info); - #endif - - MXNET_LAPACK_FSIG_GETRF(sgetrf, float) - MXNET_LAPACK_FSIG_GETRF(dgetrf, double) - - #ifdef __ANDROID__ - #define MXNET_LAPACK_FSIG_GETRI(func, dtype) \ - int func##_(int *n, dtype *a, int *lda, int *ipiv, dtype *work, \ - int *lwork, int *info); - #else - #define MXNET_LAPACK_FSIG_GETRI(func, dtype) \ - void func##_(int *n, dtype *a, int *lda, int *ipiv, dtype *work, \ - int *lwork, int *info); - #endif - - MXNET_LAPACK_FSIG_GETRI(sgetri, float) - MXNET_LAPACK_FSIG_GETRI(dgetri, double) - - #ifdef __ANDROID__ - #define MXNET_LAPACK_FSIG_GESV(func, dtype) \ - int func##_(int *n, int *nrhs, dtype *a, int *lda, \ - int *ipiv, dtype *b, int *ldb, int *info); - #else - #define MXNET_LAPACK_FSIG_GESV(func, dtype) \ - void func##_(int *n, int *nrhs, dtype *a, int *lda, \ - int *ipiv, dtype *b, int *ldb, int *info); - #endif - - MXNET_LAPACK_FSIG_GESV(sgesv, float) - MXNET_LAPACK_FSIG_GESV(dgesv, double) - - #ifdef __ANDROID__ - #define MXNET_LAPACK_FSIG_GESDD(func, dtype) \ - int func##_(char *jobz, int *m, int *n, dtype *a, int *lda, dtype *s, \ - dtype *u, int *ldu, \ - dtype *vt, int *ldvt, \ - dtype *work, int *lwork, int *iwork, int *info); - #else - #define MXNET_LAPACK_FSIG_GESDD(func, dtype) \ - void func##_(char *jobz, int *m, int *n, dtype *a, int *lda, dtype *s, \ - dtype *u, int *ldu, \ - dtype *vt, int *ldvt, \ - dtype *work, int *lwork, int *iwork, int *info); - #endif - - MXNET_LAPACK_FSIG_GESDD(sgesdd, float) - MXNET_LAPACK_FSIG_GESDD(dgesdd, double) - - #ifdef __ANDROID__ - #define MXNET_LAPACK_FSIG_GEEV(func, dtype) \ - int func##_(char *jobvl, char *jobvr, int *n, dtype *a, int *lda, \ - dtype *wr, dtype *wi, \ - dtype *vl, int *ldvl, dtype *vr, int *ldvr, \ - dtype *work, int *lwork, int *info); - #else - #define MXNET_LAPACK_FSIG_GEEV(func, dtype) \ - void func##_(char *jobvl, char *jobvr, int *n, dtype *a, int *lda, \ - dtype *wr, dtype *wi, \ - dtype *vl, int *ldvl, dtype *vr, int *ldvr, \ - dtype *work, int *lwork, int *info); - #endif - - MXNET_LAPACK_FSIG_GEEV(sgeev, float) - MXNET_LAPACK_FSIG_GEEV(dgeev, double) - - #ifdef __ANDROID__ - #define MXNET_LAPACK_FSIG_GELSD(func, dtype) \ - int func##_(int *m, int *n, int *nrhs, dtype *a, int *lda, \ - dtype *b, int *ldb, dtype *s, dtype *rcond, int *rank, \ - dtype *work, int *lwork, int *iwork, int *info); - #else - #define MXNET_LAPACK_FSIG_GELSD(func, dtype) \ - void func##_(int *m, int *n, int *nrhs, dtype *a, int *lda, \ - dtype *b, int *ldb, dtype *s, dtype *rcond, int *rank, \ - dtype *work, int *lwork, int *iwork, int *info); - #endif - - MXNET_LAPACK_FSIG_GELSD(sgelsd, float) - MXNET_LAPACK_FSIG_GELSD(dgelsd, double) +// Fortran signatures +#ifdef __ANDROID__ +#define MXNET_LAPACK_FSIGNATURE1(func, dtype) \ + int func##_(char* uplo, int* n, dtype* a, int* lda, int* info); +#else +#define MXNET_LAPACK_FSIGNATURE1(func, dtype) \ + void func##_(char* uplo, int* n, dtype* a, int* lda, int* info); +#endif + +MXNET_LAPACK_FSIGNATURE1(spotrf, float) +MXNET_LAPACK_FSIGNATURE1(dpotrf, double) +MXNET_LAPACK_FSIGNATURE1(spotri, float) +MXNET_LAPACK_FSIGNATURE1(dpotri, double) + +void dposv_(char* uplo, int* n, int* nrhs, double* a, int* lda, double* b, int* ldb, int* info); + +void sposv_(char* uplo, int* n, int* nrhs, float* a, int* lda, float* b, int* ldb, int* info); + +// Note: GELQF in row-major (MXNet) becomes GEQRF in column-major (LAPACK). +// Also, m and n are flipped, compared to the row-major version +#define MXNET_LAPACK_FSIG_GEQRF(func, dtype) \ + void func##_(int* m, int* n, dtype* a, int* lda, dtype* tau, dtype* work, int* lwork, int* info); + +MXNET_LAPACK_FSIG_GEQRF(sgeqrf, float) +MXNET_LAPACK_FSIG_GEQRF(dgeqrf, double) + +// Note: ORGLQ in row-major (MXNet) becomes ORGQR in column-major (LAPACK) +// Also, m and n are flipped, compared to the row-major version +#define MXNET_LAPACK_FSIG_ORGQR(func, dtype) \ + void func##_( \ + int* m, int* n, int* k, dtype* a, int* lda, dtype* tau, dtype* work, int* lwork, int* info); + +MXNET_LAPACK_FSIG_ORGQR(sorgqr, float) +MXNET_LAPACK_FSIG_ORGQR(dorgqr, double) + +#define MXNET_LAPACK_FSIG_SYEVD(func, dtype) \ + void func##_(char* jobz, \ + char* uplo, \ + int* n, \ + dtype* a, \ + int* lda, \ + dtype* w, \ + dtype* work, \ + int* lwork, \ + int* iwork, \ + int* liwork, \ + int* info); + +MXNET_LAPACK_FSIG_SYEVD(ssyevd, float) +MXNET_LAPACK_FSIG_SYEVD(dsyevd, double) + +#define MXNET_LAPACK_FSIG_GESVD(func, dtype) \ + void func##_(char* jobu, \ + char* jobvt, \ + int* m, \ + int* n, \ + dtype* a, \ + int* lda, \ + dtype* s, \ + dtype* u, \ + int* ldu, \ + dtype* vt, \ + int* ldvt, \ + dtype* work, \ + int* lwork, \ + int* info); + +MXNET_LAPACK_FSIG_GESVD(sgesvd, float) +MXNET_LAPACK_FSIG_GESVD(dgesvd, double) + +#ifdef __ANDROID__ +#define MXNET_LAPACK_FSIG_GETRF(func, dtype) \ + int func##_(int* m, int* n, dtype* a, int* lda, int* ipiv, int* info); +#else +#define MXNET_LAPACK_FSIG_GETRF(func, dtype) \ + void func##_(int* m, int* n, dtype* a, int* lda, int* ipiv, int* info); +#endif + +MXNET_LAPACK_FSIG_GETRF(sgetrf, float) +MXNET_LAPACK_FSIG_GETRF(dgetrf, double) + +#ifdef __ANDROID__ +#define MXNET_LAPACK_FSIG_GETRI(func, dtype) \ + int func##_(int* n, dtype* a, int* lda, int* ipiv, dtype* work, int* lwork, int* info); +#else +#define MXNET_LAPACK_FSIG_GETRI(func, dtype) \ + void func##_(int* n, dtype* a, int* lda, int* ipiv, dtype* work, int* lwork, int* info); +#endif + +MXNET_LAPACK_FSIG_GETRI(sgetri, float) +MXNET_LAPACK_FSIG_GETRI(dgetri, double) + +#ifdef __ANDROID__ +#define MXNET_LAPACK_FSIG_GESV(func, dtype) \ + int func##_(int* n, int* nrhs, dtype* a, int* lda, int* ipiv, dtype* b, int* ldb, int* info); +#else +#define MXNET_LAPACK_FSIG_GESV(func, dtype) \ + void func##_(int* n, int* nrhs, dtype* a, int* lda, int* ipiv, dtype* b, int* ldb, int* info); +#endif + +MXNET_LAPACK_FSIG_GESV(sgesv, float) +MXNET_LAPACK_FSIG_GESV(dgesv, double) + +#ifdef __ANDROID__ +#define MXNET_LAPACK_FSIG_GESDD(func, dtype) \ + int func##_(char* jobz, \ + int* m, \ + int* n, \ + dtype* a, \ + int* lda, \ + dtype* s, \ + dtype* u, \ + int* ldu, \ + dtype* vt, \ + int* ldvt, \ + dtype* work, \ + int* lwork, \ + int* iwork, \ + int* info); +#else +#define MXNET_LAPACK_FSIG_GESDD(func, dtype) \ + void func##_(char* jobz, \ + int* m, \ + int* n, \ + dtype* a, \ + int* lda, \ + dtype* s, \ + dtype* u, \ + int* ldu, \ + dtype* vt, \ + int* ldvt, \ + dtype* work, \ + int* lwork, \ + int* iwork, \ + int* info); +#endif + +MXNET_LAPACK_FSIG_GESDD(sgesdd, float) +MXNET_LAPACK_FSIG_GESDD(dgesdd, double) + +#ifdef __ANDROID__ +#define MXNET_LAPACK_FSIG_GEEV(func, dtype) \ + int func##_(char* jobvl, \ + char* jobvr, \ + int* n, \ + dtype* a, \ + int* lda, \ + dtype* wr, \ + dtype* wi, \ + dtype* vl, \ + int* ldvl, \ + dtype* vr, \ + int* ldvr, \ + dtype* work, \ + int* lwork, \ + int* info); +#else +#define MXNET_LAPACK_FSIG_GEEV(func, dtype) \ + void func##_(char* jobvl, \ + char* jobvr, \ + int* n, \ + dtype* a, \ + int* lda, \ + dtype* wr, \ + dtype* wi, \ + dtype* vl, \ + int* ldvl, \ + dtype* vr, \ + int* ldvr, \ + dtype* work, \ + int* lwork, \ + int* info); +#endif + +MXNET_LAPACK_FSIG_GEEV(sgeev, float) +MXNET_LAPACK_FSIG_GEEV(dgeev, double) + +#ifdef __ANDROID__ +#define MXNET_LAPACK_FSIG_GELSD(func, dtype) \ + int func##_(int* m, \ + int* n, \ + int* nrhs, \ + dtype* a, \ + int* lda, \ + dtype* b, \ + int* ldb, \ + dtype* s, \ + dtype* rcond, \ + int* rank, \ + dtype* work, \ + int* lwork, \ + int* iwork, \ + int* info); +#else +#define MXNET_LAPACK_FSIG_GELSD(func, dtype) \ + void func##_(int* m, \ + int* n, \ + int* nrhs, \ + dtype* a, \ + int* lda, \ + dtype* b, \ + int* ldb, \ + dtype* s, \ + dtype* rcond, \ + int* rank, \ + dtype* work, \ + int* lwork, \ + int* iwork, \ + int* info); +#endif + +MXNET_LAPACK_FSIG_GELSD(sgelsd, float) +MXNET_LAPACK_FSIG_GELSD(dgelsd, double) } #endif // MSHADOW_USE_MKL == 0 - #define CHECK_LAPACK_UPLO(a) \ CHECK(a == 'U' || a == 'L') << "neither L nor U specified as triangle in lapack call"; -inline char loup(char uplo, bool invert) { return invert ? (uplo == 'U' ? 'L' : 'U') : uplo; } +inline char loup(char uplo, bool invert) { + return invert ? (uplo == 'U' ? 'L' : 'U') : uplo; +} /*! * \brief Transpose matrix data in memory @@ -236,594 +313,772 @@ inline char loup(char uplo, bool invert) { return invert ? (uplo == 'U' ? 'L' : * \param lda leading dimension of a */ template -inline void flip(int m, int n, DType *b, int ldb, DType *a, int lda) { +inline void flip(int m, int n, DType* b, int ldb, DType* a, int lda) { for (int i = 0; i < m; ++i) for (int j = 0; j < n; ++j) b[j * ldb + i] = a[i * lda + j]; } - #if (MXNET_USE_LAPACK && (MSHADOW_USE_MKL || MXNET_USE_LAPACKE_INTERFACE)) - #if MSHADOW_USE_MKL - #include - #else - #if MXNET_USE_ILP64_LAPACKE - #define lapack_int int64_t - #endif - // prevent multiple inclusion of complex.h in lapacke.h - #define lapack_complex_float float _Complex - #define lapack_complex_double double _Complex - #include - #endif - - #define MXNET_LAPACK_ROW_MAJOR LAPACK_ROW_MAJOR - #define MXNET_LAPACK_COL_MAJOR LAPACK_COL_MAJOR - - // These function have already matching signature. - #define MXNET_LAPACK_spotrf LAPACKE_spotrf - #define MXNET_LAPACK_dpotrf LAPACKE_dpotrf - #define MXNET_LAPACK_spotri LAPACKE_spotri - #define MXNET_LAPACK_dpotri LAPACKE_dpotri - #define mxnet_lapack_sposv LAPACKE_sposv - #define mxnet_lapack_dposv LAPACKE_dposv - #define MXNET_LAPACK_dgesv LAPACKE_dgesv - #define MXNET_LAPACK_sgesv LAPACKE_sgesv - - // The following functions differ in signature from the - // MXNET_LAPACK-signature and have to be wrapped. - #define MXNET_LAPACK_CWRAP_GELQF(prefix, dtype) \ - inline int MXNET_LAPACK_##prefix##gelqf(int matrix_layout, lapack_index_t m, lapack_index_t n, \ - dtype *a, lapack_index_t lda, dtype *tau, \ - dtype *work, lapack_index_t lwork) { \ - if (lwork != -1) { \ +#if MSHADOW_USE_MKL +#include +#else +#if MXNET_USE_ILP64_LAPACKE +#define lapack_int int64_t +#endif +// prevent multiple inclusion of complex.h in lapacke.h +#define lapack_complex_float float _Complex +#define lapack_complex_double double _Complex +#include +#endif + +#define MXNET_LAPACK_ROW_MAJOR LAPACK_ROW_MAJOR +#define MXNET_LAPACK_COL_MAJOR LAPACK_COL_MAJOR + +// These function have already matching signature. +#define MXNET_LAPACK_spotrf LAPACKE_spotrf +#define MXNET_LAPACK_dpotrf LAPACKE_dpotrf +#define MXNET_LAPACK_spotri LAPACKE_spotri +#define MXNET_LAPACK_dpotri LAPACKE_dpotri +#define mxnet_lapack_sposv LAPACKE_sposv +#define mxnet_lapack_dposv LAPACKE_dposv +#define MXNET_LAPACK_dgesv LAPACKE_dgesv +#define MXNET_LAPACK_sgesv LAPACKE_sgesv + +// The following functions differ in signature from the +// MXNET_LAPACK-signature and have to be wrapped. +#define MXNET_LAPACK_CWRAP_GELQF(prefix, dtype) \ + inline int MXNET_LAPACK_##prefix##gelqf(int matrix_layout, \ + lapack_index_t m, \ + lapack_index_t n, \ + dtype* a, \ + lapack_index_t lda, \ + dtype* tau, \ + dtype* work, \ + lapack_index_t lwork) { \ + if (lwork != -1) { \ return LAPACKE_##prefix##gelqf(matrix_layout, m, n, a, lda, tau); \ - } \ - *work = 0; \ - return 0; \ + } \ + *work = 0; \ + return 0; \ } - MXNET_LAPACK_CWRAP_GELQF(s, float) - MXNET_LAPACK_CWRAP_GELQF(d, double) - - #define MXNET_LAPACK_CWRAP_ORGLQ(prefix, dtype) \ - inline int MXNET_LAPACK_##prefix##orglq(int matrix_layout, lapack_index_t m, lapack_index_t n, \ - dtype *a, lapack_index_t lda, dtype *tau, \ - dtype *work, lapack_index_t lwork) { \ - if (lwork != -1) { \ +MXNET_LAPACK_CWRAP_GELQF(s, float) +MXNET_LAPACK_CWRAP_GELQF(d, double) + +#define MXNET_LAPACK_CWRAP_ORGLQ(prefix, dtype) \ + inline int MXNET_LAPACK_##prefix##orglq(int matrix_layout, \ + lapack_index_t m, \ + lapack_index_t n, \ + dtype* a, \ + lapack_index_t lda, \ + dtype* tau, \ + dtype* work, \ + lapack_index_t lwork) { \ + if (lwork != -1) { \ return LAPACKE_##prefix##orglq(matrix_layout, m, n, m, a, lda, tau); \ - } \ - *work = 0; \ - return 0; \ + } \ + *work = 0; \ + return 0; \ } - MXNET_LAPACK_CWRAP_ORGLQ(s, float) - MXNET_LAPACK_CWRAP_ORGLQ(d, double) - - #define MXNET_LAPACK_CWRAP_GEQRF(prefix, dtype) \ - inline int MXNET_LAPACK_##prefix##geqrf(int matrix_layout, lapack_index_t m, lapack_index_t n, \ - dtype *a, lapack_index_t lda, dtype *tau, \ - dtype *work, lapack_index_t lwork) { \ - if (lwork != -1) { \ +MXNET_LAPACK_CWRAP_ORGLQ(s, float) +MXNET_LAPACK_CWRAP_ORGLQ(d, double) + +#define MXNET_LAPACK_CWRAP_GEQRF(prefix, dtype) \ + inline int MXNET_LAPACK_##prefix##geqrf(int matrix_layout, \ + lapack_index_t m, \ + lapack_index_t n, \ + dtype* a, \ + lapack_index_t lda, \ + dtype* tau, \ + dtype* work, \ + lapack_index_t lwork) { \ + if (lwork != -1) { \ return LAPACKE_##prefix##geqrf(matrix_layout, m, n, a, lda, tau); \ - } \ - *work = 0; \ - return 0; \ + } \ + *work = 0; \ + return 0; \ } - MXNET_LAPACK_CWRAP_GEQRF(s, float) - MXNET_LAPACK_CWRAP_GEQRF(d, double) - - #define MXNET_LAPACK_CWRAP_ORGQR(prefix, dtype) \ - inline int MXNET_LAPACK_##prefix##orgqr(int matrix_layout, lapack_index_t m, lapack_index_t n, \ - lapack_index_t k, dtype *a, lapack_index_t lda, \ - dtype *tau, dtype *work, lapack_index_t lwork) { \ - if (lwork != -1) { \ +MXNET_LAPACK_CWRAP_GEQRF(s, float) +MXNET_LAPACK_CWRAP_GEQRF(d, double) + +#define MXNET_LAPACK_CWRAP_ORGQR(prefix, dtype) \ + inline int MXNET_LAPACK_##prefix##orgqr(int matrix_layout, \ + lapack_index_t m, \ + lapack_index_t n, \ + lapack_index_t k, \ + dtype* a, \ + lapack_index_t lda, \ + dtype* tau, \ + dtype* work, \ + lapack_index_t lwork) { \ + if (lwork != -1) { \ return LAPACKE_##prefix##orgqr(matrix_layout, m, n, k, a, lda, tau); \ - } \ - *work = 0; \ - return 0; \ + } \ + *work = 0; \ + return 0; \ } - MXNET_LAPACK_CWRAP_ORGQR(s, float) - MXNET_LAPACK_CWRAP_ORGQR(d, double) - - // This has to be called internally in COL_MAJOR format even when matrix_layout - // is row-major as otherwise the eigenvectors would be returned as cols in a - // row-major matrix layout (see MKL documentation). - // We also have to allocate at least one DType element as workspace as the - // calling code assumes that the workspace has at least that size. - #define MXNET_LAPACK_CWRAP_SYEVD(prefix, dtype) \ - inline int MXNET_LAPACK_##prefix##syevd(int matrix_layout, char uplo, lapack_index_t n, \ - dtype *a, lapack_index_t lda, dtype *w, \ - dtype *work, lapack_index_t lwork, \ - lapack_index_t *iwork, lapack_index_t liwork) { \ - if (lwork != -1) { \ - char o(loup(uplo, (matrix_layout == MXNET_LAPACK_ROW_MAJOR))); \ +MXNET_LAPACK_CWRAP_ORGQR(s, float) +MXNET_LAPACK_CWRAP_ORGQR(d, double) + +// This has to be called internally in COL_MAJOR format even when matrix_layout +// is row-major as otherwise the eigenvectors would be returned as cols in a +// row-major matrix layout (see MKL documentation). +// We also have to allocate at least one DType element as workspace as the +// calling code assumes that the workspace has at least that size. +#define MXNET_LAPACK_CWRAP_SYEVD(prefix, dtype) \ + inline int MXNET_LAPACK_##prefix##syevd(int matrix_layout, \ + char uplo, \ + lapack_index_t n, \ + dtype* a, \ + lapack_index_t lda, \ + dtype* w, \ + dtype* work, \ + lapack_index_t lwork, \ + lapack_index_t* iwork, \ + lapack_index_t liwork) { \ + if (lwork != -1) { \ + char o(loup(uplo, (matrix_layout == MXNET_LAPACK_ROW_MAJOR))); \ return LAPACKE_##prefix##syevd(LAPACK_COL_MAJOR, 'V', o, n, a, lda, w); \ - } \ - *work = 1; \ - *iwork = 0; \ - return 0; \ + } \ + *work = 1; \ + *iwork = 0; \ + return 0; \ } - MXNET_LAPACK_CWRAP_SYEVD(s, float) - MXNET_LAPACK_CWRAP_SYEVD(d, double) - - #define MXNET_LAPACK_sgetrf LAPACKE_sgetrf - #define MXNET_LAPACK_dgetrf LAPACKE_dgetrf - - // Internally A is factorized as U * L * VT, and (according to the tech report) - // we want to factorize it as UT * L * V, so we pass ut as u and v as vt. - // We also have to allocate at least m - 1 DType elements as workspace as the internal - // LAPACKE function needs it to store `superb`. (see MKL documentation) - #define MXNET_LAPACK_CWRAP_GESVD(prefix, dtype) \ - inline int MXNET_LAPACK_##prefix##gesvd(int matrix_layout, lapack_index_t m, lapack_index_t n, \ - dtype* ut, lapack_index_t ldut, dtype* s, dtype* v, \ - lapack_index_t ldv, dtype* work, lapack_index_t lwork) { \ - if (lwork != -1) { \ - return LAPACKE_##prefix##gesvd(matrix_layout, 'S', 'O', m, n, v, ldv, s, ut, ldut, \ - v, ldv, work); \ - } \ - *work = m - 1; \ - return 0; \ +MXNET_LAPACK_CWRAP_SYEVD(s, float) +MXNET_LAPACK_CWRAP_SYEVD(d, double) + +#define MXNET_LAPACK_sgetrf LAPACKE_sgetrf +#define MXNET_LAPACK_dgetrf LAPACKE_dgetrf + +// Internally A is factorized as U * L * VT, and (according to the tech report) +// we want to factorize it as UT * L * V, so we pass ut as u and v as vt. +// We also have to allocate at least m - 1 DType elements as workspace as the internal +// LAPACKE function needs it to store `superb`. (see MKL documentation) +#define MXNET_LAPACK_CWRAP_GESVD(prefix, dtype) \ + inline int MXNET_LAPACK_##prefix##gesvd(int matrix_layout, \ + lapack_index_t m, \ + lapack_index_t n, \ + dtype* ut, \ + lapack_index_t ldut, \ + dtype* s, \ + dtype* v, \ + lapack_index_t ldv, \ + dtype* work, \ + lapack_index_t lwork) { \ + if (lwork != -1) { \ + return LAPACKE_##prefix##gesvd( \ + matrix_layout, 'S', 'O', m, n, v, ldv, s, ut, ldut, v, ldv, work); \ + } \ + *work = m - 1; \ + return 0; \ } - MXNET_LAPACK_CWRAP_GESVD(s, float) - MXNET_LAPACK_CWRAP_GESVD(d, double) - - // Computes the singular value decomposition of a general rectangular matrix - // using a divide and conquer method. - #define MXNET_LAPACK_CWRAP_GESDD(prefix, dtype) \ - inline int MXNET_LAPACK_##prefix##gesdd(int matrix_layout, lapack_index_t m, lapack_index_t n, \ - dtype *a, lapack_index_t lda, dtype *s, \ - dtype *u, lapack_index_t ldu, \ - dtype *vt, lapack_index_t ldvt, \ - dtype *work, lapack_index_t lwork, \ - lapack_index_t *iwork) { \ - if (lwork != -1) { \ - return LAPACKE_##prefix##gesdd(matrix_layout, 'O', m, n, a, lda, \ - s, u, ldu, vt, ldvt); \ - } \ - *work = 0; \ - return 0; \ +MXNET_LAPACK_CWRAP_GESVD(s, float) +MXNET_LAPACK_CWRAP_GESVD(d, double) + +// Computes the singular value decomposition of a general rectangular matrix +// using a divide and conquer method. +#define MXNET_LAPACK_CWRAP_GESDD(prefix, dtype) \ + inline int MXNET_LAPACK_##prefix##gesdd(int matrix_layout, \ + lapack_index_t m, \ + lapack_index_t n, \ + dtype* a, \ + lapack_index_t lda, \ + dtype* s, \ + dtype* u, \ + lapack_index_t ldu, \ + dtype* vt, \ + lapack_index_t ldvt, \ + dtype* work, \ + lapack_index_t lwork, \ + lapack_index_t* iwork) { \ + if (lwork != -1) { \ + return LAPACKE_##prefix##gesdd(matrix_layout, 'O', m, n, a, lda, s, u, ldu, vt, ldvt); \ + } \ + *work = 0; \ + return 0; \ } - MXNET_LAPACK_CWRAP_GESDD(s, float) - MXNET_LAPACK_CWRAP_GESDD(d, double) - - #define MXNET_LAPACK_CWRAP_GETRI(prefix, dtype) \ - inline int MXNET_LAPACK_##prefix##getri(int matrix_layout, lapack_index_t n, dtype *a, \ - lapack_index_t lda, lapack_index_t *ipiv, \ - dtype *work, lapack_index_t lwork) { \ - if (lwork != -1) { \ +MXNET_LAPACK_CWRAP_GESDD(s, float) +MXNET_LAPACK_CWRAP_GESDD(d, double) + +#define MXNET_LAPACK_CWRAP_GETRI(prefix, dtype) \ + inline int MXNET_LAPACK_##prefix##getri(int matrix_layout, \ + lapack_index_t n, \ + dtype* a, \ + lapack_index_t lda, \ + lapack_index_t* ipiv, \ + dtype* work, \ + lapack_index_t lwork) { \ + if (lwork != -1) { \ return LAPACKE_##prefix##getri(matrix_layout, n, a, lda, ipiv); \ - } \ - *work = 0; \ - return 0; \ + } \ + *work = 0; \ + return 0; \ } - MXNET_LAPACK_CWRAP_GETRI(s, float) - MXNET_LAPACK_CWRAP_GETRI(d, double) - - #define MXNET_LAPACK_CWRAP_GEEV(prefix, dtype) \ - inline int MXNET_LAPACK_##prefix##geev(int matrix_layout, char jobvl, char jobvr, \ - lapack_index_t n, dtype *a, lapack_index_t lda, \ - dtype *wr, dtype *wi, \ - dtype *vl, lapack_index_t ldvl, dtype *vr, \ - lapack_index_t ldvr, \ - dtype *work, lapack_index_t lwork) { \ - if (lwork != -1) { \ - return LAPACKE_##prefix##geev(matrix_layout, jobvl, jobvr, \ - n, a, lda, wr, wi, vl, ldvl, vr, ldvr); \ - } \ - *work = 0; \ - return 0; \ +MXNET_LAPACK_CWRAP_GETRI(s, float) +MXNET_LAPACK_CWRAP_GETRI(d, double) + +#define MXNET_LAPACK_CWRAP_GEEV(prefix, dtype) \ + inline int MXNET_LAPACK_##prefix##geev(int matrix_layout, \ + char jobvl, \ + char jobvr, \ + lapack_index_t n, \ + dtype* a, \ + lapack_index_t lda, \ + dtype* wr, \ + dtype* wi, \ + dtype* vl, \ + lapack_index_t ldvl, \ + dtype* vr, \ + lapack_index_t ldvr, \ + dtype* work, \ + lapack_index_t lwork) { \ + if (lwork != -1) { \ + return LAPACKE_##prefix##geev( \ + matrix_layout, jobvl, jobvr, n, a, lda, wr, wi, vl, ldvl, vr, ldvr); \ + } \ + *work = 0; \ + return 0; \ } - MXNET_LAPACK_CWRAP_GEEV(s, float) - MXNET_LAPACK_CWRAP_GEEV(d, double) - - #define MXNET_LAPACK_CWRAP_GELSD(prefix, dtype) \ - inline int MXNET_LAPACK_##prefix##gelsd(int matrix_layout, lapack_index_t m, lapack_index_t n, \ - lapack_index_t nrhs, dtype *a, lapack_index_t lda, \ - dtype *b, lapack_index_t ldb, dtype *s, dtype rcond, \ - lapack_index_t *rank, dtype *work, lapack_index_t lwork, \ - lapack_index_t *iwork) { \ - if (lwork != -1) { \ - return LAPACKE_##prefix##gelsd(matrix_layout, m, n, nrhs, a, lda, b, ldb, \ - s, rcond, rank); \ - } \ - *work = 0; \ - *iwork = 0; \ - return 0; \ +MXNET_LAPACK_CWRAP_GEEV(s, float) +MXNET_LAPACK_CWRAP_GEEV(d, double) + +#define MXNET_LAPACK_CWRAP_GELSD(prefix, dtype) \ + inline int MXNET_LAPACK_##prefix##gelsd(int matrix_layout, \ + lapack_index_t m, \ + lapack_index_t n, \ + lapack_index_t nrhs, \ + dtype* a, \ + lapack_index_t lda, \ + dtype* b, \ + lapack_index_t ldb, \ + dtype* s, \ + dtype rcond, \ + lapack_index_t* rank, \ + dtype* work, \ + lapack_index_t lwork, \ + lapack_index_t* iwork) { \ + if (lwork != -1) { \ + return LAPACKE_##prefix##gelsd(matrix_layout, m, n, nrhs, a, lda, b, ldb, s, rcond, rank); \ + } \ + *work = 0; \ + *iwork = 0; \ + return 0; \ } - MXNET_LAPACK_CWRAP_GELSD(s, float) - MXNET_LAPACK_CWRAP_GELSD(d, double) +MXNET_LAPACK_CWRAP_GELSD(s, float) +MXNET_LAPACK_CWRAP_GELSD(d, double) #elif MXNET_USE_LAPACK - #define MXNET_LAPACK_ROW_MAJOR 101 - #define MXNET_LAPACK_COL_MAJOR 102 - - // These functions can be called with either row- or col-major format. - #define MXNET_LAPACK_CWRAPPER1(func, dtype) \ - inline int MXNET_LAPACK_##func(int matrix_layout, char uplo, int n, dtype *a, int lda) { \ - CHECK_LAPACK_UPLO(uplo); \ - char o(loup(uplo, (matrix_layout == MXNET_LAPACK_ROW_MAJOR))); \ - int ret(0); \ - func##_(&o, &n, a, &lda, &ret); \ - return ret; \ +#define MXNET_LAPACK_ROW_MAJOR 101 +#define MXNET_LAPACK_COL_MAJOR 102 + +// These functions can be called with either row- or col-major format. +#define MXNET_LAPACK_CWRAPPER1(func, dtype) \ + inline int MXNET_LAPACK_##func(int matrix_layout, char uplo, int n, dtype* a, int lda) { \ + CHECK_LAPACK_UPLO(uplo); \ + char o(loup(uplo, (matrix_layout == MXNET_LAPACK_ROW_MAJOR))); \ + int ret(0); \ + func##_(&o, &n, a, &lda, &ret); \ + return ret; \ } - MXNET_LAPACK_CWRAPPER1(spotrf, float) - MXNET_LAPACK_CWRAPPER1(dpotrf, double) - MXNET_LAPACK_CWRAPPER1(spotri, float) - MXNET_LAPACK_CWRAPPER1(dpotri, double) - - inline int mxnet_lapack_sposv(int matrix_layout, char uplo, int n, int nrhs, - float *a, int lda, float *b, int ldb) { - int info; - if (matrix_layout == MXNET_LAPACK_ROW_MAJOR) { - // Transpose b to b_t of shape (nrhs, n) - float *b_t = new float[nrhs * n]; - flip(n, nrhs, b_t, n, b, ldb); - sposv_(&uplo, &n, &nrhs, a, &lda, b_t, &n, &info); - flip(nrhs, n, b, ldb, b_t, n); - delete [] b_t; - return info; - } - sposv_(&uplo, &n, &nrhs, a, &lda, b, &ldb, &info); +MXNET_LAPACK_CWRAPPER1(spotrf, float) +MXNET_LAPACK_CWRAPPER1(dpotrf, double) +MXNET_LAPACK_CWRAPPER1(spotri, float) +MXNET_LAPACK_CWRAPPER1(dpotri, double) + +inline int mxnet_lapack_sposv(int matrix_layout, + char uplo, + int n, + int nrhs, + float* a, + int lda, + float* b, + int ldb) { + int info; + if (matrix_layout == MXNET_LAPACK_ROW_MAJOR) { + // Transpose b to b_t of shape (nrhs, n) + float* b_t = new float[nrhs * n]; + flip(n, nrhs, b_t, n, b, ldb); + sposv_(&uplo, &n, &nrhs, a, &lda, b_t, &n, &info); + flip(nrhs, n, b, ldb, b_t, n); + delete[] b_t; return info; } + sposv_(&uplo, &n, &nrhs, a, &lda, b, &ldb, &info); + return info; +} - inline int mxnet_lapack_dposv(int matrix_layout, char uplo, int n, int nrhs, - double *a, int lda, double *b, int ldb) { - int info; - if (matrix_layout == MXNET_LAPACK_ROW_MAJOR) { - // Transpose b to b_t of shape (nrhs, n) - double *b_t = new double[nrhs * n]; - flip(n, nrhs, b_t, n, b, ldb); - dposv_(&uplo, &n, &nrhs, a, &lda, b_t, &n, &info); - flip(nrhs, n, b, ldb, b_t, n); - delete [] b_t; - return info; - } - dposv_(&uplo, &n, &nrhs, a, &lda, b, &ldb, &info); +inline int mxnet_lapack_dposv(int matrix_layout, + char uplo, + int n, + int nrhs, + double* a, + int lda, + double* b, + int ldb) { + int info; + if (matrix_layout == MXNET_LAPACK_ROW_MAJOR) { + // Transpose b to b_t of shape (nrhs, n) + double* b_t = new double[nrhs * n]; + flip(n, nrhs, b_t, n, b, ldb); + dposv_(&uplo, &n, &nrhs, a, &lda, b_t, &n, &info); + flip(nrhs, n, b, ldb, b_t, n); + delete[] b_t; return info; } + dposv_(&uplo, &n, &nrhs, a, &lda, b, &ldb, &info); + return info; +} - // Note: Both MXNET_LAPACK_*gelqf, MXNET_LAPACK_*orglq can only be called with - // row-major format (MXNet). Internally, the QR variants are done in column-major. - // In particular, the matrix dimensions m and n are flipped. - #define MXNET_LAPACK_CWRAP_GELQF(prefix, dtype) \ - inline int MXNET_LAPACK_##prefix##gelqf(int matrix_layout, int m, int n, \ - dtype *a, int lda, dtype* tau, \ - dtype* work, int lwork) { \ - if (matrix_layout == MXNET_LAPACK_ROW_MAJOR) { \ - int info(0); \ - prefix##geqrf_(&n, &m, a, &lda, tau, work, &lwork, &info); \ - return info; \ - } else { \ +// Note: Both MXNET_LAPACK_*gelqf, MXNET_LAPACK_*orglq can only be called with +// row-major format (MXNet). Internally, the QR variants are done in column-major. +// In particular, the matrix dimensions m and n are flipped. +#define MXNET_LAPACK_CWRAP_GELQF(prefix, dtype) \ + inline int MXNET_LAPACK_##prefix##gelqf( \ + int matrix_layout, int m, int n, dtype* a, int lda, dtype* tau, dtype* work, int lwork) { \ + if (matrix_layout == MXNET_LAPACK_ROW_MAJOR) { \ + int info(0); \ + prefix##geqrf_(&n, &m, a, &lda, tau, work, &lwork, &info); \ + return info; \ + } else { \ CHECK(false) << "MXNET_LAPACK_" << #prefix << "gelqf implemented for row-major layout only"; \ - return 1; \ - } \ + return 1; \ + } \ } - MXNET_LAPACK_CWRAP_GELQF(s, float) - MXNET_LAPACK_CWRAP_GELQF(d, double) - - // Note: The k argument (rank) is equal to m as well - #define MXNET_LAPACK_CWRAP_ORGLQ(prefix, dtype) \ - inline int MXNET_LAPACK_##prefix##orglq(int matrix_layout, int m, int n, \ - dtype *a, int lda, dtype* tau, \ - dtype* work, int lwork) { \ - if (matrix_layout == MXNET_LAPACK_ROW_MAJOR) { \ - int info(0); \ - prefix##orgqr_(&n, &m, &m, a, &lda, tau, work, &lwork, &info); \ - return info; \ - } else { \ +MXNET_LAPACK_CWRAP_GELQF(s, float) +MXNET_LAPACK_CWRAP_GELQF(d, double) + +// Note: The k argument (rank) is equal to m as well +#define MXNET_LAPACK_CWRAP_ORGLQ(prefix, dtype) \ + inline int MXNET_LAPACK_##prefix##orglq( \ + int matrix_layout, int m, int n, dtype* a, int lda, dtype* tau, dtype* work, int lwork) { \ + if (matrix_layout == MXNET_LAPACK_ROW_MAJOR) { \ + int info(0); \ + prefix##orgqr_(&n, &m, &m, a, &lda, tau, work, &lwork, &info); \ + return info; \ + } else { \ CHECK(false) << "MXNET_LAPACK_" << #prefix << "orglq implemented for row-major layout only"; \ - return 1; \ - } \ + return 1; \ + } \ } - MXNET_LAPACK_CWRAP_ORGLQ(s, float) - MXNET_LAPACK_CWRAP_ORGLQ(d, double) - - #define MXNET_LAPACK_CWRAP_GEQRF(prefix, dtype) \ - inline int MXNET_LAPACK_##prefix##geqrf(int matrix_layout, int m, int n, \ - dtype *a, int lda, dtype* tau, \ - dtype* work, int lwork) { \ - if (matrix_layout == MXNET_LAPACK_ROW_MAJOR) { \ +MXNET_LAPACK_CWRAP_ORGLQ(s, float) +MXNET_LAPACK_CWRAP_ORGLQ(d, double) + +#define MXNET_LAPACK_CWRAP_GEQRF(prefix, dtype) \ + inline int MXNET_LAPACK_##prefix##geqrf( \ + int matrix_layout, int m, int n, dtype* a, int lda, dtype* tau, dtype* work, int lwork) { \ + if (matrix_layout == MXNET_LAPACK_ROW_MAJOR) { \ CHECK(false) << "MXNET_LAPACK_" << #prefix << "geqrf implemented for col-major layout only"; \ - return 1; \ - } else { \ - int info(0); \ - prefix##geqrf_(&m, &n, a, &lda, tau, work, &lwork, &info); \ - return info; \ - } \ + return 1; \ + } else { \ + int info(0); \ + prefix##geqrf_(&m, &n, a, &lda, tau, work, &lwork, &info); \ + return info; \ + } \ } - MXNET_LAPACK_CWRAP_GEQRF(s, float) - MXNET_LAPACK_CWRAP_GEQRF(d, double) - - #define MXNET_LAPACK_CWRAP_ORGQR(prefix, dtype) \ - inline int MXNET_LAPACK_##prefix##orgqr(int matrix_layout, int m, int n, int k, \ - dtype *a, int lda, dtype* tau, \ - dtype* work, int lwork) { \ - if (matrix_layout == MXNET_LAPACK_ROW_MAJOR) { \ +MXNET_LAPACK_CWRAP_GEQRF(s, float) +MXNET_LAPACK_CWRAP_GEQRF(d, double) + +#define MXNET_LAPACK_CWRAP_ORGQR(prefix, dtype) \ + inline int MXNET_LAPACK_##prefix##orgqr(int matrix_layout, \ + int m, \ + int n, \ + int k, \ + dtype* a, \ + int lda, \ + dtype* tau, \ + dtype* work, \ + int lwork) { \ + if (matrix_layout == MXNET_LAPACK_ROW_MAJOR) { \ CHECK(false) << "MXNET_LAPACK_" << #prefix << "orgqr implemented for col-major layout only"; \ - return 1; \ - } else { \ - int info(0); \ - prefix##orgqr_(&m, &n, &k, a, &lda, tau, work, &lwork, &info); \ - return info; \ - } \ + return 1; \ + } else { \ + int info(0); \ + prefix##orgqr_(&m, &n, &k, a, &lda, tau, work, &lwork, &info); \ + return info; \ + } \ } - MXNET_LAPACK_CWRAP_ORGQR(s, float) - MXNET_LAPACK_CWRAP_ORGQR(d, double) - - // Note: Supports row-major format only. Internally, column-major is used, so all - // inputs/outputs are flipped (in particular, uplo is flipped). - #define MXNET_LAPACK_CWRAP_SYEVD(func, dtype) \ - inline int MXNET_LAPACK_##func(int matrix_layout, char uplo, int n, dtype *a, \ - int lda, dtype *w, dtype *work, int lwork, \ - int *iwork, int liwork) { \ - if (matrix_layout == MXNET_LAPACK_ROW_MAJOR) { \ - int info(0); \ - char jobz('V'); \ - char uplo_(loup(uplo, true)); \ - func##_(&jobz, &uplo_, &n, a, &lda, w, work, &lwork, iwork, &liwork, &info); \ - return info; \ - } else { \ +MXNET_LAPACK_CWRAP_ORGQR(s, float) +MXNET_LAPACK_CWRAP_ORGQR(d, double) + +// Note: Supports row-major format only. Internally, column-major is used, so all +// inputs/outputs are flipped (in particular, uplo is flipped). +#define MXNET_LAPACK_CWRAP_SYEVD(func, dtype) \ + inline int MXNET_LAPACK_##func(int matrix_layout, \ + char uplo, \ + int n, \ + dtype* a, \ + int lda, \ + dtype* w, \ + dtype* work, \ + int lwork, \ + int* iwork, \ + int liwork) { \ + if (matrix_layout == MXNET_LAPACK_ROW_MAJOR) { \ + int info(0); \ + char jobz('V'); \ + char uplo_(loup(uplo, true)); \ + func##_(&jobz, &uplo_, &n, a, &lda, w, work, &lwork, iwork, &liwork, &info); \ + return info; \ + } else { \ CHECK(false) << "MXNET_LAPACK_" << #func << " implemented for row-major layout only"; \ - return 1; \ - } \ + return 1; \ + } \ } - MXNET_LAPACK_CWRAP_SYEVD(ssyevd, float) - MXNET_LAPACK_CWRAP_SYEVD(dsyevd, double) - - // Note: Supports row-major format only. Internally, column-major is used, so all - // inputs/outputs are flipped and transposed. m and n are flipped as well. - #define MXNET_LAPACK_CWRAP_GESVD(func, dtype) \ - inline int MXNET_LAPACK_##func(int matrix_layout, int m, int n, dtype* ut, \ - int ldut, dtype* s, dtype* v, int ldv, \ - dtype* work, int lwork) { \ - if (matrix_layout == MXNET_LAPACK_ROW_MAJOR) { \ - int info(0); \ - char jobu('O'); \ - char jobvt('S'); \ - func##_(&jobu, &jobvt, &n, &m, v, &ldv, s, v, &ldv, ut, &ldut, work, &lwork, &info); \ - return info; \ - } else { \ +MXNET_LAPACK_CWRAP_SYEVD(ssyevd, float) +MXNET_LAPACK_CWRAP_SYEVD(dsyevd, double) + +// Note: Supports row-major format only. Internally, column-major is used, so all +// inputs/outputs are flipped and transposed. m and n are flipped as well. +#define MXNET_LAPACK_CWRAP_GESVD(func, dtype) \ + inline int MXNET_LAPACK_##func(int matrix_layout, \ + int m, \ + int n, \ + dtype* ut, \ + int ldut, \ + dtype* s, \ + dtype* v, \ + int ldv, \ + dtype* work, \ + int lwork) { \ + if (matrix_layout == MXNET_LAPACK_ROW_MAJOR) { \ + int info(0); \ + char jobu('O'); \ + char jobvt('S'); \ + func##_(&jobu, &jobvt, &n, &m, v, &ldv, s, v, &ldv, ut, &ldut, work, &lwork, &info); \ + return info; \ + } else { \ CHECK(false) << "MXNET_LAPACK_" << #func << " implemented for row-major layout only"; \ - return 1; \ - } \ + return 1; \ + } \ } - MXNET_LAPACK_CWRAP_GESVD(sgesvd, float) - MXNET_LAPACK_CWRAP_GESVD(dgesvd, double) - - #define MXNET_LAPACK_CWRAP_GESDD(func, dtype) \ - inline int MXNET_LAPACK_##func(int matrix_layout, int m, int n, \ - dtype *a, int lda, dtype *s, \ - dtype *u, int ldu, \ - dtype *vt, int ldvt, \ - dtype *work, int lwork, int *iwork) { \ - if (matrix_layout == MXNET_LAPACK_ROW_MAJOR) { \ +MXNET_LAPACK_CWRAP_GESVD(sgesvd, float) +MXNET_LAPACK_CWRAP_GESVD(dgesvd, double) + +#define MXNET_LAPACK_CWRAP_GESDD(func, dtype) \ + inline int MXNET_LAPACK_##func(int matrix_layout, \ + int m, \ + int n, \ + dtype* a, \ + int lda, \ + dtype* s, \ + dtype* u, \ + int ldu, \ + dtype* vt, \ + int ldvt, \ + dtype* work, \ + int lwork, \ + int* iwork) { \ + if (matrix_layout == MXNET_LAPACK_ROW_MAJOR) { \ CHECK(false) << "MXNET_LAPACK_" << #func << " implemented for row-major layout only"; \ - return 1; \ - } else { \ - int info(0); \ - char jobz('O'); \ - func##_(&jobz, &m, &n, a, &lda, s, u, &ldu, vt, &ldvt, work, &lwork, iwork, &info); \ - return info; \ - } \ + return 1; \ + } else { \ + int info(0); \ + char jobz('O'); \ + func##_(&jobz, &m, &n, a, &lda, s, u, &ldu, vt, &ldvt, work, &lwork, iwork, &info); \ + return info; \ + } \ } - MXNET_LAPACK_CWRAP_GESDD(sgesdd, float) - MXNET_LAPACK_CWRAP_GESDD(dgesdd, double) - - #define MXNET_LAPACK_CWRAP_GEEV(prefix, dtype) \ - inline int MXNET_LAPACK_##prefix##geev(int matrix_layout, char jobvl, char jobvr, \ - int n, dtype *a, int lda, \ - dtype *wr, dtype *wi, \ - dtype *vl, int ldvl, dtype *vr, int ldvr, \ - dtype *work, int lwork) { \ - if (matrix_layout == MXNET_LAPACK_ROW_MAJOR) { \ +MXNET_LAPACK_CWRAP_GESDD(sgesdd, float) +MXNET_LAPACK_CWRAP_GESDD(dgesdd, double) + +#define MXNET_LAPACK_CWRAP_GEEV(prefix, dtype) \ + inline int MXNET_LAPACK_##prefix##geev(int matrix_layout, \ + char jobvl, \ + char jobvr, \ + int n, \ + dtype* a, \ + int lda, \ + dtype* wr, \ + dtype* wi, \ + dtype* vl, \ + int ldvl, \ + dtype* vr, \ + int ldvr, \ + dtype* work, \ + int lwork) { \ + if (matrix_layout == MXNET_LAPACK_ROW_MAJOR) { \ CHECK(false) << "MXNET_LAPACK_" << #prefix << "geev implemented for col-major layout only"; \ - return 1; \ - } else { \ - int info(0); \ - prefix##geev_(&jobvl, &jobvr, \ - &n, a, &lda, wr, wi, vl, &ldvl, vr, &ldvr, work, &lwork, &info); \ - return info; \ - } \ + return 1; \ + } else { \ + int info(0); \ + prefix##geev_( \ + &jobvl, &jobvr, &n, a, &lda, wr, wi, vl, &ldvl, vr, &ldvr, work, &lwork, &info); \ + return info; \ + } \ } - MXNET_LAPACK_CWRAP_GEEV(s, float) - MXNET_LAPACK_CWRAP_GEEV(d, double) +MXNET_LAPACK_CWRAP_GEEV(s, float) +MXNET_LAPACK_CWRAP_GEEV(d, double) - #define MXNET_LAPACK +#define MXNET_LAPACK - // Note: Both MXNET_LAPACK_*getrf, MXNET_LAPACK_*getri can only be called with col-major format - // (MXNet) for performance. - #define MXNET_LAPACK_CWRAP_GETRF(prefix, dtype) \ - inline int MXNET_LAPACK_##prefix##getrf(int matrix_layout, int m, int n, \ - dtype *a, int lda, int *ipiv) { \ - if (matrix_layout == MXNET_LAPACK_ROW_MAJOR) { \ +// Note: Both MXNET_LAPACK_*getrf, MXNET_LAPACK_*getri can only be called with col-major format +// (MXNet) for performance. +#define MXNET_LAPACK_CWRAP_GETRF(prefix, dtype) \ + inline int MXNET_LAPACK_##prefix##getrf( \ + int matrix_layout, int m, int n, dtype* a, int lda, int* ipiv) { \ + if (matrix_layout == MXNET_LAPACK_ROW_MAJOR) { \ CHECK(false) << "MXNET_LAPACK_" << #prefix << "getri implemented for col-major layout only"; \ - return 1; \ - } else { \ - int info(0); \ - prefix##getrf_(&m, &n, a, &lda, ipiv, &info); \ - return info; \ - } \ + return 1; \ + } else { \ + int info(0); \ + prefix##getrf_(&m, &n, a, &lda, ipiv, &info); \ + return info; \ + } \ } - MXNET_LAPACK_CWRAP_GETRF(s, float) - MXNET_LAPACK_CWRAP_GETRF(d, double) +MXNET_LAPACK_CWRAP_GETRF(s, float) +MXNET_LAPACK_CWRAP_GETRF(d, double) - #define MXNET_LAPACK_CWRAP_GETRI(prefix, dtype) \ - inline int MXNET_LAPACK_##prefix##getri(int matrix_layout, int n, dtype *a, int lda, \ - int *ipiv, dtype *work, int lwork) { \ - if (matrix_layout == MXNET_LAPACK_ROW_MAJOR) { \ +#define MXNET_LAPACK_CWRAP_GETRI(prefix, dtype) \ + inline int MXNET_LAPACK_##prefix##getri( \ + int matrix_layout, int n, dtype* a, int lda, int* ipiv, dtype* work, int lwork) { \ + if (matrix_layout == MXNET_LAPACK_ROW_MAJOR) { \ CHECK(false) << "MXNET_LAPACK_" << #prefix << "getri implemented for col-major layout only"; \ - return 1; \ - } else { \ - int info(0); \ - prefix##getri_(&n, a, &lda, ipiv, work, &lwork, &info); \ - return info; \ - } \ + return 1; \ + } else { \ + int info(0); \ + prefix##getri_(&n, a, &lda, ipiv, work, &lwork, &info); \ + return info; \ + } \ } - MXNET_LAPACK_CWRAP_GETRI(s, float) - MXNET_LAPACK_CWRAP_GETRI(d, double) - - #define MXNET_LAPACK_CWRAP_GESV(prefix, dtype) \ - inline int MXNET_LAPACK_##prefix##gesv(int matrix_layout, \ - int n, int nrhs, dtype *a, int lda, \ - int *ipiv, dtype *b, int ldb) { \ - if (matrix_layout == MXNET_LAPACK_ROW_MAJOR) { \ +MXNET_LAPACK_CWRAP_GETRI(s, float) +MXNET_LAPACK_CWRAP_GETRI(d, double) + +#define MXNET_LAPACK_CWRAP_GESV(prefix, dtype) \ + inline int MXNET_LAPACK_##prefix##gesv( \ + int matrix_layout, int n, int nrhs, dtype* a, int lda, int* ipiv, dtype* b, int ldb) { \ + if (matrix_layout == MXNET_LAPACK_ROW_MAJOR) { \ CHECK(false) << "MXNET_LAPACK_" << #prefix << "gesv implemented for col-major layout only"; \ - return 1; \ - } else { \ - int info(0); \ - prefix##gesv_(&n, &nrhs, a, &lda, ipiv, b, &ldb, &info); \ - return info; \ - } \ + return 1; \ + } else { \ + int info(0); \ + prefix##gesv_(&n, &nrhs, a, &lda, ipiv, b, &ldb, &info); \ + return info; \ + } \ } - MXNET_LAPACK_CWRAP_GESV(s, float) - MXNET_LAPACK_CWRAP_GESV(d, double) - - #define MXNET_LAPACK_CWRAP_GELSD(prefix, dtype) \ - inline int MXNET_LAPACK_##prefix##gelsd(int matrix_layout, int m, int n, int nrhs, \ - dtype *a, int lda, dtype *b, int ldb, \ - dtype *s, dtype rcond, int *rank, \ - dtype *work, int lwork, int *iwork) { \ - if (matrix_layout == MXNET_LAPACK_ROW_MAJOR) { \ +MXNET_LAPACK_CWRAP_GESV(s, float) +MXNET_LAPACK_CWRAP_GESV(d, double) + +#define MXNET_LAPACK_CWRAP_GELSD(prefix, dtype) \ + inline int MXNET_LAPACK_##prefix##gelsd(int matrix_layout, \ + int m, \ + int n, \ + int nrhs, \ + dtype* a, \ + int lda, \ + dtype* b, \ + int ldb, \ + dtype* s, \ + dtype rcond, \ + int* rank, \ + dtype* work, \ + int lwork, \ + int* iwork) { \ + if (matrix_layout == MXNET_LAPACK_ROW_MAJOR) { \ CHECK(false) << "MXNET_LAPACK_" << #prefix << "gesv implemented for col-major layout only"; \ - return 1; \ - } else { \ - int info(0); \ - prefix##gelsd_(&m, &n, &nrhs, a, &lda, b, &ldb, s, &rcond, rank, \ - work, &lwork, iwork, &info); \ - return info; \ - } \ + return 1; \ + } else { \ + int info(0); \ + prefix##gelsd_( \ + &m, &n, &nrhs, a, &lda, b, &ldb, s, &rcond, rank, work, &lwork, iwork, &info); \ + return info; \ + } \ } - MXNET_LAPACK_CWRAP_GELSD(s, float) - MXNET_LAPACK_CWRAP_GELSD(d, double) +MXNET_LAPACK_CWRAP_GELSD(s, float) +MXNET_LAPACK_CWRAP_GELSD(d, double) #else - #define MXNET_LAPACK_ROW_MAJOR 101 - #define MXNET_LAPACK_COL_MAJOR 102 +#define MXNET_LAPACK_ROW_MAJOR 101 +#define MXNET_LAPACK_COL_MAJOR 102 - // Define compilable stubs. - #define MXNET_LAPACK_CWRAPPER1(func, dtype) \ +// Define compilable stubs. +#define MXNET_LAPACK_CWRAPPER1(func, dtype) \ int MXNET_LAPACK_##func(int matrix_layout, char uplo, int n, dtype* a, int lda); - #define MXNET_LAPACK_CWRAPPER2(func, dtype) \ - int MXNET_LAPACK_##func(int matrix_layout, int m, int n, dtype* a, \ - int lda, dtype* tau, dtype* work, int lwork); - - #define MXNET_LAPACK_CWRAPPER3(func, dtype) \ - int MXNET_LAPACK_##func(int matrix_layout, char uplo, int n, dtype *a, \ - int lda, dtype *w, dtype *work, int lwork, \ - int *iwork, int liwork); - - #define MXNET_LAPACK_CWRAPPER4(func, dtype) \ - int MXNET_LAPACK_##func(int matrix_layout, int m, int n, \ - dtype *a, int lda, int *ipiv); - - #define MXNET_LAPACK_CWRAPPER5(func, dtype) \ - int MXNET_LAPACK_##func(int matrix_layout, int n, dtype *a, int lda, \ - int *ipiv, dtype *work, int lwork); - - #define MXNET_LAPACK_CWRAPPER6(func, dtype) \ - int MXNET_LAPACK_##func(int matrix_layout, int m, int n, dtype* ut, \ - int ldut, dtype* s, dtype* v, int ldv, \ - dtype* work, int lwork); - - #define MXNET_LAPACK_CWRAPPER7(func, dtype) \ - int MXNET_LAPACK_##func(int matrix_order, int n, int nrhs, dtype *a, \ - int lda, int *ipiv, dtype *b, int ldb); \ - - #define MXNET_LAPACK_CWRAPPER8(func, dtype) \ - int MXNET_LAPACK_##func(int matrix_layout, char jobvl, char jobvr, \ - int n, dtype *a, int lda, \ - dtype *wr, dtype *wi, \ - dtype *vl, int ldvl, dtype *vr, int ldvr, \ - dtype *work, int lwork); \ - - #define MXNET_LAPACK_CWRAPPER9(func, dtype) \ - int MXNET_LAPACK_##func(int matrix_layout, int m, int n, \ - dtype *a, int lda, dtype *s, \ - dtype *u, int ldu, \ - dtype *vt, int ldvt, \ - dtype *work, int lwork, int *iwork); - - #define MXNET_LAPACK_CWRAPPER10(func, dtype) \ - int MXNET_LAPACK_##func(int matrix_layout, int m, int n, dtype* a, \ - int lda, dtype* tau, dtype* work, int lwork); - - #define MXNET_LAPACK_CWRAPPER11(func, dtype) \ - int MXNET_LAPACK_##func(int matrix_layout, int m, int n, int nrhs, \ - dtype *a, int lda, dtype *b, int ldb, \ - dtype *s, dtype rcond, int *rank, \ - dtype *work, int lwork, int *iwork); - - #define MXNET_LAPACK_CWRAPPER12(func, dtype) \ - int MXNET_LAPACK_##func(int matrix_layout, int m, int n, int k, dtype* a, \ - int lda, dtype* tau, dtype* work, int lwork); - - #define MXNET_LAPACK_UNAVAILABLE(func) \ - int mxnet_lapack_##func(...); - MXNET_LAPACK_CWRAPPER1(spotrf, float) - MXNET_LAPACK_CWRAPPER1(dpotrf, double) - MXNET_LAPACK_CWRAPPER1(spotri, float) - MXNET_LAPACK_CWRAPPER1(dpotri, double) - - MXNET_LAPACK_UNAVAILABLE(sposv) - MXNET_LAPACK_UNAVAILABLE(dposv) - - MXNET_LAPACK_CWRAPPER2(sgelqf, float) - MXNET_LAPACK_CWRAPPER2(dgelqf, double) - MXNET_LAPACK_CWRAPPER2(sorglq, float) - MXNET_LAPACK_CWRAPPER2(dorglq, double) - - MXNET_LAPACK_CWRAPPER3(ssyevd, float) - MXNET_LAPACK_CWRAPPER3(dsyevd, double) - - MXNET_LAPACK_CWRAPPER4(sgetrf, float) - MXNET_LAPACK_CWRAPPER4(dgetrf, double) - - MXNET_LAPACK_CWRAPPER5(sgetri, float) - MXNET_LAPACK_CWRAPPER5(dgetri, double) - - MXNET_LAPACK_CWRAPPER6(sgesvd, float) - MXNET_LAPACK_CWRAPPER6(dgesvd, double) - - MXNET_LAPACK_CWRAPPER7(sgesv, float) - MXNET_LAPACK_CWRAPPER7(dgesv, double) - - MXNET_LAPACK_CWRAPPER8(sgeev, float) - MXNET_LAPACK_CWRAPPER8(dgeev, double) - - MXNET_LAPACK_CWRAPPER9(sgesdd, float) - MXNET_LAPACK_CWRAPPER9(dgesdd, double) - - MXNET_LAPACK_CWRAPPER10(sgeqrf, float) - MXNET_LAPACK_CWRAPPER10(dgeqrf, double) - - MXNET_LAPACK_CWRAPPER11(sgelsd, float) - MXNET_LAPACK_CWRAPPER11(dgelsd, double) - - MXNET_LAPACK_CWRAPPER12(sorgqr, float) - MXNET_LAPACK_CWRAPPER12(dorgqr, double) - - #undef MXNET_LAPACK_CWRAPPER1 - #undef MXNET_LAPACK_CWRAPPER2 - #undef MXNET_LAPACK_CWRAPPER3 - #undef MXNET_LAPACK_CWRAPPER4 - #undef MXNET_LAPACK_UNAVAILABLE +#define MXNET_LAPACK_CWRAPPER2(func, dtype) \ + int MXNET_LAPACK_##func( \ + int matrix_layout, int m, int n, dtype* a, int lda, dtype* tau, dtype* work, int lwork); + +#define MXNET_LAPACK_CWRAPPER3(func, dtype) \ + int MXNET_LAPACK_##func(int matrix_layout, \ + char uplo, \ + int n, \ + dtype* a, \ + int lda, \ + dtype* w, \ + dtype* work, \ + int lwork, \ + int* iwork, \ + int liwork); + +#define MXNET_LAPACK_CWRAPPER4(func, dtype) \ + int MXNET_LAPACK_##func(int matrix_layout, int m, int n, dtype* a, int lda, int* ipiv); + +#define MXNET_LAPACK_CWRAPPER5(func, dtype) \ + int MXNET_LAPACK_##func( \ + int matrix_layout, int n, dtype* a, int lda, int* ipiv, dtype* work, int lwork); + +#define MXNET_LAPACK_CWRAPPER6(func, dtype) \ + int MXNET_LAPACK_##func(int matrix_layout, \ + int m, \ + int n, \ + dtype* ut, \ + int ldut, \ + dtype* s, \ + dtype* v, \ + int ldv, \ + dtype* work, \ + int lwork); + +#define MXNET_LAPACK_CWRAPPER7(func, dtype) \ + int MXNET_LAPACK_##func( \ + int matrix_order, int n, int nrhs, dtype* a, int lda, int* ipiv, dtype* b, int ldb); + +#define MXNET_LAPACK_CWRAPPER8(func, dtype) \ + int MXNET_LAPACK_##func(int matrix_layout, \ + char jobvl, \ + char jobvr, \ + int n, \ + dtype* a, \ + int lda, \ + dtype* wr, \ + dtype* wi, \ + dtype* vl, \ + int ldvl, \ + dtype* vr, \ + int ldvr, \ + dtype* work, \ + int lwork); + +#define MXNET_LAPACK_CWRAPPER9(func, dtype) \ + int MXNET_LAPACK_##func(int matrix_layout, \ + int m, \ + int n, \ + dtype* a, \ + int lda, \ + dtype* s, \ + dtype* u, \ + int ldu, \ + dtype* vt, \ + int ldvt, \ + dtype* work, \ + int lwork, \ + int* iwork); + +#define MXNET_LAPACK_CWRAPPER10(func, dtype) \ + int MXNET_LAPACK_##func( \ + int matrix_layout, int m, int n, dtype* a, int lda, dtype* tau, dtype* work, int lwork); + +#define MXNET_LAPACK_CWRAPPER11(func, dtype) \ + int MXNET_LAPACK_##func(int matrix_layout, \ + int m, \ + int n, \ + int nrhs, \ + dtype* a, \ + int lda, \ + dtype* b, \ + int ldb, \ + dtype* s, \ + dtype rcond, \ + int* rank, \ + dtype* work, \ + int lwork, \ + int* iwork); + +#define MXNET_LAPACK_CWRAPPER12(func, dtype) \ + int MXNET_LAPACK_##func(int matrix_layout, \ + int m, \ + int n, \ + int k, \ + dtype* a, \ + int lda, \ + dtype* tau, \ + dtype* work, \ + int lwork); + +#define MXNET_LAPACK_UNAVAILABLE(func) int mxnet_lapack_##func(...); +MXNET_LAPACK_CWRAPPER1(spotrf, float) +MXNET_LAPACK_CWRAPPER1(dpotrf, double) +MXNET_LAPACK_CWRAPPER1(spotri, float) +MXNET_LAPACK_CWRAPPER1(dpotri, double) + +MXNET_LAPACK_UNAVAILABLE(sposv) +MXNET_LAPACK_UNAVAILABLE(dposv) + +MXNET_LAPACK_CWRAPPER2(sgelqf, float) +MXNET_LAPACK_CWRAPPER2(dgelqf, double) +MXNET_LAPACK_CWRAPPER2(sorglq, float) +MXNET_LAPACK_CWRAPPER2(dorglq, double) + +MXNET_LAPACK_CWRAPPER3(ssyevd, float) +MXNET_LAPACK_CWRAPPER3(dsyevd, double) + +MXNET_LAPACK_CWRAPPER4(sgetrf, float) +MXNET_LAPACK_CWRAPPER4(dgetrf, double) + +MXNET_LAPACK_CWRAPPER5(sgetri, float) +MXNET_LAPACK_CWRAPPER5(dgetri, double) + +MXNET_LAPACK_CWRAPPER6(sgesvd, float) +MXNET_LAPACK_CWRAPPER6(dgesvd, double) + +MXNET_LAPACK_CWRAPPER7(sgesv, float) +MXNET_LAPACK_CWRAPPER7(dgesv, double) + +MXNET_LAPACK_CWRAPPER8(sgeev, float) +MXNET_LAPACK_CWRAPPER8(dgeev, double) + +MXNET_LAPACK_CWRAPPER9(sgesdd, float) +MXNET_LAPACK_CWRAPPER9(dgesdd, double) + +MXNET_LAPACK_CWRAPPER10(sgeqrf, float) +MXNET_LAPACK_CWRAPPER10(dgeqrf, double) + +MXNET_LAPACK_CWRAPPER11(sgelsd, float) +MXNET_LAPACK_CWRAPPER11(dgelsd, double) + +MXNET_LAPACK_CWRAPPER12(sorgqr, float) +MXNET_LAPACK_CWRAPPER12(dorgqr, double) + +#undef MXNET_LAPACK_CWRAPPER1 +#undef MXNET_LAPACK_CWRAPPER2 +#undef MXNET_LAPACK_CWRAPPER3 +#undef MXNET_LAPACK_CWRAPPER4 +#undef MXNET_LAPACK_UNAVAILABLE #endif template -inline int MXNET_LAPACK_posv(int matrix_layout, char uplo, int n, int nrhs, - DType *a, int lda, DType *b, int ldb); +inline int MXNET_LAPACK_posv(int matrix_layout, + char uplo, + int n, + int nrhs, + DType* a, + int lda, + DType* b, + int ldb); template <> -inline int MXNET_LAPACK_posv(int matrix_layout, char uplo, int n, - int nrhs, float *a, int lda, float *b, int ldb) { +inline int MXNET_LAPACK_posv(int matrix_layout, + char uplo, + int n, + int nrhs, + float* a, + int lda, + float* b, + int ldb) { return mxnet_lapack_sposv(matrix_layout, uplo, n, nrhs, a, lda, b, ldb); } template <> -inline int MXNET_LAPACK_posv(int matrix_layout, char uplo, int n, - int nrhs, double *a, int lda, double *b, int ldb) { +inline int MXNET_LAPACK_posv(int matrix_layout, + char uplo, + int n, + int nrhs, + double* a, + int lda, + double* b, + int ldb) { return mxnet_lapack_dposv(matrix_layout, uplo, n, nrhs, a, lda, b, ldb); } diff --git a/src/operator/channel_op_common.h b/src/operator/channel_op_common.h index 43f689d2defa..64ff16e1e749 100644 --- a/src/operator/channel_op_common.h +++ b/src/operator/channel_op_common.h @@ -22,7 +22,7 @@ * \file channel_op_common.h * \brief common function used for concat and split channel * \author Bing Xu -*/ + */ #ifndef MXNET_OPERATOR_CHANNEL_OP_COMMON_H_ #define MXNET_OPERATOR_CHANNEL_OP_COMMON_H_ #include @@ -33,20 +33,22 @@ namespace mxnet { namespace op { -template -inline void concatenate_helper(const std::vector > &input, - mshadow::Tensor *output, const int dimension, +template +inline void concatenate_helper(const std::vector >& input, + mshadow::Tensor* output, + const int dimension, const OpReqType req) { using mshadow::expr::concat; using mshadow::expr::slice; if (dimension == cdim) { mshadow::Tensor out = *output; - size_t size = input.size(); - index_t begin = 0; + size_t size = input.size(); + index_t begin = 0; for (size_t i = 0; i < size; ++i) { // If input[i] is a zero-size tensor, do nothing. - if (input[i].shape_.Size() == 0) continue; + if (input[i].shape_.Size() == 0) + continue; index_t end = begin + input[i].size(cdim); Assign(slice(out, begin, end), req, input[i]); begin = end; @@ -56,34 +58,36 @@ inline void concatenate_helper(const std::vector -inline void Concatenate(const std::vector > &input, - mshadow::Tensor *output, const int dimension, +template +inline void Concatenate(const std::vector >& input, + mshadow::Tensor* output, + const int dimension, const OpReqType req) { if (dimension < 0) { LOG(FATAL) << "dimension (" << dimension << ") must be greater than 0"; } else if (dimension >= dim) { LOG(FATAL) << "dimension (" << dimension << ") must be smaller than dim (" << dim << ")"; } else { - concatenate_helper(input, output, dimension, req); + concatenate_helper(input, output, dimension, req); } } - -template -void split_helper(const mshadow::Tensor &input, - std::vector > *output, - const int dimension, const std::vector &req) { +template +void split_helper(const mshadow::Tensor& input, + std::vector >* output, + const int dimension, + const std::vector& req) { using mshadow::expr::concat; using mshadow::expr::slice; if (dimension == cdim) { std::vector > out = *output; - size_t size = out.size(); - index_t begin = 0; + size_t size = out.size(); + index_t begin = 0; for (size_t i = 0; i < size; ++i) { // If out[i] is a zero-size tensor, do nothing. - if (out[i].shape_.Size() == 0) continue; + if (out[i].shape_.Size() == 0) + continue; index_t end = begin + out[i].size(cdim); Assign(out[i], req[i], slice(input, begin, end)); begin = end; @@ -93,16 +97,17 @@ void split_helper(const mshadow::Tensor &input, } } -template -void Split(const mshadow::Tensor &input, - std::vector > *output, - const int dimension, const std::vector &req) { +template +void Split(const mshadow::Tensor& input, + std::vector >* output, + const int dimension, + const std::vector& req) { if (dimension < 0) { LOG(FATAL) << "dimension (" << dimension << ") must be greater than 0"; } else if (dimension >= dim) { LOG(FATAL) << "dimension (" << dimension << ") must be smaller than dim (" << dim << ")"; } else { - split_helper(input, output, dimension, req); + split_helper(input, output, dimension, req); } } } // namespace op diff --git a/src/operator/contrib/adabelief-inl.h b/src/operator/contrib/adabelief-inl.h index 2f282158e4ce..3f403fd55ae8 100644 --- a/src/operator/contrib/adabelief-inl.h +++ b/src/operator/contrib/adabelief-inl.h @@ -43,29 +43,24 @@ struct AdaBeliefParam : public dmlc::Parameter { float eta; float clip_gradient; DMLC_DECLARE_PARAMETER(AdaBeliefParam) { - DMLC_DECLARE_FIELD(lr) - .describe("Learning rate"); - DMLC_DECLARE_FIELD(beta1) - .set_default(0.9f) - .describe("The decay rate for the 1st moment estimates."); - DMLC_DECLARE_FIELD(beta2) - .set_default(0.999f) - .describe("The decay rate for the 2nd moment estimates."); - DMLC_DECLARE_FIELD(epsilon) - .set_default(1e-8f) - .describe("A small constant for numerical stability."); - DMLC_DECLARE_FIELD(wd) - .set_default(0.0f) - .describe("Weight decay augments the objective function with a " - "regularization term that penalizes large weights. " - "The penalty scales with the square of the magnitude of each weight."); - DMLC_DECLARE_FIELD(eta) - .describe("Learning rate schedule multiplier"); + DMLC_DECLARE_FIELD(lr).describe("Learning rate"); + DMLC_DECLARE_FIELD(beta1).set_default(0.9f).describe( + "The decay rate for the 1st moment estimates."); + DMLC_DECLARE_FIELD(beta2).set_default(0.999f).describe( + "The decay rate for the 2nd moment estimates."); + DMLC_DECLARE_FIELD(epsilon).set_default(1e-8f).describe( + "A small constant for numerical stability."); + DMLC_DECLARE_FIELD(wd).set_default(0.0f).describe( + "Weight decay augments the objective function with a " + "regularization term that penalizes large weights. " + "The penalty scales with the square of the magnitude of each weight."); + DMLC_DECLARE_FIELD(eta).describe("Learning rate schedule multiplier"); DMLC_DECLARE_FIELD(clip_gradient) - .set_default(-1.0f) - .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] " - "If clip_gradient <= 0, gradient clipping is turned off. " - "grad = max(min(grad, clip_gradient), -clip_gradient)."); + .set_default(-1.0f) + .describe( + "Clip gradient to the range of [-clip_gradient, clip_gradient] " + "If clip_gradient <= 0, gradient clipping is turned off. " + "grad = max(min(grad, clip_gradient), -clip_gradient)."); } }; @@ -73,10 +68,10 @@ struct AdaBeliefParam : public dmlc::Parameter { // n_in = 2: weight, grad (fp16) // n_out = 1: weight (fp16) // total_in = 6: weight, grad, mean, var, weight32, rescale_grad (fp32) -template +template inline bool MPUpdateInferShape(const nnvm::NodeAttrs& attrs, - mxnet::ShapeVector *in_attrs, - mxnet::ShapeVector *out_attrs) { + mxnet::ShapeVector* in_attrs, + mxnet::ShapeVector* out_attrs) { CHECK_EQ(in_attrs->size(), static_cast(total_in)) << " in operator " << attrs.name; CHECK_EQ(out_attrs->size(), static_cast(n_out)) << " in operator " << attrs.name; SHAPE_ASSIGN_CHECK(*in_attrs, total_in - 1, mxnet::TShape()); @@ -85,10 +80,10 @@ inline bool MPUpdateInferShape(const nnvm::NodeAttrs& attrs, attrs, in_attrs, out_attrs, mxnet::TShape()); } -template +template inline bool MPUpdateInferType(const nnvm::NodeAttrs& attrs, - std::vector *in_attrs, - std::vector *out_attrs) { + std::vector* in_attrs, + std::vector* out_attrs) { CHECK_EQ(in_attrs->size(), static_cast(total_in)) << " in operator " << attrs.name; CHECK_EQ(out_attrs->size(), static_cast(n_out)) << " in operator " << attrs.name; for (int i = n_in; i < total_in; ++i) { @@ -98,56 +93,77 @@ inline bool MPUpdateInferType(const nnvm::NodeAttrs& attrs, attrs, in_attrs, out_attrs, -1); } -template +template struct MPAdaBeliefKernel { - template - MSHADOW_XINLINE static void Map(int i, DType* out_data, float* mean_data, - float* var_data, const DType* weight_data, const DType* grad_data, float* weight32, - const float param_clip_gradient, const float param_beta1, const float param_beta2, - const float param_eta, const float param_lr, const float param_wd, - const float param_rescale_grad, const float param_epsilon) { - float w = weight32[i]; - float scaled_grad = param_rescale_grad*static_cast(grad_data[i]); + template + MSHADOW_XINLINE static void Map(int i, + DType* out_data, + float* mean_data, + float* var_data, + const DType* weight_data, + const DType* grad_data, + float* weight32, + const float param_clip_gradient, + const float param_beta1, + const float param_beta2, + const float param_eta, + const float param_lr, + const float param_wd, + const float param_rescale_grad, + const float param_epsilon) { + float w = weight32[i]; + float scaled_grad = param_rescale_grad * static_cast(grad_data[i]); scaled_grad += param_wd * w; if (param_clip_gradient >= 0.f) scaled_grad = mshadow_op::clip::Map(scaled_grad, param_clip_gradient); const float mean = param_beta1 * (mean_data[i] - scaled_grad) + scaled_grad; - const float adj = mshadow_op::square::Map(scaled_grad - mean); - const float var = param_beta2*(var_data[i] - adj) + adj + param_epsilon; + const float adj = mshadow_op::square::Map(scaled_grad - mean); + const float var = param_beta2 * (var_data[i] - adj) + adj + param_epsilon; w -= param_eta * (param_lr * mean / (mshadow_op::square_root::Map(var) + param_epsilon)); mean_data[i] = mean; - var_data[i] = var; - weight32[i] = w; + var_data[i] = var; + weight32[i] = w; KERNEL_ASSIGN(out_data[i], req, w); } }; -template +template struct MPAdaBeliefUpdate { static inline void Forward(const nnvm::NodeAttrs& attrs, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs, - const float rescale_grad) { + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs, + const float rescale_grad) { using namespace mxnet_op; const auto& param = nnvm::get(attrs.parsed); - Stream* s = ctx.get_stream(); + Stream* s = ctx.get_stream(); MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { - Tensor weight = inputs[0].FlatTo2D(s); - Tensor grad = inputs[1].FlatTo2D(s); - Tensor mean = inputs[2].FlatTo2D(s); - Tensor var = inputs[3].FlatTo2D(s); + Tensor weight = inputs[0].FlatTo2D(s); + Tensor grad = inputs[1].FlatTo2D(s); + Tensor mean = inputs[2].FlatTo2D(s); + Tensor var = inputs[3].FlatTo2D(s); Tensor weight32 = inputs[4].FlatTo2D(s); - Tensor out = outputs[0].FlatTo2D(s); + Tensor out = outputs[0].FlatTo2D(s); MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { - Kernel, xpu>::Launch( - s, weight.shape_.Size(), out.dptr_, mean.dptr_, var.dptr_, - weight.dptr_, grad.dptr_, weight32.dptr_, - param.clip_gradient, param.beta1, param.beta2, param.eta, - param.lr, param.wd, rescale_grad, param.epsilon); + Kernel, xpu>::Launch(s, + weight.shape_.Size(), + out.dptr_, + mean.dptr_, + var.dptr_, + weight.dptr_, + grad.dptr_, + weight32.dptr_, + param.clip_gradient, + param.beta1, + param.beta2, + param.eta, + param.lr, + param.wd, + rescale_grad, + param.epsilon); }); }); } @@ -157,39 +173,40 @@ struct MPAdaBeliefUpdate { * \brief adabelief update. * */ -template +template struct AdaBeliefUpdate { static inline void Forward(const nnvm::NodeAttrs& attrs, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs, const float rescale_grad) { using namespace mshadow; using namespace mshadow::expr; using namespace mshadow_op; - const auto ¶m = nnvm::get(attrs.parsed); - Stream* s = ctx.get_stream(); + const auto& param = nnvm::get(attrs.parsed); + Stream* s = ctx.get_stream(); MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { - const Tensor &weight = inputs[0].FlatTo2D(s); - Tensor grad = inputs[1].FlatTo2D(s); - Tensor mean = inputs[2].FlatTo2D(s); - Tensor var = inputs[3].FlatTo2D(s); - Tensor out = outputs[0].FlatTo2D(s); + const Tensor& weight = inputs[0].FlatTo2D(s); + Tensor grad = inputs[1].FlatTo2D(s); + Tensor mean = inputs[2].FlatTo2D(s); + Tensor var = inputs[3].FlatTo2D(s); + Tensor out = outputs[0].FlatTo2D(s); grad = scalar(rescale_grad) * grad + scalar(param.wd) * weight; if (param.clip_gradient >= 0.0f) grad = F(grad, DType(param.clip_gradient)); - mean = scalar(param.beta1) * mean + scalar(1.f-param.beta1) * grad; - var = scalar(param.beta2) * var + - scalar(1.f-param.beta2) * F(grad - mean) + + mean = scalar(param.beta1) * mean + scalar(1.f - param.beta1) * grad; + var = scalar(param.beta2) * var + + scalar(1.f - param.beta2) * F(grad - mean) + scalar(param.epsilon); - Assign(out, req[0], + Assign(out, + req[0], weight - - scalar(param.eta) * (scalar(param.lr) * - mean / (F(var) + scalar(param.epsilon)))); + scalar(param.eta) * (scalar(param.lr) * mean / + (F(var) + scalar(param.epsilon)))); }); } }; @@ -207,62 +224,55 @@ struct MultiAdaBeliefParam : public dmlc::Parameter { float clip_gradient; int num_weights; DMLC_DECLARE_PARAMETER(MultiAdaBeliefParam) { - DMLC_DECLARE_FIELD(lrs) - .describe("Learning rates"); - DMLC_DECLARE_FIELD(beta1) - .set_default(0.9f) - .describe("The decay rate for the 1st moment estimates."); - DMLC_DECLARE_FIELD(beta2) - .set_default(0.999f) - .describe("The decay rate for the 2nd moment estimates."); - DMLC_DECLARE_FIELD(epsilon) - .set_default(1e-8f) - .describe("A small constant for numerical stability."); - DMLC_DECLARE_FIELD(wds) - .describe("Weight decay augments the objective function with a " - "regularization term that penalizes large weights. " - "The penalty scales with the square of the magnitude of each weight."); - DMLC_DECLARE_FIELD(etas) - .describe("Learning rates schedule multiplier"); + DMLC_DECLARE_FIELD(lrs).describe("Learning rates"); + DMLC_DECLARE_FIELD(beta1).set_default(0.9f).describe( + "The decay rate for the 1st moment estimates."); + DMLC_DECLARE_FIELD(beta2).set_default(0.999f).describe( + "The decay rate for the 2nd moment estimates."); + DMLC_DECLARE_FIELD(epsilon).set_default(1e-8f).describe( + "A small constant for numerical stability."); + DMLC_DECLARE_FIELD(wds).describe( + "Weight decay augments the objective function with a " + "regularization term that penalizes large weights. " + "The penalty scales with the square of the magnitude of each weight."); + DMLC_DECLARE_FIELD(etas).describe("Learning rates schedule multiplier"); DMLC_DECLARE_FIELD(clip_gradient) - .set_default(-1.0f) - .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] " - "If clip_gradient <= 0, gradient clipping is turned off. " - "grad = max(min(grad, clip_gradient), -clip_gradient)."); - DMLC_DECLARE_FIELD(num_weights) - .set_default(1) - .describe("Number of updated weights."); + .set_default(-1.0f) + .describe( + "Clip gradient to the range of [-clip_gradient, clip_gradient] " + "If clip_gradient <= 0, gradient clipping is turned off. " + "grad = max(min(grad, clip_gradient), -clip_gradient)."); + DMLC_DECLARE_FIELD(num_weights).set_default(1).describe("Number of updated weights."); } }; - -template +template inline bool MP_MultiAdaBelief_InferShape(const nnvm::NodeAttrs& attrs, - mxnet::ShapeVector *in_attrs, - mxnet::ShapeVector *out_attrs) { + mxnet::ShapeVector* in_attrs, + mxnet::ShapeVector* out_attrs) { const ParamType& param = dmlc::get(attrs.parsed); - CHECK_EQ(in_attrs->size(), input_stride * param.num_weights +1); + CHECK_EQ(in_attrs->size(), input_stride * param.num_weights + 1); CHECK_EQ(out_attrs->size(), param.num_weights); - bool all_inferred = true; - auto& input_shapes = *in_attrs; + bool all_inferred = true; + auto& input_shapes = *in_attrs; auto& output_shapes = *out_attrs; // Learning rates CHECK_EQ(param.lrs.ndim(), param.num_weights) - << "Number of learning rates is inconsistent with num_weights " - << "parameter passed. Expected number of learning rates: " - << param.num_weights << ", and got " << param.lrs.ndim(); + << "Number of learning rates is inconsistent with num_weights " + << "parameter passed. Expected number of learning rates: " << param.num_weights + << ", and got " << param.lrs.ndim(); // Weight decays CHECK_EQ(param.wds.ndim(), param.num_weights) - << "Number of weight decays is inconsistent with num_weights " - << "parameter passed. Expected number of weight decays: " - << param.num_weights << ", and got " << param.wds.ndim(); + << "Number of weight decays is inconsistent with num_weights " + << "parameter passed. Expected number of weight decays: " << param.num_weights << ", and got " + << param.wds.ndim(); // Learning rates schedule multiplier CHECK_EQ(param.etas.ndim(), param.num_weights) - << "Number of learning rates schedule multiplier is inconsistent with num_weights " - << "parameter passed. Expected number of learning rates schedule multiplier: " - << param.num_weights << ", and got " << param.lrs.ndim(); + << "Number of learning rates schedule multiplier is inconsistent with num_weights " + << "parameter passed. Expected number of learning rates schedule multiplier: " + << param.num_weights << ", and got " << param.lrs.ndim(); // Weights, gradients, mean and variance for (int i = 0; i < param.num_weights; ++i) { @@ -274,20 +284,20 @@ inline bool MP_MultiAdaBelief_InferShape(const nnvm::NodeAttrs& attrs, all_inferred = all_inferred && ElemwiseShape(attrs, &input_vec, &output_vec); } - SHAPE_ASSIGN_CHECK(*in_attrs, param.num_weights*input_stride, mxnet::TShape()); + SHAPE_ASSIGN_CHECK(*in_attrs, param.num_weights * input_stride, mxnet::TShape()); return all_inferred; } template inline bool MP_MultiAdaBelief_InferType(const nnvm::NodeAttrs& attrs, - std::vector *in_attrs, - std::vector *out_attrs) { + std::vector* in_attrs, + std::vector* out_attrs) { const ParamType& param = dmlc::get(attrs.parsed); - CHECK_EQ(in_attrs->size(), input_stride * param.num_weights +1); + CHECK_EQ(in_attrs->size(), input_stride * param.num_weights + 1); CHECK_EQ(out_attrs->size(), param.num_weights); - bool all_inferred = true; - auto& input_types = *in_attrs; + bool all_inferred = true; + auto& input_types = *in_attrs; auto& output_types = *out_attrs; // Weights, gradients, @@ -297,13 +307,13 @@ inline bool MP_MultiAdaBelief_InferType(const nnvm::NodeAttrs& attrs, for (int j = 0; j < input_stride - 2 - num_fp32_inputs; ++j) { input_vec.push_back(input_types[i * input_stride + j]); } - all_inferred = all_inferred && - ElemwiseType(attrs, &input_vec, &output_vec); + all_inferred = all_inferred && ElemwiseType( + attrs, &input_vec, &output_vec); } // mean, var for (int i = 0; i < param.num_weights; ++i) { - TYPE_ASSIGN_CHECK(input_types, input_stride * i +2, mshadow::kFloat32); - TYPE_ASSIGN_CHECK(input_types, input_stride * i +3, mshadow::kFloat32); + TYPE_ASSIGN_CHECK(input_types, input_stride * i + 2, mshadow::kFloat32); + TYPE_ASSIGN_CHECK(input_types, input_stride * i + 3, mshadow::kFloat32); } // master copies of weights @@ -313,25 +323,23 @@ inline bool MP_MultiAdaBelief_InferType(const nnvm::NodeAttrs& attrs, } } - TYPE_ASSIGN_CHECK(input_types, param.num_weights*input_stride, mshadow::kFloat32); + TYPE_ASSIGN_CHECK(input_types, param.num_weights * input_stride, mshadow::kFloat32); return all_inferred; } - -template +template class _type_identity { public: using type = T; }; - -template +template class _single_precision { public: using type = float; }; -template +template struct MultiKernelParam { static const int N = 50; int count; @@ -352,30 +360,32 @@ struct MultiKernelParam { MPDType epsilon; }; -template +template struct MultiMPAdaBeliefKernel { - template - MSHADOW_XINLINE static void Map(int i, const MultiKernelParam& param, - const OpReqType req, const float rescale_grad) { + template + MSHADOW_XINLINE static void Map(int i, + const MultiKernelParam& param, + const OpReqType req, + const float rescale_grad) { for (int index = 0; index < param.count; ++index) { if ((size_t)i < param.sizes[index]) { - MPDType w = has_mixed_precision ? param.weights32[index][i]: - MPDType(param.weights[index][i]); - MPDType scaled_grad = static_cast(rescale_grad)* - static_cast(param.grad_data[index][i]); + MPDType w = + has_mixed_precision ? param.weights32[index][i] : MPDType(param.weights[index][i]); + MPDType scaled_grad = + static_cast(rescale_grad) * static_cast(param.grad_data[index][i]); scaled_grad += param.wds[index] * w; if (param.clip_gradient >= 0.f) scaled_grad = mshadow_op::clip::Map(scaled_grad, param.clip_gradient); const auto mean = param.beta1 * (param.mean_data[index][i] - scaled_grad) + scaled_grad; - const auto adj = mshadow_op::square::Map(mean - scaled_grad); - const auto var = param.beta2 * (param.var_data[index][i] - adj) + adj + param.epsilon; + const auto adj = mshadow_op::square::Map(mean - scaled_grad); + const auto var = param.beta2 * (param.var_data[index][i] - adj) + adj + param.epsilon; param.mean_data[index][i] = mean; - param.var_data[index][i] = var; - w = w - param.etas[index] * (param.lrs[index] * - mean / (mshadow_op::square_root::Map(var) + param.epsilon)); + param.var_data[index][i] = var; + w = w - param.etas[index] * + (param.lrs[index] * mean / (mshadow_op::square_root::Map(var) + param.epsilon)); if (has_mixed_precision) param.weights32[index][i] = w; @@ -385,34 +395,34 @@ struct MultiMPAdaBeliefKernel { } }; -template +template void FillMultiKernelParam(const nnvm::NodeAttrs& attrs, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &outputs, - MultiKernelParam *pParam) { - const ParamType& p = nnvm::get(attrs.parsed); + const OpContext& ctx, + const std::vector& inputs, + const std::vector& outputs, + MultiKernelParam* pParam) { + const ParamType& p = nnvm::get(attrs.parsed); mxnet_op::Stream* s = ctx.get_stream(); - pParam->clip_gradient = p.clip_gradient; - pParam->beta1 = p.beta1; - pParam->beta2 = p.beta2; + pParam->clip_gradient = p.clip_gradient; + pParam->beta1 = p.beta1; + pParam->beta2 = p.beta2; pParam->epsilon = p.epsilon; - pParam->count = p.num_weights; - pParam->max_size = 0; + pParam->count = p.num_weights; + pParam->max_size = 0; constexpr bool isSame = std::is_same::value; for (int i = 0; i < pParam->count; ++i) { - const auto idx = i * input_stride; + const auto idx = i * input_stride; pParam->sizes[i] = inputs[idx].shape_.Size(); if (pParam->max_size < pParam->sizes[i]) pParam->max_size = pParam->sizes[i]; - pParam->weights[i] = inputs[idx].FlatTo2D(s).dptr_; + pParam->weights[i] = inputs[idx].FlatTo2D(s).dptr_; pParam->grad_data[i] = inputs[idx + 1].FlatTo2D(s).dptr_; pParam->mean_data[i] = inputs[idx + 2].FlatTo2D(s).dptr_; pParam->var_data[i] = inputs[idx + 3].FlatTo2D(s).dptr_; @@ -428,34 +438,34 @@ void FillMultiKernelParam(const nnvm::NodeAttrs& attrs, memcpy(pParam->wds, p.wds.begin(), pParam->count * sizeof(p.wds[0])); } -template class MPTypeChooser, int input_stride> +template class MPTypeChooser, int input_stride> static inline void MultiAdaBeliefUpdate(const nnvm::NodeAttrs& attrs, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs, - const float rescale_grad) { + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs, + const float rescale_grad) { using namespace mxnet_op; Stream* s = ctx.get_stream(); MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { using MPDType = typename MPTypeChooser::type; MultiKernelParam param; - FillMultiKernelParam - (attrs, ctx, inputs, outputs, ¶m); + FillMultiKernelParam( + attrs, ctx, inputs, outputs, ¶m); - Kernel::value>, xpu>:: - Launch(s, param.max_size, param, req[0], rescale_grad); + Kernel::value>, xpu>::Launch( + s, param.max_size, param, req[0], rescale_grad); }); } -template -void GetScaleFloat(mshadow::Stream *s, const TBlob &scale_blob, float *pScalef); +template +void GetScaleFloat(mshadow::Stream* s, const TBlob& scale_blob, float* pScalef); -template -bool PrepareInputBlobs(const OpContext &ctx, - const std::vector &inputs, - std::vector *inputs_wo_scale, - float *pScalef) { +template +bool PrepareInputBlobs(const OpContext& ctx, + const std::vector& inputs, + std::vector* inputs_wo_scale, + float* pScalef) { const size_t num_in = inputs.size() - 1; adabelief::GetScaleFloat(ctx.get_stream(), inputs[num_in], pScalef); if (!std::isfinite(*pScalef) || *pScalef == 0) @@ -468,12 +478,12 @@ bool PrepareInputBlobs(const OpContext &ctx, return true; } -template +template inline void MPUpdate(const nnvm::NodeAttrs& attrs, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { std::vector inputs_wo_scale; float scalef; if (!PrepareInputBlobs(ctx, inputs, &inputs_wo_scale, &scalef)) @@ -482,23 +492,22 @@ inline void MPUpdate(const nnvm::NodeAttrs& attrs, F::Forward(attrs, ctx, inputs_wo_scale, req, outputs, scalef); } -template +template inline void multiMPUpdate(const nnvm::NodeAttrs& attrs, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { std::vector inputs_wo_scale; float scalef; if (!PrepareInputBlobs(ctx, inputs, &inputs_wo_scale, &scalef)) return; if (!MP) - MultiAdaBeliefUpdate - (attrs, ctx, inputs_wo_scale, req, outputs, scalef); + MultiAdaBeliefUpdate(attrs, ctx, inputs_wo_scale, req, outputs, scalef); else - MultiAdaBeliefUpdate - (attrs, ctx, inputs_wo_scale, req, outputs, scalef); + MultiAdaBeliefUpdate( + attrs, ctx, inputs_wo_scale, req, outputs, scalef); } } // namespace adabelief diff --git a/src/operator/contrib/adabelief.cc b/src/operator/contrib/adabelief.cc index 06be7480f8a7..fe86a4a7ad30 100644 --- a/src/operator/contrib/adabelief.cc +++ b/src/operator/contrib/adabelief.cc @@ -33,7 +33,7 @@ DMLC_REGISTER_PARAMETER(AdaBeliefParam); DMLC_REGISTER_PARAMETER(MultiAdaBeliefParam); NNVM_REGISTER_OP(_mp_adabelief_update) -.describe(R"code(Update function for multi-precision AdaBelief optimizer. + .describe(R"code(Update function for multi-precision AdaBelief optimizer. AdaBelief is seen as a modification of Adam with a different variance estimator. @@ -57,28 +57,29 @@ It updates the weights using:: Note that gradient is rescaled to grad = rescale_grad * grad. If rescale_grad is NaN, Inf, or 0, the update is skipped. )code" ADD_FILELINE) -.set_num_inputs(6) -.set_num_outputs(1) -.set_attr_parser(ParamParser) -.set_attr("FInferShape", MPUpdateInferShape<2, 1, 6>) -.set_attr("FInferType", MPUpdateInferType<2, 1, 6>) -.set_attr("FMutateInputs", - [](const nnvm::NodeAttrs& attrs) { - return std::vector{2, 3, 4}; - }) -.set_attr("FCompute", MPUpdate>) -.add_argument("weight", "NDArray-or-Symbol", "Weight") -.add_argument("grad", "NDArray-or-Symbol", "Gradient") -.add_argument("mean", "NDArray-or-Symbol", "Moving mean") -.add_argument("var", "NDArray-or-Symbol", "Moving variance") -.add_argument("weight32", "NDArray-or-Symbol", "Weight32") -.add_argument("rescale_grad", "NDArray-or-Symbol", - "Rescale gradient to rescale_grad * grad. If NaN, Inf, or 0, " - "the update is skipped.") -.add_arguments(AdaBeliefParam::__FIELDS__()); + .set_num_inputs(6) + .set_num_outputs(1) + .set_attr_parser(ParamParser) + .set_attr("FInferShape", MPUpdateInferShape<2, 1, 6>) + .set_attr("FInferType", MPUpdateInferType<2, 1, 6>) + .set_attr("FMutateInputs", + [](const nnvm::NodeAttrs& attrs) { + return std::vector{2, 3, 4}; + }) + .set_attr("FCompute", MPUpdate>) + .add_argument("weight", "NDArray-or-Symbol", "Weight") + .add_argument("grad", "NDArray-or-Symbol", "Gradient") + .add_argument("mean", "NDArray-or-Symbol", "Moving mean") + .add_argument("var", "NDArray-or-Symbol", "Moving variance") + .add_argument("weight32", "NDArray-or-Symbol", "Weight32") + .add_argument("rescale_grad", + "NDArray-or-Symbol", + "Rescale gradient to rescale_grad * grad. If NaN, Inf, or 0, " + "the update is skipped.") + .add_arguments(AdaBeliefParam::__FIELDS__()); NNVM_REGISTER_OP(_adabelief_update) -.describe(R"code(Update function for AdaBelief optimizer. + .describe(R"code(Update function for AdaBelief optimizer. AdaBelief is seen as a modification of Adam with a different variance estimator. @@ -102,34 +103,35 @@ It updates the weights using:: Note that gradient is rescaled to grad = rescale_grad * grad. If rescale_grad is NaN, Inf, or 0, the update is skipped. ))code" ADD_FILELINE) -.set_num_inputs(5) -.set_num_outputs(1) -.set_attr_parser(ParamParser) -.set_attr("FInferShape", MPUpdateInferShape<4, 1, 5>) -.set_attr("FInferType", MPUpdateInferType<4, 1, 5>) -.set_attr("FMutateInputs", - [](const nnvm::NodeAttrs& attrs) { - return std::vector{2, 3}; - }) -.set_attr("FCompute", MPUpdate>) -.add_argument("weight", "NDArray-or-Symbol", "Weight") -.add_argument("grad", "NDArray-or-Symbol", "Gradient") -.add_argument("mean", "NDArray-or-Symbol", "Moving mean") -.add_argument("var", "NDArray-or-Symbol", "Moving variance") -.add_argument("rescale_grad", "NDArray-or-Symbol", - "Rescale gradient to rescale_grad * grad. If NaN, Inf, or 0, " - "the update is skipped.") -.add_arguments(AdaBeliefParam::__FIELDS__()); - -template<> -void GetScaleFloat(mshadow::Stream *s, const TBlob &scale_blob, float *pScalef) { - MSHADOW_REAL_TYPE_SWITCH(scale_blob.type_flag_, DType, - *pScalef = static_cast(*scale_blob.dptr()); - ) + .set_num_inputs(5) + .set_num_outputs(1) + .set_attr_parser(ParamParser) + .set_attr("FInferShape", MPUpdateInferShape<4, 1, 5>) + .set_attr("FInferType", MPUpdateInferType<4, 1, 5>) + .set_attr("FMutateInputs", + [](const nnvm::NodeAttrs& attrs) { + return std::vector{2, 3}; + }) + .set_attr("FCompute", MPUpdate>) + .add_argument("weight", "NDArray-or-Symbol", "Weight") + .add_argument("grad", "NDArray-or-Symbol", "Gradient") + .add_argument("mean", "NDArray-or-Symbol", "Moving mean") + .add_argument("var", "NDArray-or-Symbol", "Moving variance") + .add_argument("rescale_grad", + "NDArray-or-Symbol", + "Rescale gradient to rescale_grad * grad. If NaN, Inf, or 0, " + "the update is skipped.") + .add_arguments(AdaBeliefParam::__FIELDS__()); + +template <> +void GetScaleFloat(mshadow::Stream* s, const TBlob& scale_blob, float* pScalef) { + MSHADOW_REAL_TYPE_SWITCH( + scale_blob.type_flag_, DType, *pScalef = static_cast(*scale_blob.dptr());) } -static std::vector -ParamToVector(uint32_t num_args, const char *pName[], size_t nParams) { +static std::vector ParamToVector(uint32_t num_args, + const char* pName[], + size_t nParams) { std::vector ret; for (uint32_t i = 0; i < num_args; ++i) { const auto idx = std::to_string(i); @@ -145,7 +147,7 @@ inline uint32_t num_weights(const nnvm::NodeAttrs& attrs) { } NNVM_REGISTER_OP(_multi_adabelief_update) -.describe(R"code(Update function for AdaBelief optimizer. + .describe(R"code(Update function for AdaBelief optimizer. AdaBelief is seen as a modification of Adam with a different variance estimator. @@ -169,39 +171,37 @@ It updates the weights using:: Note that gradient is rescaled to grad = rescale_grad * grad. If rescale_grad is NaN, Inf, or 0, the update is skipped. ))code" ADD_FILELINE) -.set_num_inputs([](const nnvm::NodeAttrs& attrs) { - return num_weights(attrs) * 4 + 1; - }) -.set_num_outputs([](const nnvm::NodeAttrs& attrs) { - return num_weights(attrs); - }) -.set_attr_parser(ParamParser) -.set_attr("FInferShape", MP_MultiAdaBelief_InferShape) -.set_attr("FInferType", ElemwiseType<-1, -1>) -.set_attr("FListInputNames", - [](const NodeAttrs& attrs) { - const char *paramName[] = {"weight_", "grad_", "mean_", "var_", "rescale_grad_"}; - return ParamToVector(num_weights(attrs), paramName, sizeof(paramName)/sizeof(paramName[0])); - }) -// mutable: mean, var -.set_attr("FMutateInputs", - [](const nnvm::NodeAttrs& attrs) { - std::vector ret; - const auto iMax = num_weights(attrs); - for (size_t i = 0; i < iMax; ++i) { - ret.push_back(i * 4 + 2); - ret.push_back(i * 4 + 3); - } - return ret; - }) - -.set_attr("FCompute", multiMPUpdate) -.add_argument("data", "NDArray-or-Symbol[]", "data") -.add_arguments(MultiAdaBeliefParam::__FIELDS__()); - + .set_num_inputs([](const nnvm::NodeAttrs& attrs) { return num_weights(attrs) * 4 + 1; }) + .set_num_outputs([](const nnvm::NodeAttrs& attrs) { return num_weights(attrs); }) + .set_attr_parser(ParamParser) + .set_attr("FInferShape", + MP_MultiAdaBelief_InferShape) + .set_attr("FInferType", ElemwiseType<-1, -1>) + .set_attr( + "FListInputNames", + [](const NodeAttrs& attrs) { + const char* paramName[] = {"weight_", "grad_", "mean_", "var_", "rescale_grad_"}; + return ParamToVector( + num_weights(attrs), paramName, sizeof(paramName) / sizeof(paramName[0])); + }) + // mutable: mean, var + .set_attr("FMutateInputs", + [](const nnvm::NodeAttrs& attrs) { + std::vector ret; + const auto iMax = num_weights(attrs); + for (size_t i = 0; i < iMax; ++i) { + ret.push_back(i * 4 + 2); + ret.push_back(i * 4 + 3); + } + return ret; + }) + + .set_attr("FCompute", multiMPUpdate) + .add_argument("data", "NDArray-or-Symbol[]", "data") + .add_arguments(MultiAdaBeliefParam::__FIELDS__()); NNVM_REGISTER_OP(_multi_mp_adabelief_update) -.describe(R"code(Update function for multi-precision AdaBelief optimizer. + .describe(R"code(Update function for multi-precision AdaBelief optimizer. AdaBelief is seen as a modification of Adam with a different variance estimator. @@ -225,36 +225,37 @@ It updates the weights using:: Note that gradient is rescaled to grad = rescale_grad * grad. If rescale_grad is NaN, Inf, or 0, the update is skipped. ))code" ADD_FILELINE) -.set_num_inputs([](const nnvm::NodeAttrs& attrs) { - return num_weights(attrs) * 5 + 1; - }) -.set_num_outputs([](const nnvm::NodeAttrs& attrs) { - return num_weights(attrs); - }) -.set_attr_parser(ParamParser) -.set_attr("FInferShape", MP_MultiAdaBelief_InferShape) -.set_attr("FInferType", MP_MultiAdaBelief_InferType) -.set_attr("FListInputNames", - [](const NodeAttrs& attrs) { - const char *paramName[] = {"weight_", "grad_", "mean_", "var_", "weight32_", "rescale_grad_"}; - return ParamToVector(num_weights(attrs), paramName, sizeof(paramName)/sizeof(paramName[0])); - }) -// mutable: mean, var, weights32 -.set_attr("FMutateInputs", - [](const nnvm::NodeAttrs& attrs) { - std::vector ret; - const auto iMax = num_weights(attrs); - for (size_t i = 0; i < iMax; ++i) { - ret.push_back(i * 5 + 2); - ret.push_back(i * 5 + 3); - ret.push_back(i * 5 + 4); - } - return ret; - }) - -.set_attr("FCompute", multiMPUpdate) -.add_argument("data", "NDArray-or-Symbol[]", "data") -.add_arguments(MultiAdaBeliefParam::__FIELDS__()); + .set_num_inputs([](const nnvm::NodeAttrs& attrs) { return num_weights(attrs) * 5 + 1; }) + .set_num_outputs([](const nnvm::NodeAttrs& attrs) { return num_weights(attrs); }) + .set_attr_parser(ParamParser) + .set_attr("FInferShape", + MP_MultiAdaBelief_InferShape) + .set_attr("FInferType", + MP_MultiAdaBelief_InferType) + .set_attr( + "FListInputNames", + [](const NodeAttrs& attrs) { + const char* paramName[] = { + "weight_", "grad_", "mean_", "var_", "weight32_", "rescale_grad_"}; + return ParamToVector( + num_weights(attrs), paramName, sizeof(paramName) / sizeof(paramName[0])); + }) + // mutable: mean, var, weights32 + .set_attr("FMutateInputs", + [](const nnvm::NodeAttrs& attrs) { + std::vector ret; + const auto iMax = num_weights(attrs); + for (size_t i = 0; i < iMax; ++i) { + ret.push_back(i * 5 + 2); + ret.push_back(i * 5 + 3); + ret.push_back(i * 5 + 4); + } + return ret; + }) + + .set_attr("FCompute", multiMPUpdate) + .add_argument("data", "NDArray-or-Symbol[]", "data") + .add_arguments(MultiAdaBeliefParam::__FIELDS__()); } // namespace adabelief } // namespace op diff --git a/src/operator/contrib/adabelief.cu b/src/operator/contrib/adabelief.cu index e64dcb4ca006..aac63b9c6cf5 100644 --- a/src/operator/contrib/adabelief.cu +++ b/src/operator/contrib/adabelief.cu @@ -28,13 +28,13 @@ namespace mxnet { namespace op { namespace adabelief { -template<> -void GetScaleFloat(mshadow::Stream *s, const TBlob &scale_blob, float *pScalef) { +template <> +void GetScaleFloat(mshadow::Stream* s, const TBlob& scale_blob, float* pScalef) { MSHADOW_REAL_TYPE_SWITCH(scale_blob.type_flag_, DType, { - DType scale = 0; + DType scale = 0; cudaStream_t stream = mshadow::Stream::GetStream(s); - CUDA_CALL(cudaMemcpyAsync(&scale, scale_blob.dptr(), sizeof(DType), - cudaMemcpyDeviceToHost, stream)); + CUDA_CALL(cudaMemcpyAsync( + &scale, scale_blob.dptr(), sizeof(DType), cudaMemcpyDeviceToHost, stream)); CUDA_CALL(cudaStreamSynchronize(stream)); *pScalef = static_cast(scale); }) @@ -42,16 +42,17 @@ void GetScaleFloat(mshadow::Stream *s, const TBlob &scale_blob, float } // namespace adabelief NNVM_REGISTER_OP(_adabelief_update) -.set_attr("FCompute", adabelief::MPUpdate>); + .set_attr("FCompute", adabelief::MPUpdate>); NNVM_REGISTER_OP(_mp_adabelief_update) -.set_attr("FCompute", adabelief::MPUpdate>); + .set_attr("FCompute", + adabelief::MPUpdate>); NNVM_REGISTER_OP(_multi_adabelief_update) -.set_attr("FCompute", adabelief::multiMPUpdate); + .set_attr("FCompute", adabelief::multiMPUpdate); NNVM_REGISTER_OP(_multi_mp_adabelief_update) -.set_attr("FCompute", adabelief::multiMPUpdate); + .set_attr("FCompute", adabelief::multiMPUpdate); } // namespace op } // namespace mxnet diff --git a/src/operator/contrib/adamw-inl.h b/src/operator/contrib/adamw-inl.h index 56c5ea227862..cbeb0ef5eefd 100644 --- a/src/operator/contrib/adamw-inl.h +++ b/src/operator/contrib/adamw-inl.h @@ -43,29 +43,24 @@ struct AdamWParam : public dmlc::Parameter { float eta; float clip_gradient; DMLC_DECLARE_PARAMETER(AdamWParam) { - DMLC_DECLARE_FIELD(lr) - .describe("Learning rate"); - DMLC_DECLARE_FIELD(beta1) - .set_default(0.9f) - .describe("The decay rate for the 1st moment estimates."); - DMLC_DECLARE_FIELD(beta2) - .set_default(0.999f) - .describe("The decay rate for the 2nd moment estimates."); - DMLC_DECLARE_FIELD(epsilon) - .set_default(1e-8f) - .describe("A small constant for numerical stability."); - DMLC_DECLARE_FIELD(wd) - .set_default(0.0f) - .describe("Weight decay augments the objective function with a " - "regularization term that penalizes large weights. " - "The penalty scales with the square of the magnitude of each weight."); - DMLC_DECLARE_FIELD(eta) - .describe("Learning rate schedule multiplier"); + DMLC_DECLARE_FIELD(lr).describe("Learning rate"); + DMLC_DECLARE_FIELD(beta1).set_default(0.9f).describe( + "The decay rate for the 1st moment estimates."); + DMLC_DECLARE_FIELD(beta2).set_default(0.999f).describe( + "The decay rate for the 2nd moment estimates."); + DMLC_DECLARE_FIELD(epsilon).set_default(1e-8f).describe( + "A small constant for numerical stability."); + DMLC_DECLARE_FIELD(wd).set_default(0.0f).describe( + "Weight decay augments the objective function with a " + "regularization term that penalizes large weights. " + "The penalty scales with the square of the magnitude of each weight."); + DMLC_DECLARE_FIELD(eta).describe("Learning rate schedule multiplier"); DMLC_DECLARE_FIELD(clip_gradient) - .set_default(-1.0f) - .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] " - "If clip_gradient <= 0, gradient clipping is turned off. " - "grad = max(min(grad, clip_gradient), -clip_gradient)."); + .set_default(-1.0f) + .describe( + "Clip gradient to the range of [-clip_gradient, clip_gradient] " + "If clip_gradient <= 0, gradient clipping is turned off. " + "grad = max(min(grad, clip_gradient), -clip_gradient)."); } }; @@ -73,10 +68,10 @@ struct AdamWParam : public dmlc::Parameter { // n_in = 2: weight, grad (fp16) // n_out = 1: weight (fp16) // total_in = 6: weight, grad, mean, var, weight32, rescale_grad (fp32) -template +template inline bool MPUpdateInferShape(const nnvm::NodeAttrs& attrs, - mxnet::ShapeVector *in_attrs, - mxnet::ShapeVector *out_attrs) { + mxnet::ShapeVector* in_attrs, + mxnet::ShapeVector* out_attrs) { CHECK_EQ(in_attrs->size(), static_cast(total_in)) << " in operator " << attrs.name; CHECK_EQ(out_attrs->size(), static_cast(n_out)) << " in operator " << attrs.name; SHAPE_ASSIGN_CHECK(*in_attrs, total_in - 1, mxnet::TShape()); @@ -85,10 +80,10 @@ inline bool MPUpdateInferShape(const nnvm::NodeAttrs& attrs, attrs, in_attrs, out_attrs, mxnet::TShape()); } -template +template inline bool MPUpdateInferType(const nnvm::NodeAttrs& attrs, - std::vector *in_attrs, - std::vector *out_attrs) { + std::vector* in_attrs, + std::vector* out_attrs) { CHECK_EQ(in_attrs->size(), static_cast(total_in)) << " in operator " << attrs.name; CHECK_EQ(out_attrs->size(), static_cast(n_out)) << " in operator " << attrs.name; for (int i = n_in; i < total_in; ++i) { @@ -98,52 +93,75 @@ inline bool MPUpdateInferType(const nnvm::NodeAttrs& attrs, attrs, in_attrs, out_attrs, -1); } -template +template struct MPAdamWKernel { - template - MSHADOW_XINLINE static void Map(int i, DType* out_data, float* mean_data, - float* var_data, const DType* weight_data, const DType* grad_data, float* weight32, - const float param_clip_gradient, const float param_beta1, const float param_beta2, - const float param_eta, const float param_lr, const float param_wd, - const float param_rescale_grad, const float param_epsilon) { - float w = weight32[i]; - float scaled_grad = param_rescale_grad*static_cast(grad_data[i]); + template + MSHADOW_XINLINE static void Map(int i, + DType* out_data, + float* mean_data, + float* var_data, + const DType* weight_data, + const DType* grad_data, + float* weight32, + const float param_clip_gradient, + const float param_beta1, + const float param_beta2, + const float param_eta, + const float param_lr, + const float param_wd, + const float param_rescale_grad, + const float param_epsilon) { + float w = weight32[i]; + float scaled_grad = param_rescale_grad * static_cast(grad_data[i]); if (param_clip_gradient >= 0.0f) scaled_grad = mshadow_op::clip::Map(scaled_grad, param_clip_gradient); float mean = mean_data[i] = param_beta1 * mean_data[i] + (1.0f - param_beta1) * scaled_grad; - float var = var_data[i] = param_beta2 * var_data[i] + - (1.0f - param_beta2) * mshadow_op::square::Map(scaled_grad); + float var = var_data[i] = + param_beta2 * var_data[i] + (1.0f - param_beta2) * mshadow_op::square::Map(scaled_grad); - w -= param_eta * (param_lr * mean / (mshadow_op::square_root::Map(var) + param_epsilon) - + param_wd * w); + w -= param_eta * + (param_lr * mean / (mshadow_op::square_root::Map(var) + param_epsilon) + param_wd * w); weight32[i] = w; KERNEL_ASSIGN(out_data[i], req, w); } }; -template +template struct MPAdamWUpdate { static inline void Forward(const nnvm::NodeAttrs& attrs, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs, - const float rescale_grad) { + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs, + const float rescale_grad) { using namespace mxnet_op; const auto& param = nnvm::get(attrs.parsed); - Stream* s = ctx.get_stream(); + Stream* s = ctx.get_stream(); MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { - Tensor weight = inputs[0].FlatTo2D(s); - Tensor grad = inputs[1].FlatTo2D(s); - Tensor mean = inputs[2].FlatTo2D(s); - Tensor var = inputs[3].FlatTo2D(s); + Tensor weight = inputs[0].FlatTo2D(s); + Tensor grad = inputs[1].FlatTo2D(s); + Tensor mean = inputs[2].FlatTo2D(s); + Tensor var = inputs[3].FlatTo2D(s); Tensor weight32 = inputs[4].FlatTo2D(s); - Tensor out = outputs[0].FlatTo2D(s); + Tensor out = outputs[0].FlatTo2D(s); MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { - Kernel, xpu>::Launch(s, weight.shape_.Size(), out.dptr_, mean.dptr_, - var.dptr_, weight.dptr_, grad.dptr_, weight32.dptr_, param.clip_gradient, param.beta1, - param.beta2, param.eta, param.lr, param.wd, rescale_grad, param.epsilon); + Kernel, xpu>::Launch(s, + weight.shape_.Size(), + out.dptr_, + mean.dptr_, + var.dptr_, + weight.dptr_, + grad.dptr_, + weight32.dptr_, + param.clip_gradient, + param.beta1, + param.beta2, + param.eta, + param.lr, + param.wd, + rescale_grad, + param.epsilon); }); }); } @@ -152,38 +170,39 @@ struct MPAdamWUpdate { /* * \brief adam_w update. */ -template +template struct AdamWUpdate { static inline void Forward(const nnvm::NodeAttrs& attrs, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs, const float rescale_grad) { using namespace mshadow; using namespace mshadow::expr; using namespace mshadow_op; - const auto ¶m = nnvm::get(attrs.parsed); - Stream* s = ctx.get_stream(); + const auto& param = nnvm::get(attrs.parsed); + Stream* s = ctx.get_stream(); MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { - const Tensor &weight = inputs[0].FlatTo2D(s); - Tensor grad = inputs[1].FlatTo2D(s); - Tensor mean = inputs[2].FlatTo2D(s); - Tensor var = inputs[3].FlatTo2D(s); - Tensor out = outputs[0].FlatTo2D(s); + const Tensor& weight = inputs[0].FlatTo2D(s); + Tensor grad = inputs[1].FlatTo2D(s); + Tensor mean = inputs[2].FlatTo2D(s); + Tensor var = inputs[3].FlatTo2D(s); + Tensor out = outputs[0].FlatTo2D(s); grad = scalar(rescale_grad) * grad; if (param.clip_gradient >= 0.0f) grad = F(grad, DType(param.clip_gradient)); - mean = scalar(param.beta1) * mean + scalar(1.f-param.beta1) * grad; - var = scalar(param.beta2) * var + scalar(1.f-param.beta2) * F(grad); + mean = scalar(param.beta1) * mean + scalar(1.f - param.beta1) * grad; + var = scalar(param.beta2) * var + scalar(1.f - param.beta2) * F(grad); - Assign(out, req[0], - weight - - scalar(param.eta) * (scalar(param.lr) * - mean / (F(var) + scalar(param.epsilon)) + - (scalar(param.wd) * weight))); + Assign(out, + req[0], + weight - scalar(param.eta) * + (scalar(param.lr) * mean / + (F(var) + scalar(param.epsilon)) + + (scalar(param.wd) * weight))); }); } }; @@ -201,62 +220,55 @@ struct MultiAdamWParam : public dmlc::Parameter { float clip_gradient; int num_weights; DMLC_DECLARE_PARAMETER(MultiAdamWParam) { - DMLC_DECLARE_FIELD(lrs) - .describe("Learning rates"); - DMLC_DECLARE_FIELD(beta1) - .set_default(0.9f) - .describe("The decay rate for the 1st moment estimates."); - DMLC_DECLARE_FIELD(beta2) - .set_default(0.999f) - .describe("The decay rate for the 2nd moment estimates."); - DMLC_DECLARE_FIELD(epsilon) - .set_default(1e-8f) - .describe("A small constant for numerical stability."); - DMLC_DECLARE_FIELD(wds) - .describe("Weight decay augments the objective function with a " - "regularization term that penalizes large weights. " - "The penalty scales with the square of the magnitude of each weight."); - DMLC_DECLARE_FIELD(etas) - .describe("Learning rates schedule multiplier"); + DMLC_DECLARE_FIELD(lrs).describe("Learning rates"); + DMLC_DECLARE_FIELD(beta1).set_default(0.9f).describe( + "The decay rate for the 1st moment estimates."); + DMLC_DECLARE_FIELD(beta2).set_default(0.999f).describe( + "The decay rate for the 2nd moment estimates."); + DMLC_DECLARE_FIELD(epsilon).set_default(1e-8f).describe( + "A small constant for numerical stability."); + DMLC_DECLARE_FIELD(wds).describe( + "Weight decay augments the objective function with a " + "regularization term that penalizes large weights. " + "The penalty scales with the square of the magnitude of each weight."); + DMLC_DECLARE_FIELD(etas).describe("Learning rates schedule multiplier"); DMLC_DECLARE_FIELD(clip_gradient) - .set_default(-1.0f) - .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] " - "If clip_gradient <= 0, gradient clipping is turned off. " - "grad = max(min(grad, clip_gradient), -clip_gradient)."); - DMLC_DECLARE_FIELD(num_weights) - .set_default(1) - .describe("Number of updated weights."); + .set_default(-1.0f) + .describe( + "Clip gradient to the range of [-clip_gradient, clip_gradient] " + "If clip_gradient <= 0, gradient clipping is turned off. " + "grad = max(min(grad, clip_gradient), -clip_gradient)."); + DMLC_DECLARE_FIELD(num_weights).set_default(1).describe("Number of updated weights."); } }; - -template +template inline bool MP_MultiAdamW_InferShape(const nnvm::NodeAttrs& attrs, - mxnet::ShapeVector *in_attrs, - mxnet::ShapeVector *out_attrs) { + mxnet::ShapeVector* in_attrs, + mxnet::ShapeVector* out_attrs) { const ParamType& param = dmlc::get(attrs.parsed); - CHECK_EQ(in_attrs->size(), input_stride * param.num_weights +1); + CHECK_EQ(in_attrs->size(), input_stride * param.num_weights + 1); CHECK_EQ(out_attrs->size(), param.num_weights); - bool all_inferred = true; - auto& input_shapes = *in_attrs; + bool all_inferred = true; + auto& input_shapes = *in_attrs; auto& output_shapes = *out_attrs; // Learning rates CHECK_EQ(param.lrs.ndim(), param.num_weights) - << "Number of learning rates is inconsistent with num_weights " - << "parameter passed. Expected number of learning rates: " - << param.num_weights << ", and got " << param.lrs.ndim(); + << "Number of learning rates is inconsistent with num_weights " + << "parameter passed. Expected number of learning rates: " << param.num_weights + << ", and got " << param.lrs.ndim(); // Weight decays CHECK_EQ(param.wds.ndim(), param.num_weights) - << "Number of weight decays is inconsistent with num_weights " - << "parameter passed. Expected number of weight decays: " - << param.num_weights << ", and got " << param.wds.ndim(); + << "Number of weight decays is inconsistent with num_weights " + << "parameter passed. Expected number of weight decays: " << param.num_weights << ", and got " + << param.wds.ndim(); // Learning rates schedule multiplier CHECK_EQ(param.etas.ndim(), param.num_weights) - << "Number of learning rates schedule multiplier is inconsistent with num_weights " - << "parameter passed. Expected number of learning rates schedule multiplier: " - << param.num_weights << ", and got " << param.lrs.ndim(); + << "Number of learning rates schedule multiplier is inconsistent with num_weights " + << "parameter passed. Expected number of learning rates schedule multiplier: " + << param.num_weights << ", and got " << param.lrs.ndim(); // Weights, gradients, mean and variance for (int i = 0; i < param.num_weights; ++i) { @@ -268,20 +280,20 @@ inline bool MP_MultiAdamW_InferShape(const nnvm::NodeAttrs& attrs, all_inferred = all_inferred && ElemwiseShape(attrs, &input_vec, &output_vec); } - SHAPE_ASSIGN_CHECK(*in_attrs, param.num_weights*input_stride, mxnet::TShape()); + SHAPE_ASSIGN_CHECK(*in_attrs, param.num_weights * input_stride, mxnet::TShape()); return all_inferred; } template inline bool MP_MultiAdamW_InferType(const nnvm::NodeAttrs& attrs, - std::vector *in_attrs, - std::vector *out_attrs) { + std::vector* in_attrs, + std::vector* out_attrs) { const ParamType& param = dmlc::get(attrs.parsed); - CHECK_EQ(in_attrs->size(), input_stride * param.num_weights +1); + CHECK_EQ(in_attrs->size(), input_stride * param.num_weights + 1); CHECK_EQ(out_attrs->size(), param.num_weights); - bool all_inferred = true; - auto& input_types = *in_attrs; + bool all_inferred = true; + auto& input_types = *in_attrs; auto& output_types = *out_attrs; // Weights, gradients, @@ -291,13 +303,13 @@ inline bool MP_MultiAdamW_InferType(const nnvm::NodeAttrs& attrs, for (int j = 0; j < input_stride - 2 - num_fp32_inputs; ++j) { input_vec.push_back(input_types[i * input_stride + j]); } - all_inferred = all_inferred && - ElemwiseType(attrs, &input_vec, &output_vec); + all_inferred = all_inferred && ElemwiseType( + attrs, &input_vec, &output_vec); } // mean, var for (int i = 0; i < param.num_weights; ++i) { - TYPE_ASSIGN_CHECK(input_types, input_stride * i +2, mshadow::kFloat32); - TYPE_ASSIGN_CHECK(input_types, input_stride * i +3, mshadow::kFloat32); + TYPE_ASSIGN_CHECK(input_types, input_stride * i + 2, mshadow::kFloat32); + TYPE_ASSIGN_CHECK(input_types, input_stride * i + 3, mshadow::kFloat32); } // master copies of weights @@ -307,25 +319,23 @@ inline bool MP_MultiAdamW_InferType(const nnvm::NodeAttrs& attrs, } } - TYPE_ASSIGN_CHECK(input_types, param.num_weights*input_stride, mshadow::kFloat32); + TYPE_ASSIGN_CHECK(input_types, param.num_weights * input_stride, mshadow::kFloat32); return all_inferred; } - -template +template class Adam_type_identity { public: using type = T; }; - -template +template class Adam_single_precision { public: using type = float; }; -template +template struct MultiAdamKernelParam { static const int N = 50; int count; @@ -346,30 +356,32 @@ struct MultiAdamKernelParam { MPDType epsilon; }; -template +template struct MultiMPAdamWKernel { - template - MSHADOW_XINLINE static void Map(int i, const MultiAdamKernelParam& param, - const OpReqType req, const float rescale_grad) { + template + MSHADOW_XINLINE static void Map(int i, + const MultiAdamKernelParam& param, + const OpReqType req, + const float rescale_grad) { for (int index = 0; index < param.count; ++index) { if ((size_t)i < param.sizes[index]) { - MPDType w = has_mixed_precision ? param.weights32[index][i]: - MPDType(param.weights[index][i]); - MPDType scaled_grad = static_cast(rescale_grad)* - static_cast(param.grad_data[index][i]); + MPDType w = + has_mixed_precision ? param.weights32[index][i] : MPDType(param.weights[index][i]); + MPDType scaled_grad = + static_cast(rescale_grad) * static_cast(param.grad_data[index][i]); if (param.clip_gradient >= 0.0f) scaled_grad = mshadow_op::clip::Map(scaled_grad, param.clip_gradient); - const auto mean = param.beta1 * (param.mean_data[index][i]- scaled_grad) + scaled_grad; - const auto adj = mshadow_op::square::Map(scaled_grad); - const auto var = param.beta2 * (param.var_data[index][i] - adj) + adj; + const auto mean = param.beta1 * (param.mean_data[index][i] - scaled_grad) + scaled_grad; + const auto adj = mshadow_op::square::Map(scaled_grad); + const auto var = param.beta2 * (param.var_data[index][i] - adj) + adj; param.mean_data[index][i] = mean; - param.var_data[index][i] = var; - w = w - param.etas[index] * (param.lrs[index] * - mean / (mshadow_op::square_root::Map(var) + param.epsilon) - + param.wds[index] * w); + param.var_data[index][i] = var; + w = w - param.etas[index] * + (param.lrs[index] * mean / (mshadow_op::square_root::Map(var) + param.epsilon) + + param.wds[index] * w); if (has_mixed_precision) param.weights32[index][i] = w; @@ -379,34 +391,34 @@ struct MultiMPAdamWKernel { } }; -template +template void FillMultiAdamKernelParam(const nnvm::NodeAttrs& attrs, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &outputs, - MultiAdamKernelParam *pParam) { - const ParamType& p = nnvm::get(attrs.parsed); + const OpContext& ctx, + const std::vector& inputs, + const std::vector& outputs, + MultiAdamKernelParam* pParam) { + const ParamType& p = nnvm::get(attrs.parsed); mxnet_op::Stream* s = ctx.get_stream(); - pParam->clip_gradient = p.clip_gradient; - pParam->beta1 = p.beta1; - pParam->beta2 = p.beta2; + pParam->clip_gradient = p.clip_gradient; + pParam->beta1 = p.beta1; + pParam->beta2 = p.beta2; pParam->epsilon = p.epsilon; - pParam->count = p.num_weights; - pParam->max_size = 0; + pParam->count = p.num_weights; + pParam->max_size = 0; constexpr bool isSame = std::is_same::value; for (int i = 0; i < pParam->count; ++i) { - const auto idx = i * input_stride; + const auto idx = i * input_stride; pParam->sizes[i] = inputs[idx].shape_.Size(); if (pParam->max_size < pParam->sizes[i]) pParam->max_size = pParam->sizes[i]; - pParam->weights[i] = inputs[idx].FlatTo2D(s).dptr_; + pParam->weights[i] = inputs[idx].FlatTo2D(s).dptr_; pParam->grad_data[i] = inputs[idx + 1].FlatTo2D(s).dptr_; pParam->mean_data[i] = inputs[idx + 2].FlatTo2D(s).dptr_; pParam->var_data[i] = inputs[idx + 3].FlatTo2D(s).dptr_; @@ -422,34 +434,34 @@ void FillMultiAdamKernelParam(const nnvm::NodeAttrs& attrs, memcpy(pParam->wds, p.wds.begin(), pParam->count * sizeof(p.wds[0])); } -template class MPTypeChooser, int input_stride> +template class MPTypeChooser, int input_stride> static inline void MultiAdamWUpdate(const nnvm::NodeAttrs& attrs, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs, const float rescale_grad) { using namespace mxnet_op; Stream* s = ctx.get_stream(); MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { using MPDType = typename MPTypeChooser::type; MultiAdamKernelParam param; - FillMultiAdamKernelParam - (attrs, ctx, inputs, outputs, ¶m); + FillMultiAdamKernelParam( + attrs, ctx, inputs, outputs, ¶m); - Kernel::value>, xpu>:: - Launch(s, param.max_size, param, req[0], rescale_grad); + Kernel::value>, xpu>::Launch( + s, param.max_size, param, req[0], rescale_grad); }); } -template -static void GetScaleFloat(mshadow::Stream *s, const TBlob &scale_blob, float *pScalef); +template +static void GetScaleFloat(mshadow::Stream* s, const TBlob& scale_blob, float* pScalef); -template -bool PrepareInputBlobs(const OpContext &ctx, - const std::vector &inputs, - std::vector *inputs_wo_scale, - float *pScalef) { +template +bool PrepareInputBlobs(const OpContext& ctx, + const std::vector& inputs, + std::vector* inputs_wo_scale, + float* pScalef) { const size_t num_in = inputs.size() - 1; adamw::GetScaleFloat(ctx.get_stream(), inputs[num_in], pScalef); if (!std::isfinite(*pScalef) || *pScalef == 0) @@ -462,12 +474,12 @@ bool PrepareInputBlobs(const OpContext &ctx, return true; } -template +template inline void MPUpdate(const nnvm::NodeAttrs& attrs, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { std::vector inputs_wo_scale; float scalef; if (!PrepareInputBlobs(ctx, inputs, &inputs_wo_scale, &scalef)) @@ -476,23 +488,22 @@ inline void MPUpdate(const nnvm::NodeAttrs& attrs, F::Forward(attrs, ctx, inputs_wo_scale, req, outputs, scalef); } -template +template inline void multiMPUpdate(const nnvm::NodeAttrs& attrs, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { std::vector inputs_wo_scale; float scalef; if (!PrepareInputBlobs(ctx, inputs, &inputs_wo_scale, &scalef)) return; if (!MP) - MultiAdamWUpdate - (attrs, ctx, inputs_wo_scale, req, outputs, scalef); + MultiAdamWUpdate(attrs, ctx, inputs_wo_scale, req, outputs, scalef); else - MultiAdamWUpdate - (attrs, ctx, inputs_wo_scale, req, outputs, scalef); + MultiAdamWUpdate( + attrs, ctx, inputs_wo_scale, req, outputs, scalef); } } // namespace adamw diff --git a/src/operator/contrib/adamw.cc b/src/operator/contrib/adamw.cc index e66250200e54..f375e1b6047e 100644 --- a/src/operator/contrib/adamw.cc +++ b/src/operator/contrib/adamw.cc @@ -33,7 +33,7 @@ DMLC_REGISTER_PARAMETER(AdamWParam); DMLC_REGISTER_PARAMETER(MultiAdamWParam); NNVM_REGISTER_OP(_mp_adamw_update) -.describe(R"code(Update function for multi-precision AdamW optimizer. + .describe(R"code(Update function for multi-precision AdamW optimizer. AdamW is seen as a modification of Adam by decoupling the weight decay from the optimization steps taken w.r.t. the loss function. @@ -57,28 +57,29 @@ It updates the weights using:: Note that gradient is rescaled to grad = rescale_grad * grad. If rescale_grad is NaN, Inf, or 0, the update is skipped. )code" ADD_FILELINE) -.set_num_inputs(6) -.set_num_outputs(1) -.set_attr_parser(ParamParser) -.set_attr("FInferShape", MPUpdateInferShape<2, 1, 6>) -.set_attr("FInferType", MPUpdateInferType<2, 1, 6>) -.set_attr("FMutateInputs", - [](const nnvm::NodeAttrs& attrs) { - return std::vector{2, 3, 4}; - }) -.set_attr("FCompute", adamw::MPUpdate>) -.add_argument("weight", "NDArray-or-Symbol", "Weight") -.add_argument("grad", "NDArray-or-Symbol", "Gradient") -.add_argument("mean", "NDArray-or-Symbol", "Moving mean") -.add_argument("var", "NDArray-or-Symbol", "Moving variance") -.add_argument("weight32", "NDArray-or-Symbol", "Weight32") -.add_argument("rescale_grad", "NDArray-or-Symbol", - "Rescale gradient to rescale_grad * grad. If NaN, Inf, or 0, " - "the update is skipped.") -.add_arguments(AdamWParam::__FIELDS__()); + .set_num_inputs(6) + .set_num_outputs(1) + .set_attr_parser(ParamParser) + .set_attr("FInferShape", MPUpdateInferShape<2, 1, 6>) + .set_attr("FInferType", MPUpdateInferType<2, 1, 6>) + .set_attr("FMutateInputs", + [](const nnvm::NodeAttrs& attrs) { + return std::vector{2, 3, 4}; + }) + .set_attr("FCompute", adamw::MPUpdate>) + .add_argument("weight", "NDArray-or-Symbol", "Weight") + .add_argument("grad", "NDArray-or-Symbol", "Gradient") + .add_argument("mean", "NDArray-or-Symbol", "Moving mean") + .add_argument("var", "NDArray-or-Symbol", "Moving variance") + .add_argument("weight32", "NDArray-or-Symbol", "Weight32") + .add_argument("rescale_grad", + "NDArray-or-Symbol", + "Rescale gradient to rescale_grad * grad. If NaN, Inf, or 0, " + "the update is skipped.") + .add_arguments(AdamWParam::__FIELDS__()); NNVM_REGISTER_OP(_adamw_update) -.describe(R"code(Update function for AdamW optimizer. AdamW is seen as a modification of + .describe(R"code(Update function for AdamW optimizer. AdamW is seen as a modification of Adam by decoupling the weight decay from the optimization steps taken w.r.t. the loss function. Adam update consists of the following steps, where g represents gradient and m, v @@ -100,34 +101,35 @@ It updates the weights using:: Note that gradient is rescaled to grad = rescale_grad * grad. If rescale_grad is NaN, Inf, or 0, the update is skipped. )code" ADD_FILELINE) -.set_num_inputs(5) -.set_num_outputs(1) -.set_attr_parser(ParamParser) -.set_attr("FInferShape", MPUpdateInferShape<4, 1, 5>) -.set_attr("FInferType", MPUpdateInferType<4, 1, 5>) -.set_attr("FMutateInputs", - [](const nnvm::NodeAttrs& attrs) { - return std::vector{2, 3}; - }) -.set_attr("FCompute", adamw::MPUpdate>) -.add_argument("weight", "NDArray-or-Symbol", "Weight") -.add_argument("grad", "NDArray-or-Symbol", "Gradient") -.add_argument("mean", "NDArray-or-Symbol", "Moving mean") -.add_argument("var", "NDArray-or-Symbol", "Moving variance") -.add_argument("rescale_grad", "NDArray-or-Symbol", - "Rescale gradient to rescale_grad * grad. If NaN, Inf, or 0, " - "the update is skipped.") -.add_arguments(AdamWParam::__FIELDS__()); - -template<> -void GetScaleFloat(mshadow::Stream *s, const TBlob &scale_blob, float *pScalef) { - MSHADOW_REAL_TYPE_SWITCH(scale_blob.type_flag_, DType, - *pScalef = static_cast(*scale_blob.dptr()); - ) + .set_num_inputs(5) + .set_num_outputs(1) + .set_attr_parser(ParamParser) + .set_attr("FInferShape", MPUpdateInferShape<4, 1, 5>) + .set_attr("FInferType", MPUpdateInferType<4, 1, 5>) + .set_attr("FMutateInputs", + [](const nnvm::NodeAttrs& attrs) { + return std::vector{2, 3}; + }) + .set_attr("FCompute", adamw::MPUpdate>) + .add_argument("weight", "NDArray-or-Symbol", "Weight") + .add_argument("grad", "NDArray-or-Symbol", "Gradient") + .add_argument("mean", "NDArray-or-Symbol", "Moving mean") + .add_argument("var", "NDArray-or-Symbol", "Moving variance") + .add_argument("rescale_grad", + "NDArray-or-Symbol", + "Rescale gradient to rescale_grad * grad. If NaN, Inf, or 0, " + "the update is skipped.") + .add_arguments(AdamWParam::__FIELDS__()); + +template <> +void GetScaleFloat(mshadow::Stream* s, const TBlob& scale_blob, float* pScalef) { + MSHADOW_REAL_TYPE_SWITCH( + scale_blob.type_flag_, DType, *pScalef = static_cast(*scale_blob.dptr());) } -static std::vector -ParamToVector(uint32_t num_args, const char *pName[], size_t nParams) { +static std::vector ParamToVector(uint32_t num_args, + const char* pName[], + size_t nParams) { std::vector ret; for (uint32_t i = 0; i < num_args; ++i) { const auto idx = std::to_string(i); @@ -143,7 +145,7 @@ inline uint32_t num_weights(const nnvm::NodeAttrs& attrs) { } NNVM_REGISTER_OP(_multi_adamw_update) -.describe(R"code(Update function for AdamW optimizer. + .describe(R"code(Update function for AdamW optimizer. AdamW is seen as a modification of Adam by decoupling the weight decay from the optimization steps taken w.r.t. the loss function. @@ -167,39 +169,36 @@ It updates the weights using:: Note that gradient is rescaled to grad = rescale_grad * grad. If rescale_grad is NaN, Inf, or 0, the update is skipped. )code" ADD_FILELINE) -.set_num_inputs([](const nnvm::NodeAttrs& attrs) { - return num_weights(attrs) * 4 + 1; - }) -.set_num_outputs([](const nnvm::NodeAttrs& attrs) { - return num_weights(attrs); - }) -.set_attr_parser(ParamParser) -.set_attr("FInferShape", MP_MultiAdamW_InferShape) -.set_attr("FInferType", ElemwiseType<-1, -1>) -.set_attr("FListInputNames", - [](const NodeAttrs& attrs) { - const char *paramName[] = {"weight_", "grad_", "mean_", "var_", "rescale_grad_"}; - return ParamToVector(num_weights(attrs), paramName, sizeof(paramName)/sizeof(paramName[0])); - }) -// mutable: mean, var -.set_attr("FMutateInputs", - [](const nnvm::NodeAttrs& attrs) { - std::vector ret; - const auto iMax = num_weights(attrs); - for (size_t i = 0; i < iMax; ++i) { - ret.push_back(i * 4 + 2); - ret.push_back(i * 4 + 3); - } - return ret; - }) - -.set_attr("FCompute", adamw::multiMPUpdate) -.add_argument("data", "NDArray-or-Symbol[]", "data") -.add_arguments(MultiAdamWParam::__FIELDS__()); - + .set_num_inputs([](const nnvm::NodeAttrs& attrs) { return num_weights(attrs) * 4 + 1; }) + .set_num_outputs([](const nnvm::NodeAttrs& attrs) { return num_weights(attrs); }) + .set_attr_parser(ParamParser) + .set_attr("FInferShape", MP_MultiAdamW_InferShape) + .set_attr("FInferType", ElemwiseType<-1, -1>) + .set_attr( + "FListInputNames", + [](const NodeAttrs& attrs) { + const char* paramName[] = {"weight_", "grad_", "mean_", "var_", "rescale_grad_"}; + return ParamToVector( + num_weights(attrs), paramName, sizeof(paramName) / sizeof(paramName[0])); + }) + // mutable: mean, var + .set_attr("FMutateInputs", + [](const nnvm::NodeAttrs& attrs) { + std::vector ret; + const auto iMax = num_weights(attrs); + for (size_t i = 0; i < iMax; ++i) { + ret.push_back(i * 4 + 2); + ret.push_back(i * 4 + 3); + } + return ret; + }) + + .set_attr("FCompute", adamw::multiMPUpdate) + .add_argument("data", "NDArray-or-Symbol[]", "data") + .add_arguments(MultiAdamWParam::__FIELDS__()); NNVM_REGISTER_OP(_multi_mp_adamw_update) -.describe(R"code(Update function for multi-precision AdamW optimizer. + .describe(R"code(Update function for multi-precision AdamW optimizer. AdamW is seen as a modification of Adam by decoupling the weight decay from the optimization steps taken w.r.t. the loss function. @@ -223,37 +222,35 @@ It updates the weights using:: Note that gradient is rescaled to grad = rescale_grad * grad. If rescale_grad is NaN, Inf, or 0, the update is skipped. )code" ADD_FILELINE) -.set_num_inputs([](const nnvm::NodeAttrs& attrs) { - return num_weights(attrs) * 5 + 1; - }) -.set_num_outputs([](const nnvm::NodeAttrs& attrs) { - return num_weights(attrs); - }) -.set_attr_parser(ParamParser) -.set_attr("FInferShape", MP_MultiAdamW_InferShape) -.set_attr("FInferType", MP_MultiAdamW_InferType) -.set_attr("FListInputNames", - [](const NodeAttrs& attrs) { - const char *paramName[] = {"weight_", "grad_", "mean_", "var_", "weight32_", "rescale_grad_"}; - return ParamToVector(num_weights(attrs), paramName, sizeof(paramName)/sizeof(paramName[0])); - }) -// mutable: mean, var, weights32 -.set_attr("FMutateInputs", - [](const nnvm::NodeAttrs& attrs) { - std::vector ret; - const auto iMax = num_weights(attrs); - for (size_t i = 0; i < iMax; ++i) { - ret.push_back(i * 5 + 2); - ret.push_back(i * 5 + 3); - ret.push_back(i * 5 + 4); - } - return ret; - }) - -.set_attr("FCompute", adamw::multiMPUpdate) -.add_argument("data", "NDArray-or-Symbol[]", "data") -.add_arguments(MultiAdamWParam::__FIELDS__()); - + .set_num_inputs([](const nnvm::NodeAttrs& attrs) { return num_weights(attrs) * 5 + 1; }) + .set_num_outputs([](const nnvm::NodeAttrs& attrs) { return num_weights(attrs); }) + .set_attr_parser(ParamParser) + .set_attr("FInferShape", MP_MultiAdamW_InferShape) + .set_attr("FInferType", MP_MultiAdamW_InferType) + .set_attr( + "FListInputNames", + [](const NodeAttrs& attrs) { + const char* paramName[] = { + "weight_", "grad_", "mean_", "var_", "weight32_", "rescale_grad_"}; + return ParamToVector( + num_weights(attrs), paramName, sizeof(paramName) / sizeof(paramName[0])); + }) + // mutable: mean, var, weights32 + .set_attr("FMutateInputs", + [](const nnvm::NodeAttrs& attrs) { + std::vector ret; + const auto iMax = num_weights(attrs); + for (size_t i = 0; i < iMax; ++i) { + ret.push_back(i * 5 + 2); + ret.push_back(i * 5 + 3); + ret.push_back(i * 5 + 4); + } + return ret; + }) + + .set_attr("FCompute", adamw::multiMPUpdate) + .add_argument("data", "NDArray-or-Symbol[]", "data") + .add_arguments(MultiAdamWParam::__FIELDS__()); } // namespace adamw } // namespace op diff --git a/src/operator/contrib/adamw.cu b/src/operator/contrib/adamw.cu index 95fcffbd78e4..2ed0e92dd6c9 100644 --- a/src/operator/contrib/adamw.cu +++ b/src/operator/contrib/adamw.cu @@ -29,29 +29,31 @@ namespace mxnet { namespace op { namespace adamw { -template<> -void GetScaleFloat(mshadow::Stream *s, const TBlob &scale_blob, float *pScalef) { - MSHADOW_REAL_TYPE_SWITCH(scale_blob.type_flag_, DType, { - DType scale = 0; - cudaStream_t stream = mshadow::Stream::GetStream(s); - CUDA_CALL(cudaMemcpyAsync(&scale, scale_blob.dptr(), sizeof(DType), - cudaMemcpyDeviceToHost, stream)); - CUDA_CALL(cudaStreamSynchronize(stream)); - *pScalef = static_cast(scale); - }) -} +template <> +void GetScaleFloat(mshadow::Stream* s, const TBlob& scale_blob, float* pScalef) { + MSHADOW_REAL_TYPE_SWITCH( + scale_blob.type_flag_, + DType, + { + DType scale = 0; + cudaStream_t stream = mshadow::Stream::GetStream(s); + CUDA_CALL(cudaMemcpyAsync( + &scale, scale_blob.dptr(), sizeof(DType), cudaMemcpyDeviceToHost, stream)); + CUDA_CALL(cudaStreamSynchronize(stream)); + *pScalef = static_cast(scale); + })} NNVM_REGISTER_OP(_adamw_update) -.set_attr("FCompute", adamw::MPUpdate>); + .set_attr("FCompute", adamw::MPUpdate>); NNVM_REGISTER_OP(_mp_adamw_update) -.set_attr("FCompute", adamw::MPUpdate>); + .set_attr("FCompute", adamw::MPUpdate>); NNVM_REGISTER_OP(_multi_adamw_update) -.set_attr("FCompute", adamw::multiMPUpdate); + .set_attr("FCompute", adamw::multiMPUpdate); NNVM_REGISTER_OP(_multi_mp_adamw_update) -.set_attr("FCompute", adamw::multiMPUpdate); + .set_attr("FCompute", adamw::multiMPUpdate); } // namespace adamw } // namespace op diff --git a/src/operator/contrib/adaptive_avg_pooling-inl.h b/src/operator/contrib/adaptive_avg_pooling-inl.h index eedab78db0c5..e2abd1dcb324 100644 --- a/src/operator/contrib/adaptive_avg_pooling-inl.h +++ b/src/operator/contrib/adaptive_avg_pooling-inl.h @@ -21,7 +21,7 @@ * \file adaptive_avg_pooling-inl.h * \brief adaptive average pooling operator * \author Hang Zhang -*/ + */ #ifndef MXNET_OPERATOR_CONTRIB_ADAPTIVE_AVG_POOLING_INL_H_ #define MXNET_OPERATOR_CONTRIB_ADAPTIVE_AVG_POOLING_INL_H_ @@ -50,8 +50,9 @@ namespace op { struct AdaptiveAvgPoolParam : public dmlc::Parameter { mxnet::Tuple output_size; DMLC_DECLARE_PARAMETER(AdaptiveAvgPoolParam) { - DMLC_DECLARE_FIELD(output_size).set_default(mxnet::Tuple()) - .describe("int (output size) or a tuple of int for output (height, width)."); + DMLC_DECLARE_FIELD(output_size) + .set_default(mxnet::Tuple()) + .describe("int (output size) or a tuple of int for output (height, width)."); } }; @@ -59,73 +60,70 @@ static inline bool IsWriting(const OpReqType ort) { return ort == kWriteTo || ort == kWriteInplace; } -template -void AdaptiveAvgPoolUpdateOutput(mshadow::Stream *s, - const std::vector &input, - const std::vector &output); +template +void AdaptiveAvgPoolUpdateOutput(mshadow::Stream* s, + const std::vector& input, + const std::vector& output); -template -void AdaptiveAvgPoolUpdateGradInput(mshadow::Stream *s, - const std::vector &input, - const std::vector &output); +template +void AdaptiveAvgPoolUpdateGradInput(mshadow::Stream* s, + const std::vector& input, + const std::vector& output); #if MXNET_USE_CUDA -template -void AdaptiveAvgPoolUpdateOutput(mshadow::Stream *s, - const std::vector &input, - const std::vector &output); - -template -void AdaptiveAvgPoolUpdateGradInput(mshadow::Stream *s, - const std::vector &input, - const std::vector &output); +template +void AdaptiveAvgPoolUpdateOutput(mshadow::Stream* s, + const std::vector& input, + const std::vector& output); + +template +void AdaptiveAvgPoolUpdateGradInput(mshadow::Stream* s, + const std::vector& input, + const std::vector& output); #endif // MXNET_USE_CUDA template inline void AdaptiveAvgPoolOpForward(const nnvm::NodeAttrs& attrs, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { CHECK_EQ(inputs.size(), 1U); CHECK_EQ(outputs.size(), 1U); - mshadow::Stream *s = ctx.get_stream(); + mshadow::Stream* s = ctx.get_stream(); MSHADOW_REAL_TYPE_SWITCH_EX(inputs[0].type_flag_, DType, AccReal, { AdaptiveAvgPoolUpdateOutput(s, inputs, outputs); }); } - template inline void AdaptiveAvgPoolOpBackward(const nnvm::NodeAttrs& attrs, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { CHECK_EQ(inputs.size(), 1U); CHECK_EQ(outputs.size(), 1U); - mshadow::Stream *s = ctx.get_stream(); + mshadow::Stream* s = ctx.get_stream(); if (IsWriting(req[0])) { // zero grad before backwarding - MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, { - Fill(s, outputs[0], kWriteTo, 0); - }) + MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, { Fill(s, outputs[0], kWriteTo, 0); }) } MSHADOW_REAL_TYPE_SWITCH_EX(inputs[0].type_flag_, DType, AccReal, { AdaptiveAvgPoolUpdateGradInput(s, inputs, outputs); }); } - static bool AdaptiveAvgPoolOpInferShape(const nnvm::NodeAttrs& attrs, - mxnet::ShapeVector *in_shape, - mxnet::ShapeVector *out_shape) { + mxnet::ShapeVector* in_shape, + mxnet::ShapeVector* out_shape) { using namespace mshadow; CHECK_EQ(in_shape->size(), 1U) << "Input:[data]"; CHECK_EQ(out_shape->size(), 1U) << "Output:[data]"; const AdaptiveAvgPoolParam& param = nnvm::get(attrs.parsed); mxnet::TShape dshape(in_shape->at(0)); - if (mxnet::op::shape_is_none(dshape)) return false; + if (mxnet::op::shape_is_none(dshape)) + return false; if (param.output_size.ndim() == 0) { dshape[2] = 1; dshape[3] = 1; @@ -145,11 +143,11 @@ static bool AdaptiveAvgPoolOpInferShape(const nnvm::NodeAttrs& attrs, } using namespace mshadow; -template +template MSHADOW_XINLINE int get_stride(Tensor tensor, int idx) { int stride = 1; - for (int i = Dim-2; i >= idx; --i) { - stride *= tensor.size(i+1); + for (int i = Dim - 2; i >= idx; --i) { + stride *= tensor.size(i + 1); } return stride; } diff --git a/src/operator/contrib/adaptive_avg_pooling.cc b/src/operator/contrib/adaptive_avg_pooling.cc index 42c39cc157c6..03ca4288f9fe 100644 --- a/src/operator/contrib/adaptive_avg_pooling.cc +++ b/src/operator/contrib/adaptive_avg_pooling.cc @@ -21,60 +21,59 @@ * \file adaptive_avg_pooling.cc * \brief adaptive average pooling operator * \author Hang Zhang -*/ + */ #include "adaptive_avg_pooling-inl.h" // #include "elemwise_op_common.h" #include "../elemwise_op_common.h" #define START_IND(a, b, c) static_cast(std::floor(static_cast(a * c) / b)) -#define END_IND(a, b, c) static_cast(std::ceil(static_cast((a + 1) * c) / b)) +#define END_IND(a, b, c) static_cast(std::ceil(static_cast((a + 1) * c) / b)) namespace mxnet { namespace op { using namespace mshadow; -template -static void SpatialAdaptiveAveragePooling_updateOutput_frame( - real *input_p, - real *output_p, - int64_t sizeD, - int64_t isizeH, - int64_t isizeW, - int64_t osizeH, - int64_t osizeW, - int64_t istrideD, - int64_t istrideH, - int64_t istrideW) { +template +static void SpatialAdaptiveAveragePooling_updateOutput_frame(real* input_p, + real* output_p, + int64_t sizeD, + int64_t isizeH, + int64_t isizeW, + int64_t osizeH, + int64_t osizeW, + int64_t istrideD, + int64_t istrideH, + int64_t istrideW) { int64_t d; #pragma omp parallel for private(d) \ -num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount()) + num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount()) for (d = 0; d < sizeD; d++) { /* loop over output */ int64_t oh, ow, ih, iw; - int outOffset = d*osizeH*osizeW; + int outOffset = d * osizeH * osizeW; for (oh = 0; oh < osizeH; oh++) { - int istartH = START_IND(oh, osizeH, isizeH); + int istartH = START_IND(oh, osizeH, isizeH); int startOffsetH = istartH * istrideH; - int outOffsetH = oh * osizeW; - int iendH = END_IND(oh, osizeH, isizeH); - int kH = iendH - istartH; + int outOffsetH = oh * osizeW; + int iendH = END_IND(oh, osizeH, isizeH); + int kH = iendH - istartH; for (ow = 0; ow < osizeW; ow++) { int istartW = START_IND(ow, osizeW, isizeW); int iendW = END_IND(ow, osizeW, isizeW); - int kW = iendW - istartW; + int kW = iendW - istartW; /* local pointers */ - real *ip = input_p + d*istrideD + startOffsetH + istartW*istrideW; - real *op = output_p + outOffset + outOffsetH + ow; + real* ip = input_p + d * istrideD + startOffsetH + istartW * istrideW; + real* op = output_p + outOffset + outOffsetH + ow; /* compute local average: */ real sum = 0; for (ih = 0; ih < kH; ih++) { - int ihOffset = ih*istrideH; + int ihOffset = ih * istrideH; for (iw = 0; iw < kW; iw++) { - real val = *(ip + ihOffset + iw*istrideW); + real val = *(ip + ihOffset + iw * istrideW); sum += val; } } @@ -86,41 +85,40 @@ num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount()) } } -template -static void SpatialAdaptiveAveragePooling_updateGradInput_frame( - real *gradInput_p, - real *gradOutput_p, - int64_t sizeD, - int64_t isizeH, - int64_t isizeW, - int64_t osizeH, - int64_t osizeW) { +template +static void SpatialAdaptiveAveragePooling_updateGradInput_frame(real* gradInput_p, + real* gradOutput_p, + int64_t sizeD, + int64_t isizeH, + int64_t isizeW, + int64_t osizeH, + int64_t osizeW) { int64_t d; #pragma omp parallel for private(d) \ -num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount()) + num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount()) for (d = 0; d < sizeD; d++) { - real *gradInput_p_d = gradInput_p + d*isizeW*isizeH; - real *gradOutput_p_d = gradOutput_p + d*osizeW*osizeH; + real* gradInput_p_d = gradInput_p + d * isizeW * isizeH; + real* gradOutput_p_d = gradOutput_p + d * osizeW * osizeH; /* calculate average */ int64_t oh, ow; for (oh = 0; oh < osizeH; oh++) { int istartH = START_IND(oh, osizeH, isizeH); int iendH = END_IND(oh, osizeH, isizeH); - int kH = iendH - istartH; + int kH = iendH - istartH; for (ow = 0; ow < osizeW; ow++) { int istartW = START_IND(ow, osizeW, isizeW); int iendW = END_IND(ow, osizeW, isizeW); - int kW = iendW - istartW; + int kW = iendW - istartW; - real grad_delta = gradOutput_p_d[oh*osizeW +ow] / kH / kW; + real grad_delta = gradOutput_p_d[oh * osizeW + ow] / kH / kW; int ih, iw; for (ih = istartH; ih < iendH; ih++) { for (iw = istartW; iw < iendW; iw++) { /* update gradient */ - gradInput_p_d[ih*isizeW + iw] += grad_delta; + gradInput_p_d[ih * isizeW + iw] += grad_delta; } } } @@ -128,16 +126,15 @@ num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount()) } } - -template -void AdaptiveAvgPoolUpdateOutput(mshadow::Stream *s, - const std::vector &input, - const std::vector &output) { +template +void AdaptiveAvgPoolUpdateOutput(mshadow::Stream* s, + const std::vector& input, + const std::vector& output) { Tensor itensor = input[0].get(s); Tensor otensor = output[0].get(s); - DType *input_data = itensor.dptr_; - DType *output_data = otensor.dptr_; + DType* input_data = itensor.dptr_; + DType* output_data = otensor.dptr_; int64_t sizeB = itensor.size(0); int64_t sizeD = itensor.size(1); @@ -154,28 +151,31 @@ void AdaptiveAvgPoolUpdateOutput(mshadow::Stream *s, int64_t b; #pragma omp parallel for private(b) \ -num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount()) + num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount()) for (b = 0; b < sizeB; b++) { SpatialAdaptiveAveragePooling_updateOutput_frame( - input_data+b*istrideB, output_data+b*sizeD*osizeH*osizeW, - sizeD, - isizeH, isizeW, - osizeH, osizeW, - istrideD, - istrideH, istrideW); + input_data + b * istrideB, + output_data + b * sizeD * osizeH * osizeW, + sizeD, + isizeH, + isizeW, + osizeH, + osizeW, + istrideD, + istrideH, + istrideW); } } - -template -void AdaptiveAvgPoolUpdateGradInput(mshadow::Stream *s, - const std::vector &input, - const std::vector &output) { +template +void AdaptiveAvgPoolUpdateGradInput(mshadow::Stream* s, + const std::vector& input, + const std::vector& output) { Tensor gradOut = input[0].get(s); - Tensor gradIn = output[0].get(s); + Tensor gradIn = output[0].get(s); - DType *gradOutput_data = gradOut.dptr_; - DType *gradInput_data = gradIn.dptr_; + DType* gradOutput_data = gradOut.dptr_; + DType* gradInput_data = gradIn.dptr_; int64_t sizeB = gradIn.size(0); int64_t sizeD = gradIn.size(1); @@ -187,21 +187,23 @@ void AdaptiveAvgPoolUpdateGradInput(mshadow::Stream *s, int64_t b; #pragma omp parallel for private(b) \ -num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount()) + num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount()) for (b = 0; b < sizeB; b++) { SpatialAdaptiveAveragePooling_updateGradInput_frame( - gradInput_data+b*sizeD*isizeH*isizeW, gradOutput_data+b*sizeD*osizeH*osizeW, - sizeD, - isizeH, isizeW, - osizeH, osizeW); + gradInput_data + b * sizeD * isizeH * isizeW, + gradOutput_data + b * sizeD * osizeH * osizeW, + sizeD, + isizeH, + isizeW, + osizeH, + osizeW); } } - DMLC_REGISTER_PARAMETER(AdaptiveAvgPoolParam); NNVM_REGISTER_OP(_contrib_AdaptiveAvgPooling2D) -.describe(R"code( + .describe(R"code( Applies a 2D adaptive average pooling over a 4D input with the shape of (NCHW). The pooling kernel and stride sizes are automatically chosen for desired output sizes. @@ -212,23 +214,22 @@ The pooling kernel and stride sizes are automatically chosen for desired output (N x C x height x width) for any input (NCHW). )code" ADD_FILELINE) -.set_attr_parser(ParamParser) -.set_num_inputs(1) -.set_num_outputs(1) -.set_attr("FInferShape", AdaptiveAvgPoolOpInferShape) -.set_attr("FCompute", AdaptiveAvgPoolOpForward) -.set_attr("FGradient", - ElemwiseGradUseNone{"_backward_contrib_AdaptiveAvgPooling2D"}) -.add_argument("data", "NDArray-or-Symbol", "Input data") -.add_arguments(AdaptiveAvgPoolParam::__FIELDS__()); + .set_attr_parser(ParamParser) + .set_num_inputs(1) + .set_num_outputs(1) + .set_attr("FInferShape", AdaptiveAvgPoolOpInferShape) + .set_attr("FCompute", AdaptiveAvgPoolOpForward) + .set_attr("FGradient", + ElemwiseGradUseNone{"_backward_contrib_AdaptiveAvgPooling2D"}) + .add_argument("data", "NDArray-or-Symbol", "Input data") + .add_arguments(AdaptiveAvgPoolParam::__FIELDS__()); NNVM_REGISTER_OP(_backward_contrib_AdaptiveAvgPooling2D) -.set_attr_parser(ParamParser) -.set_num_inputs(1) -.set_num_outputs(1) -.set_attr("TIsBackward", true) -.set_attr("FCompute", AdaptiveAvgPoolOpBackward); - + .set_attr_parser(ParamParser) + .set_num_inputs(1) + .set_num_outputs(1) + .set_attr("TIsBackward", true) + .set_attr("FCompute", AdaptiveAvgPoolOpBackward); } // namespace op } // namespace mxnet diff --git a/src/operator/contrib/adaptive_avg_pooling.cu b/src/operator/contrib/adaptive_avg_pooling.cu index 375c420a0440..3dd37fa91fbb 100644 --- a/src/operator/contrib/adaptive_avg_pooling.cu +++ b/src/operator/contrib/adaptive_avg_pooling.cu @@ -21,23 +21,25 @@ * \file adaptive_avg_pooling.cu * \brief adaptive average pooling operator * \author Hang Zhang -*/ + */ #include #include #include "adaptive_avg_pooling-inl.h" #define START_IND(a, b, c) static_cast(floor(static_cast(a * c) / b)) -#define END_IND(a, b, c) static_cast(ceil(static_cast((a + 1) * c) / b)) -#define CUDA_MAX_THREADS 1024 // this is safe, in reality 256 is our limit +#define END_IND(a, b, c) static_cast(ceil(static_cast((a + 1) * c) / b)) +#define CUDA_MAX_THREADS 1024 // this is safe, in reality 256 is our limit namespace mxnet { namespace op { using namespace mshadow; -template +template struct ScalarConvert { - static __host__ __device__ __forceinline__ Out to(const In v) { return (Out) v; } + static __host__ __device__ __forceinline__ Out to(const In v) { + return (Out)v; + } }; /* @@ -46,10 +48,15 @@ struct ScalarConvert { * 4D input, 4D output */ template -__global__ void adaptiveaveragepool(T *input, T *output, - int isizeH, int isizeW, - int osizeH, int osizeW, - int64_t istrideD, int64_t istrideH, int64_t istrideW) { +__global__ void adaptiveaveragepool(T* input, + T* output, + int isizeH, + int isizeW, + int osizeH, + int osizeW, + int64_t istrideD, + int64_t istrideH, + int64_t istrideW) { // iterators on output pixels int oh, ow; @@ -57,36 +64,36 @@ __global__ void adaptiveaveragepool(T *input, T *output, int o_plane = blockIdx.x; int i_plane = o_plane; - output = output + o_plane*osizeH*osizeW; - input = input + i_plane*istrideD; + output = output + o_plane * osizeH * osizeW; + input = input + i_plane * istrideD; - int ostartH = blockDim.y*blockIdx.y + threadIdx.y; - int oendH = osizeH; - const int ostepH = blockDim.y*gridDim.y; + int ostartH = blockDim.y * blockIdx.y + threadIdx.y; + int oendH = osizeH; + const int ostepH = blockDim.y * gridDim.y; - int ostartW = threadIdx.x; - int oendW = osizeW; + int ostartW = threadIdx.x; + int oendW = osizeW; const int ostepW = blockDim.x; // For all output pixels... for (oh = ostartH; oh < oendH; oh += ostepH) { int istartH = START_IND(oh, osizeH, isizeH); int iendH = END_IND(oh, osizeH, isizeH); - int kH = iendH - istartH; + int kH = iendH - istartH; for (ow = ostartW; ow < oendW; ow += ostepW) { int istartW = START_IND(ow, osizeW, isizeW); int iendW = END_IND(ow, osizeW, isizeW); - int kW = iendW - istartW; + int kW = iendW - istartW; // Compute the average pooling over corresponding input pixels - T *ptr_input = input + istartH*istrideH + istartW*istrideW; - T *ptr_output = output + oh*osizeW + ow; - T sum = ScalarConvert::to(0); + T* ptr_input = input + istartH * istrideH + istartW * istrideW; + T* ptr_output = output + oh * osizeW + ow; + T sum = ScalarConvert::to(0); int ih, iw; for (ih = 0; ih < kH; ++ih) { for (iw = 0; iw < kW; ++iw) { - T val = ptr_input[iw*istrideW]; + T val = ptr_input[iw * istrideW]; sum += val; } ptr_input += istrideH; // next input line @@ -103,10 +110,12 @@ __global__ void adaptiveaveragepool(T *input, T *output, * (uses atomic add) */ template -__global__ void atomicadaptiveaveragegradinput( - T *gradInput, T *gradOutput, - int isizeH, int isizeW, int osizeH, int osizeW -) { +__global__ void atomicadaptiveaveragegradinput(T* gradInput, + T* gradOutput, + int isizeH, + int isizeW, + int osizeH, + int osizeW) { // iterators on output indices int oh, ow; @@ -114,32 +123,32 @@ __global__ void atomicadaptiveaveragegradinput( int o_plane = blockIdx.x; int i_plane = o_plane; - gradOutput = gradOutput + o_plane*osizeW*osizeH; - gradInput = gradInput + i_plane*isizeW*isizeH; + gradOutput = gradOutput + o_plane * osizeW * osizeH; + gradInput = gradInput + i_plane * isizeW * isizeH; - int ostartH = blockDim.y*blockIdx.y + threadIdx.y; - int oendH = osizeH; - int ostepH = blockDim.y*gridDim.y; + int ostartH = blockDim.y * blockIdx.y + threadIdx.y; + int oendH = osizeH; + int ostepH = blockDim.y * gridDim.y; int ostartW = threadIdx.x; - int oendW = osizeW; - int ostepW = blockDim.x; + int oendW = osizeW; + int ostepW = blockDim.x; // For all output pixels... for (oh = ostartH; oh < oendH; oh += ostepH) { int istartH = START_IND(oh, osizeH, isizeH); int iendH = END_IND(oh, osizeH, isizeH); - int kH = iendH - istartH; + int kH = iendH - istartH; for (ow = ostartW; ow < oendW; ow += ostepW) { int istartW = START_IND(ow, osizeW, isizeW); int iendW = END_IND(ow, osizeW, isizeW); - int kW = iendW - istartW; + int kW = iendW - istartW; // Compute the gradients for over corresponding input pixels - T *ptr_gradInput = gradInput + istartH*isizeW + istartW; - T *ptr_gradOutput = gradOutput + oh*osizeW + ow; - T grad_delta = *ptr_gradOutput / kW / kH; + T* ptr_gradInput = gradInput + istartH * isizeW + istartW; + T* ptr_gradOutput = gradOutput + oh * osizeW + ow; + T grad_delta = *ptr_gradOutput / kW / kH; int ih, iw; for (ih = 0; ih < kH; ++ih) { @@ -153,16 +162,15 @@ __global__ void atomicadaptiveaveragegradinput( } } - -template -void AdaptiveAvgPoolUpdateOutput(mshadow::Stream *s, - const std::vector &input, - const std::vector &output) { +template +void AdaptiveAvgPoolUpdateOutput(mshadow::Stream* s, + const std::vector& input, + const std::vector& output) { Tensor itensor = input[0].get(s); Tensor otensor = output[0].get(s); - DType *input_data = itensor.dptr_; - DType *output_data = otensor.dptr_; + DType* input_data = itensor.dptr_; + DType* output_data = otensor.dptr_; int64_t sizeB = itensor.size(0); int64_t sizeD = itensor.size(1); @@ -183,21 +191,20 @@ void AdaptiveAvgPoolUpdateOutput(mshadow::Stream *s, cudaStream_t stream = mshadow::Stream::GetStream(s); // run averagepool kernel - adaptiveaveragepool <<>> ( - input_data, output_data, isizeH, isizeW, osizeH, osizeW, - istrideD, istrideH, istrideW); + adaptiveaveragepool<<>>( + input_data, output_data, isizeH, isizeW, osizeH, osizeW, istrideD, istrideH, istrideW); MSHADOW_CUDA_POST_KERNEL_CHECK(AdaptiveAvgPoolUpdateOutput); } -template -void AdaptiveAvgPoolUpdateGradInput(mshadow::Stream *s, - const std::vector &input, - const std::vector &output) { +template +void AdaptiveAvgPoolUpdateGradInput(mshadow::Stream* s, + const std::vector& input, + const std::vector& output) { Tensor gradOut = input[0].get(s); - Tensor gradIn = output[0].get(s); + Tensor gradIn = output[0].get(s); - DType *gradOutput_data = gradOut.dptr_; - DType *gradInput_data = gradIn.dptr_; + DType* gradOutput_data = gradOut.dptr_; + DType* gradInput_data = gradIn.dptr_; int64_t sizeB = gradIn.size(0); int64_t sizeD = gradIn.size(1); @@ -214,16 +221,16 @@ void AdaptiveAvgPoolUpdateGradInput(mshadow::Stream *s, cudaStream_t stream = mshadow::Stream::GetStream(s); // run updateGradInput kernel, accumulate gradients atomically - atomicadaptiveaveragegradinput <<>> ( - gradInput_data, gradOutput_data, isizeH, isizeW, osizeH, osizeW); + atomicadaptiveaveragegradinput<<>>( + gradInput_data, gradOutput_data, isizeH, isizeW, osizeH, osizeW); MSHADOW_CUDA_POST_KERNEL_CHECK(AdaptiveAvgPoolUpdateGradInput); } NNVM_REGISTER_OP(_contrib_AdaptiveAvgPooling2D) -.set_attr("FCompute", AdaptiveAvgPoolOpForward); + .set_attr("FCompute", AdaptiveAvgPoolOpForward); NNVM_REGISTER_OP(_backward_contrib_AdaptiveAvgPooling2D) -.set_attr("FCompute", AdaptiveAvgPoolOpBackward); + .set_attr("FCompute", AdaptiveAvgPoolOpBackward); } // namespace op } // namespace mxnet diff --git a/src/operator/contrib/allclose_op-inl.h b/src/operator/contrib/allclose_op-inl.h index c54b2630924c..a3e3a8ff98ba 100644 --- a/src/operator/contrib/allclose_op-inl.h +++ b/src/operator/contrib/allclose_op-inl.h @@ -37,29 +37,24 @@ namespace mxnet { namespace op { // Intermediate and Output data types could be integers OR unsigned characters -#define USE_INTEGER 0 +#define USE_INTEGER 0 #if USE_INTEGER - #define INTERM_DATA_TYPE int32_t - #define OUT_DATA_TYPE mshadow::kInt32 +#define INTERM_DATA_TYPE int32_t +#define OUT_DATA_TYPE mshadow::kInt32 #else - #define INTERM_DATA_TYPE uint8_t - #define OUT_DATA_TYPE mshadow::kUint8 +#define INTERM_DATA_TYPE uint8_t +#define OUT_DATA_TYPE mshadow::kUint8 #endif struct AllCloseParam : public dmlc::Parameter { float rtol, atol; bool equal_nan; DMLC_DECLARE_PARAMETER(AllCloseParam) { - DMLC_DECLARE_FIELD(rtol) - .set_default(1e-05) - .describe("Relative tolerance."); - DMLC_DECLARE_FIELD(atol) - .set_default(1e-08) - .describe("Absolute tolerance."); - DMLC_DECLARE_FIELD(equal_nan) - .set_default(true) - .describe("Whether to compare NaN's as equal. If True, NaN's in A will be considered equal " - "to NaN's in B in the output array."); + DMLC_DECLARE_FIELD(rtol).set_default(1e-05).describe("Relative tolerance."); + DMLC_DECLARE_FIELD(atol).set_default(1e-08).describe("Absolute tolerance."); + DMLC_DECLARE_FIELD(equal_nan).set_default(true).describe( + "Whether to compare NaN's as equal. If True, NaN's in A will be considered equal " + "to NaN's in B in the output array."); } }; @@ -74,8 +69,8 @@ inline bool AllCloseShape(const nnvm::NodeAttrs& attrs, } inline bool AllCloseType(const nnvm::NodeAttrs& attrs, - std::vector *in_attrs, - std::vector *out_attrs) { + std::vector* in_attrs, + std::vector* out_attrs) { CHECK_EQ(in_attrs->size(), 2U); CHECK_EQ(out_attrs->size(), 1U); @@ -86,44 +81,51 @@ inline bool AllCloseType(const nnvm::NodeAttrs& attrs, using mshadow::isnan_typed::IsNan; -template +template struct allclose_forward { - template - MSHADOW_XINLINE static void Map(int i, INTERM_DATA_TYPE *out_data, - const DType *in_a, const DType *in_b, - const float rtol, const float atol, bool equal_nan) { - const DType a = in_a[i], b = in_b[i]; - bool val; - if (IsNan(a) || IsNan(b)) - val = equal_nan && IsNan(a) == IsNan(b); - else - val = a == b || (a > b? a - b : b - a) <= atol + (b > 0? rtol * b : (-rtol) * b); - - KERNEL_ASSIGN(out_data[i], req, val? 1 : 0); + template + MSHADOW_XINLINE static void Map(int i, + INTERM_DATA_TYPE* out_data, + const DType* in_a, + const DType* in_b, + const float rtol, + const float atol, + bool equal_nan) { + const DType a = in_a[i], b = in_b[i]; + bool val; + if (IsNan(a) || IsNan(b)) + val = equal_nan && IsNan(a) == IsNan(b); + else + val = a == b || (a > b ? a - b : b - a) <= atol + (b > 0 ? rtol * b : (-rtol) * b); + + KERNEL_ASSIGN(out_data[i], req, val ? 1 : 0); } }; -template -size_t GetAdditionalMemoryLogical(mshadow::Stream *s, const int num_items); +template +size_t GetAdditionalMemoryLogical(mshadow::Stream* s, const int num_items); -template -INTERM_DATA_TYPE *GetAdditionalMemoryLogical(const OpContext& ctx, - int num_items, size_t *pExtraStorageBytes) { -// Get length of the additional memory (which is used only by DeviceReduce::Min(...) on gpu) +template +INTERM_DATA_TYPE* GetAdditionalMemoryLogical(const OpContext& ctx, + int num_items, + size_t* pExtraStorageBytes) { + // Get length of the additional memory (which is used only by DeviceReduce::Min(...) on gpu) *pExtraStorageBytes = GetAdditionalMemoryLogical(ctx.get_stream(), num_items); const size_t workspace_total_bytes_ = num_items * sizeof(INTERM_DATA_TYPE) + *pExtraStorageBytes; - mshadow::Tensor workspace = - ctx.requested[0].get_space_typed( + mshadow::Tensor workspace = ctx.requested[0].get_space_typed( mshadow::Shape1(workspace_total_bytes_), ctx.get_stream()); - return reinterpret_cast(workspace.dptr_); + return reinterpret_cast(workspace.dptr_); } -template -void GetResultLogical(mshadow::Stream *s, INTERM_DATA_TYPE *workMem, size_t extraStorageBytes, - int num_items, INTERM_DATA_TYPE *outPntr); +template +void GetResultLogical(mshadow::Stream* s, + INTERM_DATA_TYPE* workMem, + size_t extraStorageBytes, + int num_items, + INTERM_DATA_TYPE* outPntr); -template +template void AllClose(const nnvm::NodeAttrs& attrs, const OpContext& ctx, const std::vector& inputs, @@ -133,24 +135,29 @@ void AllClose(const nnvm::NodeAttrs& attrs, CHECK_EQ(outputs.size(), 1U); CHECK_EQ(req.size(), 1U); - const TBlob& in0 = inputs[0]; - const TBlob& in1 = inputs[1]; + const TBlob& in0 = inputs[0]; + const TBlob& in1 = inputs[1]; const int num_items = in0.Size(); size_t extraStorageBytes; - auto workspaceMem = GetAdditionalMemoryLogical(ctx, num_items, &extraStorageBytes); - auto s = ctx.get_stream(); + auto workspaceMem = GetAdditionalMemoryLogical(ctx, num_items, &extraStorageBytes); + auto s = ctx.get_stream(); const AllCloseParam& param = nnvm::get(attrs.parsed); using namespace mxnet_op; MSHADOW_TYPE_SWITCH(in0.type_flag_, DType, { MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { - Kernel, xpu>::Launch( - s, num_items, workspaceMem, in0.dptr(), in1.dptr(), - param.rtol, param.atol, param.equal_nan); + Kernel, xpu>::Launch(s, + num_items, + workspaceMem, + in0.dptr(), + in1.dptr(), + param.rtol, + param.atol, + param.equal_nan); }); }); - auto *pOut = outputs[0].dptr(); + auto* pOut = outputs[0].dptr(); GetResultLogical(s, workspaceMem, extraStorageBytes, num_items, pOut); } diff --git a/src/operator/contrib/allclose_op.cc b/src/operator/contrib/allclose_op.cc index 0f69aba523ae..e3eeb4cd1745 100644 --- a/src/operator/contrib/allclose_op.cc +++ b/src/operator/contrib/allclose_op.cc @@ -30,7 +30,8 @@ namespace op { DMLC_REGISTER_PARAMETER(AllCloseParam); NNVM_REGISTER_OP(_contrib_allclose) -.describe(R"code(This operators implements the numpy.allclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False) + .describe( + R"code(This operators implements the numpy.allclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False) .. math:: @@ -53,33 +54,38 @@ Examples:: y = True )code" ADD_FILELINE) -.set_attr_parser(ParamParser) -.set_num_inputs(2) -.set_num_outputs(1) -.set_attr("FListInputNames", - [](const NodeAttrs& attrs) { - return std::vector{"a", "b"}; - }) -.set_attr("FInferShape", AllCloseShape) -.set_attr("FInferType", AllCloseType) -.set_attr("FCompute", AllClose) -.set_attr("FResourceRequest", [](const NodeAttrs& n) { - return std::vector{ResourceRequest::kTempSpace}; -}) -.add_argument("a", "NDArray-or-Symbol", "Input array a") -.add_argument("b", "NDArray-or-Symbol", "Input array b") -.add_arguments(AllCloseParam::__FIELDS__()); + .set_attr_parser(ParamParser) + .set_num_inputs(2) + .set_num_outputs(1) + .set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"a", "b"}; + }) + .set_attr("FInferShape", AllCloseShape) + .set_attr("FInferType", AllCloseType) + .set_attr("FCompute", AllClose) + .set_attr("FResourceRequest", + [](const NodeAttrs& n) { + return std::vector{ResourceRequest::kTempSpace}; + }) + .add_argument("a", "NDArray-or-Symbol", "Input array a") + .add_argument("b", "NDArray-or-Symbol", "Input array b") + .add_arguments(AllCloseParam::__FIELDS__()); -template<> -size_t GetAdditionalMemoryLogical(mshadow::Stream *s, const int num_items) { +template <> +size_t GetAdditionalMemoryLogical(mshadow::Stream* s, const int num_items) { return 0; } -template<> -void GetResultLogical(mshadow::Stream *s, INTERM_DATA_TYPE *workMem, - size_t extraStorageBytes, int num_items, INTERM_DATA_TYPE *outPntr) { - while (num_items-- > 0 && workMem[num_items]) {} - outPntr[0] = num_items >= 0? 0 : 1; +template <> +void GetResultLogical(mshadow::Stream* s, + INTERM_DATA_TYPE* workMem, + size_t extraStorageBytes, + int num_items, + INTERM_DATA_TYPE* outPntr) { + while (num_items-- > 0 && workMem[num_items]) { + } + outPntr[0] = num_items >= 0 ? 0 : 1; } } // namespace op diff --git a/src/operator/contrib/allclose_op.cu b/src/operator/contrib/allclose_op.cu index f923ab060813..0bb40770d10a 100644 --- a/src/operator/contrib/allclose_op.cu +++ b/src/operator/contrib/allclose_op.cu @@ -28,31 +28,33 @@ namespace mxnet { namespace op { -template -size_t GetAdditionalMemory(mshadow::Stream *s, const int num_items) { - T *d_in = nullptr; - T *d_out = nullptr; +template +size_t GetAdditionalMemory(mshadow::Stream* s, const int num_items) { + T* d_in = nullptr; + T* d_out = nullptr; size_t temp_storage_bytes = 0; - cudaStream_t stream = mshadow::Stream::GetStream(s); + cudaStream_t stream = mshadow::Stream::GetStream(s); cub::DeviceReduce::Min(nullptr, temp_storage_bytes, d_in, d_out, num_items, stream); return temp_storage_bytes; } -template<> -size_t GetAdditionalMemoryLogical(mshadow::Stream *s, const int num_items) { +template <> +size_t GetAdditionalMemoryLogical(mshadow::Stream* s, const int num_items) { return GetAdditionalMemory(s, num_items); } -template<> -void GetResultLogical(mshadow::Stream *s, INTERM_DATA_TYPE *workMem, - size_t extraStorageBytes, int num_items, INTERM_DATA_TYPE *outPntr) { +template <> +void GetResultLogical(mshadow::Stream* s, + INTERM_DATA_TYPE* workMem, + size_t extraStorageBytes, + int num_items, + INTERM_DATA_TYPE* outPntr) { cudaStream_t stream = mshadow::Stream::GetStream(s); - cub::DeviceReduce::Min(workMem + num_items, extraStorageBytes, - workMem, outPntr, num_items, stream); + cub::DeviceReduce::Min( + workMem + num_items, extraStorageBytes, workMem, outPntr, num_items, stream); } -NNVM_REGISTER_OP(_contrib_allclose) -.set_attr("FCompute", AllClose); +NNVM_REGISTER_OP(_contrib_allclose).set_attr("FCompute", AllClose); } // namespace op } // namespace mxnet diff --git a/src/operator/contrib/batch_norm_relu.cc b/src/operator/contrib/batch_norm_relu.cc index c35f6c9c9ad8..51193c6433e0 100644 --- a/src/operator/contrib/batch_norm_relu.cc +++ b/src/operator/contrib/batch_norm_relu.cc @@ -22,7 +22,7 @@ * \file batch_norm_relu.cc * \brief * \author Xinyu Chen -*/ + */ #include "../nn/batch_norm-inl.h" #include @@ -37,39 +37,43 @@ namespace op { namespace batchnormrelu { -enum BatchNormWithReLUOpInputs {kData, kGamma, kBeta, kInMovingMean, - kInMovingVar}; // kGamma: weights, kBeta: biases -enum BatchNormWithReLUOpOutputs {kOut, kMean, kVar, kWorkspace}; // req, out_data -enum BatchNormWithReLUOpResource {kTempSpace}; -enum BatchNormWithReLUOpAuxiliary {kMovingMean, kMovingVar}; // aux_states +enum BatchNormWithReLUOpInputs { + kData, + kGamma, + kBeta, + kInMovingMean, + kInMovingVar +}; // kGamma: weights, kBeta: biases +enum BatchNormWithReLUOpOutputs { kOut, kMean, kVar, kWorkspace }; // req, out_data +enum BatchNormWithReLUOpResource { kTempSpace }; +enum BatchNormWithReLUOpAuxiliary { kMovingMean, kMovingVar }; // aux_states /*! \brief Default channel axis if none specified in the params */ constexpr int DEFAULT_AXIS = 1; } // namespace batchnormrelu static bool BatchNormWithReLUShape(const nnvm::NodeAttrs& attrs, - mxnet::ShapeVector *in_shape, - mxnet::ShapeVector *out_shape) { + mxnet::ShapeVector* in_shape, + mxnet::ShapeVector* out_shape) { const BatchNormParam& param = nnvm::get(attrs.parsed); using namespace mshadow; CHECK_EQ(in_shape->size(), 5U) << "Input:[data, gamma, beta, MovingMean, MovingVar]"; CHECK_EQ(out_shape->size(), 4U); - const mxnet::TShape &dshape = in_shape->at(batchnormrelu::kData); + const mxnet::TShape& dshape = in_shape->at(batchnormrelu::kData); if (!mxnet::ndim_is_known(dshape)) { return false; } - const size_t channelAxis = static_cast(param.axis < 0 - ? static_cast(dshape.ndim()) + param.axis - : param.axis); + const size_t channelAxis = static_cast( + param.axis < 0 ? static_cast(dshape.ndim()) + param.axis : param.axis); CHECK_LT(channelAxis, dshape.ndim()) << "Channel axis out of range: " << param.axis; const int channelCount = dshape[channelAxis]; - in_shape->at(batchnormrelu::kGamma) = mxnet::TShape(Shape1(channelCount)); - in_shape->at(batchnormrelu::kBeta) = mxnet::TShape(Shape1(channelCount)); + in_shape->at(batchnormrelu::kGamma) = mxnet::TShape(Shape1(channelCount)); + in_shape->at(batchnormrelu::kBeta) = mxnet::TShape(Shape1(channelCount)); in_shape->at(batchnormrelu::kInMovingMean) = mxnet::TShape(Shape1(channelCount)); // kMovingMean - in_shape->at(batchnormrelu::kInMovingVar) = mxnet::TShape(Shape1(channelCount)); // kMovingVar + in_shape->at(batchnormrelu::kInMovingVar) = mxnet::TShape(Shape1(channelCount)); // kMovingVar out_shape->clear(); out_shape->push_back(dshape); // kOut @@ -80,7 +84,8 @@ static bool BatchNormWithReLUShape(const nnvm::NodeAttrs& attrs, } static bool BatchNormWithReLUType(const nnvm::NodeAttrs& attrs, - std::vector *in_type, std::vector *out_type) { + std::vector* in_type, + std::vector* out_type) { using namespace mshadow; CHECK_GE(in_type->size(), 1U); const size_t n_out = 4; @@ -93,20 +98,20 @@ static bool BatchNormWithReLUType(const nnvm::NodeAttrs& attrs, if (type_is_none(dtype)) { // Input type is undefined, we try backward inference if (out_type->size() == 0 || type_is_none((*out_type)[0])) { - // Neither the input nor the output are defined, - // types cannot be infered for this op - return false; + // Neither the input nor the output are defined, + // types cannot be infered for this op + return false; } else { - // Input type is undefined but output type is: backward inference - dtype = (*out_type)[0]; - (*in_type)[0] = dtype; - MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, { - dtype_param = mshadow::DataType::kFlag; }); + // Input type is undefined but output type is: backward inference + dtype = (*out_type)[0]; + (*in_type)[0] = dtype; + MSHADOW_REAL_TYPE_SWITCH_EX( + dtype, DTypeX, AccRealX, { dtype_param = mshadow::DataType::kFlag; }); } } else { // Input type is defined but output type is not: forward inference - MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, { - dtype_param = mshadow::DataType::kFlag; }); + MSHADOW_REAL_TYPE_SWITCH_EX( + dtype, DTypeX, AccRealX, { dtype_param = mshadow::DataType::kFlag; }); out_type->clear(); out_type->push_back(dtype); for (size_t i = 1; i < n_out; ++i) { @@ -126,25 +131,26 @@ static bool BatchNormWithReLUType(const nnvm::NodeAttrs& attrs, } #if MXNET_USE_ONEDNN == 1 -static inline bool SupportMKLDNNBNReLU(const NDArray &input, const BatchNormParam ¶m) { - if (mxnet::op::batchnorm::disable_mkl) return false; +static inline bool SupportMKLDNNBNReLU(const NDArray& input, const BatchNormParam& param) { + if (mxnet::op::batchnorm::disable_mkl) + return false; const mxnet::TShape shape = input.shape(); - const int ndim = shape.ndim(); - if (ndim == 0 || shape.Size() == 0) return false; + const int ndim = shape.ndim(); + if (ndim == 0 || shape.Size() == 0) + return false; const int dtype = input.dtype(); - return (dtype == mshadow::kFloat32 || - dtype == mshadow::kBfloat16) && - SupportStorageMKLDNN(input.storage_type()); + return (dtype == mshadow::kFloat32 || dtype == mshadow::kBfloat16) && + SupportStorageMKLDNN(input.storage_type()); } -void BatchNormWithReLUComputeExCPU(const nnvm::NodeAttrs &attrs, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { +void BatchNormWithReLUComputeExCPU(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { CHECK_EQ(inputs.size(), 5U); - const BatchNormParam ¶m = nnvm::get(attrs.parsed); - bool fuse_relu = true; + const BatchNormParam& param = nnvm::get(attrs.parsed); + bool fuse_relu = true; if (SupportMKLDNNBNReLU(inputs[0], param)) { CHECK_GT(outputs.size(), 3U); MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs); @@ -156,45 +162,45 @@ void BatchNormWithReLUComputeExCPU(const nnvm::NodeAttrs &attrs, LOG(FATAL) << "BatchNormWithReLU operator only supports MKL-DNN Backend."; } -void BatchNormWithReLUGradComputeExCPU(const nnvm::NodeAttrs &attrs, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { - const BatchNormParam ¶m = nnvm::get(attrs.parsed); - bool fuse_relu = true; +void BatchNormWithReLUGradComputeExCPU(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + const BatchNormParam& param = nnvm::get(attrs.parsed); + bool fuse_relu = true; if (SupportMKLDNNBNReLU(inputs[0], param)) { - CHECK_EQ(inputs.size(), 9U); - MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs); - MKLDNNBatchNormBackward(attrs, ctx, inputs, req, outputs, fuse_relu); - return; + CHECK_EQ(inputs.size(), 9U); + MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs); + MKLDNNBatchNormBackward(attrs, ctx, inputs, req, outputs, fuse_relu); + return; } LOG(FATAL) << "BatchNormWithReLU operator only supports MKL-DNN Backend."; } #endif -static inline bool BatchNormWithReLUStorageType(const nnvm::NodeAttrs &attrs, +static inline bool BatchNormWithReLUStorageType(const nnvm::NodeAttrs& attrs, const int dev_mask, - DispatchMode *dispatch_mode, - std::vector *in_attrs, - std::vector *out_attrs) { - const BatchNormParam ¶m = nnvm::get(attrs.parsed); + DispatchMode* dispatch_mode, + std::vector* in_attrs, + std::vector* out_attrs) { + const BatchNormParam& param = nnvm::get(attrs.parsed); bool dispatched = false; #if MXNET_USE_ONEDNN == 1 if (!dispatched) { - dispatched = MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode, - in_attrs, out_attrs); + dispatched = MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs, out_attrs); } if (!MKLDNNEnvSet()) { *dispatch_mode = DispatchMode::kFComputeFallback; } #else for (int& v : *in_attrs) - if (v == - 1) v = kDefaultStorage; + if (v == -1) + v = kDefaultStorage; if (!dispatched && common::ContainsOnlyStorage(*in_attrs, kDefaultStorage)) { - dispatched = storage_type_assign(out_attrs, kDefaultStorage, - dispatch_mode, DispatchMode::kFCompute); + dispatched = + storage_type_assign(out_attrs, kDefaultStorage, dispatch_mode, DispatchMode::kFCompute); } if (!dispatched) { dispatched = dispatch_fallback(out_attrs, dispatch_mode); @@ -225,10 +231,10 @@ std::vector BatchNormWithReLUGrad(const nnvm::ObjectPtr& n, heads.emplace_back(out_data.at(batchnormrelu::kWorkspace)); nnvm::ObjectPtr gnode = nnvm::Node::Create(); - gnode->inputs = std::move(heads); + gnode->inputs = std::move(heads); gnode->control_deps.emplace_back(n); - gnode->attrs = n->attrs; - gnode->attrs.op = nnvm::Op::Get("_backward_contrib_BatchNormWithReLU"); + gnode->attrs = n->attrs; + gnode->attrs.op = nnvm::Op::Get("_backward_contrib_BatchNormWithReLU"); gnode->attrs.name = n->attrs.name + "_backward"; // The input of batchnorm std::vector in_grad; @@ -237,8 +243,8 @@ std::vector BatchNormWithReLUGrad(const nnvm::ObjectPtr& n, in_grad.emplace_back(gnode, i, 0); // attach no gradient node to forbid gradient on aux_state nnvm::ObjectPtr ng = nnvm::Node::Create(); - ng->attrs.op = Op::Get("_NoGradient"); - ng->attrs.name = "NoGradient"; + ng->attrs.op = Op::Get("_NoGradient"); + ng->attrs.name = "NoGradient"; // the aux state of batchnorm for (size_t i = 3; i < 5; ++i) in_grad.emplace_back(ng); @@ -246,74 +252,81 @@ std::vector BatchNormWithReLUGrad(const nnvm::ObjectPtr& n, } NNVM_REGISTER_OP(_contrib_BatchNormWithReLU) -.add_alias("_npx_batch_norm_with_relu") -.describe(R"code(Batch normalization with ReLU fusion. + .add_alias("_npx_batch_norm_with_relu") + .describe(R"code(Batch normalization with ReLU fusion. An extented operator of Batch normalization which can fuse ReLU activation. )code" ADD_FILELINE) -.set_num_inputs(5) -.set_num_outputs(4) -.set_attr_parser(ParamParser) -.set_attr("FListInputNames", - [](const NodeAttrs& attrs) { - return std::vector{"data", "gamma", "beta", "moving_mean", "moving_var"}; -}) -.set_attr("FListOutputNames", - [](const NodeAttrs& attrs) { - return std::vector{"output", "mean", "var", "workspace"}; -}) -.set_attr("FNumVisibleOutputs", - [](const NodeAttrs& attrs) { - const BatchNormParam& param = nnvm::get(attrs.parsed); - return param.output_mean_var ? 3 : 1; -}) -.set_attr("FMutateInputs", [](const nnvm::NodeAttrs& attrs) { - return std::vector{3, 4}; -}) -.set_attr("FInferShape", BatchNormWithReLUShape) -.set_attr("FInferType", BatchNormWithReLUType) -.set_attr("FInferStorageType", BatchNormWithReLUStorageType) + .set_num_inputs(5) + .set_num_outputs(4) + .set_attr_parser(ParamParser) + .set_attr( + "FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"data", "gamma", "beta", "moving_mean", "moving_var"}; + }) + .set_attr( + "FListOutputNames", + [](const NodeAttrs& attrs) { + return std::vector{"output", "mean", "var", "workspace"}; + }) + .set_attr("FNumVisibleOutputs", + [](const NodeAttrs& attrs) { + const BatchNormParam& param = + nnvm::get(attrs.parsed); + return param.output_mean_var ? 3 : 1; + }) + .set_attr("FMutateInputs", + [](const nnvm::NodeAttrs& attrs) { + return std::vector{3, 4}; + }) + .set_attr("FInferShape", BatchNormWithReLUShape) + .set_attr("FInferType", BatchNormWithReLUType) + .set_attr("FInferStorageType", BatchNormWithReLUStorageType) #if MXNET_USE_ONEDNN == 1 -.set_attr("FComputeEx", BatchNormWithReLUComputeExCPU) + .set_attr("FComputeEx", BatchNormWithReLUComputeExCPU) #endif -.set_attr("FGradient", BatchNormWithReLUGrad) + .set_attr("FGradient", BatchNormWithReLUGrad) #if MXNET_USE_ONEDNN == 1 -.set_attr("TIsMKLDNN", true) -.set_attr("FResourceRequest", [](const NodeAttrs& n) { - return std::vector{ResourceRequest::kTempSpace}; -}) + .set_attr("TIsMKLDNN", true) + .set_attr("FResourceRequest", + [](const NodeAttrs& n) { + return std::vector{ResourceRequest::kTempSpace}; + }) #endif -.add_argument("data", "NDArray-or-Symbol", "Input data to batch normalization") -.add_argument("gamma", "NDArray-or-Symbol", "gamma array") -.add_argument("beta", "NDArray-or-Symbol", "beta array") -.add_argument("moving_mean", "NDArray-or-Symbol", "running mean of input") -.add_argument("moving_var", "NDArray-or-Symbol", "running variance of input") -.add_arguments(BatchNormParam::__FIELDS__()) -.set_attr( - "FSetInputVarAttrOnCompose", - [](const nnvm::NodeAttrs& attrs, nnvm::ObjectPtr var, const int index) { - if (var->attrs.dict.find("__init__") != var->attrs.dict.end()) return; - if (index == 3) { - var->attrs.dict["__init__"] = "[\"zero\", {}]"; - } else if (index == 4) { - var->attrs.dict["__init__"] = "[\"one\", {}]"; - } - }); + .add_argument("data", "NDArray-or-Symbol", "Input data to batch normalization") + .add_argument("gamma", "NDArray-or-Symbol", "gamma array") + .add_argument("beta", "NDArray-or-Symbol", "beta array") + .add_argument("moving_mean", "NDArray-or-Symbol", "running mean of input") + .add_argument("moving_var", "NDArray-or-Symbol", "running variance of input") + .add_arguments(BatchNormParam::__FIELDS__()) + .set_attr( + "FSetInputVarAttrOnCompose", + [](const nnvm::NodeAttrs& attrs, nnvm::ObjectPtr var, const int index) { + if (var->attrs.dict.find("__init__") != var->attrs.dict.end()) + return; + if (index == 3) { + var->attrs.dict["__init__"] = "[\"zero\", {}]"; + } else if (index == 4) { + var->attrs.dict["__init__"] = "[\"one\", {}]"; + } + }); NNVM_REGISTER_OP(_backward_contrib_BatchNormWithReLU) -.set_num_inputs(9) -.set_num_outputs(3) -.set_attr("TIsBackward", true) -.set_attr("FInferStorageType", BatchNormWithReLUStorageType) + .set_num_inputs(9) + .set_num_outputs(3) + .set_attr("TIsBackward", true) + .set_attr("FInferStorageType", BatchNormWithReLUStorageType) #if MXNET_USE_ONEDNN == 1 -.set_attr("FResourceRequest", [](const NodeAttrs& n) { - return std::vector{ResourceRequest::kTempSpace}; -}) -.set_attr("TIsMKLDNN", true) -.set_attr("FComputeEx", BatchNormWithReLUGradComputeExCPU) + .set_attr("FResourceRequest", + [](const NodeAttrs& n) { + return std::vector{ResourceRequest::kTempSpace}; + }) + .set_attr("TIsMKLDNN", true) + .set_attr("FComputeEx", BatchNormWithReLUGradComputeExCPU) #endif -.set_attr_parser(ParamParser); + .set_attr_parser(ParamParser); } // namespace op } // namespace mxnet diff --git a/src/operator/contrib/bilinear_resize-inl.h b/src/operator/contrib/bilinear_resize-inl.h index 6032c881ed46..580d63868f7f 100644 --- a/src/operator/contrib/bilinear_resize-inl.h +++ b/src/operator/contrib/bilinear_resize-inl.h @@ -21,7 +21,7 @@ * \file bilinear_resize-inl.h * \brief bilinear resize operator * \author Hang Zhang -*/ + */ #ifndef MXNET_OPERATOR_CONTRIB_BILINEAR_RESIZE_INL_H_ #define MXNET_OPERATOR_CONTRIB_BILINEAR_RESIZE_INL_H_ @@ -45,11 +45,17 @@ #include "../mshadow_op.h" namespace bilinear_resize { -enum BilinearResizeOpMode{simple, odd_scale, like, to_even_down, to_even_up, to_odd_down, - to_odd_up}; +enum BilinearResizeOpMode { + simple, + odd_scale, + like, + to_even_down, + to_even_up, + to_odd_down, + to_odd_up +}; } // namespace bilinear_resize - namespace mxnet { namespace op { @@ -61,54 +67,60 @@ struct BilinearSampleParam : public dmlc::Parameter { int mode; bool align_corners; DMLC_DECLARE_PARAMETER(BilinearSampleParam) { - DMLC_DECLARE_FIELD(height).set_default(1).set_lower_bound(1) - .describe("output height (required, but ignored if scale_height is defined or mode is not " - "\"size\")"); - DMLC_DECLARE_FIELD(width).set_default(1).set_lower_bound(1) - .describe("output width (required, but ignored if scale_width is defined or mode is not " - "\"size\")"); - DMLC_DECLARE_FIELD(scale_height).set_default(dmlc::optional()) - .describe("sampling scale of the height (optional, used in modes \"scale\" and \"odd_scale\")"); - DMLC_DECLARE_FIELD(scale_width).set_default(dmlc::optional()) - .describe("sampling scale of the width (optional, used in modes \"scale\" and \"odd_scale\")"); + DMLC_DECLARE_FIELD(height).set_default(1).set_lower_bound(1).describe( + "output height (required, but ignored if scale_height is defined or mode is not " + "\"size\")"); + DMLC_DECLARE_FIELD(width).set_default(1).set_lower_bound(1).describe( + "output width (required, but ignored if scale_width is defined or mode is not " + "\"size\")"); + DMLC_DECLARE_FIELD(scale_height) + .set_default(dmlc::optional()) + .describe( + "sampling scale of the height (optional, used in modes \"scale\" and \"odd_scale\")"); + DMLC_DECLARE_FIELD(scale_width) + .set_default(dmlc::optional()) + .describe( + "sampling scale of the width (optional, used in modes \"scale\" and \"odd_scale\")"); DMLC_DECLARE_FIELD(mode) - .add_enum("size", bilinear_resize::simple) - .add_enum("odd_scale", bilinear_resize::odd_scale) - .add_enum("like", bilinear_resize::like) - .add_enum("to_even_down", bilinear_resize::to_even_down) - .add_enum("to_even_up", bilinear_resize::to_even_up) - .add_enum("to_odd_down", bilinear_resize::to_odd_down) - .add_enum("to_odd_up", bilinear_resize::to_odd_up) - .set_default(bilinear_resize::simple) - .describe("resizing mode. \"simple\" - output height equals parameter \"height\" if " - "\"scale_height\" parameter is not defined or input height multiplied by " - "\"scale_height\" otherwise. Same for width;" - "\"odd_scale\" - if original height or width is odd, then result height is " - "calculated like result_h = (original_h - 1) * scale + 1; " - "for scale > 1 the result shape would be like if we did deconvolution with kernel " - "= (1, 1) and stride = (height_scale, width_scale); and for scale < 1 shape " - "would be like we did convolution with kernel = (1, 1) and " - "stride = (int(1 / height_scale), int( 1/ width_scale);" - "\"like\" - resize first input to the height and width of second input; " - "\"to_even_down\" - resize input to nearest lower even height and width " - "(if original height is odd then result height = original height - 1);" - "\"to_even_up\" - resize input to nearest bigger even height and width " - "(if original height is odd then result height = original height + 1);" - "\"to_odd_down\" - resize input to nearest odd height and width " - "(if original height is odd then result height = original height - 1);" - "\"to_odd_up\" - resize input to nearest odd height and width " - "(if original height is odd then result height = original height + 1);"); - DMLC_DECLARE_FIELD(align_corners).set_default(true) - .describe("With align_corners = True, the interpolating doesn't proportionally align the" - "output and input pixels, and thus the output values can depend on the input size."); + .add_enum("size", bilinear_resize::simple) + .add_enum("odd_scale", bilinear_resize::odd_scale) + .add_enum("like", bilinear_resize::like) + .add_enum("to_even_down", bilinear_resize::to_even_down) + .add_enum("to_even_up", bilinear_resize::to_even_up) + .add_enum("to_odd_down", bilinear_resize::to_odd_down) + .add_enum("to_odd_up", bilinear_resize::to_odd_up) + .set_default(bilinear_resize::simple) + .describe( + "resizing mode. \"simple\" - output height equals parameter \"height\" if " + "\"scale_height\" parameter is not defined or input height multiplied by " + "\"scale_height\" otherwise. Same for width;" + "\"odd_scale\" - if original height or width is odd, then result height is " + "calculated like result_h = (original_h - 1) * scale + 1; " + "for scale > 1 the result shape would be like if we did deconvolution with kernel " + "= (1, 1) and stride = (height_scale, width_scale); and for scale < 1 shape " + "would be like we did convolution with kernel = (1, 1) and " + "stride = (int(1 / height_scale), int( 1/ width_scale);" + "\"like\" - resize first input to the height and width of second input; " + "\"to_even_down\" - resize input to nearest lower even height and width " + "(if original height is odd then result height = original height - 1);" + "\"to_even_up\" - resize input to nearest bigger even height and width " + "(if original height is odd then result height = original height + 1);" + "\"to_odd_down\" - resize input to nearest odd height and width " + "(if original height is odd then result height = original height - 1);" + "\"to_odd_up\" - resize input to nearest odd height and width " + "(if original height is odd then result height = original height + 1);"); + DMLC_DECLARE_FIELD(align_corners) + .set_default(true) + .describe( + "With align_corners = True, the interpolating doesn't proportionally align the" + "output and input pixels, and thus the output values can depend on the input size."); } }; template -static inline DType area_pixel_compute_scale( - int64_t input_size, - int64_t output_size, - bool align_corners) { +static inline DType area_pixel_compute_scale(int64_t input_size, + int64_t output_size, + bool align_corners) { /* We view each pixel as an area, idx + 0.5 as its center index. * Here is an example formula in 1D case. * if align_corners: center of two corner pixel areas are preserved, @@ -121,20 +133,18 @@ static inline DType area_pixel_compute_scale( * src_idx + 0.5 = scale * (dst_index + 0.5) */ if (output_size > 1) { - return align_corners - ? static_cast(input_size - 1) / (output_size - 1) - : static_cast(input_size) / output_size; + return align_corners ? static_cast(input_size - 1) / (output_size - 1) + : static_cast(input_size) / output_size; } else { return DType(0); } } template -static inline DType area_pixel_compute_source_index( - DType scale, - int64_t dst_index, - bool align_corners, - bool cubic) { +static inline DType area_pixel_compute_source_index(DType scale, + int64_t dst_index, + bool align_corners, + bool cubic) { if (align_corners) { return scale * dst_index; } else { @@ -159,99 +169,95 @@ static inline bool IsWriting(const OpReqType ort) { return ort == kWriteTo || ort == kWriteInplace; } -template -void SpatialUpSamplingBilinearUpdateOutput(mshadow::Stream *s, - const std::vector &input, - const std::vector &output, +template +void SpatialUpSamplingBilinearUpdateOutput(mshadow::Stream* s, + const std::vector& input, + const std::vector& output, bool align_corners); -template -void SpatialUpSamplingBilinearUpdateGradInput(mshadow::Stream *s, - const std::vector &input, - const std::vector &output, +template +void SpatialUpSamplingBilinearUpdateGradInput(mshadow::Stream* s, + const std::vector& input, + const std::vector& output, bool modeLike, bool align_corners); #if MXNET_USE_CUDA -template -void SpatialUpSamplingBilinearUpdateOutput(mshadow::Stream *s, - const std::vector &input, - const std::vector &output, +template +void SpatialUpSamplingBilinearUpdateOutput(mshadow::Stream* s, + const std::vector& input, + const std::vector& output, bool align_corners); -template -void SpatialUpSamplingBilinearUpdateGradInput(mshadow::Stream *s, - const std::vector &input, - const std::vector &output, +template +void SpatialUpSamplingBilinearUpdateGradInput(mshadow::Stream* s, + const std::vector& input, + const std::vector& output, bool modeLike, bool align_corners); #endif // MXNET_USE_CUDA template inline void BilinearSampleOpForward(const nnvm::NodeAttrs& attrs, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { const BilinearSampleParam& param = nnvm::get(attrs.parsed); - size_t expected = param.mode == bilinear_resize::like ? 2 : 1; + size_t expected = param.mode == bilinear_resize::like ? 2 : 1; CHECK_EQ(inputs.size(), expected); CHECK_EQ(outputs.size(), 1U); CHECK_EQ(inputs[0].CheckContiguous(), true); if (expected == 2) { - CHECK_EQ(inputs[1].CheckContiguous(), true); + CHECK_EQ(inputs[1].CheckContiguous(), true); } CHECK_EQ(outputs[0].CheckContiguous(), true); - bool align_corners = param.align_corners; - mshadow::Stream *s = ctx.get_stream(); + bool align_corners = param.align_corners; + mshadow::Stream* s = ctx.get_stream(); MSHADOW_REAL_TYPE_SWITCH_EX(inputs[0].type_flag_, DType, AccReal, { SpatialUpSamplingBilinearUpdateOutput(s, inputs, outputs, align_corners); }); } - template inline void BilinearSampleOpBackward(const nnvm::NodeAttrs& attrs, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { const BilinearSampleParam& param = nnvm::get(attrs.parsed); CHECK_EQ(inputs.size(), 1U); - bool modeLike = param.mode == bilinear_resize::like; + bool modeLike = param.mode == bilinear_resize::like; bool align_corners = param.align_corners; - size_t expected = modeLike ? 2 : 1; + size_t expected = modeLike ? 2 : 1; CHECK_EQ(outputs.size(), expected); - mshadow::Stream *s = ctx.get_stream(); + mshadow::Stream* s = ctx.get_stream(); if (IsWriting(req[0])) { // zero grad before backwarding - MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, { - Fill(s, outputs[0], kWriteTo, 0); - }) + MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, { Fill(s, outputs[0], kWriteTo, 0); }) } MSHADOW_REAL_TYPE_SWITCH_EX(inputs[0].type_flag_, DType, AccReal, { - SpatialUpSamplingBilinearUpdateGradInput(s, inputs, outputs - , modeLike, align_corners); + SpatialUpSamplingBilinearUpdateGradInput( + s, inputs, outputs, modeLike, align_corners); }); } - static bool BilinearSampleOpInferShape(const nnvm::NodeAttrs& attrs, - mxnet::ShapeVector *in_shape, - mxnet::ShapeVector *out_shape) { + mxnet::ShapeVector* in_shape, + mxnet::ShapeVector* out_shape) { using namespace mshadow; CHECK_EQ(out_shape->size(), 1U) << "Output:[data]"; const BilinearSampleParam& param = nnvm::get(attrs.parsed); - size_t expected = param.mode == bilinear_resize::like ? 2 : 1; + size_t expected = param.mode == bilinear_resize::like ? 2 : 1; CHECK_EQ(in_shape->size(), expected); mxnet::TShape dshape(in_shape->at(0)); - if (mxnet::op::shape_is_none(dshape)) return false; + if (mxnet::op::shape_is_none(dshape)) + return false; int16_t new_height = -1; - int16_t new_width = -1; + int16_t new_width = -1; switch (param.mode) { - case bilinear_resize::simple: - { + case bilinear_resize::simple: { if (param.scale_height.has_value()) { new_height = static_cast(param.scale_height.value() * in_shape->at(0)[2]); } else { @@ -264,48 +270,44 @@ static bool BilinearSampleOpInferShape(const nnvm::NodeAttrs& attrs, } break; } - case bilinear_resize::odd_scale: - { - new_height = ((dshape[2] % 2) == 0) ? (int16_t) (dshape[2] * param.scale_height.value()) : - (int16_t) ((dshape[2] - 1) * param.scale_height.value()) + 1; - new_width = ((dshape[3] % 2) == 0) ? (int16_t) (dshape[3] * param.scale_width.value()) : - (int16_t) ((dshape[3] - 1) * param.scale_width.value()) + 1; + case bilinear_resize::odd_scale: { + new_height = ((dshape[2] % 2) == 0) + ? (int16_t)(dshape[2] * param.scale_height.value()) + : (int16_t)((dshape[2] - 1) * param.scale_height.value()) + 1; + new_width = ((dshape[3] % 2) == 0) + ? (int16_t)(dshape[3] * param.scale_width.value()) + : (int16_t)((dshape[3] - 1) * param.scale_width.value()) + 1; break; } - case bilinear_resize::like: - { + case bilinear_resize::like: { TShape like_shape(in_shape->at(1)); - if (dshape.ndim() == 0) return false; + if (dshape.ndim() == 0) + return false; new_height = like_shape[2]; - new_width = like_shape[3]; + new_width = like_shape[3]; break; } - case bilinear_resize::to_even_down: - { + case bilinear_resize::to_even_down: { new_height = ((dshape[2] % 2) == 0) ? dshape[2] : dshape[2] - 1; - new_width = ((dshape[3] % 2) == 0) ? dshape[3] : dshape[3] - 1; + new_width = ((dshape[3] % 2) == 0) ? dshape[3] : dshape[3] - 1; break; } - case bilinear_resize::to_even_up: - { + case bilinear_resize::to_even_up: { new_height = ((dshape[2] % 2) == 0) ? dshape[2] : dshape[2] + 1; - new_width = ((dshape[3] % 2) == 0) ? dshape[3] : dshape[3] + 1; + new_width = ((dshape[3] % 2) == 0) ? dshape[3] : dshape[3] + 1; break; } - case bilinear_resize::to_odd_down: - { + case bilinear_resize::to_odd_down: { new_height = ((dshape[2] % 2) == 1) ? dshape[2] : dshape[2] - 1; - new_width = ((dshape[3] % 2) == 1) ? dshape[3] : dshape[3] - 1; + new_width = ((dshape[3] % 2) == 1) ? dshape[3] : dshape[3] - 1; break; } - case bilinear_resize::to_odd_up: - { + case bilinear_resize::to_odd_up: { new_height = ((dshape[2] % 2) == 1) ? dshape[2] : dshape[2] + 1; - new_width = ((dshape[3] % 2) == 1) ? dshape[3] : dshape[3] + 1; + new_width = ((dshape[3] % 2) == 1) ? dshape[3] : dshape[3] + 1; break; } - default: - { + default: { LOG(FATAL) << "Invalid mode " << param.mode; } } @@ -318,7 +320,6 @@ static bool BilinearSampleOpInferShape(const nnvm::NodeAttrs& attrs, return true; } - inline uint16_t BilinearSampleOpNumInputs(const NodeAttrs& attrs) { auto& param = nnvm::get(attrs.parsed); if (param.mode == bilinear_resize::like) { diff --git a/src/operator/contrib/bilinear_resize.cc b/src/operator/contrib/bilinear_resize.cc index 399a5a79bd56..61b3446bd2ab 100644 --- a/src/operator/contrib/bilinear_resize.cc +++ b/src/operator/contrib/bilinear_resize.cc @@ -21,7 +21,7 @@ * \file bilinear_resize.cc * \brief bilinear resize operator * \author Hang Zhang -*/ + */ #include "bilinear_resize-inl.h" #include "../elemwise_op_common.h" @@ -30,38 +30,38 @@ namespace op { using namespace mshadow; -template -void SpatialUpSamplingBilinearUpdateOutput(mshadow::Stream *s, - const std::vector &input, - const std::vector &output, +template +void SpatialUpSamplingBilinearUpdateOutput(mshadow::Stream* s, + const std::vector& input, + const std::vector& output, bool align_corners) { Tensor itensor = input[0].get(s); Tensor otensor = output[0].get(s); - int nbatch = otensor.size(0); - int channels = otensor.size(1); - int outputHeight = otensor.size(2); - int outputWidth = otensor.size(3); - int inputHeight = itensor.size(2); - int inputWidth = itensor.size(3); + int nbatch = otensor.size(0); + int channels = otensor.size(1); + int outputHeight = otensor.size(2); + int outputWidth = otensor.size(3); + int inputHeight = itensor.size(2); + int inputWidth = itensor.size(3); const auto nthreads = engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); - DType *idata = itensor.dptr_; - DType *odata = otensor.dptr_; - channels = nbatch * channels; - const int input_elems_per_channel = inputWidth * inputHeight; + DType* idata = itensor.dptr_; + DType* odata = otensor.dptr_; + channels = nbatch * channels; + const int input_elems_per_channel = inputWidth * inputHeight; const int output_elems_per_channel = outputWidth * outputHeight; // special case: just copy if (inputHeight == outputHeight && inputWidth == outputWidth) { #pragma omp parallel for num_threads(nthreads) for (int index = 0; index < output_elems_per_channel; index++) { - const int h2 = index / outputWidth; - const int h1 = h2; - const int w2 = index % outputWidth; - const int w1 = w2; + const int h2 = index / outputWidth; + const int h1 = h2; + const int w2 = index % outputWidth; + const int w1 = w2; const DType* pos1 = &idata[h1 * inputWidth + w1]; - DType* pos2 = &odata[index]; + DType* pos2 = &odata[index]; for (int c = 0; c < channels; ++c) { *pos2 = *pos1; pos1 += input_elems_per_channel; @@ -70,75 +70,71 @@ void SpatialUpSamplingBilinearUpdateOutput(mshadow::Stream *s, } return; } - const float rheight = area_pixel_compute_scale( - inputHeight, outputHeight, align_corners); - const float rwidth = area_pixel_compute_scale( - inputWidth, outputWidth, align_corners); + const float rheight = area_pixel_compute_scale(inputHeight, outputHeight, align_corners); + const float rwidth = area_pixel_compute_scale(inputWidth, outputWidth, align_corners); #pragma omp parallel for num_threads(nthreads) for (int index = 0; index < output_elems_per_channel; index++) { const int h2 = index / outputWidth; const int w2 = index % outputWidth; - const float h1r = area_pixel_compute_source_index( - rheight, h2, align_corners, false); - const int h1 = h1r; - const int h1p = (h1 < inputHeight - 1) ? 1 : 0; + const float h1r = area_pixel_compute_source_index(rheight, h2, align_corners, false); + const int h1 = h1r; + const int h1p = (h1 < inputHeight - 1) ? 1 : 0; const DType h1lambda = h1r - h1; const DType h0lambda = (DType)1. - h1lambda; - const float w1r = area_pixel_compute_source_index( - rwidth, w2, align_corners, false); - const int w1 = w1r; - const int w1p = (w1 < inputWidth - 1) ? 1 : 0; + const float w1r = area_pixel_compute_source_index(rwidth, w2, align_corners, false); + const int w1 = w1r; + const int w1p = (w1 < inputWidth - 1) ? 1 : 0; const DType w1lambda = w1r - w1; const DType w0lambda = (DType)1. - w1lambda; - const DType* pos1 = &idata[h1 * inputWidth + w1]; - DType* pos2 = &odata[index]; + const DType* pos1 = &idata[h1 * inputWidth + w1]; + DType* pos2 = &odata[index]; for (int c = 0; c < channels; ++c) { - *pos2 = h0lambda * (w0lambda * (*pos1) + w1lambda * *(pos1 + w1p)) - + h1lambda * (w0lambda * *(pos1 + h1p * inputWidth) - + w1lambda * *(pos1 + h1p * inputWidth + w1p)); + *pos2 = h0lambda * (w0lambda * (*pos1) + w1lambda * *(pos1 + w1p)) + + h1lambda * (w0lambda * *(pos1 + h1p * inputWidth) + + w1lambda * *(pos1 + h1p * inputWidth + w1p)); pos1 += input_elems_per_channel; pos2 += output_elems_per_channel; } } } -template -void SpatialUpSamplingBilinearUpdateGradInput(mshadow::Stream *s, - const std::vector &input, - const std::vector &output, +template +void SpatialUpSamplingBilinearUpdateGradInput(mshadow::Stream* s, + const std::vector& input, + const std::vector& output, bool modeLike, bool align_corners) { Tensor gradOutput = input[0].get(s); - Tensor gradInput = output[0].get(s); + Tensor gradInput = output[0].get(s); - int nbatch = gradInput.size(0); - int channels = gradInput.size(1); + int nbatch = gradInput.size(0); + int channels = gradInput.size(1); int outputHeight = gradOutput.size(2); - int outputWidth = gradOutput.size(3); - int inputHeight = gradInput.size(2); - int inputWidth = gradInput.size(3); + int outputWidth = gradOutput.size(3); + int inputHeight = gradInput.size(2); + int inputWidth = gradInput.size(3); const auto nthreads = engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); - DType *dataInput = gradInput.dptr_; - DType *dataOutput = gradOutput.dptr_; - channels = nbatch * channels; - const int input_elems_per_channel = inputWidth * inputHeight; + DType* dataInput = gradInput.dptr_; + DType* dataOutput = gradOutput.dptr_; + channels = nbatch * channels; + const int input_elems_per_channel = inputWidth * inputHeight; const int output_elems_per_channel = outputWidth * outputHeight; // special case: same-size matching grids if (inputHeight == outputHeight && inputWidth == outputWidth) { #pragma omp parallel for num_threads(nthreads) for (int index = 0; index < output_elems_per_channel; index++) { - const int h2 = index / outputWidth; - const int h1 = h2; - const int w2 = index % outputWidth; - const int w1 = w2; - DType* pos1 = &dataInput[h1 * inputWidth + w1]; + const int h2 = index / outputWidth; + const int h1 = h2; + const int w2 = index % outputWidth; + const int w1 = w2; + DType* pos1 = &dataInput[h1 * inputWidth + w1]; const DType* pos2 = &dataOutput[index]; for (int c = 0; c < channels; ++c) { *pos1 += *pos2; @@ -148,33 +144,29 @@ void SpatialUpSamplingBilinearUpdateGradInput(mshadow::Stream *s, } return; } - const float rheight = area_pixel_compute_scale( - inputHeight, outputHeight, align_corners); - const float rwidth = area_pixel_compute_scale( - inputWidth, outputWidth, align_corners); + const float rheight = area_pixel_compute_scale(inputHeight, outputHeight, align_corners); + const float rwidth = area_pixel_compute_scale(inputWidth, outputWidth, align_corners); #pragma omp parallel for num_threads(nthreads) for (int index = 0; index < output_elems_per_channel; index++) { const int h2 = index / outputWidth; const int w2 = index % outputWidth; - const float h1r = area_pixel_compute_source_index( - rheight, h2, align_corners, false); - const int h1 = h1r; - const int h1p = (h1 < inputHeight - 1) ? 1 : 0; + const float h1r = area_pixel_compute_source_index(rheight, h2, align_corners, false); + const int h1 = h1r; + const int h1p = (h1 < inputHeight - 1) ? 1 : 0; const DType h1lambda = h1r - h1; const DType h0lambda = (DType)1. - h1lambda; - const float w1r = area_pixel_compute_source_index( - rwidth, w2, align_corners, false); - const int w1 = w1r; - const int w1p = (w1 < inputWidth - 1) ? 1 : 0; + const float w1r = area_pixel_compute_source_index(rwidth, w2, align_corners, false); + const int w1 = w1r; + const int w1p = (w1 < inputWidth - 1) ? 1 : 0; const DType w1lambda = w1r - w1; const DType w0lambda = (DType)1. - w1lambda; - DType* posInput = &dataInput[h1 * inputWidth + w1]; + DType* posInput = &dataInput[h1 * inputWidth + w1]; const DType* posOutput = &dataOutput[index]; for (int c = 0; c < channels; ++c) { - #pragma omp critical +#pragma omp critical { *posInput += h0lambda * w0lambda * (*posOutput); *(posInput + w1p) += h0lambda * w1lambda * (*posOutput); @@ -188,15 +180,15 @@ void SpatialUpSamplingBilinearUpdateGradInput(mshadow::Stream *s, if (modeLike) { Tensor gradInputLike = output[1].get(s); - int inputHeightLike = gradInputLike.size(2); - int inputWidthLike = gradInputLike.size(3); - DType *dataInputLike = gradInputLike.dptr_; - int channelsLike = nbatch * gradInputLike.size(1); + int inputHeightLike = gradInputLike.size(2); + int inputWidthLike = gradInputLike.size(3); + DType* dataInputLike = gradInputLike.dptr_; + int channelsLike = nbatch * gradInputLike.size(1); const int inputLike_elems_per_channel = inputHeightLike * inputWidthLike; #pragma omp parallel for num_threads(nthreads) for (int index = 0; index < inputLike_elems_per_channel; index++) { - DType *posInput = &dataInputLike[index]; + DType* posInput = &dataInputLike[index]; for (int c = 0; c < channelsLike; ++c) { *posInput = 0; posInput += inputLike_elems_per_channel; @@ -208,7 +200,7 @@ void SpatialUpSamplingBilinearUpdateGradInput(mshadow::Stream *s, DMLC_REGISTER_PARAMETER(BilinearSampleParam); NNVM_REGISTER_OP(_contrib_BilinearResize2D) -.describe(R"code( + .describe(R"code( Perform 2D resizing (upsampling or downsampling) for 4D input using bilinear interpolation. Expected input is a 4 dimensional NDArray (NCHW) and the output @@ -218,25 +210,24 @@ first in one direction, and then again in the other direction. See the wikipedia `Bilinear interpolation `_ for more details. )code" ADD_FILELINE) -.set_attr_parser(ParamParser) -.set_num_inputs(BilinearSampleOpNumInputs) -.set_num_outputs(1) -.set_attr("FListInputNames", BilinearSampleOpInputNames) -.set_attr("FInferShape", BilinearSampleOpInferShape) -.set_attr("FCompute", BilinearSampleOpForward) -.set_attr("FGradient", - ElemwiseGradUseNone{"_backward_contrib_BilinearResize2D"}) -.add_argument("data", "NDArray-or-Symbol", "Input data") -.add_argument("like", "NDArray-or-Symbol", "Resize data to it's shape") -.add_arguments(BilinearSampleParam::__FIELDS__()); + .set_attr_parser(ParamParser) + .set_num_inputs(BilinearSampleOpNumInputs) + .set_num_outputs(1) + .set_attr("FListInputNames", BilinearSampleOpInputNames) + .set_attr("FInferShape", BilinearSampleOpInferShape) + .set_attr("FCompute", BilinearSampleOpForward) + .set_attr("FGradient", + ElemwiseGradUseNone{"_backward_contrib_BilinearResize2D"}) + .add_argument("data", "NDArray-or-Symbol", "Input data") + .add_argument("like", "NDArray-or-Symbol", "Resize data to it's shape") + .add_arguments(BilinearSampleParam::__FIELDS__()); NNVM_REGISTER_OP(_backward_contrib_BilinearResize2D) -.set_attr_parser(ParamParser) -.set_num_inputs(1) -.set_num_outputs(BilinearSampleOpNumBackwardOutputs) -.set_attr("TIsBackward", true) -.set_attr("FCompute", BilinearSampleOpBackward); - + .set_attr_parser(ParamParser) + .set_num_inputs(1) + .set_num_outputs(BilinearSampleOpNumBackwardOutputs) + .set_attr("TIsBackward", true) + .set_attr("FCompute", BilinearSampleOpBackward); } // namespace op } // namespace mxnet diff --git a/src/operator/contrib/bilinear_resize.cu b/src/operator/contrib/bilinear_resize.cu index 1a6e6964ddfb..cbae235e7c01 100644 --- a/src/operator/contrib/bilinear_resize.cu +++ b/src/operator/contrib/bilinear_resize.cu @@ -21,7 +21,7 @@ * \file bilinear_resize.cu * \brief bilinear resize operator * \author Hang Zhang -*/ + */ #include #include #include "bilinear_resize-inl.h" @@ -35,23 +35,18 @@ using namespace mshadow; // fastSpecializedAtomicAdd adapted from Torch // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/KernelUtils.cuh template < - typename Dtype, - typename std::enable_if::value>::type* = - nullptr> - __device__ MSHADOW_FORCE_INLINE void fastSpecializedAtomicAdd( - Dtype* tensor, - size_t index, - const size_t numel, - Dtype value) { -#if ( \ - (CUDA_VERSION < 10000) || \ - (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700))) - atomicAdd( - reinterpret_cast(tensor) + index, - static_cast(value)); + typename Dtype, + typename std::enable_if::value>::type* = nullptr> +__device__ MSHADOW_FORCE_INLINE void fastSpecializedAtomicAdd(Dtype* tensor, + size_t index, + const size_t numel, + Dtype value) { +#if ((CUDA_VERSION < 10000) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700))) + atomicAdd(reinterpret_cast(tensor) + index, + static_cast(value)); #else - bool low_bit = (index % 2 == 0) && - (reinterpret_cast(tensor) % sizeof(__half2) == 0); + bool low_bit = + (index % 2 == 0) && (reinterpret_cast(tensor) % sizeof(__half2) == 0); if (low_bit && index < (numel - 1)) { __half2 value2; @@ -65,31 +60,27 @@ template < atomicAdd(reinterpret_cast<__half2*>(tensor) + index / 2, value2); } else { - atomicAdd( - reinterpret_cast<__half*>(tensor) + index, static_cast<__half>(value)); + atomicAdd(reinterpret_cast<__half*>(tensor) + index, static_cast<__half>(value)); } #endif } template < - typename Dtype, - typename std::enable_if::value>::type* = - nullptr> - __device__ MSHADOW_FORCE_INLINE void fastSpecializedAtomicAdd( - Dtype* tensor, - size_t index, - const size_t numel, - Dtype value) { + typename Dtype, + typename std::enable_if::value>::type* = nullptr> +__device__ MSHADOW_FORCE_INLINE void fastSpecializedAtomicAdd(Dtype* tensor, + size_t index, + const size_t numel, + Dtype value) { atomicAdd(tensor + index, value); } template -__device__ MSHADOW_FORCE_INLINE void fastAtomicAdd( - Dtype* tensor, - size_t index, - const size_t numel, - Dtype value, - bool fast_atomics) { +__device__ MSHADOW_FORCE_INLINE void fastAtomicAdd(Dtype* tensor, + size_t index, + const size_t numel, + Dtype value, + bool fast_atomics) { if (fast_atomics) { fastSpecializedAtomicAdd(tensor, index, numel, value); } else { @@ -97,19 +88,17 @@ __device__ MSHADOW_FORCE_INLINE void fastAtomicAdd( } } - -template -__global__ void like_mode_kernel_backward(const int n, - Tensor dataLike) { - int index = threadIdx.x + blockIdx.x * blockDim.x; +template +__global__ void like_mode_kernel_backward(const int n, Tensor dataLike) { + int index = threadIdx.x + blockIdx.x * blockDim.x; const int batchsize = dataLike.size(0); - const int channels = dataLike.size(1); - const int height = dataLike.size(2); - const int width = dataLike.size(3); + const int channels = dataLike.size(1); + const int height = dataLike.size(2); + const int width = dataLike.size(3); if (index < n) { const int w = index % width; const int h = index / width; - for (int n = 0; n < batchsize ; n++) { + for (int n = 0; n < batchsize; n++) { for (int c = 0; c < channels; ++c) { dataLike[n][c][h][w] = 0; } @@ -121,166 +110,157 @@ __global__ void like_mode_kernel_backward(const int n, // caffe_gpu_interp2_kernel_backward adapted from Torch // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/UpSampleBilinear2d.cu // Backward (adjoint) operation 1 <- 2 (accumulates) -template -__global__ void caffe_gpu_interp2_kernel_backward( - const size_t nc, - const int height1, - const int width1, - const int height2, - const int width2, - const Acctype rheight, - const Acctype rwidth, - const bool align_corners, - Dtype* __restrict__ idata, - const Dtype* __restrict__ odata) { +template +__global__ void caffe_gpu_interp2_kernel_backward(const size_t nc, + const int height1, + const int width1, + const int height2, + const int width2, + const Acctype rheight, + const Acctype rwidth, + const bool align_corners, + Dtype* __restrict__ idata, + const Dtype* __restrict__ odata) { const size_t o_numel = nc * width2 * height2; const size_t i_numel = nc * width1 * height1; for (size_t index = blockDim.x * blockIdx.x + threadIdx.x; index < o_numel; - index += blockDim.x * gridDim.x) { + index += blockDim.x * gridDim.x) { size_t index_temp = index; - const int w2 = index_temp % width2; // 0:width2-1 + const int w2 = index_temp % width2; // 0:width2-1 index_temp /= width2; - const int h2 = index_temp % height2; // 0:height2-1 + const int h2 = index_temp % height2; // 0:height2-1 const size_t nc = index_temp / height2; // - const Acctype h1r = cu_area_pixel_compute_source_index( - rheight, h2, align_corners, false); - const int h1 = h1r; - const int h1p = (h1 < height1 - 1) ? 1 : 0; + const Acctype h1r = + cu_area_pixel_compute_source_index(rheight, h2, align_corners, false); + const int h1 = h1r; + const int h1p = (h1 < height1 - 1) ? 1 : 0; const Acctype h1lambda = h1r - h1; const Acctype h0lambda = static_cast(1) - h1lambda; // - const Acctype w1r = cu_area_pixel_compute_source_index( - rwidth, w2, align_corners, false); - const int w1 = w1r; - const int w1p = (w1 < width1 - 1) ? 1 : 0; + const Acctype w1r = + cu_area_pixel_compute_source_index(rwidth, w2, align_corners, false); + const int w1 = w1r; + const int w1p = (w1 < width1 - 1) ? 1 : 0; const Acctype w1lambda = w1r - w1; const Acctype w0lambda = static_cast(1) - w1lambda; const Dtype d2val = odata[index]; - fastAtomicAdd( - idata, - idx(nc, height1, width1, h1, w1), - i_numel, - ScalarConvert::to(h0lambda * w0lambda * d2val), - true); - fastAtomicAdd( - idata, - idx(nc, height1, width1, h1, w1 + w1p), - i_numel, - ScalarConvert::to(h0lambda * w1lambda * d2val), - true); - fastAtomicAdd( - idata, - idx(nc, height1, width1, h1 + h1p, w1), - i_numel, - ScalarConvert::to(h1lambda * w0lambda * d2val), - true); - fastAtomicAdd( - idata, - idx(nc, height1, width1, h1 + h1p, w1 + w1p), - i_numel, - ScalarConvert::to(h1lambda * w1lambda * d2val), - true); + fastAtomicAdd(idata, + idx(nc, height1, width1, h1, w1), + i_numel, + ScalarConvert::to(h0lambda * w0lambda * d2val), + true); + fastAtomicAdd(idata, + idx(nc, height1, width1, h1, w1 + w1p), + i_numel, + ScalarConvert::to(h0lambda * w1lambda * d2val), + true); + fastAtomicAdd(idata, + idx(nc, height1, width1, h1 + h1p, w1), + i_numel, + ScalarConvert::to(h1lambda * w0lambda * d2val), + true); + fastAtomicAdd(idata, + idx(nc, height1, width1, h1 + h1p, w1 + w1p), + i_numel, + ScalarConvert::to(h1lambda * w1lambda * d2val), + true); } } -template -void SpatialUpSamplingBilinearUpdateOutput(mshadow::Stream *s, - const std::vector &input, - const std::vector &output, - bool align_corners) { +template +void SpatialUpSamplingBilinearUpdateOutput(mshadow::Stream* s, + const std::vector& input, + const std::vector& output, + bool align_corners) { Tensor idata = input[0].get(s); Tensor odata = output[0].get(s); - int outputHeight = odata.size(2); - int outputWidth = odata.size(3); - int nbatch = idata.size(0); - int channels = idata.size(1); - int inputHeight = idata.size(2); - int inputWidth = idata.size(3); + int outputHeight = odata.size(2); + int outputWidth = odata.size(3); + int nbatch = idata.size(0); + int channels = idata.size(1); + int inputHeight = idata.size(2); + int inputWidth = idata.size(3); - const AccReal rheight = cu_area_pixel_compute_scale( - inputHeight, outputHeight, align_corners); - const AccReal rwidth = cu_area_pixel_compute_scale( - inputWidth, outputWidth, align_corners); + const AccReal rheight = + cu_area_pixel_compute_scale(inputHeight, outputHeight, align_corners); + const AccReal rwidth = + cu_area_pixel_compute_scale(inputWidth, outputWidth, align_corners); const int num_kernels = nbatch * channels * outputHeight * outputWidth; - const int num_threads = getNumThreads(inputHeight*inputWidth, false); + const int num_threads = getNumThreads(inputHeight * inputWidth, false); dim3 blocks(static_cast(num_kernels / num_threads) + 1); dim3 threads(num_threads); cudaStream_t stream = mshadow::Stream::GetStream(s); - caffe_gpu_interp2_kernel - <<>>( - nbatch * channels, - inputHeight, - inputWidth, - outputHeight, - outputWidth, - rheight, - rwidth, - align_corners, - idata.dptr_, - odata.dptr_); + caffe_gpu_interp2_kernel<<>>(nbatch * channels, + inputHeight, + inputWidth, + outputHeight, + outputWidth, + rheight, + rwidth, + align_corners, + idata.dptr_, + odata.dptr_); MSHADOW_CUDA_POST_KERNEL_CHECK(SpatialUpSamplingBilinearUpdateOutput); } -template -void SpatialUpSamplingBilinearUpdateGradInput(mshadow::Stream *s, - const std::vector &input, - const std::vector &output, +template +void SpatialUpSamplingBilinearUpdateGradInput(mshadow::Stream* s, + const std::vector& input, + const std::vector& output, bool modeLike, - bool align_corners) { + bool align_corners) { Tensor gradOutput = input[0].get(s); - Tensor gradInput = output[0].get(s); - int outputHeight = gradOutput.size(2); - int outputWidth = gradOutput.size(3); - int nbatch = gradInput.size(0); - int channels = gradInput.size(1); - int inputHeight = gradInput.size(2); - int inputWidth = gradInput.size(3); + Tensor gradInput = output[0].get(s); + int outputHeight = gradOutput.size(2); + int outputWidth = gradOutput.size(3); + int nbatch = gradInput.size(0); + int channels = gradInput.size(1); + int inputHeight = gradInput.size(2); + int inputWidth = gradInput.size(3); - const AccReal rheight = cu_area_pixel_compute_scale( - inputHeight, outputHeight, align_corners); - const AccReal rwidth = cu_area_pixel_compute_scale( - inputWidth, outputWidth, align_corners); + const AccReal rheight = + cu_area_pixel_compute_scale(inputHeight, outputHeight, align_corners); + const AccReal rwidth = + cu_area_pixel_compute_scale(inputWidth, outputWidth, align_corners); const int num_kernels = nbatch * channels * outputHeight * outputWidth; - const int num_threads = getNumThreads(inputHeight*inputWidth, false); + const int num_threads = getNumThreads(inputHeight * inputWidth, false); dim3 blocks(static_cast(num_kernels / num_threads) + 1); dim3 threads(num_threads); cudaStream_t stream = mshadow::Stream::GetStream(s); caffe_gpu_interp2_kernel_backward - <<>>( - nbatch * channels, - inputHeight, - inputWidth, - outputHeight, - outputWidth, - rheight, - rwidth, - align_corners, - gradInput.dptr_, - gradOutput.dptr_); + <<>>(nbatch * channels, + inputHeight, + inputWidth, + outputHeight, + outputWidth, + rheight, + rwidth, + align_corners, + gradInput.dptr_, + gradOutput.dptr_); if (modeLike) { Tensor dataLike = output[1].get(s); - int heightLike = dataLike.size(2); - int widthLike = dataLike.size(3); - const int num_kernels_like = heightLike * widthLike; - const int num_threads_like = getNumThreads(num_kernels_like, false); + int heightLike = dataLike.size(2); + int widthLike = dataLike.size(3); + const int num_kernels_like = heightLike * widthLike; + const int num_threads_like = getNumThreads(num_kernels_like, false); dim3 blocksLike(static_cast(num_kernels_like / num_threads_like) + 1); dim3 threadsLike(num_threads_like); like_mode_kernel_backward - <<>>( - num_kernels_like, dataLike); + <<>>(num_kernels_like, dataLike); } MSHADOW_CUDA_POST_KERNEL_CHECK(SpatialUpSamplingBilinearUpdateGradInput); } NNVM_REGISTER_OP(_contrib_BilinearResize2D) -.set_attr("FCompute", BilinearSampleOpForward); + .set_attr("FCompute", BilinearSampleOpForward); NNVM_REGISTER_OP(_backward_contrib_BilinearResize2D) -.set_attr("FCompute", BilinearSampleOpBackward); + .set_attr("FCompute", BilinearSampleOpBackward); } // namespace op } // namespace mxnet diff --git a/src/operator/contrib/boolean_mask-inl.h b/src/operator/contrib/boolean_mask-inl.h index 12a69444c681..2a29204804f6 100644 --- a/src/operator/contrib/boolean_mask-inl.h +++ b/src/operator/contrib/boolean_mask-inl.h @@ -19,7 +19,7 @@ /*! * Copyright (c) 2018 by Contributors * \file boolean_mask-inl.h -*/ + */ #ifndef MXNET_OPERATOR_CONTRIB_BOOLEAN_MASK_INL_H_ #define MXNET_OPERATOR_CONTRIB_BOOLEAN_MASK_INL_H_ @@ -45,18 +45,14 @@ namespace op { struct BooleanMaskParam : public dmlc::Parameter { int axis; DMLC_DECLARE_PARAMETER(BooleanMaskParam) { - DMLC_DECLARE_FIELD(axis).set_default(0) - .describe("An integer that represents the axis in NDArray to mask from."); + DMLC_DECLARE_FIELD(axis).set_default(0).describe( + "An integer that represents the axis in NDArray to mask from."); } }; struct BooleanMaskForwardCPUKernel { - template - static void Map(int i, - DType* out, - const DType* data, - const int32_t* idx, - const size_t col_size) { + template + static void Map(int i, DType* out, const DType* data, const int32_t* idx, const size_t col_size) { // i is row id already int32_t prev = (i == 0) ? 0 : idx[i - 1]; int32_t curr = idx[i]; @@ -72,14 +68,11 @@ struct BooleanMaskForwardCPUKernel { }; struct BooleanMaskForwardKernel { - template - static void MSHADOW_XINLINE Map(int i, - DType* out, - const DType* data, - const int32_t* idx, - const size_t col_size) { - int row_id = i / col_size; - int col_id = i % col_size; + template + static void MSHADOW_XINLINE + Map(int i, DType* out, const DType* data, const int32_t* idx, const size_t col_size) { + int row_id = i / col_size; + int col_id = i % col_size; int32_t prev = (row_id == 0) ? 0 : idx[row_id - 1]; int32_t curr = idx[row_id]; if (prev != curr) { @@ -89,15 +82,15 @@ struct BooleanMaskForwardKernel { }; struct BooleanMaskBackwardKernel { - template + template static void MSHADOW_XINLINE Map(int i, DType* igrad, const OpReqType req, const DType* ograd, const int32_t* idx, const size_t col_size) { - int row_id = i / col_size; - int col_id = i % col_size; + int row_id = i / col_size; + int col_id = i % col_size; int32_t prev = (row_id == 0) ? 0 : idx[row_id - 1]; int32_t curr = idx[row_id]; if (prev != curr) { @@ -112,19 +105,19 @@ struct BooleanMaskBackwardKernel { } }; -template +template inline void BooleanMaskForward(const nnvm::NodeAttrs& attrs, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs); + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs); -template +template inline void BooleanMaskBackward(const nnvm::NodeAttrs& attrs, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs); + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs); } // namespace op } // namespace mxnet diff --git a/src/operator/contrib/boolean_mask.cc b/src/operator/contrib/boolean_mask.cc index 882984430d52..774c228b8957 100644 --- a/src/operator/contrib/boolean_mask.cc +++ b/src/operator/contrib/boolean_mask.cc @@ -19,7 +19,7 @@ /*! * Copyright (c) 2018 by Contributors * \file boolean_mask.cc -*/ + */ #include "./boolean_mask-inl.h" @@ -29,8 +29,8 @@ namespace op { DMLC_REGISTER_PARAMETER(BooleanMaskParam); bool BooleanMaskType(const nnvm::NodeAttrs& attrs, - std::vector *in_attrs, - std::vector *out_attrs) { + std::vector* in_attrs, + std::vector* out_attrs) { CHECK_EQ(in_attrs->size(), 2); CHECK_EQ(out_attrs->size(), 1); TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); @@ -41,14 +41,14 @@ bool BooleanMaskType(const nnvm::NodeAttrs& attrs, bool BooleanMaskStorageType(const nnvm::NodeAttrs& attrs, const int dev_mask, DispatchMode* dispatch_mode, - std::vector *in_attrs, - std::vector *out_attrs) { + std::vector* in_attrs, + std::vector* out_attrs) { CHECK_EQ(in_attrs->size(), 2); CHECK_EQ(out_attrs->size(), 1); - for (int &attr : *in_attrs) { + for (int& attr : *in_attrs) { CHECK_EQ(attr, kDefaultStorage) << "Only default storage is supported"; } - for (int &attr : *out_attrs) { + for (int& attr : *out_attrs) { attr = kDefaultStorage; } *dispatch_mode = DispatchMode::kFComputeEx; @@ -58,24 +58,24 @@ bool BooleanMaskStorageType(const nnvm::NodeAttrs& attrs, bool BooleanMaskBackStorageType(const nnvm::NodeAttrs& attrs, const int dev_mask, DispatchMode* dispatch_mode, - std::vector *in_attrs, - std::vector *out_attrs) { + std::vector* in_attrs, + std::vector* out_attrs) { CHECK_EQ(in_attrs->size(), 3); CHECK_EQ(out_attrs->size(), 2); - for (int &attr : *in_attrs) { + for (int& attr : *in_attrs) { CHECK_EQ(attr, kDefaultStorage) << "Only default storage is supported"; } - for (int &attr : *out_attrs) { + for (int& attr : *out_attrs) { attr = kDefaultStorage; } - for (int & out_attr : *out_attrs) + for (int& out_attr : *out_attrs) out_attr = kDefaultStorage; *dispatch_mode = DispatchMode::kFComputeEx; return true; } struct BooleanMaskBackwardCPUWriteKernel { - template + template static void Map(int i, DType* igrad, const OpReqType /*req*/, @@ -98,20 +98,20 @@ struct BooleanMaskBackwardCPUWriteKernel { } }; -template<> +template <> inline void BooleanMaskForward(const nnvm::NodeAttrs& attrs, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { CHECK_EQ(inputs.size(), 2U); CHECK_EQ(outputs.size(), 1U); CHECK(req[0] == kWriteTo || req[0] == kWriteInplace); const BooleanMaskParam& param = nnvm::get(attrs.parsed); - const int axis = param.axis; - const NDArray &data = inputs[0]; - const NDArray &idx = inputs[1]; - const NDArray &out = outputs[0]; + const int axis = param.axis; + const NDArray& data = inputs[0]; + const NDArray& idx = inputs[1]; + const NDArray& out = outputs[0]; CHECK_EQ(axis, 0) << "Not supported yet"; CHECK_EQ(data.shape()[axis], idx.shape()[0]); CHECK_EQ(idx.shape().ndim(), 1U); // idx is required to be 1-d. @@ -130,62 +130,75 @@ inline void BooleanMaskForward(const nnvm::NodeAttrs& attrs, }); // set the output shape forcefully mxnet::TShape s = data.shape(); - s[axis] = valid_num; + s[axis] = valid_num; - const_cast(out).Init(s); + const_cast(out).Init(s); // do the copy MSHADOW_TYPE_SWITCH_WITH_BOOL(data.dtype(), DType, { - size_t input_size = data.shape().Size(); - size_t col_size = input_size / idx_size; - mshadow::Stream *stream = ctx.get_stream(); - mxnet_op::Kernel::Launch( - stream, idx_size, out.data().dptr(), data.data().dptr(), - prefix_sum.data(), col_size); + size_t input_size = data.shape().Size(); + size_t col_size = input_size / idx_size; + mshadow::Stream* stream = ctx.get_stream(); + mxnet_op::Kernel::Launch(stream, + idx_size, + out.data().dptr(), + data.data().dptr(), + prefix_sum.data(), + col_size); }); } -template<> +template <> inline void BooleanMaskBackward(const nnvm::NodeAttrs& attrs, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { CHECK_EQ(inputs.size(), 3U); CHECK_EQ(outputs.size(), 2U); - if (req[0] == kNullOp) return; + if (req[0] == kNullOp) + return; // inputs: {ograd, data, idx} // outputs: {igrad_data, igrad_idx} - const NDArray& ograd = inputs[0]; - const NDArray& idx = inputs[2]; + const NDArray& ograd = inputs[0]; + const NDArray& idx = inputs[2]; const NDArray& igrad_data = outputs[0]; MSHADOW_TYPE_SWITCH(igrad_data.dtype(), DType, { MSHADOW_TYPE_SWITCH_WITH_BOOL(idx.dtype(), IType, { size_t input_size = igrad_data.shape().Size(); - size_t idx_size = idx.shape()[0]; - size_t col_size = input_size / idx_size; + size_t idx_size = idx.shape()[0]; + size_t col_size = input_size / idx_size; std::vector prefix_sum(idx_size, 0); IType* idx_dptr = idx.data().dptr(); for (size_t i = 0; i < idx_size; i++) { prefix_sum[i] = (i == 0) ? 0 : prefix_sum[i - 1]; prefix_sum[i] += (idx_dptr[i]) ? 1 : 0; } - mshadow::Stream *stream = ctx.get_stream(); + mshadow::Stream* stream = ctx.get_stream(); if (req[0] == kAddTo) { - mxnet_op::Kernel::Launch( - stream, idx_size, igrad_data.data().dptr(), req[0], - ograd.data().dptr(), prefix_sum.data(), col_size); + mxnet_op::Kernel::Launch(stream, + idx_size, + igrad_data.data().dptr(), + req[0], + ograd.data().dptr(), + prefix_sum.data(), + col_size); } else { mxnet_op::Kernel::Launch( - stream, idx_size, igrad_data.data().dptr(), req[0], - ograd.data().dptr(), prefix_sum.data(), col_size); + stream, + idx_size, + igrad_data.data().dptr(), + req[0], + ograd.data().dptr(), + prefix_sum.data(), + col_size); } }); }); } NNVM_REGISTER_OP(_contrib_boolean_mask) -.add_alias("_npi_boolean_mask") -.describe(R"code( + .add_alias("_npi_boolean_mask") + .describe(R"code( Given an n-d NDArray data, and a 1-d NDArray index, the operator produces an un-predeterminable shaped n-d NDArray out, which stands for the rows in x where the corresonding element in index is non-zero. @@ -199,28 +212,28 @@ which stands for the rows in x where the corresonding element in index is non-ze )code" ADD_FILELINE) -.set_attr_parser(ParamParser) -.set_num_inputs(2) -.set_num_outputs(1) -.set_attr("FListInputNames", - [](const NodeAttrs& attrs) { - return std::vector{"data", "index"}; - }) -.set_attr("FInferType", BooleanMaskType) -.set_attr("FComputeEx", BooleanMaskForward) -.set_attr("FInferStorageType", BooleanMaskStorageType) -.set_attr("FGradient", ElemwiseGradUseIn{"_backward_contrib_boolean_mask"}) -.add_argument("data", "NDArray-or-Symbol", "Data") -.add_argument("index", "NDArray-or-Symbol", "Mask") -.add_arguments(BooleanMaskParam::__FIELDS__()); + .set_attr_parser(ParamParser) + .set_num_inputs(2) + .set_num_outputs(1) + .set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"data", "index"}; + }) + .set_attr("FInferType", BooleanMaskType) + .set_attr("FComputeEx", BooleanMaskForward) + .set_attr("FInferStorageType", BooleanMaskStorageType) + .set_attr("FGradient", ElemwiseGradUseIn{"_backward_contrib_boolean_mask"}) + .add_argument("data", "NDArray-or-Symbol", "Data") + .add_argument("index", "NDArray-or-Symbol", "Mask") + .add_arguments(BooleanMaskParam::__FIELDS__()); NNVM_REGISTER_OP(_backward_contrib_boolean_mask) -.set_num_inputs(3) -.set_num_outputs(2) -.set_attr("TIsBackward", true) -.set_attr("FInferStorageType", BooleanMaskBackStorageType) -.set_attr("FComputeEx", BooleanMaskBackward) -.add_arguments(BooleanMaskParam::__FIELDS__()); + .set_num_inputs(3) + .set_num_outputs(2) + .set_attr("TIsBackward", true) + .set_attr("FInferStorageType", BooleanMaskBackStorageType) + .set_attr("FComputeEx", BooleanMaskBackward) + .add_arguments(BooleanMaskParam::__FIELDS__()); } // namespace op } // namespace mxnet diff --git a/src/operator/contrib/boolean_mask.cu b/src/operator/contrib/boolean_mask.cu index 95f5614ba44d..ebfe4ea88a89 100644 --- a/src/operator/contrib/boolean_mask.cu +++ b/src/operator/contrib/boolean_mask.cu @@ -19,7 +19,7 @@ /*! * Copyright (c) 2018 by Contributors * \file boolean_mask.cu -*/ + */ #include "./boolean_mask-inl.h" #include @@ -27,149 +27,138 @@ namespace mxnet { namespace op { -template<> +template <> inline void BooleanMaskForward(const nnvm::NodeAttrs& attrs, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { using namespace mshadow; CHECK_EQ(inputs.size(), 2U); CHECK_EQ(outputs.size(), 1U); CHECK(req[0] == kWriteTo || req[0] == kWriteInplace); const BooleanMaskParam& param = nnvm::get(attrs.parsed); - const int axis = param.axis; - const NDArray &data = inputs[0]; - const NDArray &idx = inputs[1]; - const NDArray &out = outputs[0]; + const int axis = param.axis; + const NDArray& data = inputs[0]; + const NDArray& idx = inputs[1]; + const NDArray& out = outputs[0]; CHECK_EQ(axis, 0) << "Not supported yet"; CHECK_EQ(data.shape()[axis], idx.shape()[0]); CHECK_EQ(idx.shape().ndim(), 1U); - Stream* s = ctx.get_stream(); + Stream* s = ctx.get_stream(); cudaStream_t stream = Stream::GetStream(s); // count the number of 1s in `idx`, so that we could know the output dimension - size_t idx_size = idx.shape()[0]; - int32_t valid_num = 0; - int32_t* prefix_sum = nullptr; - void* d_temp_storage = nullptr; + size_t idx_size = idx.shape()[0]; + int32_t valid_num = 0; + int32_t* prefix_sum = nullptr; + void* d_temp_storage = nullptr; size_t temp_storage_bytes = 0; // Calculate total temporary memory size - cub::DeviceScan::InclusiveSum(d_temp_storage, - temp_storage_bytes, - prefix_sum, - prefix_sum, - idx_size, - stream); + cub::DeviceScan::InclusiveSum( + d_temp_storage, temp_storage_bytes, prefix_sum, prefix_sum, idx_size, stream); size_t buffer_size = idx_size * sizeof(int32_t); temp_storage_bytes += buffer_size; // Allocate memory on GPU and allocate pointer Tensor workspace = - ctx.requested[0].get_space_typed(Shape1(temp_storage_bytes), s); - prefix_sum = reinterpret_cast(workspace.dptr_); + ctx.requested[0].get_space_typed(Shape1(temp_storage_bytes), s); + prefix_sum = reinterpret_cast(workspace.dptr_); d_temp_storage = workspace.dptr_ + buffer_size; MSHADOW_TYPE_SWITCH_WITH_BOOL(idx.dtype(), IType, { mxnet_op::Kernel::Launch( - s, idx.shape()[0], prefix_sum, idx.data().dptr()); + s, idx.shape()[0], prefix_sum, idx.data().dptr()); }); // Calculate prefix sum - cub::DeviceScan::InclusiveSum(d_temp_storage, - temp_storage_bytes, - prefix_sum, - prefix_sum, - idx_size, - stream); - CUDA_CALL(cudaMemcpyAsync(&valid_num, &prefix_sum[idx_size - 1], sizeof(int32_t), - cudaMemcpyDeviceToHost, stream)); + cub::DeviceScan::InclusiveSum( + d_temp_storage, temp_storage_bytes, prefix_sum, prefix_sum, idx_size, stream); + CUDA_CALL(cudaMemcpyAsync( + &valid_num, &prefix_sum[idx_size - 1], sizeof(int32_t), cudaMemcpyDeviceToHost, stream)); CUDA_CALL(cudaStreamSynchronize(stream)); // Set the output shape forcefully mxnet::TShape data_shape = data.shape(); - data_shape[axis] = valid_num; - const_cast(out).Init(data_shape); + data_shape[axis] = valid_num; + const_cast(out).Init(data_shape); size_t input_size = data.shape().Size(); - size_t col_size = input_size / idx.shape()[0]; + size_t col_size = input_size / idx.shape()[0]; // Do the copy MSHADOW_TYPE_SWITCH_WITH_BOOL(out.dtype(), DType, { if (valid_num > 0) { mxnet_op::Kernel::Launch( - s, input_size, out.data().dptr(), data.data().dptr(), prefix_sum, col_size); + s, input_size, out.data().dptr(), data.data().dptr(), prefix_sum, col_size); } }); } -template<> +template <> inline void BooleanMaskBackward(const nnvm::NodeAttrs& attrs, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { using namespace mshadow; CHECK_EQ(inputs.size(), 3U); CHECK_EQ(outputs.size(), 2U); - if (req[0] == kNullOp) return; + if (req[0] == kNullOp) + return; // inputs: {ograd, data, idx} // outputs: {igrad_data, igrad_idx} - const NDArray& ograd = inputs[0]; - const NDArray& idx = inputs[2]; + const NDArray& ograd = inputs[0]; + const NDArray& idx = inputs[2]; const NDArray& igrad_data = outputs[0]; - Stream* s = ctx.get_stream(); - cudaStream_t stream = Stream::GetStream(s); + Stream* s = ctx.get_stream(); + cudaStream_t stream = Stream::GetStream(s); // Count the number of 1s in `idx`, so that we could know the output dimension - size_t idx_size = idx.shape()[0]; - int32_t* prefix_sum = nullptr; - void* d_temp_storage = nullptr; + size_t idx_size = idx.shape()[0]; + int32_t* prefix_sum = nullptr; + void* d_temp_storage = nullptr; size_t temp_storage_bytes = 0; // Calculate total temporary memory size - cub::DeviceScan::InclusiveSum(d_temp_storage, - temp_storage_bytes, - prefix_sum, - prefix_sum, - idx_size, - stream); + cub::DeviceScan::InclusiveSum( + d_temp_storage, temp_storage_bytes, prefix_sum, prefix_sum, idx_size, stream); size_t buffer_size = idx_size * sizeof(int32_t); temp_storage_bytes += buffer_size; // Allocate memory on GPU and allocate pointer Tensor workspace = - ctx.requested[0].get_space_typed(Shape1(temp_storage_bytes), s); - prefix_sum = reinterpret_cast(workspace.dptr_); + ctx.requested[0].get_space_typed(Shape1(temp_storage_bytes), s); + prefix_sum = reinterpret_cast(workspace.dptr_); d_temp_storage = workspace.dptr_ + buffer_size; MSHADOW_TYPE_SWITCH_WITH_BOOL(idx.dtype(), IType, { mxnet_op::Kernel::Launch( - s, idx.shape()[0], prefix_sum, idx.data().dptr()); + s, idx.shape()[0], prefix_sum, idx.data().dptr()); }); // Calculate prefix sum - cub::DeviceScan::InclusiveSum(d_temp_storage, - temp_storage_bytes, - prefix_sum, - prefix_sum, - idx_size, - stream); + cub::DeviceScan::InclusiveSum( + d_temp_storage, temp_storage_bytes, prefix_sum, prefix_sum, idx_size, stream); size_t input_size = igrad_data.shape().Size(); - size_t col_size = input_size / idx_size; + size_t col_size = input_size / idx_size; // Backward pass MSHADOW_TYPE_SWITCH(igrad_data.dtype(), DType, { if (input_size > 0) { - mxnet_op::Kernel::Launch( - s, input_size, igrad_data.data().dptr(), req[0], ograd.data().dptr(), - prefix_sum, col_size); + mxnet_op::Kernel::Launch(s, + input_size, + igrad_data.data().dptr(), + req[0], + ograd.data().dptr(), + prefix_sum, + col_size); } }); } NNVM_REGISTER_OP(_contrib_boolean_mask) -.set_attr("FResourceRequest", - [](const NodeAttrs& attrs) { - return std::vector{ResourceRequest::kTempSpace}; - }) -.set_attr("THasDeterministicOutput", true) -.set_attr("FComputeEx", BooleanMaskForward); + .set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) + .set_attr("THasDeterministicOutput", true) + .set_attr("FComputeEx", BooleanMaskForward); NNVM_REGISTER_OP(_backward_contrib_boolean_mask) -.set_attr("FResourceRequest", - [](const NodeAttrs& attrs) { - return std::vector{ResourceRequest::kTempSpace}; - }) -.set_attr("FComputeEx", BooleanMaskBackward); + .set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) + .set_attr("FComputeEx", BooleanMaskBackward); } // namespace op } // namespace mxnet diff --git a/src/operator/contrib/bounding_box-common.h b/src/operator/contrib/bounding_box-common.h index 69a96c60569a..ecbeafa70ab1 100644 --- a/src/operator/contrib/bounding_box-common.h +++ b/src/operator/contrib/bounding_box-common.h @@ -21,7 +21,7 @@ * \file bounding_box-common.h * \brief bounding box util functions and operators commonly used by CPU and GPU implementations * \author Joshua Zhang -*/ + */ #ifndef MXNET_OPERATOR_CONTRIB_BOUNDING_BOX_COMMON_H_ #define MXNET_OPERATOR_CONTRIB_BOUNDING_BOX_COMMON_H_ #include "../mshadow_op.h" @@ -31,81 +31,95 @@ namespace mxnet { namespace op { namespace box_common_enum { -enum BoxType {kCorner, kCenter}; +enum BoxType { kCorner, kCenter }; } // compute line intersect along either height or width -template -MSHADOW_XINLINE DType Intersect(const DType *a, const DType *b, int encode) { +template +MSHADOW_XINLINE DType Intersect(const DType* a, const DType* b, int encode) { DType a1 = a[0]; DType a2 = a[2]; DType b1 = b[0]; DType b2 = b[2]; DType w; if (box_common_enum::kCorner == encode) { - DType left = a1 > b1 ? a1 : b1; + DType left = a1 > b1 ? a1 : b1; DType right = a2 < b2 ? a2 : b2; - w = right - left; + w = right - left; } else { - DType aw = a2 / 2; - DType bw = b2 / 2; - DType al = a1 - aw; - DType ar = a1 + aw; - DType bl = b1 - bw; - DType br = b1 + bw; - DType left = bl > al ? bl : al; + DType aw = a2 / 2; + DType bw = b2 / 2; + DType al = a1 - aw; + DType ar = a1 + aw; + DType bl = b1 - bw; + DType br = b1 + bw; + DType left = bl > al ? bl : al; DType right = br < ar ? br : ar; - w = right - left; + w = right - left; } return w > 0 ? w : DType(0); } /*! - * \brief Implementation of the non-maximum suppression operation - * - * \param i the launched thread index - * \param index sorted index in descending order - * \param batch_start map (b, k) to compact index by indices[batch_start[b] + k] - * \param input the input of nms op - * \param areas pre-computed box areas - * \param k nms topk number - * \param ref compare reference position - * \param num number of input boxes in each batch - * \param stride input stride, usually 6 (id-score-x1-y1-x2-y2) - * \param offset_box box offset, usually 2 - * \param thresh nms threshold - * \param force force suppress regardless of class id - * \param offset_id class id offset, used when force == false, usually 0 - * \param encode box encoding type, corner(0) or center(1) - * \param DType the data type - */ + * \brief Implementation of the non-maximum suppression operation + * + * \param i the launched thread index + * \param index sorted index in descending order + * \param batch_start map (b, k) to compact index by indices[batch_start[b] + k] + * \param input the input of nms op + * \param areas pre-computed box areas + * \param k nms topk number + * \param ref compare reference position + * \param num number of input boxes in each batch + * \param stride input stride, usually 6 (id-score-x1-y1-x2-y2) + * \param offset_box box offset, usually 2 + * \param thresh nms threshold + * \param force force suppress regardless of class id + * \param offset_id class id offset, used when force == false, usually 0 + * \param encode box encoding type, corner(0) or center(1) + * \param DType the data type + */ struct nms_impl { - template - MSHADOW_XINLINE static void Map(int i, int32_t *index, const int32_t *batch_start, - const DType *input, const DType *areas, - int k, int ref, int num, - int stride, int offset_box, int offset_id, - float thresh, bool force, int encode) { - int b = i / k; // batch + template + MSHADOW_XINLINE static void Map(int i, + int32_t* index, + const int32_t* batch_start, + const DType* input, + const DType* areas, + int k, + int ref, + int num, + int stride, + int offset_box, + int offset_id, + float thresh, + bool force, + int encode) { + int b = i / k; // batch int pos = i % k + ref + 1; // position - ref = static_cast(batch_start[b]) + ref; - pos = static_cast(batch_start[b]) + pos; - if (ref >= static_cast(batch_start[b + 1])) return; - if (pos >= static_cast(batch_start[b + 1])) return; - if (index[ref] < 0) return; // reference has been suppressed - if (index[pos] < 0) return; // self been suppressed + ref = static_cast(batch_start[b]) + ref; + pos = static_cast(batch_start[b]) + pos; + if (ref >= static_cast(batch_start[b + 1])) + return; + if (pos >= static_cast(batch_start[b + 1])) + return; + if (index[ref] < 0) + return; // reference has been suppressed + if (index[pos] < 0) + return; // self been suppressed int ref_offset = static_cast(index[ref]) * stride + offset_box; int pos_offset = static_cast(index[pos]) * stride + offset_box; - if (!force && offset_id >=0) { + if (!force && offset_id >= 0) { int ref_id = static_cast(input[ref_offset - offset_box + offset_id]); int pos_id = static_cast(input[pos_offset - offset_box + offset_id]); - if (ref_id != pos_id) return; // different class + if (ref_id != pos_id) + return; // different class } DType intersect = Intersect(input + ref_offset, input + pos_offset, encode); intersect *= Intersect(input + ref_offset + 1, input + pos_offset + 1, encode); int ref_area_offset = static_cast(index[ref]); int pos_area_offset = static_cast(index[pos]); - DType iou = intersect / (areas[ref_area_offset] + areas[pos_area_offset] - intersect); + DType iou = intersect / (areas[ref_area_offset] + areas[pos_area_offset] - intersect); if (iou > thresh) { index[pos] = -1; } @@ -114,33 +128,33 @@ struct nms_impl { namespace mshadow_op { struct less_than : public mxnet_op::tunable { - template + template MSHADOW_XINLINE static DType Map(DType a, DType b) { return static_cast(a < b); } }; struct greater_than : public mxnet_op::tunable { - template + template MSHADOW_XINLINE static DType Map(DType a, DType b) { return static_cast(a > b); } }; struct not_equal : public mxnet_op::tunable { - template + template MSHADOW_XINLINE static DType Map(DType a, DType b) { return static_cast(a != b); } }; struct bool_and : public mxnet_op::tunable { - template + template MSHADOW_XINLINE static DType Map(DType a, DType b) { return static_cast(a && b); } }; -} // namespace mshadow_op +} // namespace mshadow_op } // namespace op } // namespace mxnet diff --git a/src/operator/contrib/bounding_box-inl.h b/src/operator/contrib/bounding_box-inl.h index c94591eddfdc..192605316fb7 100644 --- a/src/operator/contrib/bounding_box-inl.h +++ b/src/operator/contrib/bounding_box-inl.h @@ -21,7 +21,7 @@ * \file bounding_box-inl.h * \brief bounding box util functions and operators * \author Joshua Zhang -*/ + */ #ifndef MXNET_OPERATOR_CONTRIB_BOUNDING_BOX_INL_H_ #define MXNET_OPERATOR_CONTRIB_BOUNDING_BOX_INL_H_ #include @@ -39,10 +39,10 @@ namespace mxnet { namespace op { namespace box_nms_enum { -enum BoxNMSOpInputs {kData}; -enum BoxNMSOpOutputs {kOut, kTemp}; -enum BoxNMSOpResource {kTempSpace}; -} // box_nms_enum +enum BoxNMSOpInputs { kData }; +enum BoxNMSOpOutputs { kOut, kTemp }; +enum BoxNMSOpResource { kTempSpace }; +} // namespace box_nms_enum struct BoxNMSParam : public dmlc::Parameter { float overlap_thresh; @@ -56,90 +56,96 @@ struct BoxNMSParam : public dmlc::Parameter { int in_format; int out_format; DMLC_DECLARE_PARAMETER(BoxNMSParam) { - DMLC_DECLARE_FIELD(overlap_thresh).set_default(0.5) - .describe("Overlapping(IoU) threshold to suppress object with smaller score."); - DMLC_DECLARE_FIELD(valid_thresh).set_default(0) - .describe("Filter input boxes to those whose scores greater than valid_thresh."); - DMLC_DECLARE_FIELD(topk).set_default(-1) - .describe("Apply nms to topk boxes with descending scores, -1 to no restriction."); - DMLC_DECLARE_FIELD(coord_start).set_default(2) - .describe("Start index of the consecutive 4 coordinates."); - DMLC_DECLARE_FIELD(score_index).set_default(1) - .describe("Index of the scores/confidence of boxes."); - DMLC_DECLARE_FIELD(id_index).set_default(-1) - .describe("Optional, index of the class categories, -1 to disable."); - DMLC_DECLARE_FIELD(background_id).set_default(-1) - .describe("Optional, id of the background class which will be ignored in nms."); - DMLC_DECLARE_FIELD(force_suppress).set_default(false) - .describe("Optional, if set false and id_index is provided, nms will only apply" - " to boxes belongs to the same category"); - DMLC_DECLARE_FIELD(in_format).set_default(box_common_enum::kCorner) - .add_enum("corner", box_common_enum::kCorner) - .add_enum("center", box_common_enum::kCenter) - .describe("The input box encoding type. \n" - " \"corner\" means boxes are encoded as [xmin, ymin, xmax, ymax]," - " \"center\" means boxes are encodes as [x, y, width, height]."); - DMLC_DECLARE_FIELD(out_format).set_default(box_common_enum::kCorner) - .add_enum("corner", box_common_enum::kCorner) - .add_enum("center", box_common_enum::kCenter) - .describe("The output box encoding type. \n" - " \"corner\" means boxes are encoded as [xmin, ymin, xmax, ymax]," - " \"center\" means boxes are encodes as [x, y, width, height]."); + DMLC_DECLARE_FIELD(overlap_thresh) + .set_default(0.5) + .describe("Overlapping(IoU) threshold to suppress object with smaller score."); + DMLC_DECLARE_FIELD(valid_thresh) + .set_default(0) + .describe("Filter input boxes to those whose scores greater than valid_thresh."); + DMLC_DECLARE_FIELD(topk).set_default(-1).describe( + "Apply nms to topk boxes with descending scores, -1 to no restriction."); + DMLC_DECLARE_FIELD(coord_start) + .set_default(2) + .describe("Start index of the consecutive 4 coordinates."); + DMLC_DECLARE_FIELD(score_index) + .set_default(1) + .describe("Index of the scores/confidence of boxes."); + DMLC_DECLARE_FIELD(id_index).set_default(-1).describe( + "Optional, index of the class categories, -1 to disable."); + DMLC_DECLARE_FIELD(background_id) + .set_default(-1) + .describe("Optional, id of the background class which will be ignored in nms."); + DMLC_DECLARE_FIELD(force_suppress) + .set_default(false) + .describe( + "Optional, if set false and id_index is provided, nms will only apply" + " to boxes belongs to the same category"); + DMLC_DECLARE_FIELD(in_format) + .set_default(box_common_enum::kCorner) + .add_enum("corner", box_common_enum::kCorner) + .add_enum("center", box_common_enum::kCenter) + .describe( + "The input box encoding type. \n" + " \"corner\" means boxes are encoded as [xmin, ymin, xmax, ymax]," + " \"center\" means boxes are encodes as [x, y, width, height]."); + DMLC_DECLARE_FIELD(out_format) + .set_default(box_common_enum::kCorner) + .add_enum("corner", box_common_enum::kCorner) + .add_enum("center", box_common_enum::kCenter) + .describe( + "The output box encoding type. \n" + " \"corner\" means boxes are encoded as [xmin, ymin, xmax, ymax]," + " \"center\" means boxes are encodes as [x, y, width, height]."); } }; // BoxNMSParam inline bool BoxNMSShape(const nnvm::NodeAttrs& attrs, - mxnet::ShapeVector *in_attrs, - mxnet::ShapeVector *out_attrs) { + mxnet::ShapeVector* in_attrs, + mxnet::ShapeVector* out_attrs) { const BoxNMSParam& param = nnvm::get(attrs.parsed); CHECK_EQ(in_attrs->size(), 1U); CHECK_EQ(out_attrs->size(), 2U); - if (mxnet::op::shape_is_none(in_attrs->at(0)) - && mxnet::op::shape_is_none(out_attrs->at(0))) { + if (mxnet::op::shape_is_none(in_attrs->at(0)) && mxnet::op::shape_is_none(out_attrs->at(0))) { return false; } mxnet::TShape& ishape = (*in_attrs)[0]; - int indim = ishape.ndim(); - CHECK(indim >= 2) - << "input must have dim >= 2" - << " the last two dimensions are num_box and box_width " - << ishape << " provided"; + int indim = ishape.ndim(); + CHECK(indim >= 2) << "input must have dim >= 2" + << " the last two dimensions are num_box and box_width " << ishape + << " provided"; int width_elem = ishape[indim - 1]; - int expected = 5; + int expected = 5; if (param.id_index >= 0) { expected += 1; } - CHECK_GE(width_elem, expected) - << "the last dimension must have at least 5 elements" - << " namely (score, coordinates x 4) " - << width_elem << " provided, " << expected << " expected."; + CHECK_GE(width_elem, expected) << "the last dimension must have at least 5 elements" + << " namely (score, coordinates x 4) " << width_elem + << " provided, " << expected << " expected."; // check indices int coord_start = param.coord_start; - int coord_end = param.coord_start + 3; + int coord_end = param.coord_start + 3; int score_index = param.score_index; CHECK(score_index >= 0 && score_index < width_elem) - << "score_index: " << score_index << " out of range: (0, " - << width_elem << ")"; + << "score_index: " << score_index << " out of range: (0, " << width_elem << ")"; CHECK(score_index < coord_start || score_index > coord_end) - << "score_index: " << score_index << " conflict with coordinates: (" - << coord_start << ", " << coord_end << ")."; + << "score_index: " << score_index << " conflict with coordinates: (" << coord_start << ", " + << coord_end << ")."; CHECK(coord_start >= 0 && coord_end < width_elem) - << "coordinates: (" << coord_start << ", " << coord_end - << ") out of range:: (0, " << width_elem << ")"; + << "coordinates: (" << coord_start << ", " << coord_end << ") out of range:: (0, " + << width_elem << ")"; if (param.id_index >= 0) { int id_index = param.id_index; CHECK(id_index >= 0 && id_index < width_elem) - << "id_index: " << id_index << " out of range: (0, " - << width_elem << ")"; + << "id_index: " << id_index << " out of range: (0, " << width_elem << ")"; CHECK(id_index < coord_start || id_index > coord_end) - << "id_index: " << id_index << " conflict with coordinates: (" - << coord_start << ", " << coord_end << ")."; + << "id_index: " << id_index << " conflict with coordinates: (" << coord_start << ", " + << coord_end << ")."; CHECK_NE(id_index, score_index) - << "id_index: " << id_index << " conflict with score_index: " << score_index; + << "id_index: " << id_index << " conflict with score_index: " << score_index; } mxnet::TShape oshape = ishape; - oshape[indim - 1] = 1; + oshape[indim - 1] = 1; SHAPE_ASSIGN_CHECK(*out_attrs, 0, ishape); // out_shape[0] == in_shape SHAPE_ASSIGN_CHECK(*out_attrs, 1, oshape); // out_shape[1] return true; @@ -149,7 +155,7 @@ inline uint32_t BoxNMSNumVisibleOutputs(const NodeAttrs& attrs) { return static_cast(1); } -template +template int CopyIf(mshadow::Tensor out, mshadow::Tensor value, mshadow::Tensor flag) { @@ -164,49 +170,51 @@ int CopyIf(mshadow::Tensor out, } struct corner_to_center { - template - MSHADOW_XINLINE static void Map(int i, DType *data, int stride) { - int index = i * stride; + template + MSHADOW_XINLINE static void Map(int i, DType* data, int stride) { + int index = i * stride; DType left = data[index]; - if (left < 0) return; - DType top = data[index+1]; - DType right = data[index+2]; - DType bot = data[index+3]; - data[index] = (left + right) / 2; - data[index+1] = (top + bot) / 2; - data[index+2] = right - left; - data[index+3] = bot - top; + if (left < 0) + return; + DType top = data[index + 1]; + DType right = data[index + 2]; + DType bot = data[index + 3]; + data[index] = (left + right) / 2; + data[index + 1] = (top + bot) / 2; + data[index + 2] = right - left; + data[index + 3] = bot - top; } }; struct center_to_corner { - template - MSHADOW_XINLINE static void Map(int i, DType *data, int stride) { + template + MSHADOW_XINLINE static void Map(int i, DType* data, int stride) { int index = i * stride; - DType x = data[index]; - if (x < 0) return; - DType y = data[index+1]; - DType width = data[index+2] / 2; - DType height = data[index+3] / 2; - data[index] = x - width; - data[index+1] = y - height; - data[index+2] = x + width; - data[index+3] = y + height; + DType x = data[index]; + if (x < 0) + return; + DType y = data[index + 1]; + DType width = data[index + 2] / 2; + DType height = data[index + 3] / 2; + data[index] = x - width; + data[index + 1] = y - height; + data[index + 2] = x + width; + data[index + 3] = y + height; } }; -template -MSHADOW_XINLINE DType BoxArea(const DType *box, int encode) { +template +MSHADOW_XINLINE DType BoxArea(const DType* box, int encode) { DType a1 = box[0]; DType a2 = box[1]; DType a3 = box[2]; DType a4 = box[3]; DType width, height; if (box_common_enum::kCorner == encode) { - width = a3 - a1; + width = a3 - a1; height = a4 - a2; } else { - width = a3; + width = a3; height = a4; } if (width < 0 || height < 0) { @@ -229,46 +237,68 @@ MSHADOW_XINLINE DType BoxArea(const DType *box, int encode) { * \param encode passed to BoxArea to compute area */ struct compute_area { - template - MSHADOW_XINLINE static void Map(int i, DType *out, const DType *in, - const int32_t *indices, const int32_t *batch_start, - int topk, int num_elem, int stride, int encode) { - int b = i / topk; - int k = i % topk; + template + MSHADOW_XINLINE static void Map(int i, + DType* out, + const DType* in, + const int32_t* indices, + const int32_t* batch_start, + int topk, + int num_elem, + int stride, + int encode) { + int b = i / topk; + int k = i % topk; int pos = static_cast(batch_start[b]) + k; - if (pos >= static_cast(batch_start[b + 1])) return; - int index = static_cast(indices[pos]); + if (pos >= static_cast(batch_start[b + 1])) + return; + int index = static_cast(indices[pos]); int in_index = index * stride; - out[index] = BoxArea(in + in_index, encode); + out[index] = BoxArea(in + in_index, encode); } }; -template -void NMSApply(mshadow::Stream *s, - int num_batch, int topk, +template +void NMSApply(mshadow::Stream* s, + int num_batch, + int topk, mshadow::Tensor* sorted_index, mshadow::Tensor* batch_start, mshadow::Tensor* buffer, mshadow::Tensor* areas, - int num_elem, int width_elem, - int coord_start, int id_index, - float threshold, bool force_suppress, + int num_elem, + int width_elem, + int coord_start, + int id_index, + float threshold, + bool force_suppress, int in_format) { using namespace mxnet_op; // go through each box as reference, suppress if overlap > threshold // sorted_index with -1 is marked as suppressed for (int ref = 0; ref < topk; ++ref) { int num_worker = topk - ref - 1; - if (num_worker < 1) continue; - Kernel::Launch(s, num_batch * num_worker, - sorted_index->dptr_, batch_start->dptr_, buffer->dptr_, areas->dptr_, - num_worker, ref, num_elem, - width_elem, coord_start, id_index, - threshold, force_suppress, in_format); + if (num_worker < 1) + continue; + Kernel::Launch(s, + num_batch * num_worker, + sorted_index->dptr_, + batch_start->dptr_, + buffer->dptr_, + areas->dptr_, + num_worker, + ref, + num_elem, + width_elem, + coord_start, + id_index, + threshold, + force_suppress, + in_format); } } -inline void NMSCalculateBatchStart(mshadow::Stream *s, +inline void NMSCalculateBatchStart(mshadow::Stream* s, mshadow::Tensor* batch_start, mshadow::Tensor* valid_batch_id, int num_batch) { @@ -282,31 +312,38 @@ inline void NMSCalculateBatchStart(mshadow::Stream *s, } /*! - * \brief Assign output of nms by indexing input - * - * \param i the launched thread index (total num_batch) - * \param out output array [cls, conf, b0, b1, b2, b3] - * \param record book keeping the selected index for backward - * \param index compact sorted_index, use batch_start to access - * \param batch_start map(b, k) to compact index by index[batch_start[b] + k] - * \param k nms topk number - * \param num number of input boxes in each batch - * \param stride input stride, usually 6 (id-score-x1-y2-x2-y2) - */ + * \brief Assign output of nms by indexing input + * + * \param i the launched thread index (total num_batch) + * \param out output array [cls, conf, b0, b1, b2, b3] + * \param record book keeping the selected index for backward + * \param index compact sorted_index, use batch_start to access + * \param batch_start map(b, k) to compact index by index[batch_start[b] + k] + * \param k nms topk number + * \param num number of input boxes in each batch + * \param stride input stride, usually 6 (id-score-x1-y2-x2-y2) + */ struct nms_assign { - template - MSHADOW_XINLINE static void Map(int i, DType *out, DType *record, const DType *input, - const int32_t *index, const int32_t *batch_start, - int k, int num, int stride) { + template + MSHADOW_XINLINE static void Map(int i, + DType* out, + DType* record, + const DType* input, + const int32_t* index, + const int32_t* batch_start, + int k, + int num, + int stride) { int count = 0; for (int j = 0; j < k; ++j) { int pos = static_cast(batch_start[i]) + j; - if (pos >= static_cast(batch_start[i + 1])) return; + if (pos >= static_cast(batch_start[i + 1])) + return; int location = static_cast(index[pos]); if (location >= 0) { // copy to output int out_location = (i * num + count) * stride; - int in_location = location * stride; + int in_location = location * stride; for (int s = 0; s < stride; ++s) { out[out_location + s] = input[in_location + s]; } @@ -318,14 +355,18 @@ struct nms_assign { } }; - struct nms_backward { - template - MSHADOW_XINLINE static void Map(int i, DType *in_grad, const DType *out_grad, - const DType *record, int num, int stride) { + template + MSHADOW_XINLINE static void Map(int i, + DType* in_grad, + const DType* out_grad, + const DType* record, + int num, + int stride) { int index = static_cast(record[i]); - if (index < 0) return; - int loc = index * stride; + if (index < 0) + return; + int loc = index * stride; int from_loc = i * stride; for (int j = 0; j < stride; ++j) { in_grad[loc + j] = out_grad[from_loc + j]; @@ -333,36 +374,36 @@ struct nms_backward { } }; -template +template void BoxNMSForward(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { using namespace mshadow; using namespace mshadow::expr; using namespace mxnet_op; CHECK_EQ(inputs.size(), 1U); CHECK_EQ(outputs.size(), 2U) << "BoxNMS output: [output, temp]"; const BoxNMSParam& param = nnvm::get(attrs.parsed); - Stream *s = ctx.get_stream(); - mxnet::TShape in_shape = inputs[box_nms_enum::kData].shape_; - int indim = in_shape.ndim(); - int num_batch = indim <= 2? 1 : in_shape.ProdShape(0, indim - 2); - int num_elem = in_shape[indim - 2]; - int width_elem = in_shape[indim - 1]; - bool class_exist = param.id_index >= 0; + Stream* s = ctx.get_stream(); + mxnet::TShape in_shape = inputs[box_nms_enum::kData].shape_; + int indim = in_shape.ndim(); + int num_batch = indim <= 2 ? 1 : in_shape.ProdShape(0, indim - 2); + int num_elem = in_shape[indim - 2]; + int width_elem = in_shape[indim - 1]; + bool class_exist = param.id_index >= 0; MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { - Tensor data = inputs[box_nms_enum::kData] - .get_with_shape(Shape3(num_batch, num_elem, width_elem), s); - Tensor out = outputs[box_nms_enum::kOut] - .get_with_shape(Shape3(num_batch, num_elem, width_elem), s); - Tensor record = outputs[box_nms_enum::kTemp] - .get_with_shape(Shape3(num_batch, num_elem, 1), s); + Tensor data = inputs[box_nms_enum::kData].get_with_shape( + Shape3(num_batch, num_elem, width_elem), s); + Tensor out = outputs[box_nms_enum::kOut].get_with_shape( + Shape3(num_batch, num_elem, width_elem), s); + Tensor record = outputs[box_nms_enum::kTemp].get_with_shape( + Shape3(num_batch, num_elem, 1), s); // prepare workspace - Shape<1> sort_index_shape = Shape1(num_batch * num_elem); - Shape<3> buffer_shape = Shape3(num_batch, num_elem, width_elem); + Shape<1> sort_index_shape = Shape1(num_batch * num_elem); + Shape<3> buffer_shape = Shape3(num_batch, num_elem, width_elem); Shape<1> batch_start_shape = Shape1(num_batch + 1); // index @@ -372,19 +413,19 @@ void BoxNMSForward(const nnvm::NodeAttrs& attrs, dtype_size += buffer_shape.Size(); } // ceil up when sizeof(DType) is larger than sizeof(DType) - index_t int32_offset = (int32_size * sizeof(int32_t) - 1) / sizeof(DType) + 1; + index_t int32_offset = (int32_size * sizeof(int32_t) - 1) / sizeof(DType) + 1; index_t workspace_size = int32_offset + dtype_size; - Tensor workspace = ctx.requested[box_nms_enum::kTempSpace] - .get_space_typed(Shape1(workspace_size), s); + Tensor workspace = + ctx.requested[box_nms_enum::kTempSpace].get_space_typed( + Shape1(workspace_size), s); Tensor sorted_index( - reinterpret_cast(workspace.dptr_), sort_index_shape, s); - Tensor all_sorted_index(sorted_index.dptr_ + sorted_index.MSize(), - sort_index_shape, s); + reinterpret_cast(workspace.dptr_), sort_index_shape, s); + Tensor all_sorted_index( + sorted_index.dptr_ + sorted_index.MSize(), sort_index_shape, s); Tensor batch_id( - all_sorted_index.dptr_ + all_sorted_index.MSize(), sort_index_shape, s); + all_sorted_index.dptr_ + all_sorted_index.MSize(), sort_index_shape, s); Tensor batch_start(batch_id.dptr_ + batch_id.MSize(), batch_start_shape, s); - Tensor scores(workspace.dptr_ + int32_offset, - sort_index_shape, s); + Tensor scores(workspace.dptr_ + int32_offset, sort_index_shape, s); Tensor areas(scores.dptr_ + scores.MSize(), sort_index_shape, s); Tensor classes(areas.dptr_ + areas.MSize(), sort_index_shape, s); Tensor buffer = data; @@ -397,19 +438,19 @@ void BoxNMSForward(const nnvm::NodeAttrs& attrs, // indecies int score_index = param.score_index; int coord_start = param.coord_start; - int id_index = param.id_index; + int id_index = param.id_index; // sort topk - int topk = param.topk < 0? num_elem : std::min(num_elem, param.topk); + int topk = param.topk < 0 ? num_elem : std::min(num_elem, param.topk); if (topk < 1) { - out = F(buffer); + out = F(buffer); record = reshape(range(0, num_batch * num_elem), record.shape_); return; } // use classes, areas and scores as temporary storage Tensor all_scores = areas; - all_scores = reshape(slice<2>(buffer, score_index, score_index + 1), all_scores.shape_); + all_scores = reshape(slice<2>(buffer, score_index, score_index + 1), all_scores.shape_); all_sorted_index = range(0, num_batch * num_elem); Tensor all_classes = classes; if (class_exist) { @@ -421,20 +462,20 @@ void BoxNMSForward(const nnvm::NodeAttrs& attrs, Tensor valid_box = scores; if (class_exist) { valid_box = F( - F(all_scores, ScalarExp(param.valid_thresh)), - F(all_classes, ScalarExp(param.background_id))); + F(all_scores, ScalarExp(param.valid_thresh)), + F(all_classes, ScalarExp(param.background_id))); } else { valid_box = F(all_scores, ScalarExp(param.valid_thresh)); } - classes = F(valid_box); - valid_box = classes; + classes = F(valid_box); + valid_box = classes; int num_valid = mxnet::op::CopyIf(scores, all_scores, valid_box); mxnet::op::CopyIf(sorted_index, all_sorted_index, valid_box); // if everything is filtered, output -1 if (num_valid == 0) { record = -1; - out = -1; + out = -1; return; } // mark the invalid boxes before nms @@ -459,102 +500,119 @@ void BoxNMSForward(const nnvm::NodeAttrs& attrs, // pre-compute areas of candidates areas = 0; - Kernel::Launch(s, num_batch * topk, - areas.dptr_, buffer.dptr_ + coord_start, sorted_index.dptr_, batch_start.dptr_, - topk, num_elem, width_elem, param.in_format); + Kernel::Launch(s, + num_batch * topk, + areas.dptr_, + buffer.dptr_ + coord_start, + sorted_index.dptr_, + batch_start.dptr_, + topk, + num_elem, + width_elem, + param.in_format); // apply nms - mxnet::op::NMSApply(s, num_batch, topk, &sorted_index, - &batch_start, &buffer, &areas, - num_elem, width_elem, coord_start, - id_index, param.overlap_thresh, - param.force_suppress, param.in_format); + mxnet::op::NMSApply(s, + num_batch, + topk, + &sorted_index, + &batch_start, + &buffer, + &areas, + num_elem, + width_elem, + coord_start, + id_index, + param.overlap_thresh, + param.force_suppress, + param.in_format); // store the results to output, keep a record for backward record = -1; - out = -1; - Kernel::Launch(s, num_batch, - out.dptr_, record.dptr_, buffer.dptr_, sorted_index.dptr_, batch_start.dptr_, - topk, num_elem, width_elem); + out = -1; + Kernel::Launch(s, + num_batch, + out.dptr_, + record.dptr_, + buffer.dptr_, + sorted_index.dptr_, + batch_start.dptr_, + topk, + num_elem, + width_elem); // convert encoding if (param.in_format != param.out_format) { if (box_common_enum::kCenter == param.out_format) { - Kernel::Launch(s, num_batch * num_elem, - out.dptr_ + coord_start, width_elem); + Kernel::Launch( + s, num_batch * num_elem, out.dptr_ + coord_start, width_elem); } else { - Kernel::Launch(s, num_batch * num_elem, - out.dptr_ + coord_start, width_elem); + Kernel::Launch( + s, num_batch * num_elem, out.dptr_ + coord_start, width_elem); } } }); } -template +template void BoxNMSBackward(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { using namespace mshadow; using namespace mshadow::expr; using namespace mxnet_op; CHECK_EQ(inputs.size(), 4U); CHECK_EQ(outputs.size(), 1U); - Stream *s = ctx.get_stream(); + Stream* s = ctx.get_stream(); mxnet::TShape in_shape = outputs[box_nms_enum::kData].shape_; - int indim = in_shape.ndim(); - int num_batch = indim <= 2? 1 : in_shape.ProdShape(0, indim - 2); - int num_elem = in_shape[indim - 2]; - int width_elem = in_shape[indim - 1]; + int indim = in_shape.ndim(); + int num_batch = indim <= 2 ? 1 : in_shape.ProdShape(0, indim - 2); + int num_elem = in_shape[indim - 2]; + int width_elem = in_shape[indim - 1]; MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { - Tensor out_grad = inputs[box_nms_enum::kOut] - .get_with_shape(Shape3(num_batch, num_elem, width_elem), s); - Tensor in_grad = outputs[box_nms_enum::kData] - .get_with_shape(Shape3(num_batch, num_elem, width_elem), s); - Tensor record = inputs[box_nms_enum::kTemp + 2] - .get_with_shape(Shape3(num_batch, num_elem, 1), s); + Tensor out_grad = inputs[box_nms_enum::kOut].get_with_shape( + Shape3(num_batch, num_elem, width_elem), s); + Tensor in_grad = outputs[box_nms_enum::kData].get_with_shape( + Shape3(num_batch, num_elem, width_elem), s); + Tensor record = inputs[box_nms_enum::kTemp + 2].get_with_shape( + Shape3(num_batch, num_elem, 1), s); in_grad = 0; - Kernel::Launch(s, num_batch * num_elem, in_grad.dptr_, - out_grad.dptr_, record.dptr_, num_elem, width_elem); + Kernel::Launch( + s, num_batch * num_elem, in_grad.dptr_, out_grad.dptr_, record.dptr_, num_elem, width_elem); }); } struct BoxOverlapParam : public dmlc::Parameter { int format; DMLC_DECLARE_PARAMETER(BoxOverlapParam) { - DMLC_DECLARE_FIELD(format).set_default(box_common_enum::kCorner) - .add_enum("corner", box_common_enum::kCorner) - .add_enum("center", box_common_enum::kCenter) - .describe("The box encoding type. \n" - " \"corner\" means boxes are encoded as [xmin, ymin, xmax, ymax]," - " \"center\" means boxes are encodes as [x, y, width, height]."); + DMLC_DECLARE_FIELD(format) + .set_default(box_common_enum::kCorner) + .add_enum("corner", box_common_enum::kCorner) + .add_enum("center", box_common_enum::kCenter) + .describe( + "The box encoding type. \n" + " \"corner\" means boxes are encoded as [xmin, ymin, xmax, ymax]," + " \"center\" means boxes are encodes as [x, y, width, height]."); } }; // BoxOverlapParam inline bool BoxOverlapShape(const nnvm::NodeAttrs& attrs, - mxnet::ShapeVector *in_attrs, - mxnet::ShapeVector *out_attrs) { + mxnet::ShapeVector* in_attrs, + mxnet::ShapeVector* out_attrs) { CHECK_EQ(in_attrs->size(), 2U); CHECK_EQ(out_attrs->size(), 1U); mxnet::TShape& lshape = (*in_attrs)[0]; mxnet::TShape& rshape = (*in_attrs)[1]; - CHECK_GE(lshape.ndim(), 2) - << "lhs must have dim >= 2 " - << lshape.ndim() << " provided"; + CHECK_GE(lshape.ndim(), 2) << "lhs must have dim >= 2 " << lshape.ndim() << " provided"; int ldim = lshape[lshape.ndim() - 1]; - CHECK_EQ(ldim, 4) - << "last dimension of lhs must be 4 " - << ldim << " provided"; - CHECK_GE(rshape.ndim(), 2) - << "rhs must have dim >= 2 " - << rshape.ndim() << " provided"; + CHECK_EQ(ldim, 4) << "last dimension of lhs must be 4 " << ldim << " provided"; + CHECK_GE(rshape.ndim(), 2) << "rhs must have dim >= 2 " << rshape.ndim() << " provided"; int rdim = rshape[rshape.ndim() - 1]; - CHECK_EQ(rdim, 4) - << "last dimension of rhs must be 4 " - << rdim << " provided"; + CHECK_EQ(rdim, 4) << "last dimension of rhs must be 4 " << rdim << " provided"; // assign output shape mxnet::TShape oshape(lshape.ndim() + rshape.ndim() - 2, -1); @@ -570,14 +628,19 @@ inline bool BoxOverlapShape(const nnvm::NodeAttrs& attrs, } struct compute_overlap { - template - MSHADOW_XINLINE static void Map(int i, DType *out, const DType *lhs, - const DType *rhs, int num, - int begin, int stride, int encode) { - int l = i / num; - int r = i % num; - int l_index = l * stride + begin; - int r_index = r * stride + begin; + template + MSHADOW_XINLINE static void Map(int i, + DType* out, + const DType* lhs, + const DType* rhs, + int num, + int begin, + int stride, + int encode) { + int l = i / num; + int r = i % num; + int l_index = l * stride + begin; + int r_index = r * stride + begin; DType intersect = Intersect(lhs + l_index, rhs + r_index, encode); intersect *= Intersect(lhs + l_index + 1, rhs + r_index + 1, encode); if (intersect <= 0) { @@ -586,41 +649,38 @@ struct compute_overlap { } DType l_area = BoxArea(lhs + l_index, encode); DType r_area = BoxArea(rhs + r_index, encode); - out[i] = intersect / (l_area + r_area - intersect); + out[i] = intersect / (l_area + r_area - intersect); } }; -template +template void BoxOverlapForward(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { using namespace mshadow; using namespace mshadow::expr; using namespace mxnet_op; CHECK_EQ(inputs.size(), 2U); CHECK_EQ(outputs.size(), 1U); const BoxOverlapParam& param = nnvm::get(attrs.parsed); - Stream *s = ctx.get_stream(); - mxnet::TShape lshape = inputs[0].shape_; - mxnet::TShape rshape = inputs[1].shape_; - int lsize = lshape.ProdShape(0, lshape.ndim() - 1); - int rsize = rshape.ProdShape(0, rshape.ndim() - 1); + Stream* s = ctx.get_stream(); + mxnet::TShape lshape = inputs[0].shape_; + mxnet::TShape rshape = inputs[1].shape_; + int lsize = lshape.ProdShape(0, lshape.ndim() - 1); + int rsize = rshape.ProdShape(0, rshape.ndim() - 1); MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { - Tensor lhs = inputs[0] - .get_with_shape(Shape1(lsize * 4), s); - Tensor rhs = inputs[1] - .get_with_shape(Shape1(rsize * 4), s); - Tensor out = outputs[0] - .get_with_shape(Shape1(lsize * rsize), s); - - Kernel::Launch(s, lsize * rsize, out.dptr_, - lhs.dptr_, rhs.dptr_, rsize, 0, 4, param.format); + Tensor lhs = inputs[0].get_with_shape(Shape1(lsize * 4), s); + Tensor rhs = inputs[1].get_with_shape(Shape1(rsize * 4), s); + Tensor out = outputs[0].get_with_shape(Shape1(lsize * rsize), s); + + Kernel::Launch( + s, lsize * rsize, out.dptr_, lhs.dptr_, rhs.dptr_, rsize, 0, 4, param.format); }); } -template +template void BoxOverlapBackward(const nnvm::NodeAttrs& attrs, const OpContext& ctx, const std::vector& inputs, @@ -631,7 +691,7 @@ void BoxOverlapBackward(const nnvm::NodeAttrs& attrs, using namespace mxnet_op; CHECK_EQ(inputs.size(), 1U); CHECK_EQ(outputs.size(), 2U); - Stream *s = ctx.get_stream(); + Stream* s = ctx.get_stream(); MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { Tensor in_grad_lhs = outputs[0].FlatTo2D(s); Tensor in_grad_rhs = outputs[1].FlatTo2D(s); @@ -646,28 +706,26 @@ struct BipartiteMatchingParam : public dmlc::Parameter { float threshold; int topk; DMLC_DECLARE_PARAMETER(BipartiteMatchingParam) { - DMLC_DECLARE_FIELD(is_ascend).set_default(false) - .describe("Use ascend order for scores instead of descending. " - "Please set threshold accordingly."); - DMLC_DECLARE_FIELD(threshold) - .describe("Ignore matching when score < thresh, if is_ascend=false, " - "or ignore score > thresh, if is_ascend=true."); - DMLC_DECLARE_FIELD(topk).set_default(-1) - .describe("Limit the number of matches to topk, set -1 for no limit"); + DMLC_DECLARE_FIELD(is_ascend).set_default(false).describe( + "Use ascend order for scores instead of descending. " + "Please set threshold accordingly."); + DMLC_DECLARE_FIELD(threshold).describe( + "Ignore matching when score < thresh, if is_ascend=false, " + "or ignore score > thresh, if is_ascend=true."); + DMLC_DECLARE_FIELD(topk).set_default(-1).describe( + "Limit the number of matches to topk, set -1 for no limit"); } }; // BipartiteMatchingParam inline bool MatchingShape(const nnvm::NodeAttrs& attrs, - mxnet::ShapeVector *in_attrs, - mxnet::ShapeVector *out_attrs) { + mxnet::ShapeVector* in_attrs, + mxnet::ShapeVector* out_attrs) { // const BipartiteMatchingParam& param = nnvm::get(attrs.parsed); CHECK_EQ(in_attrs->size(), 1U); CHECK_EQ(out_attrs->size(), 2U); mxnet::TShape& dshape = (*in_attrs)[0]; - CHECK_GE(dshape.ndim(), 2) - << "score matrix must have dim >= 2 " - << dshape.ndim() << " provided"; + CHECK_GE(dshape.ndim(), 2) << "score matrix must have dim >= 2 " << dshape.ndim() << " provided"; // assign output shape mxnet::TShape oshape(dshape.ndim() - 1, -1); @@ -681,24 +739,30 @@ inline bool MatchingShape(const nnvm::NodeAttrs& attrs, } struct bipartite_matching { - template - MSHADOW_XINLINE static void Map(int i, DType *row_marker, DType *col_marker, - const DType *scores, const int32_t *sorted_index, - int num_batch, int num_row, int num_col, - float threshold, bool is_ascend, int topk) { - int stride = num_row * num_col; - const int32_t *index = sorted_index + i * stride; - const DType *score = scores + i * stride; - DType *rmarker = row_marker + i * num_row; - DType *cmarker = col_marker + i * num_col; - int count = 0; + template + MSHADOW_XINLINE static void Map(int i, + DType* row_marker, + DType* col_marker, + const DType* scores, + const int32_t* sorted_index, + int num_batch, + int num_row, + int num_col, + float threshold, + bool is_ascend, + int topk) { + int stride = num_row * num_col; + const int32_t* index = sorted_index + i * stride; + const DType* score = scores + i * stride; + DType* rmarker = row_marker + i * num_row; + DType* cmarker = col_marker + i * num_col; + int count = 0; for (int j = 0; j < stride; ++j) { int idx = static_cast(index[j]) % stride; - int r = idx / num_col; - int c = idx % num_col; + int r = idx / num_col; + int c = idx % num_col; if (rmarker[r] == -1 && cmarker[c] == -1) { - if ((!is_ascend && score[j] > threshold) || - (is_ascend && score[j] < threshold)) { + if ((!is_ascend && score[j] > threshold) || (is_ascend && score[j] < threshold)) { rmarker[r] = c; cmarker[c] = r; ++count; @@ -714,49 +778,48 @@ struct bipartite_matching { } }; -template +template void BipartiteMatchingForward(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { using namespace mshadow; using namespace mshadow::expr; using namespace mxnet_op; CHECK_EQ(inputs.size(), 1U); CHECK_EQ(outputs.size(), 2U); const BipartiteMatchingParam& param = nnvm::get(attrs.parsed); - Stream *s = ctx.get_stream(); - mxnet::TShape dshape = inputs[0].shape_; + Stream* s = ctx.get_stream(); + mxnet::TShape dshape = inputs[0].shape_; CHECK(mxnet::shape_is_known(dshape)) << "Found unknown shape in BipartiteMatchingForward, " << "received: " << dshape; if (dshape.Size() == 0) { return; // noop for unknown shape or empty array } - int row = dshape[dshape.ndim() - 2]; - int col = dshape[dshape.ndim() - 1]; + int row = dshape[dshape.ndim() - 2]; + int col = dshape[dshape.ndim() - 1]; int batch_size = dshape.Size() / row / col; MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { - Tensor scores = inputs[0] - .get_with_shape(Shape1(dshape.Size()), s); - Tensor row_marker = outputs[0] - .get_with_shape(Shape2(batch_size, row), s); - Tensor col_marker = outputs[1] - .get_with_shape(Shape2(batch_size, col), s); + Tensor scores = + inputs[0].get_with_shape(Shape1(dshape.Size()), s); + Tensor row_marker = + outputs[0].get_with_shape(Shape2(batch_size, row), s); + Tensor col_marker = + outputs[1].get_with_shape(Shape2(batch_size, col), s); Shape<1> sort_index_shape = Shape1(dshape.Size()); - index_t workspace_size = sort_index_shape.Size(); + index_t workspace_size = sort_index_shape.Size(); workspace_size += (sort_index_shape.Size() * 2 * sizeof(int32_t) - 1) / sizeof(DType) + 1; - Tensor workspace = ctx.requested[0] - .get_space_typed(Shape1(workspace_size), s); - Tensor scores_copy(workspace.dptr_, - sort_index_shape, s); - Tensor sorted_index(reinterpret_cast( - scores_copy.dptr_ + scores_copy.MSize()), sort_index_shape, s); - Tensor batch_id(sorted_index.dptr_ + sorted_index.MSize(), - sort_index_shape, s); + Tensor workspace = + ctx.requested[0].get_space_typed(Shape1(workspace_size), s); + Tensor scores_copy(workspace.dptr_, sort_index_shape, s); + Tensor sorted_index( + reinterpret_cast(scores_copy.dptr_ + scores_copy.MSize()), sort_index_shape, s); + Tensor batch_id( + sorted_index.dptr_ + sorted_index.MSize(), sort_index_shape, s); // sort according to score - scores_copy = F(scores); + scores_copy = F(scores); sorted_index = range(0, dshape.Size()); mxnet::op::SortByKey(scores_copy, sorted_index, param.is_ascend); batch_id = (sorted_index / ScalarExp(row * col)); @@ -767,24 +830,33 @@ void BipartiteMatchingForward(const nnvm::NodeAttrs& attrs, // bipartite matching, parallelization is limited to batch_size row_marker = -1; col_marker = -1; - Kernel::Launch(s, batch_size, row_marker.dptr_, - col_marker.dptr_, scores_copy.dptr_, sorted_index.dptr_, batch_size, row, col, - param.threshold, param.is_ascend, param.topk); + Kernel::Launch(s, + batch_size, + row_marker.dptr_, + col_marker.dptr_, + scores_copy.dptr_, + sorted_index.dptr_, + batch_size, + row, + col, + param.threshold, + param.is_ascend, + param.topk); }); } -template +template void BipartiteMatchingBackward(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { using namespace mshadow; using namespace mshadow::expr; using namespace mxnet_op; CHECK_EQ(inputs.size(), 2U); CHECK_EQ(outputs.size(), 1U); - Stream *s = ctx.get_stream(); + Stream* s = ctx.get_stream(); mxnet::TShape dshape = outputs[0].shape_; CHECK(mxnet::shape_is_known(dshape)) << "Found unknown shape in BipartiteMatchingBackward, " << "received: " << dshape; @@ -798,10 +870,9 @@ void BipartiteMatchingBackward(const nnvm::NodeAttrs& attrs, }); } - inline bool BoxEncodeShape(const nnvm::NodeAttrs& attrs, - mxnet::ShapeVector *in_attrs, - mxnet::ShapeVector *out_attrs) { + mxnet::ShapeVector* in_attrs, + mxnet::ShapeVector* out_attrs) { CHECK_EQ(in_attrs->size(), 6U); CHECK_EQ(out_attrs->size(), 2U); mxnet::TShape& sshape = (*in_attrs)[0]; @@ -809,29 +880,20 @@ inline bool BoxEncodeShape(const nnvm::NodeAttrs& attrs, mxnet::TShape& ashape = (*in_attrs)[2]; mxnet::TShape& rshape = (*in_attrs)[3]; - CHECK_EQ(sshape.ndim(), 2) - << "samples shape must have dim == 2, " - << sshape.ndim() << " provided"; + CHECK_EQ(sshape.ndim(), 2) << "samples shape must have dim == 2, " << sshape.ndim() + << " provided"; - CHECK_GE(mshape.ndim(), 2) - << "matches shape must have dim == 2, " - << mshape.ndim() << " provided"; + CHECK_GE(mshape.ndim(), 2) << "matches shape must have dim == 2, " << mshape.ndim() + << " provided"; - CHECK_GE(ashape.ndim(), 3) - << "matches shape must have dim == 3, " - << ashape.ndim() << " provided"; + CHECK_GE(ashape.ndim(), 3) << "matches shape must have dim == 3, " << ashape.ndim() + << " provided"; int ldim = ashape[ashape.ndim() - 1]; - CHECK_EQ(ldim, 4) - << "last dimension of anchors must be 4, " - << ldim << " provided"; + CHECK_EQ(ldim, 4) << "last dimension of anchors must be 4, " << ldim << " provided"; - CHECK_GE(rshape.ndim(), 3) - << "refs shape must have dim == 3, " - << ashape.ndim() << " provided"; + CHECK_GE(rshape.ndim(), 3) << "refs shape must have dim == 3, " << ashape.ndim() << " provided"; ldim = rshape[rshape.ndim() - 1]; - CHECK_EQ(ldim, 4) - << "last dimension of anchors must be 4, " - << ldim << " provided"; + CHECK_EQ(ldim, 4) << "last dimension of anchors must be 4, " << ldim << " provided"; // asign input shape SHAPE_ASSIGN_CHECK(*in_attrs, 4, mshadow::Shape1(4)); @@ -845,51 +907,61 @@ inline bool BoxEncodeShape(const nnvm::NodeAttrs& attrs, } struct box_encode { - template - MSHADOW_XINLINE static void Map(index_t i, DType *out_targets, DType *out_masks, - const DType *samples, const DType *matches, - const DType *anchors, const DType *refs, - const DType *means, const DType *stds, - const int m, const int n) { - index_t j = i / n; + template + MSHADOW_XINLINE static void Map(index_t i, + DType* out_targets, + DType* out_masks, + const DType* samples, + const DType* matches, + const DType* anchors, + const DType* refs, + const DType* means, + const DType* stds, + const int m, + const int n) { + index_t j = i / n; index_t match = matches[i]; // xmin: 0, ymin:1, xmax: 2, ymax: 3 // x:0, y:1, w:2, h:3 - index_t ref_index = (j * m + match) * 4; - DType ref_xmin = refs[ref_index + 0]; - DType ref_ymin = refs[ref_index + 1]; - DType ref_width = refs[ref_index + 2] - ref_xmin; - DType ref_height = refs[ref_index + 3] - ref_ymin; - DType ref_x = ref_xmin + ref_width * 0.5; - DType ref_y = ref_ymin + ref_height * 0.5; - index_t a_index = i * 4; - DType a_xmin = anchors[a_index + 0]; - DType a_ymin = anchors[a_index + 1]; - DType a_width = anchors[a_index + 2] - a_xmin; - DType a_height = anchors[a_index + 3] - a_ymin; - DType a_x = a_xmin + a_width * 0.5; - DType a_y = a_ymin + a_height * 0.5; - DType valid = samples[i] > 0.5 ? 1.0 : 0.0; + index_t ref_index = (j * m + match) * 4; + DType ref_xmin = refs[ref_index + 0]; + DType ref_ymin = refs[ref_index + 1]; + DType ref_width = refs[ref_index + 2] - ref_xmin; + DType ref_height = refs[ref_index + 3] - ref_ymin; + DType ref_x = ref_xmin + ref_width * 0.5; + DType ref_y = ref_ymin + ref_height * 0.5; + index_t a_index = i * 4; + DType a_xmin = anchors[a_index + 0]; + DType a_ymin = anchors[a_index + 1]; + DType a_width = anchors[a_index + 2] - a_xmin; + DType a_height = anchors[a_index + 3] - a_ymin; + DType a_x = a_xmin + a_width * 0.5; + DType a_y = a_ymin + a_height * 0.5; + DType valid = samples[i] > 0.5 ? 1.0 : 0.0; out_masks[a_index + 0] = valid; out_masks[a_index + 1] = valid; out_masks[a_index + 2] = valid; out_masks[a_index + 3] = valid; - out_targets[a_index + 0] = valid > static_cast(0.5) ? - ((ref_x - a_x) / a_width - static_cast(means[0])) / - static_cast(stds[0]) : static_cast(0.0); - out_targets[a_index + 1] = valid > static_cast(0.5) ? - ((ref_y - a_y) / a_height - static_cast(means[1])) / - static_cast(stds[1]) : static_cast(0.0); - out_targets[a_index + 2] = valid > static_cast(0.5) ? - (log(ref_width / a_width) - static_cast(means[2])) / - static_cast(stds[2]) : static_cast(0.0); - out_targets[a_index + 3] = valid > static_cast(0.5) ? - (log(ref_height / a_height) - static_cast(means[3])) / - static_cast(stds[3]) : static_cast(0.0); + out_targets[a_index + 0] = + valid > static_cast(0.5) + ? ((ref_x - a_x) / a_width - static_cast(means[0])) / static_cast(stds[0]) + : static_cast(0.0); + out_targets[a_index + 1] = valid > static_cast(0.5) + ? ((ref_y - a_y) / a_height - static_cast(means[1])) / + static_cast(stds[1]) + : static_cast(0.0); + out_targets[a_index + 2] = valid > static_cast(0.5) + ? (log(ref_width / a_width) - static_cast(means[2])) / + static_cast(stds[2]) + : static_cast(0.0); + out_targets[a_index + 3] = valid > static_cast(0.5) + ? (log(ref_height / a_height) - static_cast(means[3])) / + static_cast(stds[3]) + : static_cast(0.0); } }; -template +template void BoxEncodeForward(const nnvm::NodeAttrs& attrs, const OpContext& ctx, const std::vector& inputs, @@ -900,34 +972,36 @@ void BoxEncodeForward(const nnvm::NodeAttrs& attrs, using namespace mxnet_op; CHECK_EQ(inputs.size(), 6U); CHECK_EQ(outputs.size(), 2U); - Stream *s = ctx.get_stream(); + Stream* s = ctx.get_stream(); // samples, matches, anchors, refs, means, stds mxnet::TShape anchor_shape = inputs[2].shape_; - int loop_size = anchor_shape.ProdShape(0, 2); - int b = anchor_shape[0]; - int n = anchor_shape[1]; - int m = inputs[3].shape_[1]; + int loop_size = anchor_shape.ProdShape(0, 2); + int b = anchor_shape[0]; + int n = anchor_shape[1]; + int m = inputs[3].shape_[1]; MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { - Tensor samples = inputs[0] - .get_with_shape(Shape2(b, n), s); - Tensor matches = inputs[1] - .get_with_shape(Shape2(b, n), s); - Tensor anchors = inputs[2] - .get_with_shape(Shape3(b, n, 4), s); - Tensor refs = inputs[3] - .get_with_shape(Shape3(b, m, 4), s); - Tensor means = inputs[4] - .get_with_shape(Shape1(4), s); - Tensor stds = inputs[5] - .get_with_shape(Shape1(4), s); - Tensor out_targets = outputs[0] - .get_with_shape(Shape3(b, n, 4), s); - Tensor out_masks = outputs[1] - .get_with_shape(Shape3(b, n, 4), s); - - Kernel::Launch(s, loop_size, out_targets.dptr_, - out_masks.dptr_, samples.dptr_, matches.dptr_, anchors.dptr_, - refs.dptr_, means.dptr_, stds.dptr_, m, n); + Tensor samples = inputs[0].get_with_shape(Shape2(b, n), s); + Tensor matches = inputs[1].get_with_shape(Shape2(b, n), s); + Tensor anchors = inputs[2].get_with_shape(Shape3(b, n, 4), s); + Tensor refs = inputs[3].get_with_shape(Shape3(b, m, 4), s); + Tensor means = inputs[4].get_with_shape(Shape1(4), s); + Tensor stds = inputs[5].get_with_shape(Shape1(4), s); + Tensor out_targets = + outputs[0].get_with_shape(Shape3(b, n, 4), s); + Tensor out_masks = outputs[1].get_with_shape(Shape3(b, n, 4), s); + + Kernel::Launch(s, + loop_size, + out_targets.dptr_, + out_masks.dptr_, + samples.dptr_, + matches.dptr_, + anchors.dptr_, + refs.dptr_, + means.dptr_, + stds.dptr_, + m, + n); }); } @@ -939,48 +1013,43 @@ struct BoxDecodeParam : public dmlc::Parameter { float clip; int format; DMLC_DECLARE_PARAMETER(BoxDecodeParam) { - DMLC_DECLARE_FIELD(std0).set_default(1.0) - .describe("value to be divided from the 1st encoded values"); - DMLC_DECLARE_FIELD(std1).set_default(1.0) - .describe("value to be divided from the 2nd encoded values"); - DMLC_DECLARE_FIELD(std2).set_default(1.0) - .describe("value to be divided from the 3rd encoded values"); - DMLC_DECLARE_FIELD(std3).set_default(1.0) - .describe("value to be divided from the 4th encoded values"); - DMLC_DECLARE_FIELD(clip).set_default(-1.0) - .describe("If larger than 0, bounding box target will be clipped to this value."); - DMLC_DECLARE_FIELD(format).set_default(box_common_enum::kCenter) - .add_enum("corner", box_common_enum::kCorner) - .add_enum("center", box_common_enum::kCenter) - .describe("The box encoding type. \n" - " \"corner\" means boxes are encoded as [xmin, ymin, xmax, ymax]," - " \"center\" means boxes are encodes as [x, y, width, height]."); + DMLC_DECLARE_FIELD(std0).set_default(1.0).describe( + "value to be divided from the 1st encoded values"); + DMLC_DECLARE_FIELD(std1).set_default(1.0).describe( + "value to be divided from the 2nd encoded values"); + DMLC_DECLARE_FIELD(std2).set_default(1.0).describe( + "value to be divided from the 3rd encoded values"); + DMLC_DECLARE_FIELD(std3).set_default(1.0).describe( + "value to be divided from the 4th encoded values"); + DMLC_DECLARE_FIELD(clip).set_default(-1.0).describe( + "If larger than 0, bounding box target will be clipped to this value."); + DMLC_DECLARE_FIELD(format) + .set_default(box_common_enum::kCenter) + .add_enum("corner", box_common_enum::kCorner) + .add_enum("center", box_common_enum::kCenter) + .describe( + "The box encoding type. \n" + " \"corner\" means boxes are encoded as [xmin, ymin, xmax, ymax]," + " \"center\" means boxes are encodes as [x, y, width, height]."); } }; // BoxDecodeParam inline bool BoxDecodeShape(const nnvm::NodeAttrs& attrs, - mxnet::ShapeVector *in_attrs, - mxnet::ShapeVector *out_attrs) { + mxnet::ShapeVector* in_attrs, + mxnet::ShapeVector* out_attrs) { CHECK_EQ(in_attrs->size(), 2U); CHECK_EQ(out_attrs->size(), 1U); mxnet::TShape& dshape = (*in_attrs)[0]; mxnet::TShape& ashape = (*in_attrs)[1]; - CHECK_EQ(dshape.ndim(), 3) - << "data shape must have dim == 3, " - << dshape.ndim() << " provided"; + CHECK_EQ(dshape.ndim(), 3) << "data shape must have dim == 3, " << dshape.ndim() << " provided"; int ldim = dshape[dshape.ndim() - 1]; - CHECK_EQ(ldim, 4) - << "last dimension of data must be 4, " - << ldim << " provided"; + CHECK_EQ(ldim, 4) << "last dimension of data must be 4, " << ldim << " provided"; - CHECK_GE(ashape.ndim(), 3) - << "anchors shape must have dim == 3, " - << ashape.ndim() << " provided"; + CHECK_GE(ashape.ndim(), 3) << "anchors shape must have dim == 3, " << ashape.ndim() + << " provided"; ldim = ashape[ashape.ndim() - 1]; - CHECK_EQ(ldim, 4) - << "last dimension of anchors must be 4, " - << ldim << " provided"; + CHECK_EQ(ldim, 4) << "last dimension of anchors must be 4, " << ldim << " provided"; // assign output shape mxnet::TShape oshape = dshape; @@ -988,39 +1057,44 @@ inline bool BoxDecodeShape(const nnvm::NodeAttrs& attrs, return shape_is_known(oshape); } -template +template struct box_decode { - template - MSHADOW_XINLINE static void Map(index_t i, DType *out, const DType *x, - const DType *anchors, const DType std0, - const DType std1, const DType std2, - const DType std3, const DType clip, + template + MSHADOW_XINLINE static void Map(index_t i, + DType* out, + const DType* x, + const DType* anchors, + const DType std0, + const DType std1, + const DType std2, + const DType std3, + const DType clip, const int n) { - index_t index = i * 4; + index_t index = i * 4; index_t a_index = (i % n) * 4; - DType a_x = anchors[a_index + 0]; - DType a_y = anchors[a_index + 1]; - DType a_width = anchors[a_index + 2]; - DType a_height = anchors[a_index + 3]; + DType a_x = anchors[a_index + 0]; + DType a_y = anchors[a_index + 1]; + DType a_width = anchors[a_index + 2]; + DType a_height = anchors[a_index + 3]; if (box_common_enum::kCorner == anchor_encode) { // a_x = xmin, a_y = ymin, a_width = xmax, a_height = ymax - a_width = a_width - a_x; + a_width = a_width - a_x; a_height = a_height - a_y; - a_x = a_x + a_width * 0.5; - a_y = a_y + a_height * 0.5; + a_x = a_x + a_width * 0.5; + a_y = a_y + a_height * 0.5; } DType ox = x[index + 0] * std0 * a_width + a_x; DType oy = x[index + 1] * std1 * a_height + a_y; DType dw = x[index + 2] * std2; DType dh = x[index + 3] * std3; if (has_clip) { - dw = dw < clip ? dw : clip; - dh = dh < clip ? dh : clip; + dw = dw < clip ? dw : clip; + dh = dh < clip ? dh : clip; } - dw = exp(dw); - dh = exp(dh); - DType ow = dw * a_width * 0.5; - DType oh = dh * a_height * 0.5; + dw = exp(dw); + dh = exp(dh); + DType ow = dw * a_width * 0.5; + DType oh = dh * a_height * 0.5; out[index + 0] = ox - ow; out[index + 1] = oy - oh; out[index + 2] = ox + ow; @@ -1028,7 +1102,7 @@ struct box_decode { } }; -template +template void BoxDecodeForward(const nnvm::NodeAttrs& attrs, const OpContext& ctx, const std::vector& inputs, @@ -1039,39 +1113,68 @@ void BoxDecodeForward(const nnvm::NodeAttrs& attrs, using namespace mxnet_op; CHECK_EQ(inputs.size(), 2U); CHECK_EQ(outputs.size(), 1U); - Stream *s = ctx.get_stream(); - mxnet::TShape x_shape = inputs[0].shape_; - int b = x_shape[0]; - int n = x_shape[1]; - int loop_size = b * n; + Stream* s = ctx.get_stream(); + mxnet::TShape x_shape = inputs[0].shape_; + int b = x_shape[0]; + int n = x_shape[1]; + int loop_size = b * n; const BoxDecodeParam& param = nnvm::get(attrs.parsed); MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { - Tensor data = inputs[0] - .get_with_shape(Shape3(b, n, 4), s); - Tensor anchors = inputs[1] - .get_with_shape(Shape3(1, n, 4), s); - Tensor out = outputs[0] - .get_with_shape(Shape3(b, n, 4), s); + Tensor data = inputs[0].get_with_shape(Shape3(b, n, 4), s); + Tensor anchors = inputs[1].get_with_shape(Shape3(1, n, 4), s); + Tensor out = outputs[0].get_with_shape(Shape3(b, n, 4), s); if (box_common_enum::kCorner == param.format && param.clip > 0.0) { - Kernel, xpu>::Launch(s, loop_size, - out.dptr_, data.dptr_, anchors.dptr_, static_cast(param.std0), - static_cast(param.std1), static_cast(param.std2), - static_cast(param.std3), static_cast(param.clip), n); + Kernel, xpu>::Launch( + s, + loop_size, + out.dptr_, + data.dptr_, + anchors.dptr_, + static_cast(param.std0), + static_cast(param.std1), + static_cast(param.std2), + static_cast(param.std3), + static_cast(param.clip), + n); } else if (box_common_enum::kCenter == param.format && param.clip > 0.0) { - Kernel, xpu>::Launch(s, loop_size, - out.dptr_, data.dptr_, anchors.dptr_, static_cast(param.std0), - static_cast(param.std1), static_cast(param.std2), - static_cast(param.std3), static_cast(param.clip), n); + Kernel, xpu>::Launch( + s, + loop_size, + out.dptr_, + data.dptr_, + anchors.dptr_, + static_cast(param.std0), + static_cast(param.std1), + static_cast(param.std2), + static_cast(param.std3), + static_cast(param.clip), + n); } else if (box_common_enum::kCorner == param.format && param.clip <= 0.0) { - Kernel, xpu>::Launch(s, loop_size, - out.dptr_, data.dptr_, anchors.dptr_, static_cast(param.std0), - static_cast(param.std1), static_cast(param.std2), - static_cast(param.std3), static_cast(param.clip), n); + Kernel, xpu>::Launch( + s, + loop_size, + out.dptr_, + data.dptr_, + anchors.dptr_, + static_cast(param.std0), + static_cast(param.std1), + static_cast(param.std2), + static_cast(param.std3), + static_cast(param.clip), + n); } else { - Kernel, xpu>::Launch(s, loop_size, - out.dptr_, data.dptr_, anchors.dptr_, static_cast(param.std0), - static_cast(param.std1), static_cast(param.std2), - static_cast(param.std3), static_cast(param.clip), n); + Kernel, xpu>::Launch( + s, + loop_size, + out.dptr_, + data.dptr_, + anchors.dptr_, + static_cast(param.std0), + static_cast(param.std1), + static_cast(param.std2), + static_cast(param.std3), + static_cast(param.clip), + n); } }); } diff --git a/src/operator/contrib/bounding_box.cc b/src/operator/contrib/bounding_box.cc index d4a599275e88..0d6a29b98c43 100644 --- a/src/operator/contrib/bounding_box.cc +++ b/src/operator/contrib/bounding_box.cc @@ -17,12 +17,12 @@ * under the License. */ - /*! - * Copyright (c) 2017 by Contributors - * \file bounding_box.cc - * \brief Bounding box util functions and operators - * \author Joshua Zhang - */ +/*! + * Copyright (c) 2017 by Contributors + * \file bounding_box.cc + * \brief Bounding box util functions and operators + * \author Joshua Zhang + */ #include "./bounding_box-inl.h" #include "../elemwise_op_common.h" @@ -34,11 +34,10 @@ DMLC_REGISTER_PARAMETER(BoxOverlapParam); DMLC_REGISTER_PARAMETER(BipartiteMatchingParam); DMLC_REGISTER_PARAMETER(BoxDecodeParam); - NNVM_REGISTER_OP(_contrib_box_nms) -.add_alias("_contrib_box_non_maximum_suppression") -.add_alias("_npx_box_nms") -.describe(R"code(Apply non-maximum suppression to input. + .add_alias("_contrib_box_non_maximum_suppression") + .add_alias("_npx_box_nms") + .describe(R"code(Apply non-maximum suppression to input. The output will be sorted in descending order according to ``score``. Boxes with overlaps larger than ``overlap_thresh``, smaller scores and background boxes @@ -94,33 +93,33 @@ Examples:: [0, 0, 0, 0, 0, 0], [0.1, 0.1, 0.1, 0.1, 0.1, 0.1]] )code" ADD_FILELINE) -.set_num_inputs(1) -.set_num_outputs(2) -.set_attr_parser(ParamParser) -.set_attr("FNumVisibleOutputs", BoxNMSNumVisibleOutputs) -.set_attr("FInferShape", BoxNMSShape) -.set_attr("FInferType", ElemwiseType<1, 2>) -.set_attr("FResourceRequest", - [](const NodeAttrs& attrs) { - return std::vector{ResourceRequest::kTempSpace}; - }) -.set_attr("THasDeterministicOutput", true) -.set_attr("FCompute", BoxNMSForward) -.set_attr("FGradient", ElemwiseGradUseOut{"_backward_contrib_box_nms"}) -.add_argument("data", "NDArray-or-Symbol", "The input") -.add_arguments(BoxNMSParam::__FIELDS__()); + .set_num_inputs(1) + .set_num_outputs(2) + .set_attr_parser(ParamParser) + .set_attr("FNumVisibleOutputs", BoxNMSNumVisibleOutputs) + .set_attr("FInferShape", BoxNMSShape) + .set_attr("FInferType", ElemwiseType<1, 2>) + .set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) + .set_attr("THasDeterministicOutput", true) + .set_attr("FCompute", BoxNMSForward) + .set_attr("FGradient", ElemwiseGradUseOut{"_backward_contrib_box_nms"}) + .add_argument("data", "NDArray-or-Symbol", "The input") + .add_arguments(BoxNMSParam::__FIELDS__()); NNVM_REGISTER_OP(_backward_contrib_box_nms) -.set_num_inputs(4) -.set_num_outputs(1) -.set_attr_parser(ParamParser) -.set_attr("TIsBackward", true) -.set_attr("FCompute", BoxNMSBackward) -.add_arguments(BoxNMSParam::__FIELDS__()); + .set_num_inputs(4) + .set_num_outputs(1) + .set_attr_parser(ParamParser) + .set_attr("TIsBackward", true) + .set_attr("FCompute", BoxNMSBackward) + .add_arguments(BoxNMSParam::__FIELDS__()); NNVM_REGISTER_OP(_contrib_box_iou) -.add_alias("_npx_box_iou") -.describe(R"doc(Bounding box overlap of two arrays. + .add_alias("_npx_box_iou") + .describe(R"doc(Bounding box overlap of two arrays. The overlap is defined as Intersection-over-Union, aka, IOU. - lhs: (a_1, a_2, ..., a_n, 4) array - rhs: (b_1, b_2, ..., b_n, 4) array @@ -137,32 +136,32 @@ NNVM_REGISTER_OP(_contrib_box_iou) box_iou(x, y, format='corner') = [[0.1428], [0.1428]] )doc" ADD_FILELINE) -.set_num_inputs(2) -.set_num_outputs(1) -.set_attr_parser(ParamParser) -.set_attr("FListInputNames", - [](const NodeAttrs& attrs) { - return std::vector{"lhs", "rhs"}; - }) -.set_attr("FInferShape", BoxOverlapShape) -.set_attr("FInferType", ElemwiseType<2, 1>) -.set_attr("FCompute", BoxOverlapForward) -.set_attr("FGradient", ElemwiseGradUseNone{"_backward_contrib_box_iou"}) -.add_argument("lhs", "NDArray-or-Symbol", "The first input") -.add_argument("rhs", "NDArray-or-Symbol", "The second input") -.add_arguments(BoxOverlapParam::__FIELDS__()); + .set_num_inputs(2) + .set_num_outputs(1) + .set_attr_parser(ParamParser) + .set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"lhs", "rhs"}; + }) + .set_attr("FInferShape", BoxOverlapShape) + .set_attr("FInferType", ElemwiseType<2, 1>) + .set_attr("FCompute", BoxOverlapForward) + .set_attr("FGradient", ElemwiseGradUseNone{"_backward_contrib_box_iou"}) + .add_argument("lhs", "NDArray-or-Symbol", "The first input") + .add_argument("rhs", "NDArray-or-Symbol", "The second input") + .add_arguments(BoxOverlapParam::__FIELDS__()); NNVM_REGISTER_OP(_backward_contrib_box_iou) -.set_num_inputs(1) -.set_num_outputs(2) -.set_attr_parser(ParamParser) -.set_attr("TIsBackward", true) -.set_attr("FCompute", BoxOverlapBackward) -.add_arguments(BoxOverlapParam::__FIELDS__()); + .set_num_inputs(1) + .set_num_outputs(2) + .set_attr_parser(ParamParser) + .set_attr("TIsBackward", true) + .set_attr("FCompute", BoxOverlapBackward) + .add_arguments(BoxOverlapParam::__FIELDS__()); NNVM_REGISTER_OP(_contrib_bipartite_matching) -.add_alias("_npx_bipartite_matching") -.describe(R"doc(Compute bipartite matching. + .add_alias("_npx_bipartite_matching") + .describe(R"doc(Compute bipartite matching. The matching is performed on score matrix with shape [B, N, M] - B: batch_size - N: number of rows to match @@ -184,73 +183,78 @@ NNVM_REGISTER_OP(_contrib_bipartite_matching) y = [2, 0] )doc" ADD_FILELINE) -.set_num_inputs(1) -.set_num_outputs(2) -.set_attr_parser(ParamParser) -.set_attr("FResourceRequest", - [](const NodeAttrs& attrs) { - return std::vector{ResourceRequest::kTempSpace}; - }) -.set_attr("THasDeterministicOutput", true) -.set_attr("FInferShape", MatchingShape) -.set_attr("FInferType", ElemwiseType<1, 2>) -.set_attr("FCompute", BipartiteMatchingForward) -.set_attr("FGradient", - ElemwiseGradUseNone{"_backward_contrib_bipartite_matching"}) -.add_argument("data", "NDArray-or-Symbol", "The input") -.add_arguments(BipartiteMatchingParam::__FIELDS__()); + .set_num_inputs(1) + .set_num_outputs(2) + .set_attr_parser(ParamParser) + .set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) + .set_attr("THasDeterministicOutput", true) + .set_attr("FInferShape", MatchingShape) + .set_attr("FInferType", ElemwiseType<1, 2>) + .set_attr("FCompute", BipartiteMatchingForward) + .set_attr("FGradient", + ElemwiseGradUseNone{"_backward_contrib_bipartite_matching"}) + .add_argument("data", "NDArray-or-Symbol", "The input") + .add_arguments(BipartiteMatchingParam::__FIELDS__()); NNVM_REGISTER_OP(_backward_contrib_bipartite_matching) -.set_num_inputs(2) -.set_num_outputs(1) -.set_attr_parser(ParamParser) -.set_attr("TIsBackward", true) -.set_attr("FCompute", BipartiteMatchingBackward) -.add_arguments(BipartiteMatchingParam::__FIELDS__()); + .set_num_inputs(2) + .set_num_outputs(1) + .set_attr_parser(ParamParser) + .set_attr("TIsBackward", true) + .set_attr("FCompute", BipartiteMatchingBackward) + .add_arguments(BipartiteMatchingParam::__FIELDS__()); NNVM_REGISTER_OP(_contrib_box_encode) -.add_alias("_npx_box_encode") -.describe(R"doc(Encode bounding boxes training target with normalized center offsets. + .add_alias("_npx_box_encode") + .describe(R"doc(Encode bounding boxes training target with normalized center offsets. Input bounding boxes are using corner type: `x_{min}, y_{min}, x_{max}, y_{max}`.) array )doc" ADD_FILELINE) -.set_num_inputs(6) -.set_num_outputs(2) -.set_attr("FListInputNames", - [](const NodeAttrs& attrs) { - return std::vector{"samples", "matches", "anchors", "refs", "means", "stds"}; - }) -.set_attr("FInferShape", BoxEncodeShape) -.set_attr("FInferType", ElemwiseType<6, 2>) -.set_attr("FCompute", BoxEncodeForward) -.set_attr("FGradient", MakeZeroGradNodes) -.add_argument("samples", "NDArray-or-Symbol", "(B, N) value +1 (positive), -1 (negative), " - "0 (ignore)") -.add_argument("matches", "NDArray-or-Symbol", "(B, N) value range [0, M)") -.add_argument("anchors", "NDArray-or-Symbol", "(B, N, 4) encoded in corner") -.add_argument("refs", "NDArray-or-Symbol", "(B, M, 4) encoded in corner") -.add_argument("means", "NDArray-or-Symbol", "(4,) Mean value to be subtracted from encoded values") -.add_argument("stds", "NDArray-or-Symbol", "(4,) Std value to be divided from encoded values"); + .set_num_inputs(6) + .set_num_outputs(2) + .set_attr( + "FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"samples", "matches", "anchors", "refs", "means", "stds"}; + }) + .set_attr("FInferShape", BoxEncodeShape) + .set_attr("FInferType", ElemwiseType<6, 2>) + .set_attr("FCompute", BoxEncodeForward) + .set_attr("FGradient", MakeZeroGradNodes) + .add_argument("samples", + "NDArray-or-Symbol", + "(B, N) value +1 (positive), -1 (negative), " + "0 (ignore)") + .add_argument("matches", "NDArray-or-Symbol", "(B, N) value range [0, M)") + .add_argument("anchors", "NDArray-or-Symbol", "(B, N, 4) encoded in corner") + .add_argument("refs", "NDArray-or-Symbol", "(B, M, 4) encoded in corner") + .add_argument("means", + "NDArray-or-Symbol", + "(4,) Mean value to be subtracted from encoded values") + .add_argument("stds", "NDArray-or-Symbol", "(4,) Std value to be divided from encoded values"); NNVM_REGISTER_OP(_contrib_box_decode) -.add_alias("_npx_box_decode") -.describe(R"doc(Decode bounding boxes training target with normalized center offsets. + .add_alias("_npx_box_decode") + .describe(R"doc(Decode bounding boxes training target with normalized center offsets. Input bounding boxes are using corner type: ``x_{min}, y_{min}, x_{max}, y_{max}`` or center type: ``x, y, width, height``.) array )doc" ADD_FILELINE) -.set_num_inputs(2) -.set_num_outputs(1) -.set_attr_parser(ParamParser) -.set_attr("FListInputNames", - [](const NodeAttrs& attrs) { - return std::vector{"data", "anchors"}; - }) -.set_attr("FInferShape", BoxDecodeShape) -.set_attr("FInferType", ElemwiseType<2, 1>) -.set_attr("FCompute", BoxDecodeForward) -.set_attr("FGradient", MakeZeroGradNodes) -.add_argument("data", "NDArray-or-Symbol", "(B, N, 4) predicted bbox offset") -.add_argument("anchors", "NDArray-or-Symbol", "(1, N, 4) encoded in corner or center") -.add_arguments(BoxDecodeParam::__FIELDS__()); + .set_num_inputs(2) + .set_num_outputs(1) + .set_attr_parser(ParamParser) + .set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"data", "anchors"}; + }) + .set_attr("FInferShape", BoxDecodeShape) + .set_attr("FInferType", ElemwiseType<2, 1>) + .set_attr("FCompute", BoxDecodeForward) + .set_attr("FGradient", MakeZeroGradNodes) + .add_argument("data", "NDArray-or-Symbol", "(B, N, 4) predicted bbox offset") + .add_argument("anchors", "NDArray-or-Symbol", "(1, N, 4) encoded in corner or center") + .add_arguments(BoxDecodeParam::__FIELDS__()); } // namespace op } // namespace mxnet diff --git a/src/operator/contrib/bounding_box.cu b/src/operator/contrib/bounding_box.cu index 7915273f64e8..f87c2dccc4da 100644 --- a/src/operator/contrib/bounding_box.cu +++ b/src/operator/contrib/bounding_box.cu @@ -17,12 +17,12 @@ * under the License. */ - /*! - * Copyright (c) 2017 by Contributors - * \file bounding_box.cu - * \brief Bounding box util functions and operators - * \author Joshua Zhang - */ +/*! + * Copyright (c) 2017 by Contributors + * \file bounding_box.cu + * \brief Bounding box util functions and operators + * \author Joshua Zhang + */ #include @@ -35,8 +35,8 @@ namespace op { namespace { -using mshadow::Tensor; using mshadow::Stream; +using mshadow::Tensor; template struct TempWorkspace { @@ -57,26 +57,29 @@ inline size_t ceil_div(size_t x, size_t y) { } inline size_t align(size_t x, size_t alignment) { - return ceil_div(x, alignment) * alignment; + return ceil_div(x, alignment) * alignment; } template -__global__ void FilterAndPrepareAuxDataKernel(const DType* data, DType* out, DType* scores, - index_t num_elements_per_batch, - const index_t element_width, - const index_t N, - const float threshold, - const int id_index, const int score_index, - const int background_id) { - index_t tid = blockIdx.x * blockDim.x + threadIdx.x; - bool first_in_element = (tid % element_width == 0); +__global__ void FilterAndPrepareAuxDataKernel(const DType* data, + DType* out, + DType* scores, + index_t num_elements_per_batch, + const index_t element_width, + const index_t N, + const float threshold, + const int id_index, + const int score_index, + const int background_id) { + index_t tid = blockIdx.x * blockDim.x + threadIdx.x; + bool first_in_element = (tid % element_width == 0); index_t start_of_my_element = tid - (tid % element_width); if (tid < N) { - DType my_score = data[start_of_my_element + score_index]; + DType my_score = data[start_of_my_element + score_index]; bool filtered_out = my_score <= threshold; if (id_index != -1 && background_id != -1) { - DType my_id = data[start_of_my_element + id_index]; + DType my_id = data[start_of_my_element + id_index]; filtered_out = filtered_out || (my_id == background_id); } if (!filtered_out) { @@ -100,35 +103,40 @@ void FilterAndPrepareAuxData(const Tensor& data, const BoxNMSParam& param, Stream* s) { const int n_threads = 512; - index_t N = data.shape_.Size(); - const auto blocks = ceil_div(N, n_threads); - FilterAndPrepareAuxDataKernel<<::GetStream(s)>>>( - data.dptr_, out->dptr_, workspace.scores, - data.shape_[1], data.shape_[2], N, - param.valid_thresh, param.id_index, - param.score_index, param.background_id); + index_t N = data.shape_.Size(); + const auto blocks = ceil_div(N, n_threads); + FilterAndPrepareAuxDataKernel<<::GetStream(s)>>>( + data.dptr_, + out->dptr_, + workspace.scores, + data.shape_[1], + data.shape_[2], + N, + param.valid_thresh, + param.id_index, + param.score_index, + param.background_id); } template -__global__ void CompactDataKernel(const index_t* indices, const DType* source, - DType* destination, const index_t topk, - const index_t element_width, - const index_t num_elements_per_batch, - const int score_index, - const index_t N) { +__global__ void CompactDataKernel(const index_t* indices, + const DType* source, + DType* destination, + const index_t topk, + const index_t element_width, + const index_t num_elements_per_batch, + const int score_index, + const index_t N) { const index_t tid_start = blockIdx.x * blockDim.x + threadIdx.x; for (index_t tid = tid_start; tid < N; tid += blockDim.x * gridDim.x) { - const index_t my_element = tid / element_width; + const index_t my_element = tid / element_width; const index_t my_element_in_batch = my_element % num_elements_per_batch; if (check_topk && my_element_in_batch >= topk) { destination[tid] = -1; } else { DType ret; const index_t source_element = indices[my_element]; - DType score = 0; + DType score = 0; if (check_score) { score = source[source_element * element_width + score_index]; } @@ -149,24 +157,30 @@ void CompactData(const Tensor& indices, const index_t topk, const int score_index, Stream* s) { - const int n_threads = 512; + const int n_threads = 512; const size_t max_blocks = 320; - index_t N = source.shape_.Size(); - const auto blocks = std::min(ceil_div(N, n_threads), max_blocks); + index_t N = source.shape_.Size(); + const auto blocks = std::min(ceil_div(N, n_threads), max_blocks); if (topk > 0) { - CompactDataKernel<<::GetStream(s)>>>( - indices.dptr_, source.dptr_, - destination->dptr_, topk, - source.shape_[2], source.shape_[1], - score_index, N); + CompactDataKernel + <<::GetStream(s)>>>(indices.dptr_, + source.dptr_, + destination->dptr_, + topk, + source.shape_[2], + source.shape_[1], + score_index, + N); } else { - CompactDataKernel<<::GetStream(s)>>>( - indices.dptr_, source.dptr_, - destination->dptr_, topk, - source.shape_[2], source.shape_[1], - score_index, N); + CompactDataKernel + <<::GetStream(s)>>>(indices.dptr_, + source.dptr_, + destination->dptr_, + topk, + source.shape_[2], + source.shape_[1], + score_index, + N); } } @@ -176,48 +190,49 @@ void WorkspaceForSort(const index_t num_elem, const int alignment, TempWorkspace* workspace) { const size_t sort_scores_temp_space = - mxnet::op::SortByKeyWorkspaceSize(num_elem, false, false); + mxnet::op::SortByKeyWorkspaceSize(num_elem, false, false); const size_t sort_topk_scores_temp_space = - mxnet::op::SortByKeyWorkspaceSize(topk, false, false); - workspace->scratch_space = align(std::max(sort_scores_temp_space, sort_topk_scores_temp_space), - alignment); + mxnet::op::SortByKeyWorkspaceSize(topk, false, false); + workspace->scratch_space = + align(std::max(sort_scores_temp_space, sort_topk_scores_temp_space), alignment); } template -__global__ void CalculateGreedyNMSResultsKernel(const DType* data, uint32_t* result, - const index_t current_start, - const index_t num_elems, - const index_t num_batches, - const index_t num_blocks_per_row_batch, - const index_t num_blocks_per_row, - const index_t topk, - const index_t element_width, - const index_t num_elements_per_batch, - const int coord_index, - const int class_index, - const int score_index, - const float threshold); +__global__ void CalculateGreedyNMSResultsKernel(const DType* data, + uint32_t* result, + const index_t current_start, + const index_t num_elems, + const index_t num_batches, + const index_t num_blocks_per_row_batch, + const index_t num_blocks_per_row, + const index_t topk, + const index_t element_width, + const index_t num_elements_per_batch, + const int coord_index, + const int class_index, + const int score_index, + const float threshold); template __global__ void ReduceNMSResultTriangleKernel(uint32_t* nms_results, - DType * data, - const index_t score_index, - const index_t element_width, - const index_t num_batches, - const index_t num_elems, - const index_t start_index, - const index_t topk); + DType* data, + const index_t score_index, + const index_t element_width, + const index_t num_batches, + const index_t num_elems, + const index_t start_index, + const index_t topk); template __global__ void ReduceNMSResultRestKernel(DType* data, - const uint32_t* nms_results, - const index_t score_index, - const index_t element_width, - const index_t num_batches, - const index_t num_elements_per_batch, - const index_t start_index, - const index_t topk, - const index_t num_blocks_per_batch); + const uint32_t* nms_results, + const index_t score_index, + const index_t element_width, + const index_t num_batches, + const index_t num_elements_per_batch, + const index_t start_index, + const index_t topk, + const index_t num_blocks_per_batch); template struct NMS { @@ -228,43 +243,72 @@ struct NMS { const index_t topk, const BoxNMSParam& param, Stream* s) { - const int n_threads = 512; - const index_t num_batches = data->shape_[0]; + const int n_threads = 512; + const index_t num_batches = data->shape_[0]; const index_t num_elements_per_batch = data->shape_[1]; - const index_t element_width = data->shape_[2]; + const index_t element_width = data->shape_[2]; for (index_t current_start = 0; current_start < topk; current_start += THRESHOLD) { - const index_t n_elems = topk - current_start; + const index_t n_elems = topk - current_start; const index_t num_blocks_per_row_batch = ceil_div(n_elems, n_threads); - const index_t num_blocks_per_row = num_blocks_per_row_batch * num_batches; + const index_t num_blocks_per_row = num_blocks_per_row_batch * num_batches; const index_t n_blocks = THRESHOLD / (sizeof(uint32_t) * 8) * num_blocks_per_row; if (param.in_format == box_common_enum::kCorner) { CalculateGreedyNMSResultsKernel - <<::GetStream(s)>>>( - data->dptr_, scratch->dptr_, current_start, n_elems, num_batches, - num_blocks_per_row_batch, num_blocks_per_row, topk, element_width, - num_elements_per_batch, param.coord_start, - param.force_suppress ? -1 : param.id_index, - param.score_index, param.overlap_thresh); + <<::GetStream(s)>>>( + data->dptr_, + scratch->dptr_, + current_start, + n_elems, + num_batches, + num_blocks_per_row_batch, + num_blocks_per_row, + topk, + element_width, + num_elements_per_batch, + param.coord_start, + param.force_suppress ? -1 : param.id_index, + param.score_index, + param.overlap_thresh); } else { CalculateGreedyNMSResultsKernel - <<::GetStream(s)>>>( - data->dptr_, scratch->dptr_, current_start, n_elems, num_batches, - num_blocks_per_row_batch, num_blocks_per_row, topk, element_width, - num_elements_per_batch, param.coord_start, - param.force_suppress ? -1 : param.id_index, - param.score_index, param.overlap_thresh); + <<::GetStream(s)>>>( + data->dptr_, + scratch->dptr_, + current_start, + n_elems, + num_batches, + num_blocks_per_row_batch, + num_blocks_per_row, + topk, + element_width, + num_elements_per_batch, + param.coord_start, + param.force_suppress ? -1 : param.id_index, + param.score_index, + param.overlap_thresh); } ReduceNMSResultTriangleKernel<<::GetStream(s)>>>( - scratch->dptr_, data->dptr_, param.score_index, - element_width, num_batches, num_elements_per_batch, - current_start, topk); - const index_t n_rest_elems = n_elems - THRESHOLD; + scratch->dptr_, + data->dptr_, + param.score_index, + element_width, + num_batches, + num_elements_per_batch, + current_start, + topk); + const index_t n_rest_elems = n_elems - THRESHOLD; const index_t num_rest_blocks_per_batch = ceil_div(n_rest_elems, n_threads); - const index_t num_rest_blocks = num_rest_blocks_per_batch * num_batches; + const index_t num_rest_blocks = num_rest_blocks_per_batch * num_batches; if (n_rest_elems > 0) { ReduceNMSResultRestKernel<<::GetStream(s)>>>( - data->dptr_, scratch->dptr_, param.score_index, element_width, - num_batches, num_elements_per_batch, current_start, topk, + data->dptr_, + scratch->dptr_, + param.score_index, + element_width, + num_batches, + num_elements_per_batch, + current_start, + topk, num_rest_blocks_per_batch); } } @@ -272,47 +316,52 @@ struct NMS { }; template -__device__ __forceinline__ DType calculate_area(const DType b0, const DType b1, - const DType b2, const DType b3) { - DType width = b2; +__device__ __forceinline__ DType +calculate_area(const DType b0, const DType b1, const DType b2, const DType b3) { + DType width = b2; DType height = b3; if (encode == box_common_enum::kCorner) { width -= b0; height -= b1; } - if (width < 0 || height < 0) return 0; + if (width < 0 || height < 0) + return 0; return width * height; } template -__device__ __forceinline__ DType calculate_intersection(const DType a0, const DType a1, - const DType a2, const DType a3, - const DType b0, const DType b1, - const DType b2, const DType b3) { +__device__ __forceinline__ DType calculate_intersection(const DType a0, + const DType a1, + const DType a2, + const DType a3, + const DType b0, + const DType b1, + const DType b2, + const DType b3) { DType wx, wy; if (encode == box_common_enum::kCorner) { - const DType left = a0 > b0 ? a0 : b0; + const DType left = a0 > b0 ? a0 : b0; const DType bottom = a1 > b1 ? a1 : b1; - const DType right = a2 < b2 ? a2 : b2; - const DType top = a3 < b3 ? a3 : b3; - wx = right - left; - wy = top - bottom; + const DType right = a2 < b2 ? a2 : b2; + const DType top = a3 < b3 ? a3 : b3; + wx = right - left; + wy = top - bottom; } else { - const DType al = 2 * a0 - a2; - const DType ar = 2 * a0 + a2; - const DType bl = 2 * b0 - b2; - const DType br = 2 * b0 + b2; - const DType left = bl > al ? bl : al; - const DType right = br < ar ? br : ar; - wx = right - left; - const DType ab = 2 * a1 - a3; - const DType at = 2 * a1 + a3; - const DType bb = 2 * b1 - b3; - const DType bt = 2 * b1 + b3; + const DType al = 2 * a0 - a2; + const DType ar = 2 * a0 + a2; + const DType bl = 2 * b0 - b2; + const DType br = 2 * b0 + b2; + const DType left = bl > al ? bl : al; + const DType right = br < ar ? br : ar; + wx = right - left; + const DType ab = 2 * a1 - a3; + const DType at = 2 * a1 + a3; + const DType bb = 2 * b1 - b3; + const DType bt = 2 * b1 + b3; const DType bottom = bb > ab ? bb : ab; - const DType top = bt < at ? bt : at; - wy = top - bottom; - wy = wy / 4; // To compensate for both wx and wy being 2x too large + const DType top = bt < at ? bt : at; + wy = top - bottom; + wy = wy / 4; // To compensate for both wx and wy being 2x too large } if (wx <= 0 || wy <= 0) { return 0; @@ -322,35 +371,36 @@ __device__ __forceinline__ DType calculate_intersection(const DType a0, const DT } template -__launch_bounds__(512) -__global__ void CalculateGreedyNMSResultsKernel(const DType* data, uint32_t* result, - const index_t current_start, - const index_t num_elems, - const index_t num_batches, - const index_t num_blocks_per_row_batch, - const index_t num_blocks_per_row, - const index_t topk, - const index_t element_width, - const index_t num_elements_per_batch, - const int coord_index, - const int class_index, - const int score_index, - const float threshold) { - constexpr int max_elem_width = 20; +__launch_bounds__(512) __global__ + void CalculateGreedyNMSResultsKernel(const DType* data, + uint32_t* result, + const index_t current_start, + const index_t num_elems, + const index_t num_batches, + const index_t num_blocks_per_row_batch, + const index_t num_blocks_per_row, + const index_t topk, + const index_t element_width, + const index_t num_elements_per_batch, + const int coord_index, + const int class_index, + const int score_index, + const float threshold) { + constexpr int max_elem_width = 20; constexpr int num_other_boxes = sizeof(uint32_t) * 8; __shared__ DType other_boxes[max_elem_width * num_other_boxes]; __shared__ DType other_boxes_areas[num_other_boxes]; - const index_t my_row = blockIdx.x / num_blocks_per_row; - const index_t my_block_offset_in_row = blockIdx.x % num_blocks_per_row; + const index_t my_row = blockIdx.x / num_blocks_per_row; + const index_t my_block_offset_in_row = blockIdx.x % num_blocks_per_row; const index_t my_block_offset_in_batch = my_block_offset_in_row % num_blocks_per_row_batch; - const index_t my_batch = (my_block_offset_in_row) / num_blocks_per_row_batch; - const index_t my_element_in_batch = my_block_offset_in_batch * blockDim.x + - current_start + threadIdx.x; + const index_t my_batch = (my_block_offset_in_row) / num_blocks_per_row_batch; + const index_t my_element_in_batch = + my_block_offset_in_batch * blockDim.x + current_start + threadIdx.x; // Load other boxes - const index_t offset = (my_batch * num_elements_per_batch + - current_start + my_row * num_other_boxes) * - element_width; + const index_t offset = + (my_batch * num_elements_per_batch + current_start + my_row * num_other_boxes) * + element_width; for (int i = threadIdx.x; i < element_width * num_other_boxes; i += blockDim.x) { other_boxes[i] = data[offset + i]; } @@ -358,22 +408,23 @@ __global__ void CalculateGreedyNMSResultsKernel(const DType* data, uint32_t* res if (threadIdx.x < num_other_boxes) { const int other_boxes_offset = element_width * threadIdx.x; - const DType their_area = calculate_area( - other_boxes[other_boxes_offset + coord_index + 0], - other_boxes[other_boxes_offset + coord_index + 1], - other_boxes[other_boxes_offset + coord_index + 2], - other_boxes[other_boxes_offset + coord_index + 3]); + const DType their_area = + calculate_area(other_boxes[other_boxes_offset + coord_index + 0], + other_boxes[other_boxes_offset + coord_index + 1], + other_boxes[other_boxes_offset + coord_index + 2], + other_boxes[other_boxes_offset + coord_index + 3]); other_boxes_areas[threadIdx.x] = their_area; } __syncthreads(); - if (my_element_in_batch >= topk) return; + if (my_element_in_batch >= topk) + return; DType my_box[4]; DType my_class = -1; DType my_score = -1; - const index_t my_offset = (my_batch * num_elements_per_batch + my_element_in_batch) * - element_width; + const index_t my_offset = + (my_batch * num_elements_per_batch + my_element_in_batch) * element_width; my_score = data[my_offset + score_index]; #pragma unroll for (int i = 0; i < 4; ++i) { @@ -393,12 +444,15 @@ __global__ void CalculateGreedyNMSResultsKernel(const DType* data, uint32_t* res other_boxes[other_boxes_offset + score_index] != -1) { const DType their_area = other_boxes_areas[i]; - const DType intersect = calculate_intersection( - my_box[0], my_box[1], my_box[2], my_box[3], - other_boxes[other_boxes_offset + coord_index + 0], - other_boxes[other_boxes_offset + coord_index + 1], - other_boxes[other_boxes_offset + coord_index + 2], - other_boxes[other_boxes_offset + coord_index + 3]); + const DType intersect = + calculate_intersection(my_box[0], + my_box[1], + my_box[2], + my_box[3], + other_boxes[other_boxes_offset + coord_index + 0], + other_boxes[other_boxes_offset + coord_index + 1], + other_boxes[other_boxes_offset + coord_index + 2], + other_boxes[other_boxes_offset + coord_index + 3]); if (intersect > threshold * (my_area + their_area - intersect)) { ret = ret | (1u << i); } @@ -409,48 +463,45 @@ __global__ void CalculateGreedyNMSResultsKernel(const DType* data, uint32_t* res } template -__launch_bounds__(NMS::THRESHOLD) -__global__ void ReduceNMSResultTriangleKernel(uint32_t* nms_results, - DType * data, - const index_t score_index, - const index_t element_width, - const index_t num_batches, - const index_t num_elements_per_batch, - const index_t start_index, - const index_t topk) { - constexpr int n_threads = NMS::THRESHOLD; - constexpr int warp_size = 32; - const index_t my_batch = blockIdx.x; +__launch_bounds__(NMS::THRESHOLD) __global__ + void ReduceNMSResultTriangleKernel(uint32_t* nms_results, + DType* data, + const index_t score_index, + const index_t element_width, + const index_t num_batches, + const index_t num_elements_per_batch, + const index_t start_index, + const index_t topk) { + constexpr int n_threads = NMS::THRESHOLD; + constexpr int warp_size = 32; + const index_t my_batch = blockIdx.x; const index_t my_element_in_batch = threadIdx.x + start_index; - const index_t my_element = my_batch * topk + my_element_in_batch; - const int my_warp = threadIdx.x / warp_size; - const int my_lane = threadIdx.x % warp_size; + const index_t my_element = my_batch * topk + my_element_in_batch; + const int my_warp = threadIdx.x / warp_size; + const int my_lane = threadIdx.x % warp_size; __shared__ uint32_t current_valid_boxes[n_threads / warp_size]; - const uint32_t full_mask = 0xFFFFFFFF; - const uint32_t my_lane_mask = 1 << my_lane; + const uint32_t full_mask = 0xFFFFFFFF; + const uint32_t my_lane_mask = 1 << my_lane; const uint32_t earlier_threads_mask = (1 << (my_lane + 1)) - 1; - uint32_t valid = my_lane_mask; - uint32_t valid_boxes = full_mask; + uint32_t valid = my_lane_mask; + uint32_t valid_boxes = full_mask; - uint32_t my_next_mask = my_element_in_batch < topk ? - nms_results[my_element]: - full_mask; + uint32_t my_next_mask = my_element_in_batch < topk ? nms_results[my_element] : full_mask; #pragma unroll for (int i = 0; i < n_threads / warp_size; ++i) { uint32_t my_mask = my_next_mask; - my_next_mask = (((i + 1) < n_threads / warp_size) && - (my_element_in_batch < topk)) ? - nms_results[(i + 1) * topk * num_batches + my_element]: - full_mask; + my_next_mask = (((i + 1) < n_threads / warp_size) && (my_element_in_batch < topk)) + ? nms_results[(i + 1) * topk * num_batches + my_element] + : full_mask; if (my_warp == i && !__all_sync(full_mask, my_mask == full_mask)) { my_mask = my_mask | earlier_threads_mask; // Loop over warp_size - 1 because the last // thread does not contribute to the mask anyway #pragma unroll for (int j = 0; j < warp_size - 1; ++j) { - const uint32_t mask = __shfl_sync(full_mask, valid ? my_mask : full_mask, j); - valid = valid & mask; + const uint32_t mask = __shfl_sync(full_mask, valid ? my_mask : full_mask, j); + valid = valid & mask; } valid_boxes = __ballot_sync(full_mask, valid); } @@ -466,47 +517,48 @@ __global__ void ReduceNMSResultTriangleKernel(uint32_t* nms_results, nms_results[my_element] = valid_boxes; } if (valid == 0) { - data[(my_batch * num_elements_per_batch + my_element_in_batch) * element_width + - score_index] = -1; + data[(my_batch * num_elements_per_batch + my_element_in_batch) * element_width + score_index] = + -1; } } template -__launch_bounds__(512) -__global__ void ReduceNMSResultRestKernel(DType* data, - const uint32_t* nms_results, - const index_t score_index, - const index_t element_width, - const index_t num_batches, - const index_t num_elements_per_batch, - const index_t start_index, - const index_t topk, - const index_t num_blocks_per_batch) { - constexpr int num_other_boxes = sizeof(uint32_t) * 8; - constexpr int num_iterations = NMS::THRESHOLD / num_other_boxes; - constexpr int warp_size = 32; +__launch_bounds__(512) __global__ + void ReduceNMSResultRestKernel(DType* data, + const uint32_t* nms_results, + const index_t score_index, + const index_t element_width, + const index_t num_batches, + const index_t num_elements_per_batch, + const index_t start_index, + const index_t topk, + const index_t num_blocks_per_batch) { + constexpr int num_other_boxes = sizeof(uint32_t) * 8; + constexpr int num_iterations = NMS::THRESHOLD / num_other_boxes; + constexpr int warp_size = 32; const index_t my_block_offset_in_batch = blockIdx.x % num_blocks_per_batch; - const index_t my_batch = blockIdx.x / num_blocks_per_batch; - const index_t my_element_in_batch = my_block_offset_in_batch * blockDim.x + - start_index + NMS::THRESHOLD + threadIdx.x; + const index_t my_batch = blockIdx.x / num_blocks_per_batch; + const index_t my_element_in_batch = + my_block_offset_in_batch * blockDim.x + start_index + NMS::THRESHOLD + threadIdx.x; const index_t my_element = my_batch * topk + my_element_in_batch; - if (my_element_in_batch >= topk) return; + if (my_element_in_batch >= topk) + return; bool valid = true; #pragma unroll for (int i = 0; i < num_iterations; ++i) { - const uint32_t my_mask = nms_results[i * topk * num_batches + my_element]; + const uint32_t my_mask = nms_results[i * topk * num_batches + my_element]; const uint32_t valid_boxes = nms_results[my_batch * topk + i * warp_size + start_index]; const bool no_hit = (valid_boxes & (~my_mask)) == 0; - valid = valid && no_hit; + valid = valid && no_hit; } if (!valid) { - data[(my_batch * num_elements_per_batch + my_element_in_batch) * element_width + - score_index] = -1; + data[(my_batch * num_elements_per_batch + my_element_in_batch) * element_width + score_index] = + -1; } } @@ -517,47 +569,45 @@ TempWorkspace GetWorkspace(const index_t num_batch, const index_t topk, const OpContext& ctx) { TempWorkspace workspace; - Stream *s = ctx.get_stream(); + Stream* s = ctx.get_stream(); const int alignment = 128; // Get the workspace size - workspace.scores_temp_space = 2 * align(num_batch * num_elem * sizeof(DType), alignment); + workspace.scores_temp_space = 2 * align(num_batch * num_elem * sizeof(DType), alignment); workspace.indices_temp_spaces = 2 * align(num_batch * num_elem * sizeof(index_t), alignment); WorkspaceForSort(num_elem, topk, alignment, &workspace); // Place for a buffer workspace.buffer_space = align(num_batch * num_elem * width_elem * sizeof(DType), alignment); - workspace.nms_scratch_space = align(NMS::THRESHOLD / (sizeof(uint32_t) * 8) * - num_batch * topk * sizeof(uint32_t), alignment); + workspace.nms_scratch_space = + align(NMS::THRESHOLD / (sizeof(uint32_t) * 8) * num_batch * topk * sizeof(uint32_t), + alignment); - const size_t workspace_size = workspace.scores_temp_space + - workspace.scratch_space + - workspace.buffer_space + - workspace.nms_scratch_space + + const size_t workspace_size = workspace.scores_temp_space + workspace.scratch_space + + workspace.buffer_space + workspace.nms_scratch_space + workspace.indices_temp_spaces; // Obtain the memory for workspace - Tensor scratch_memory = ctx.requested[box_nms_enum::kTempSpace] - .get_space_typed(mshadow::Shape1(ceil_div(workspace_size, sizeof(double))), s); + Tensor scratch_memory = + ctx.requested[box_nms_enum::kTempSpace].get_space_typed( + mshadow::Shape1(ceil_div(workspace_size, sizeof(double))), s); // Populate workspace pointers - workspace.scores = reinterpret_cast(scratch_memory.dptr_); - workspace.scratch = reinterpret_cast(workspace.scores) + - workspace.scores_temp_space; - workspace.buffer = reinterpret_cast(workspace.scratch + - workspace.scratch_space); - workspace.nms_scratch = reinterpret_cast( - reinterpret_cast(workspace.buffer) + - workspace.buffer_space); - workspace.indices = reinterpret_cast( - reinterpret_cast(workspace.nms_scratch) + - workspace.nms_scratch_space); + workspace.scores = reinterpret_cast(scratch_memory.dptr_); + workspace.scratch = reinterpret_cast(workspace.scores) + workspace.scores_temp_space; + workspace.buffer = reinterpret_cast(workspace.scratch + workspace.scratch_space); + workspace.nms_scratch = reinterpret_cast(reinterpret_cast(workspace.buffer) + + workspace.buffer_space); + workspace.indices = reinterpret_cast(reinterpret_cast(workspace.nms_scratch) + + workspace.nms_scratch_space); return workspace; } template -__global__ void ExtractScoresKernel(const DType* data, DType* scores, - const index_t N, const int element_width, - const int score_index) { +__global__ void ExtractScoresKernel(const DType* data, + DType* scores, + const index_t N, + const int element_width, + const int score_index) { const index_t tid = blockIdx.x * blockDim.x + threadIdx.x; if (tid < N) { scores[tid] = data[tid * element_width + score_index]; @@ -576,31 +626,31 @@ void CompactNMSResults(const Tensor& data, const index_t topk, Stream* s) { using mshadow::Shape1; - constexpr int n_threads = 512; - const index_t num_elements = scores->shape_.Size(); + constexpr int n_threads = 512; + const index_t num_elements = scores->shape_.Size(); const index_t num_elements_per_batch = data.shape_[1]; - const index_t num_batches = data.shape_[0]; - const int element_width = data.shape_[2]; - const index_t n_blocks = ceil_div(num_elements, n_threads); + const index_t num_batches = data.shape_[0]; + const int element_width = data.shape_[2]; + const index_t n_blocks = ceil_div(num_elements, n_threads); ExtractScoresKernel<<::GetStream(s)>>>( data.dptr_, scores->dptr_, num_elements, element_width, score_index); *indices = mshadow::expr::range(0, num_elements); for (index_t i = 0; i < num_batches; ++i) { // Sort each batch separately - Tensor scores_batch(scores->dptr_ + i * num_elements_per_batch, - Shape1(topk), - s); - Tensor indices_batch(indices->dptr_ + i * num_elements_per_batch, - Shape1(topk), - s); - Tensor sorted_scores_batch(sorted_scores->dptr_ + i * num_elements_per_batch, - Shape1(topk), - s); - Tensor sorted_indices_batch(sorted_indices->dptr_ + i * num_elements_per_batch, - Shape1(topk), - s); - mxnet::op::SortByKey(scores_batch, indices_batch, false, scratch, - 0, 8 * sizeof(DType), &sorted_scores_batch, + Tensor scores_batch(scores->dptr_ + i * num_elements_per_batch, Shape1(topk), s); + Tensor indices_batch( + indices->dptr_ + i * num_elements_per_batch, Shape1(topk), s); + Tensor sorted_scores_batch( + sorted_scores->dptr_ + i * num_elements_per_batch, Shape1(topk), s); + Tensor sorted_indices_batch( + sorted_indices->dptr_ + i * num_elements_per_batch, Shape1(topk), s); + mxnet::op::SortByKey(scores_batch, + indices_batch, + false, + scratch, + 0, + 8 * sizeof(DType), + &sorted_scores_batch, &sorted_indices_batch); } CompactData(*sorted_indices, data, out, topk, score_index, s); @@ -621,80 +671,84 @@ void BoxNMSForwardGPU_notemp(const nnvm::NodeAttrs& attrs, CHECK_EQ(inputs.size(), 1U); CHECK_EQ(outputs.size(), 2U) << "BoxNMS output: [output, temp]"; const BoxNMSParam& param = nnvm::get(attrs.parsed); - Stream *s = ctx.get_stream(); - mxnet::TShape in_shape = inputs[box_nms_enum::kData].shape_; - int indim = in_shape.ndim(); - int num_batch = indim <= 2? 1 : in_shape.ProdShape(0, indim - 2); - int num_elem = in_shape[indim - 2]; - int width_elem = in_shape[indim - 1]; + Stream* s = ctx.get_stream(); + mxnet::TShape in_shape = inputs[box_nms_enum::kData].shape_; + int indim = in_shape.ndim(); + int num_batch = indim <= 2 ? 1 : in_shape.ProdShape(0, indim - 2); + int num_elem = in_shape[indim - 2]; + int width_elem = in_shape[indim - 1]; MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { - Tensor data = inputs[box_nms_enum::kData] - .get_with_shape(Shape3(num_batch, num_elem, width_elem), s); - Tensor out = outputs[box_nms_enum::kOut] - .get_with_shape(Shape3(num_batch, num_elem, width_elem), s); + Tensor data = inputs[box_nms_enum::kData].get_with_shape( + Shape3(num_batch, num_elem, width_elem), s); + Tensor out = outputs[box_nms_enum::kOut].get_with_shape( + Shape3(num_batch, num_elem, width_elem), s); // Special case for topk == 0 if (param.topk == 0) { - if (req[0] != kNullOp && - req[0] != kWriteInplace) { + if (req[0] != kNullOp && req[0] != kWriteInplace) { out = mshadow::expr::F(data); } return; } - index_t topk = param.topk > 0 ? std::min(param.topk, num_elem) : num_elem; - const auto& workspace = GetWorkspace(num_batch, num_elem, - width_elem, topk, ctx); + index_t topk = param.topk > 0 ? std::min(param.topk, num_elem) : num_elem; + const auto& workspace = GetWorkspace(num_batch, num_elem, width_elem, topk, ctx); FilterAndPrepareAuxData(data, &out, workspace, param, s); Tensor scores(workspace.scores, Shape1(num_batch * num_elem), s); - Tensor sorted_scores(workspace.scores + scores.MSize(), - Shape1(num_batch * num_elem), s); + Tensor sorted_scores( + workspace.scores + scores.MSize(), Shape1(num_batch * num_elem), s); Tensor indices(workspace.indices, Shape1(num_batch * num_elem), s); - Tensor sorted_indices(workspace.indices + indices.MSize(), - Shape1(num_batch * num_elem), s); - Tensor scratch(reinterpret_cast(workspace.scratch), - Shape1(workspace.scratch_space), s); - Tensor buffer(workspace.buffer, - Shape3(num_batch, num_elem, width_elem), s); - Tensor nms_scratch(workspace.nms_scratch, - Shape2(NMS::THRESHOLD / (sizeof(uint32_t) * 8), - topk * num_batch), - s); + Tensor sorted_indices( + workspace.indices + indices.MSize(), Shape1(num_batch * num_elem), s); + Tensor scratch( + reinterpret_cast(workspace.scratch), Shape1(workspace.scratch_space), s); + Tensor buffer(workspace.buffer, Shape3(num_batch, num_elem, width_elem), s); + Tensor nms_scratch( + workspace.nms_scratch, + Shape2(NMS::THRESHOLD / (sizeof(uint32_t) * 8), topk * num_batch), + s); indices = mshadow::expr::range(0, num_batch * num_elem); for (index_t i = 0; i < num_batch; ++i) { // Sort each batch separately - Tensor scores_batch(scores.dptr_ + i * num_elem, - Shape1(num_elem), - s); - Tensor indices_batch(indices.dptr_ + i * num_elem, - Shape1(num_elem), - s); - Tensor sorted_scores_batch(sorted_scores.dptr_ + i * num_elem, - Shape1(num_elem), - s); - Tensor sorted_indices_batch(sorted_indices.dptr_ + i * num_elem, - Shape1(num_elem), - s); - mxnet::op::SortByKey(scores_batch, indices_batch, false, &scratch, 0, - 8 * sizeof(DType), &sorted_scores_batch, + Tensor scores_batch(scores.dptr_ + i * num_elem, Shape1(num_elem), s); + Tensor indices_batch(indices.dptr_ + i * num_elem, Shape1(num_elem), s); + Tensor sorted_scores_batch( + sorted_scores.dptr_ + i * num_elem, Shape1(num_elem), s); + Tensor sorted_indices_batch( + sorted_indices.dptr_ + i * num_elem, Shape1(num_elem), s); + mxnet::op::SortByKey(scores_batch, + indices_batch, + false, + &scratch, + 0, + 8 * sizeof(DType), + &sorted_scores_batch, &sorted_indices_batch); } CompactData(sorted_indices, out, &buffer, topk, -1, s); NMS nms; nms(&buffer, &nms_scratch, topk, param, s); - CompactNMSResults(buffer, &out, &indices, &scores, &sorted_indices, - &sorted_scores, &scratch, param.score_index, topk, s); + CompactNMSResults(buffer, + &out, + &indices, + &scores, + &sorted_indices, + &sorted_scores, + &scratch, + param.score_index, + topk, + s); // convert encoding if (param.in_format != param.out_format) { if (box_common_enum::kCenter == param.out_format) { - mxnet::op::mxnet_op::Kernel::Launch(s, num_batch * num_elem, - out.dptr_ + param.coord_start, width_elem); + mxnet::op::mxnet_op::Kernel::Launch( + s, num_batch * num_elem, out.dptr_ + param.coord_start, width_elem); } else { - mxnet::op::mxnet_op::Kernel::Launch(s, num_batch * num_elem, - out.dptr_ + param.coord_start, width_elem); + mxnet::op::mxnet_op::Kernel::Launch( + s, num_batch * num_elem, out.dptr_ + param.coord_start, width_elem); } } }); @@ -717,30 +771,25 @@ void BoxNMSForwardGPU(const nnvm::NodeAttrs& attrs, BoxNMSForward(attrs, ctx, inputs, req, outputs); } - -NNVM_REGISTER_OP(_contrib_box_nms) -.set_attr("FCompute", BoxNMSForwardGPU); +NNVM_REGISTER_OP(_contrib_box_nms).set_attr("FCompute", BoxNMSForwardGPU); NNVM_REGISTER_OP(_backward_contrib_box_nms) -.set_attr("FCompute", BoxNMSBackward); + .set_attr("FCompute", BoxNMSBackward); -NNVM_REGISTER_OP(_contrib_box_iou) -.set_attr("FCompute", BoxOverlapForward); +NNVM_REGISTER_OP(_contrib_box_iou).set_attr("FCompute", BoxOverlapForward); NNVM_REGISTER_OP(_backward_contrib_box_iou) -.set_attr("FCompute", BoxOverlapBackward); + .set_attr("FCompute", BoxOverlapBackward); NNVM_REGISTER_OP(_contrib_bipartite_matching) -.set_attr("FCompute", BipartiteMatchingForward); + .set_attr("FCompute", BipartiteMatchingForward); NNVM_REGISTER_OP(_backward_contrib_bipartite_matching) -.set_attr("FCompute", BipartiteMatchingBackward); + .set_attr("FCompute", BipartiteMatchingBackward); -NNVM_REGISTER_OP(_contrib_box_encode) -.set_attr("FCompute", BoxEncodeForward); +NNVM_REGISTER_OP(_contrib_box_encode).set_attr("FCompute", BoxEncodeForward); -NNVM_REGISTER_OP(_contrib_box_decode) -.set_attr("FCompute", BoxDecodeForward); +NNVM_REGISTER_OP(_contrib_box_decode).set_attr("FCompute", BoxDecodeForward); } // namespace op } // namespace mxnet diff --git a/src/operator/contrib/count_sketch-inl.h b/src/operator/contrib/count_sketch-inl.h index f67856a398a4..4053d83a5471 100644 --- a/src/operator/contrib/count_sketch-inl.h +++ b/src/operator/contrib/count_sketch-inl.h @@ -22,7 +22,7 @@ * \file count_sketch-inl.h * \brief count_sketch operator and symbol * \author Chen Zhu -*/ + */ #ifndef MXNET_OPERATOR_CONTRIB_COUNT_SKETCH_INL_H_ #define MXNET_OPERATOR_CONTRIB_COUNT_SKETCH_INL_H_ #include @@ -39,98 +39,110 @@ namespace op { // Declare enumeration of input order to make code more intuitive. // These enums are only visible within this header namespace CountSketch { -enum CountSketchOpInputs{kData, kH, kS}; -enum CountSketchOpOutputs{kOut}; +enum CountSketchOpInputs { kData, kH, kS }; +enum CountSketchOpOutputs { kOut }; } // namespace CountSketch // seems that we can infer all the parameters from data shapes at the moment struct CountSketchParam : public dmlc::Parameter { - int out_dim; - int processing_batch_size; - DMLC_DECLARE_PARAMETER(CountSketchParam) { - DMLC_DECLARE_FIELD(out_dim) - .describe("The output dimension."); - DMLC_DECLARE_FIELD(processing_batch_size).set_default(32) + int out_dim; + int processing_batch_size; + DMLC_DECLARE_PARAMETER(CountSketchParam) { + DMLC_DECLARE_FIELD(out_dim).describe("The output dimension."); + DMLC_DECLARE_FIELD(processing_batch_size) + .set_default(32) .describe("How many sketch vectors to process at one time."); - } + } }; -template +template class CountSketchOp : public Operator { public: - explicit CountSketchOp(CountSketchParam param) { - this->param_ = param; - } + explicit CountSketchOp(CountSketchParam param) { + this->param_ = param; + } - virtual void Forward(const OpContext &ctx, - const std::vector &in_data, - const std::vector &req, - const std::vector &out_data, - const std::vector &aux_args) { - using namespace mshadow; - CHECK_EQ(in_data.size(), 3); - CHECK_EQ(out_data.size(), 1); - Stream *s = ctx.get_stream(); - - // use FlatTo2D to preseve the possible 4D shape - // h and s should be 1d vectors - Tensor data = in_data[CountSketch::kData].FlatTo2D(s); - - const mxnet::TShape& hshape = in_data[CountSketch::kH].shape_; - const mxnet::TShape& sshape = in_data[CountSketch::kS].shape_; - Tensor h = in_data[CountSketch::kH].get_with_shape( - Shape1(hshape.ProdShape(0, hshape.ndim())), s); - Tensor ss = in_data[CountSketch::kS].get_with_shape( - Shape1(sshape.ProdShape(0, sshape.ndim())), s); - Tensor out = out_data[CountSketch::kOut].FlatTo2D(s); - n_samples = data.shape_[0]; - in_dim = data.shape_[1]; + virtual void Forward(const OpContext& ctx, + const std::vector& in_data, + const std::vector& req, + const std::vector& out_data, + const std::vector& aux_args) { + using namespace mshadow; + CHECK_EQ(in_data.size(), 3); + CHECK_EQ(out_data.size(), 1); + Stream* s = ctx.get_stream(); + + // use FlatTo2D to preseve the possible 4D shape + // h and s should be 1d vectors + Tensor data = in_data[CountSketch::kData].FlatTo2D(s); + + const mxnet::TShape& hshape = in_data[CountSketch::kH].shape_; + const mxnet::TShape& sshape = in_data[CountSketch::kS].shape_; + Tensor h = in_data[CountSketch::kH].get_with_shape( + Shape1(hshape.ProdShape(0, hshape.ndim())), s); + Tensor ss = in_data[CountSketch::kS].get_with_shape( + Shape1(sshape.ProdShape(0, sshape.ndim())), s); + Tensor out = out_data[CountSketch::kOut].FlatTo2D(s); + n_samples = data.shape_[0]; + in_dim = data.shape_[1]; // firstly set out to zero as we will use sum out = 0; - CountSketchForward(out, data, h, ss, n_samples, - this->param_.processing_batch_size, in_dim, this->param_.out_dim); - } + CountSketchForward(out, + data, + h, + ss, + n_samples, + this->param_.processing_batch_size, + in_dim, + this->param_.out_dim); + } - virtual void Backward(const OpContext &ctx, - const std::vector &out_grad, - const std::vector &in_data, - const std::vector &out_data, - const std::vector &req, - const std::vector &in_grad, - const std::vector &aux_args) { + virtual void Backward(const OpContext& ctx, + const std::vector& out_grad, + const std::vector& in_data, + const std::vector& out_data, + const std::vector& req, + const std::vector& in_grad, + const std::vector& aux_args) { using namespace mshadow; - Stream *s = ctx.get_stream(); + Stream* s = ctx.get_stream(); Tensor ograd = out_grad[CountSketch::kOut].FlatTo2D(s); Tensor dgrad = in_grad[CountSketch::kData].FlatTo2D(s); const mxnet::TShape& hshape = in_data[CountSketch::kH].shape_; const mxnet::TShape& sshape = in_data[CountSketch::kS].shape_; - Tensor h = in_data[CountSketch::kH].get_with_shape( - Shape1(hshape.ProdShape(0, hshape.ndim())), s); + Tensor h = in_data[CountSketch::kH].get_with_shape( + Shape1(hshape.ProdShape(0, hshape.ndim())), s); Tensor ss = in_data[CountSketch::kS].get_with_shape( - Shape1(sshape.ProdShape(0, sshape.ndim())), s); + Shape1(sshape.ProdShape(0, sshape.ndim())), s); - CountSketchBackward(dgrad, ograd, h, ss, n_samples, - this->param_.processing_batch_size, in_dim, this->param_.out_dim); - } + CountSketchBackward(dgrad, + ograd, + h, + ss, + n_samples, + this->param_.processing_batch_size, + in_dim, + this->param_.out_dim); + } private: - CountSketchParam param_; - int n_samples; - int in_dim; + CountSketchParam param_; + int n_samples; + int in_dim; }; // class CountSketchOp // Declare Factory Function -template +template Operator* CreateOp(CountSketchParam param, int dtype); #if DMLC_USE_CXX11 class CountSketchProp : public OperatorProperty { public: - std::vector ListArguments() const override { - return {"data", "h", "s"}; - } - std::vector ListOutputs() const override { + std::vector ListArguments() const override { + return {"data", "h", "s"}; + } + std::vector ListOutputs() const override { return {"output"}; } int NumOutputs() const override { @@ -144,44 +156,45 @@ class CountSketchProp : public OperatorProperty { return param_.__DICT__(); } - bool InferShape(mxnet::ShapeVector *in_shape, - mxnet::ShapeVector *out_shape, - mxnet::ShapeVector *aux_shape) const override { + bool InferShape(mxnet::ShapeVector* in_shape, + mxnet::ShapeVector* out_shape, + mxnet::ShapeVector* aux_shape) const override { using namespace mshadow; - CHECK_EQ(in_shape->size(), 3) <<"Input:[data, h, s]"; - const mxnet::TShape &dshape = (*in_shape)[CountSketch::kData]; + CHECK_EQ(in_shape->size(), 3) << "Input:[data, h, s]"; + const mxnet::TShape& dshape = (*in_shape)[CountSketch::kData]; // require data to be known - if (mxnet::op::shape_is_none(dshape)) return false; + if (mxnet::op::shape_is_none(dshape)) + return false; out_shape->clear(); if (dshape.ndim() == 4) { // check the shapes of h and s - CHECK_EQ((*in_shape)[CountSketch::kH][1], dshape[3]) - << "H should be 2D tensor with same length as input shape[3], " - << (*in_shape)[CountSketch::kH][1] << " v.s. " << dshape[3]; - CHECK_EQ((*in_shape)[CountSketch::kS][1], dshape[3]) - << "S should be 2D tensor with same length as input shape[3], " - << (*in_shape)[CountSketch::kS][1] << " v.s. " << dshape[3]; - - out_shape->push_back(Shape4(dshape[0], dshape[1], dshape[2], param_.out_dim)); + CHECK_EQ((*in_shape)[CountSketch::kH][1], dshape[3]) + << "H should be 2D tensor with same length as input shape[3], " + << (*in_shape)[CountSketch::kH][1] << " v.s. " << dshape[3]; + CHECK_EQ((*in_shape)[CountSketch::kS][1], dshape[3]) + << "S should be 2D tensor with same length as input shape[3], " + << (*in_shape)[CountSketch::kS][1] << " v.s. " << dshape[3]; + + out_shape->push_back(Shape4(dshape[0], dshape[1], dshape[2], param_.out_dim)); } else if (dshape.ndim() == 2) { - CHECK_EQ((*in_shape)[CountSketch::kH][1], dshape[1]) - << "H should be 2D tensor with same length as input shape[1], " - << (*in_shape)[CountSketch::kH][1] << " v.s. " << dshape[1]; - CHECK_EQ((*in_shape)[CountSketch::kS][1], dshape[1]) - << "S should be 2D tensor with same length as input shape[1], " - << (*in_shape)[CountSketch::kS][1] << " v.s. " << dshape[1]; - out_shape->push_back(Shape2(dshape[0], param_.out_dim)); + CHECK_EQ((*in_shape)[CountSketch::kH][1], dshape[1]) + << "H should be 2D tensor with same length as input shape[1], " + << (*in_shape)[CountSketch::kH][1] << " v.s. " << dshape[1]; + CHECK_EQ((*in_shape)[CountSketch::kS][1], dshape[1]) + << "S should be 2D tensor with same length as input shape[1], " + << (*in_shape)[CountSketch::kS][1] << " v.s. " << dshape[1]; + out_shape->push_back(Shape2(dshape[0], param_.out_dim)); } else { - CHECK_EQ(dshape.ndim(), 2) <<"Data should be 2D or 4D!"; - return false; + CHECK_EQ(dshape.ndim(), 2) << "Data should be 2D or 4D!"; + return false; } return true; } - bool InferType(std::vector *in_type, - std::vector *out_type, - std::vector *aux_type) const override { + bool InferType(std::vector* in_type, + std::vector* out_type, + std::vector* aux_type) const override { CHECK_GE(in_type->size(), 1); int dtype = (*in_type)[0]; CHECK_NE(dtype, -1) << "First input must have specified type"; @@ -199,7 +212,7 @@ class CountSketchProp : public OperatorProperty { OperatorProperty* Copy() const override { CountSketchProp* cs_sym = new CountSketchProp(); - cs_sym->param_ = this->param_; + cs_sym->param_ = this->param_; return cs_sym; } @@ -208,19 +221,20 @@ class CountSketchProp : public OperatorProperty { } // declare dependency and inplace optimization options - std::vector DeclareBackwardDependency( - const std::vector &out_grad, - const std::vector &in_data, - const std::vector &out_data) const override { - return {out_grad[CountSketch::kOut], in_data[CountSketch::kData], - in_data[CountSketch::kH], in_data[CountSketch::kS]}; + std::vector DeclareBackwardDependency(const std::vector& out_grad, + const std::vector& in_data, + const std::vector& out_data) const override { + return {out_grad[CountSketch::kOut], + in_data[CountSketch::kData], + in_data[CountSketch::kH], + in_data[CountSketch::kS]}; } std::vector > BackwardInplaceOption( - const std::vector &out_grad, - const std::vector &in_data, - const std::vector &out_data, - const std::vector &in_grad) const override { + const std::vector& out_grad, + const std::vector& in_data, + const std::vector& out_data, + const std::vector& in_grad) const override { return {{in_data[CountSketch::kData], in_grad[CountSketch::kData]}}; } @@ -229,11 +243,12 @@ class CountSketchProp : public OperatorProperty { return nullptr; } - Operator* CreateOperatorEx(Context ctx, mxnet::ShapeVector *in_shape, - std::vector *in_type) const override; + Operator* CreateOperatorEx(Context ctx, + mxnet::ShapeVector* in_shape, + std::vector* in_type) const override; private: - CountSketchParam param_; + CountSketchParam param_; }; #endif } // namespace op diff --git a/src/operator/contrib/count_sketch.cc b/src/operator/contrib/count_sketch.cc index 4b6504e564ee..e75e97a2b07d 100644 --- a/src/operator/contrib/count_sketch.cc +++ b/src/operator/contrib/count_sketch.cc @@ -22,28 +22,29 @@ * \file count_sketch.cc * \brief count_sketch op * \author Chen Zhu -*/ + */ #include "./count_sketch-inl.h" namespace mxnet { namespace op { -template<> -Operator *CreateOp(CountSketchParam param, int dtype) { - LOG(FATAL) << "CountSketch is only available for GPU."; - return nullptr; +template <> +Operator* CreateOp(CountSketchParam param, int dtype) { + LOG(FATAL) << "CountSketch is only available for GPU."; + return nullptr; } -Operator *CountSketchProp::CreateOperatorEx(Context ctx, mxnet::ShapeVector *in_shape, - std::vector *in_type) const { - mxnet::ShapeVector out_shape, aux_shape; - std::vector out_type, aux_type; - CHECK(InferType(in_type, &out_type, &aux_type)); - CHECK(InferShape(in_shape, &out_shape, &aux_shape)); - DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0]); +Operator* CountSketchProp::CreateOperatorEx(Context ctx, + mxnet::ShapeVector* in_shape, + std::vector* in_type) const { + mxnet::ShapeVector out_shape, aux_shape; + std::vector out_type, aux_type; + CHECK(InferType(in_type, &out_type, &aux_type)); + CHECK(InferShape(in_shape, &out_shape, &aux_shape)); + DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0]); } DMLC_REGISTER_PARAMETER(CountSketchParam); MXNET_REGISTER_OP_PROPERTY(_contrib_count_sketch, CountSketchProp) -.describe(R"code(Apply CountSketch to input: map a d-dimension data to k-dimension data" + .describe(R"code(Apply CountSketch to input: map a d-dimension data to k-dimension data" .. note:: `count_sketch` is only available on GPU. @@ -65,10 +66,10 @@ Example:: [3.2, 0, 0, -5.7, 6.6]] )code" ADD_FILELINE) -.add_argument("data", "NDArray-or-Symbol", "Input data to the CountSketchOp.") -.add_argument("h", "NDArray-or-Symbol", "The index vector") -.add_argument("s", "NDArray-or-Symbol", "The sign vector") -.add_arguments(CountSketchParam::__FIELDS__()); + .add_argument("data", "NDArray-or-Symbol", "Input data to the CountSketchOp.") + .add_argument("h", "NDArray-or-Symbol", "The index vector") + .add_argument("s", "NDArray-or-Symbol", "The sign vector") + .add_arguments(CountSketchParam::__FIELDS__()); } // namespace op } // namespace mxnet diff --git a/src/operator/contrib/count_sketch.cu b/src/operator/contrib/count_sketch.cu index c6370c09d1c1..71cd8ddb5c8d 100644 --- a/src/operator/contrib/count_sketch.cu +++ b/src/operator/contrib/count_sketch.cu @@ -22,15 +22,13 @@ * \file count_sketch.cu * \brief count_sketch op * \author Chen Zhu, Yang Shi -*/ + */ #include "./count_sketch-inl.h" #include #include #include - - -#define WARPS_PER_BLOCK 1 +#define WARPS_PER_BLOCK 1 #define THREADS_PER_BLOCK 512 namespace mshadow { @@ -48,21 +46,26 @@ __device__ void atomic_add(double* address, double val) { // #atomic-functions // NOLINT_NEXT_LINE(runtime/int) - unsigned long long int* address_as_ull = (unsigned long long int*) address; // NOLINT(*) - unsigned long long int old = *address_as_ull, assumed; // NOLINT(*) + unsigned long long int* address_as_ull = (unsigned long long int*)address; // NOLINT(*) + unsigned long long int old = *address_as_ull, assumed; // NOLINT(*) do { assumed = old; - old = atomicCAS(address_as_ull, assumed, - __double_as_longlong(val + __longlong_as_double(assumed))); + old = atomicCAS( + address_as_ull, assumed, __double_as_longlong(val + __longlong_as_double(assumed))); // Note: uses integer comparison to avoid hang in case of NaN // (since NaN != NaN) } while (assumed != old); } template -__global__ void sketch_forward_kernel(const int nthreads, DType *out, const DType *h, - const DType *s, const DType *in, const int n_smaples, - const int in_dim, const int out_dim) { +__global__ void sketch_forward_kernel(const int nthreads, + DType* out, + const DType* h, + const DType* s, + const DType* in, + const int n_smaples, + const int in_dim, + const int out_dim) { // input: n_smaples * in_dim // output: n_smaples * out_dim const int index = blockIdx.x * blockDim.x + threadIdx.x; @@ -71,120 +74,134 @@ __global__ void sketch_forward_kernel(const int nthreads, DType *out, const DTyp } // nthreads is the maximum of thread indices, should be equal to in_dim // index is point index - const int i_indim = index % in_dim; + const int i_indim = index % in_dim; const int i_sample = index / in_dim; // get the target location in the output - const int target = i_sample*out_dim + h[i_indim]; + const int target = i_sample * out_dim + h[i_indim]; atomic_add(out + target, s[i_indim] * in[index]); } template -__global__ void sketch_backward_kernel(const int nthreads, DType *in_grad, const DType *h, - const DType *s, const DType *out_grad, const int n_smaples, - const int in_dim, const int out_dim) { +__global__ void sketch_backward_kernel(const int nthreads, + DType* in_grad, + const DType* h, + const DType* s, + const DType* out_grad, + const int n_smaples, + const int in_dim, + const int out_dim) { // only calculate gradient regarding x // can also calculate gradient regarding s if needed - const int index = blockIdx.x * blockDim.x + threadIdx.x; - const int i_indim = index % in_dim; + const int index = blockIdx.x * blockDim.x + threadIdx.x; + const int i_indim = index % in_dim; const int i_sample = index / in_dim; - const int i_outdim = i_sample*out_dim + h[i_indim]; - in_grad[index] = out_grad[i_outdim] * s[i_indim]; + const int i_outdim = i_sample * out_dim + h[i_indim]; + in_grad[index] = out_grad[i_outdim] * s[i_indim]; } } // namespace cuda // CountSketch Forward template -inline void CountSketchForward(const Tensor &out, - const Tensor &in, - const Tensor &h, - const Tensor &s, +inline void CountSketchForward(const Tensor& out, + const Tensor& in, + const Tensor& h, + const Tensor& s, const int n_samples, const int processing_batch_size, const int in_dim, const int out_dim) { - DType *out_ptr = out.dptr_; - const DType *in_ptr = in.dptr_; - const DType *h_ptr = h.dptr_; - const DType *s_ptr = s.dptr_; - int upper_bound = n_samples/processing_batch_size; - if (n_samples%processing_batch_size == 0) { - upper_bound = upper_bound-1; + DType* out_ptr = out.dptr_; + const DType* in_ptr = in.dptr_; + const DType* h_ptr = h.dptr_; + const DType* s_ptr = s.dptr_; + int upper_bound = n_samples / processing_batch_size; + if (n_samples % processing_batch_size == 0) { + upper_bound = upper_bound - 1; } // guarantee there are at least one iteration - upper_bound = upper_bound > 0? upper_bound:0; - int bstart = 0; - for ( int i = 0; i <= upper_bound; i++ ) { + upper_bound = upper_bound > 0 ? upper_bound : 0; + int bstart = 0; + for (int i = 0; i <= upper_bound; i++) { const int batchlen = min(processing_batch_size, n_samples - bstart); const int nthreads = batchlen * in_dim; // to make number of threads the same as input const int threads_per_block = min(THREADS_PER_BLOCK, nthreads); - int nblocks = (nthreads + threads_per_block - 1) / threads_per_block; - cuda::sketch_forward_kernel<<>>( - nthreads, out_ptr+bstart*out_dim, h_ptr, - s_ptr, in_ptr+bstart*in_dim, batchlen, - in_dim, out_dim); + int nblocks = (nthreads + threads_per_block - 1) / threads_per_block; + cuda::sketch_forward_kernel<<>>(nthreads, + out_ptr + bstart * out_dim, + h_ptr, + s_ptr, + in_ptr + bstart * in_dim, + batchlen, + in_dim, + out_dim); cudaError_t err = cudaDeviceSynchronize(); CHECK_EQ(err, cudaSuccess) << "Error occured! CUDA: " << cudaGetErrorString(err); - bstart = (i+1)*batchlen; + bstart = (i + 1) * batchlen; } } -template -inline void CountSketchBackward(const Tensor &in_grad, - const Tensor &out_grad, - const Tensor &h, - const Tensor &s, +template +inline void CountSketchBackward(const Tensor& in_grad, + const Tensor& out_grad, + const Tensor& h, + const Tensor& s, const int n_samples, const int processing_batch_size, const int in_dim, const int out_dim) { - DType *in_grad_ptr = in_grad.dptr_; - const DType *out_grad_ptr = out_grad.dptr_; - const DType *h_ptr = h.dptr_; - const DType *s_ptr = s.dptr_; - int upper_bound = n_samples/processing_batch_size; - if (n_samples%processing_batch_size == 0) { - upper_bound = upper_bound-1; + DType* in_grad_ptr = in_grad.dptr_; + const DType* out_grad_ptr = out_grad.dptr_; + const DType* h_ptr = h.dptr_; + const DType* s_ptr = s.dptr_; + int upper_bound = n_samples / processing_batch_size; + if (n_samples % processing_batch_size == 0) { + upper_bound = upper_bound - 1; } // guarantee there are at least one iteration upper_bound = upper_bound > 0 ? upper_bound : 0; - int bstart = 0; - for ( int i = 0; i <= upper_bound; i++ ) { + int bstart = 0; + for (int i = 0; i <= upper_bound; i++) { const int batchlen = min(processing_batch_size, n_samples - bstart); const int nthreads = batchlen * in_dim; // to make number of threads the same as input const int threads_per_block = min(THREADS_PER_BLOCK, nthreads); - int nblocks = (nthreads + threads_per_block - 1) / threads_per_block; - cuda::sketch_backward_kernel<<>>( - nthreads, in_grad_ptr+bstart*in_dim, h_ptr, - s_ptr, out_grad_ptr+bstart*out_dim, batchlen, - in_dim, out_dim); + int nblocks = (nthreads + threads_per_block - 1) / threads_per_block; + cuda::sketch_backward_kernel + <<>>(nthreads, + in_grad_ptr + bstart * in_dim, + h_ptr, + s_ptr, + out_grad_ptr + bstart * out_dim, + batchlen, + in_dim, + out_dim); cudaError_t err = cudaDeviceSynchronize(); CHECK_EQ(err, cudaSuccess) << "Error occured! CUDA: " << cudaGetErrorString(err); - bstart = (i+1)*batchlen; + bstart = (i + 1) * batchlen; } } } // namespace mshadow namespace mxnet { namespace op { -template<> +template <> Operator* CreateOp(CountSketchParam param, int dtype) { - Operator *op = nullptr; + Operator* op = nullptr; switch (dtype) { - case mshadow::kFloat32: - op = new CountSketchOp(param); - break; - case mshadow::kFloat64: - op = new CountSketchOp(param); - break; - case mshadow::kFloat16: - LOG(FATAL) << "float16 count sketch layer is currently" - "not supported."; - break; - default: - LOG(FATAL) << "Unsupported type " << dtype; + case mshadow::kFloat32: + op = new CountSketchOp(param); + break; + case mshadow::kFloat64: + op = new CountSketchOp(param); + break; + case mshadow::kFloat16: + LOG(FATAL) << "float16 count sketch layer is currently" + "not supported."; + break; + default: + LOG(FATAL) << "Unsupported type " << dtype; } return op; } diff --git a/src/operator/contrib/deformable_psroi_pooling-inl.h b/src/operator/contrib/deformable_psroi_pooling-inl.h index f848f531b0cd..908c36ce65cc 100644 --- a/src/operator/contrib/deformable_psroi_pooling-inl.h +++ b/src/operator/contrib/deformable_psroi_pooling-inl.h @@ -18,12 +18,12 @@ */ /*! -* Copyright (c) 2017 Microsoft -* Licensed under The Apache-2.0 License [see LICENSE for details] -* \file deformable_psroi_pooling-inl.h -* \brief deformable psroi pooling operator and symbol -* \author Yi Li, Guodong Zhang, Jifeng Dai -*/ + * Copyright (c) 2017 Microsoft + * Licensed under The Apache-2.0 License [see LICENSE for details] + * \file deformable_psroi_pooling-inl.h + * \brief deformable psroi pooling operator and symbol + * \author Yi Li, Guodong Zhang, Jifeng Dai + */ #ifndef MXNET_OPERATOR_CONTRIB_DEFORMABLE_PSROI_POOLING_INL_H_ #define MXNET_OPERATOR_CONTRIB_DEFORMABLE_PSROI_POOLING_INL_H_ @@ -37,16 +37,15 @@ #include "../mshadow_op.h" #include "../operator_common.h" - namespace mxnet { namespace op { - // Declare enumeration of input order to make code more intuitive. - // These enums are only visible within this header +// Declare enumeration of input order to make code more intuitive. +// These enums are only visible within this header namespace deformablepsroipool { - enum DeformablePSROIPoolingOpInputs { kData, kBox, kTrans }; - enum DeformablePSROIPoolingOpOutputs { kOut, kTopCount }; -} // deformablepsroipool +enum DeformablePSROIPoolingOpInputs { kData, kBox, kTrans }; +enum DeformablePSROIPoolingOpOutputs { kOut, kTopCount }; +} // namespace deformablepsroipool struct DeformablePSROIPoolingParam : public dmlc::Parameter { // mxnet::TShape pooled_size; @@ -59,35 +58,36 @@ struct DeformablePSROIPoolingParam : public dmlc::Parameter +template class DeformablePSROIPoolingOp : public Operator { public: explicit DeformablePSROIPoolingOp(DeformablePSROIPoolingParam p) { this->param_ = p; } - virtual void Forward(const OpContext &ctx, - const std::vector &in_data, - const std::vector &req, - const std::vector &out_data, - const std::vector &aux_args) { + virtual void Forward(const OpContext& ctx, + const std::vector& in_data, + const std::vector& req, + const std::vector& out_data, + const std::vector& aux_args) { using namespace mshadow; - size_t in_expected = param_.no_trans? 2 : 3; + size_t in_expected = param_.no_trans ? 2 : 3; size_t out_expected = 2; CHECK_EQ(in_data.size(), in_expected); CHECK_EQ(out_data.size(), out_expected); @@ -95,38 +95,48 @@ class DeformablePSROIPoolingOp : public Operator { in_data[deformablepsroipool::kBox].shape_[0]); CHECK_EQ(out_data[deformablepsroipool::kTopCount].shape_[0], in_data[deformablepsroipool::kBox].shape_[0]); - Stream *s = ctx.get_stream(); + Stream* s = ctx.get_stream(); Tensor data = in_data[deformablepsroipool::kData].get(s); Tensor bbox = in_data[deformablepsroipool::kBox].get(s); - Tensor out = out_data[deformablepsroipool::kOut].get(s); - Tensor top_count = out_data[deformablepsroipool::kTopCount] - .get(s); + Tensor out = out_data[deformablepsroipool::kOut].get(s); + Tensor top_count = + out_data[deformablepsroipool::kTopCount].get(s); CHECK_EQ(data.CheckContiguous(), true); CHECK_EQ(bbox.CheckContiguous(), true); CHECK_EQ(out.CheckContiguous(), true); CHECK_EQ(top_count.CheckContiguous(), true); - out = -FLT_MAX; + out = -FLT_MAX; top_count = 0.0f; Tensor trans{nullptr, mshadow::Shape4(0, 0, 0, 0)}; if (!param_.no_trans) { trans = in_data[deformablepsroipool::kTrans].get(s); } - DeformablePSROIPoolForward(out, data, bbox, trans, top_count, param_.no_trans, - param_.spatial_scale, param_.output_dim, param_.group_size, param_.pooled_size, - param_.part_size, param_.sample_per_part, param_.trans_std); + DeformablePSROIPoolForward(out, + data, + bbox, + trans, + top_count, + param_.no_trans, + param_.spatial_scale, + param_.output_dim, + param_.group_size, + param_.pooled_size, + param_.part_size, + param_.sample_per_part, + param_.trans_std); } - virtual void Backward(const OpContext &ctx, - const std::vector &out_grad, - const std::vector &in_data, - const std::vector &out_data, - const std::vector &req, - const std::vector &in_grad, - const std::vector &aux_args) { + virtual void Backward(const OpContext& ctx, + const std::vector& out_grad, + const std::vector& in_data, + const std::vector& out_data, + const std::vector& req, + const std::vector& in_grad, + const std::vector& aux_args) { using namespace mshadow; - size_t in_expected = param_.no_trans ? 2 : 3; + size_t in_expected = param_.no_trans ? 2 : 3; size_t out_expected = 2; CHECK_EQ(in_data.size(), in_expected); CHECK_EQ(out_data.size(), out_expected); @@ -134,26 +144,26 @@ class DeformablePSROIPoolingOp : public Operator { in_data[deformablepsroipool::kBox].shape_[0]); CHECK_EQ(out_data[deformablepsroipool::kTopCount].shape_[0], in_data[deformablepsroipool::kBox].shape_[0]); - CHECK_NE(req[deformablepsroipool::kData], kWriteInplace) << - "DeformablePSROIPooling: Backward doesn't support kWriteInplace."; - CHECK_NE(req[deformablepsroipool::kBox], kWriteInplace) << - "DeformablePSROIPooling: Backward doesn't support kWriteInplace."; + CHECK_NE(req[deformablepsroipool::kData], kWriteInplace) + << "DeformablePSROIPooling: Backward doesn't support kWriteInplace."; + CHECK_NE(req[deformablepsroipool::kBox], kWriteInplace) + << "DeformablePSROIPooling: Backward doesn't support kWriteInplace."; // CHECK_NE(req[deformablepsroipool::kTrans], kWriteInplace) << // "DeformablePSROIPooling: Backward doesn't support kWriteInplace."; - Stream *s = ctx.get_stream(); + Stream* s = ctx.get_stream(); Tensor grad_out = out_grad[deformablepsroipool::kOut].get(s); - Tensor data = in_data[deformablepsroipool::kData].get(s); - Tensor bbox = in_data[deformablepsroipool::kBox].get(s); - Tensor top_count = out_data[deformablepsroipool::kTopCount] - .get(s); - Tensor grad_in = in_grad[deformablepsroipool::kData].get(s); + Tensor data = in_data[deformablepsroipool::kData].get(s); + Tensor bbox = in_data[deformablepsroipool::kBox].get(s); + Tensor top_count = + out_data[deformablepsroipool::kTopCount].get(s); + Tensor grad_in = in_grad[deformablepsroipool::kData].get(s); Tensor grad_roi = in_grad[deformablepsroipool::kBox].get(s); Tensor grad_trans{nullptr, mshadow::Shape4(0, 0, 0, 0)}; Tensor trans{nullptr, mshadow::Shape4(0, 0, 0, 0)}; if (!param_.no_trans) { CHECK_EQ(in_grad.size(), 3); - trans = in_data[deformablepsroipool::kTrans].get(s); + trans = in_data[deformablepsroipool::kTrans].get(s); grad_trans = in_grad[deformablepsroipool::kTrans].get(s); } @@ -167,9 +177,21 @@ class DeformablePSROIPoolingOp : public Operator { if (!param_.no_trans) { Assign(grad_trans, req[deformablepsroipool::kTrans], 0); } - DeformablePSROIPoolBackwardAcc(grad_in, grad_trans, grad_out, data, bbox, trans, - top_count, param_.no_trans, param_.spatial_scale, param_.output_dim, param_.group_size, - param_.pooled_size, param_.part_size, param_.sample_per_part, param_.trans_std); + DeformablePSROIPoolBackwardAcc(grad_in, + grad_trans, + grad_out, + data, + bbox, + trans, + top_count, + param_.no_trans, + param_.spatial_scale, + param_.output_dim, + param_.group_size, + param_.pooled_size, + param_.part_size, + param_.sample_per_part, + param_.trans_std); Assign(grad_roi, req[deformablepsroipool::kBox], 0); } @@ -178,7 +200,7 @@ class DeformablePSROIPoolingOp : public Operator { }; // class DeformablePSROIPoolingOp // Decalre Factory function, used for dispatch specialization -template +template Operator* CreateOp(DeformablePSROIPoolingParam param, int dtype); #if DMLC_USE_CXX11 @@ -186,14 +208,14 @@ class DeformablePSROIPoolingProp : public OperatorProperty { public: std::vector ListArguments() const override { if (param_.no_trans) { - return{ "data", "rois" }; + return {"data", "rois"}; } else { - return{ "data", "rois", "trans" }; + return {"data", "rois", "trans"}; } } std::vector ListOutputs() const override { - return{ "output", "top_count" }; + return {"output", "top_count"}; } int NumOutputs() const override { @@ -215,9 +237,9 @@ class DeformablePSROIPoolingProp : public OperatorProperty { return param_.__DICT__(); } - bool InferShape(mxnet::ShapeVector *in_shape, - mxnet::ShapeVector *out_shape, - mxnet::ShapeVector *aux_shape) const override { + bool InferShape(mxnet::ShapeVector* in_shape, + mxnet::ShapeVector* out_shape, + mxnet::ShapeVector* aux_shape) const override { using namespace mshadow; if (param_.no_trans) { CHECK_EQ(in_shape->size(), 2) << "Input:[data, rois]"; @@ -241,15 +263,15 @@ class DeformablePSROIPoolingProp : public OperatorProperty { // top_count: [num_rois, c, pooled_h, pooled_w] out_shape->clear(); out_shape->push_back( - Shape4(bshape[0], param_.output_dim, param_.pooled_size, param_.pooled_size)); + Shape4(bshape[0], param_.output_dim, param_.pooled_size, param_.pooled_size)); out_shape->push_back( - Shape4(bshape[0], param_.output_dim, param_.pooled_size, param_.pooled_size)); + Shape4(bshape[0], param_.output_dim, param_.pooled_size, param_.pooled_size)); return true; } - bool InferType(std::vector *in_type, - std::vector *out_type, - std::vector *aux_type) const override { + bool InferType(std::vector* in_type, + std::vector* out_type, + std::vector* aux_type) const override { CHECK_GE(in_type->size(), 2); int dtype = (*in_type)[0]; CHECK_EQ(dtype, (*in_type)[1]); @@ -263,7 +285,7 @@ class DeformablePSROIPoolingProp : public OperatorProperty { OperatorProperty* Copy() const override { DeformablePSROIPoolingProp* deformable_psroi_pooling_sym = new DeformablePSROIPoolingProp(); - deformable_psroi_pooling_sym->param_ = this->param_; + deformable_psroi_pooling_sym->param_ = this->param_; return deformable_psroi_pooling_sym; } @@ -272,29 +294,31 @@ class DeformablePSROIPoolingProp : public OperatorProperty { } // decalre dependency and inplace optimization options - std::vector DeclareBackwardDependency(const std::vector &out_grad, - const std::vector &in_data, - const std::vector &out_data) const override { + std::vector DeclareBackwardDependency(const std::vector& out_grad, + const std::vector& in_data, + const std::vector& out_data) const override { if (param_.no_trans) { - return{ out_grad[deformablepsroipool::kOut], in_data[deformablepsroipool::kData], - in_data[deformablepsroipool::kBox], out_data[deformablepsroipool::kTopCount] }; + return {out_grad[deformablepsroipool::kOut], + in_data[deformablepsroipool::kData], + in_data[deformablepsroipool::kBox], + out_data[deformablepsroipool::kTopCount]}; } else { - return{ out_grad[deformablepsroipool::kOut], in_data[deformablepsroipool::kData], - in_data[deformablepsroipool::kBox], in_data[deformablepsroipool::kTrans], - out_data[deformablepsroipool::kTopCount] }; + return {out_grad[deformablepsroipool::kOut], + in_data[deformablepsroipool::kData], + in_data[deformablepsroipool::kBox], + in_data[deformablepsroipool::kTrans], + out_data[deformablepsroipool::kTopCount]}; } } - Operator* CreateOperator(Context ctx) const override { LOG(FATAL) << "Not Implemented."; return nullptr; } Operator* CreateOperatorEx(Context ctx, - mxnet::ShapeVector *in_shape, - std::vector *in_type) const override; - + mxnet::ShapeVector* in_shape, + std::vector* in_type) const override; private: DeformablePSROIPoolingParam param_; diff --git a/src/operator/contrib/deformable_psroi_pooling.cc b/src/operator/contrib/deformable_psroi_pooling.cc index 697376dd573f..77c57c752ec1 100644 --- a/src/operator/contrib/deformable_psroi_pooling.cc +++ b/src/operator/contrib/deformable_psroi_pooling.cc @@ -23,7 +23,7 @@ * \file deformable_psroi_pooling.cc * \brief * \author Yi Li, Guodong Zhang, Jifeng Dai -*/ + */ #include "./deformable_psroi_pooling-inl.h" #include #include @@ -31,350 +31,413 @@ #include #include +using std::ceil; +using std::floor; using std::max; using std::min; -using std::floor; -using std::ceil; using std::round; namespace mshadow { - template - inline DType bilinear_interp_cpu(const DType* data, - const DType x, const DType y, - const index_t width, const index_t height) { - index_t x1 = floor(x); - index_t x2 = ceil(x); - index_t y1 = floor(y); - index_t y2 = ceil(y); - DType dist_x = static_cast(x - x1); - DType dist_y = static_cast(y - y1); - DType value11 = data[y1 * width + x1]; - DType value12 = data[y2 * width + x1]; - DType value21 = data[y1 * width + x2]; - DType value22 = data[y2 * width + x2]; - DType value = (1 - dist_x) * (1 - dist_y) * value11 + (1 - dist_x) * dist_y * value12 + - dist_x * (1 - dist_y) * value21 + dist_x * dist_y * value22; - return value; - } - - template - inline void DeformablePSROIPoolForwardCPU(const index_t count, const DType* bottom_data, - const DType spatial_scale, const index_t channels, - const index_t height, const index_t width, - const index_t pooled_height, const index_t pooled_width, - const DType* bottom_rois, const DType* bottom_trans, - const bool no_trans, const DType trans_std, - const index_t sample_per_part, const index_t output_dim, - const index_t group_size, const index_t part_size, - const index_t num_classes, - const index_t channels_each_class, - DType* top_data, DType* top_count) { - const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); +template +inline DType bilinear_interp_cpu(const DType* data, + const DType x, + const DType y, + const index_t width, + const index_t height) { + index_t x1 = floor(x); + index_t x2 = ceil(x); + index_t y1 = floor(y); + index_t y2 = ceil(y); + DType dist_x = static_cast(x - x1); + DType dist_y = static_cast(y - y1); + DType value11 = data[y1 * width + x1]; + DType value12 = data[y2 * width + x1]; + DType value21 = data[y1 * width + x2]; + DType value22 = data[y2 * width + x2]; + DType value = (1 - dist_x) * (1 - dist_y) * value11 + (1 - dist_x) * dist_y * value12 + + dist_x * (1 - dist_y) * value21 + dist_x * dist_y * value22; + return value; +} + +template +inline void DeformablePSROIPoolForwardCPU(const index_t count, + const DType* bottom_data, + const DType spatial_scale, + const index_t channels, + const index_t height, + const index_t width, + const index_t pooled_height, + const index_t pooled_width, + const DType* bottom_rois, + const DType* bottom_trans, + const bool no_trans, + const DType trans_std, + const index_t sample_per_part, + const index_t output_dim, + const index_t group_size, + const index_t part_size, + const index_t num_classes, + const index_t channels_each_class, + DType* top_data, + DType* top_count) { + const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); #pragma omp parallel for num_threads(omp_threads) - for (index_t index = 0; index < count; index++) { - // The output is in order (n, ctop, ph, pw) - index_t pw = index % pooled_width; - index_t ph = (index / pooled_width) % pooled_height; - index_t ctop = (index / pooled_width / pooled_height) % output_dim; - index_t n = index / pooled_width / pooled_height / output_dim; - - // [start, end) interval for spatial sampling - const DType* offset_bottom_rois = bottom_rois + n * 5; - index_t roi_batch_ind = offset_bottom_rois[0]; - DType roi_start_w = static_cast(round(offset_bottom_rois[1])) * spatial_scale - 0.5; - DType roi_start_h = static_cast(round(offset_bottom_rois[2])) * spatial_scale - 0.5; - DType roi_end_w = static_cast(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5; - DType roi_end_h = static_cast(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5; - - // Force too small ROIs to be 1x1 - DType roi_width = max(roi_end_w - roi_start_w, static_cast(0.1)); // avoid 0 - DType roi_height = max(roi_end_h - roi_start_h, static_cast(0.1)); - - // Compute w and h at bottom - DType bin_size_h = roi_height / static_cast(pooled_height); - DType bin_size_w = roi_width / static_cast(pooled_width); - - DType sub_bin_size_h = bin_size_h / static_cast(sample_per_part); - DType sub_bin_size_w = bin_size_w / static_cast(sample_per_part); - - index_t part_h = floor(static_cast(ph) / pooled_height * part_size); - index_t part_w = floor(static_cast(pw) / pooled_width * part_size); - index_t class_id = ctop / channels_each_class; - DType trans_x = no_trans ? static_cast(0) : - bottom_trans[(((n * num_classes + class_id) * 2) - * part_size + part_h) - * part_size + part_w] * trans_std; - DType trans_y = no_trans ? static_cast(0) : - bottom_trans[(((n * num_classes + class_id) * 2 + 1) - * part_size + part_h) - * part_size + part_w] * trans_std; - - DType wstart = static_cast(pw) * bin_size_w + roi_start_w; - wstart += trans_x * roi_width; - DType hstart = static_cast(ph) * bin_size_h + roi_start_h; - hstart += trans_y * roi_height; - - DType sum = 0; - index_t count = 0; - index_t gw = floor(static_cast(pw) * group_size / pooled_width); - index_t gh = floor(static_cast(ph) * group_size / pooled_height); - gw = min(max(gw, static_cast(0)), group_size - 1); - gh = min(max(gh, static_cast(0)), group_size - 1); - - const DType* offset_bottom_data = bottom_data + (roi_batch_ind * channels) * height * width; - for (index_t ih = 0; ih < sample_per_part; ih++) { - for (index_t iw = 0; iw < sample_per_part; iw++) { - DType w = wstart + iw * sub_bin_size_w; - DType h = hstart + ih * sub_bin_size_h; - // bilinear interpolation - if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5) { - continue; - } - w = min(max(w, static_cast(0)), static_cast(width - 1)); - h = min(max(h, static_cast(0)), static_cast(height - 1)); - index_t c = (ctop * group_size + gh) * group_size + gw; - DType val = bilinear_interp_cpu(offset_bottom_data + c * height * width, - w, h, width, height); - sum += val; - count++; + for (index_t index = 0; index < count; index++) { + // The output is in order (n, ctop, ph, pw) + index_t pw = index % pooled_width; + index_t ph = (index / pooled_width) % pooled_height; + index_t ctop = (index / pooled_width / pooled_height) % output_dim; + index_t n = index / pooled_width / pooled_height / output_dim; + + // [start, end) interval for spatial sampling + const DType* offset_bottom_rois = bottom_rois + n * 5; + index_t roi_batch_ind = offset_bottom_rois[0]; + DType roi_start_w = static_cast(round(offset_bottom_rois[1])) * spatial_scale - 0.5; + DType roi_start_h = static_cast(round(offset_bottom_rois[2])) * spatial_scale - 0.5; + DType roi_end_w = static_cast(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5; + DType roi_end_h = static_cast(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5; + + // Force too small ROIs to be 1x1 + DType roi_width = max(roi_end_w - roi_start_w, static_cast(0.1)); // avoid 0 + DType roi_height = max(roi_end_h - roi_start_h, static_cast(0.1)); + + // Compute w and h at bottom + DType bin_size_h = roi_height / static_cast(pooled_height); + DType bin_size_w = roi_width / static_cast(pooled_width); + + DType sub_bin_size_h = bin_size_h / static_cast(sample_per_part); + DType sub_bin_size_w = bin_size_w / static_cast(sample_per_part); + + index_t part_h = floor(static_cast(ph) / pooled_height * part_size); + index_t part_w = floor(static_cast(pw) / pooled_width * part_size); + index_t class_id = ctop / channels_each_class; + DType trans_x = + no_trans + ? static_cast(0) + : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + + part_w] * + trans_std; + DType trans_y = + no_trans ? static_cast(0) + : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * + part_size + + part_w] * + trans_std; + + DType wstart = static_cast(pw) * bin_size_w + roi_start_w; + wstart += trans_x * roi_width; + DType hstart = static_cast(ph) * bin_size_h + roi_start_h; + hstart += trans_y * roi_height; + + DType sum = 0; + index_t count = 0; + index_t gw = floor(static_cast(pw) * group_size / pooled_width); + index_t gh = floor(static_cast(ph) * group_size / pooled_height); + gw = min(max(gw, static_cast(0)), group_size - 1); + gh = min(max(gh, static_cast(0)), group_size - 1); + + const DType* offset_bottom_data = bottom_data + (roi_batch_ind * channels) * height * width; + for (index_t ih = 0; ih < sample_per_part; ih++) { + for (index_t iw = 0; iw < sample_per_part; iw++) { + DType w = wstart + iw * sub_bin_size_w; + DType h = hstart + ih * sub_bin_size_h; + // bilinear interpolation + if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5) { + continue; } + w = min(max(w, static_cast(0)), static_cast(width - 1)); + h = min(max(h, static_cast(0)), static_cast(height - 1)); + index_t c = (ctop * group_size + gh) * group_size + gw; + DType val = + bilinear_interp_cpu(offset_bottom_data + c * height * width, w, h, width, height); + sum += val; + count++; } - top_data[index] = count == 0 ? static_cast(0) : sum / count; - top_count[index] = count; } + top_data[index] = count == 0 ? static_cast(0) : sum / count; + top_count[index] = count; } - - template - inline void DeformablePSROIPoolForward(const Tensor &out, - const Tensor &data, - const Tensor &bbox, - const Tensor &trans, - const Tensor &top_count, - const bool no_trans, const float spatial_scale, - const index_t output_dim, const index_t group_size, - const index_t pooled_size, const index_t part_size, - const index_t sample_per_part, const float trans_std) { - const DType *bottom_data = data.dptr_; - const DType *bottom_rois = bbox.dptr_; - const DType *bottom_trans = no_trans ? nullptr : trans.dptr_; - DType *top_data = out.dptr_; - DType *top_count_data = top_count.dptr_; - const index_t count = out.shape_.Size(); - const index_t channels = data.size(1); - const index_t height = data.size(2); - const index_t width = data.size(3); - const index_t pooled_height = pooled_size; - const index_t pooled_width = pooled_size; - const index_t num_classes = no_trans ? 1 : trans.size(1) / 2; - const index_t channels_each_class = no_trans ? output_dim : output_dim / num_classes; - DeformablePSROIPoolForwardCPU(count, bottom_data, spatial_scale, - channels, height, width, - pooled_height, pooled_width, - bottom_rois, bottom_trans, - no_trans, trans_std, sample_per_part, - output_dim, group_size, part_size, num_classes, - channels_each_class, top_data, top_count_data); - - return; - } - - template - inline void DeformablePSROIPoolBackwardAccCPU(const index_t count, const DType* top_diff, - const DType* top_count, const index_t num_rois, - const DType spatial_scale, const index_t channels, - const index_t height, const index_t width, - const index_t pooled_height, - const index_t pooled_width, - const index_t output_dim, - DType* bottom_data_diff, - DType* bottom_trans_diff, - const DType* bottom_data, - const DType* bottom_rois, - const DType* bottom_trans, - const bool no_trans, - const DType trans_std, - const index_t sample_per_part, - const index_t group_size, - const index_t part_size, - const index_t num_classes, - const index_t channels_each_class) { - for (index_t index = 0; index < count; index++) { - // The output is in order (n, ctop, ph, pw) - index_t pw = index % pooled_width; - index_t ph = (index / pooled_width) % pooled_height; - index_t ctop = (index / pooled_width / pooled_height) % output_dim; - index_t n = index / pooled_width / pooled_height / output_dim; - - // [start, end) interval for spatial sampling - const DType* offset_bottom_rois = bottom_rois + n * 5; - index_t roi_batch_ind = offset_bottom_rois[0]; - DType roi_start_w = static_cast(round(offset_bottom_rois[1])) * spatial_scale - 0.5; - DType roi_start_h = static_cast(round(offset_bottom_rois[2])) * spatial_scale - 0.5; - DType roi_end_w = static_cast(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5; - DType roi_end_h = static_cast(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5; - - // Force too small ROIs to be 1x1 - DType roi_width = max(roi_end_w - roi_start_w, static_cast(0.1)); // avoid 0 - DType roi_height = max(roi_end_h - roi_start_h, static_cast(0.1)); - - // Compute w and h at bottom - DType bin_size_h = roi_height / static_cast(pooled_height); - DType bin_size_w = roi_width / static_cast(pooled_width); - - DType sub_bin_size_h = bin_size_h / static_cast(sample_per_part); - DType sub_bin_size_w = bin_size_w / static_cast(sample_per_part); - - index_t part_h = floor(static_cast(ph) / pooled_height * part_size); - index_t part_w = floor(static_cast(pw) / pooled_width * part_size); - index_t class_id = ctop / channels_each_class; - DType trans_x = no_trans ? static_cast(0) : - bottom_trans[(((n * num_classes + class_id) * 2) - * part_size + part_h) - * part_size + part_w] * trans_std; - DType trans_y = no_trans ? static_cast(0) : - bottom_trans[(((n * num_classes + class_id) * 2 + 1) - * part_size + part_h) - * part_size + part_w] * trans_std; - - DType wstart = static_cast(pw) * bin_size_w + roi_start_w; - wstart += trans_x * roi_width; - DType hstart = static_cast(ph) * bin_size_h + roi_start_h; - hstart += trans_y * roi_height; - - if (top_count[index] <= 0) { - continue; - } - DType diff_val = top_diff[index] / top_count[index]; - const DType* offset_bottom_data = bottom_data + roi_batch_ind * channels * height * width; - DType* offset_bottom_data_diff = bottom_data_diff + roi_batch_ind * channels * height * width; - index_t gw = floor(static_cast(pw)* group_size / pooled_width); - index_t gh = floor(static_cast(ph)* group_size / pooled_height); - gw = min(max(gw, static_cast(0)), group_size - 1); - gh = min(max(gh, static_cast(0)), group_size - 1); - - for (index_t ih = 0; ih < sample_per_part; ih++) { - for (index_t iw = 0; iw < sample_per_part; iw++) { - DType w = wstart + iw * sub_bin_size_w; - DType h = hstart + ih * sub_bin_size_h; - // bilinear interpolation - if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5) { - continue; - } - w = min(max(w, static_cast(0)), static_cast(width - 1)); - h = min(max(h, static_cast(0)), static_cast(height - 1)); - index_t c = (ctop * group_size + gh) * group_size + gw; - // backward on feature - index_t x0 = floor(w); - index_t x1 = ceil(w); - index_t y0 = floor(h); - index_t y1 = ceil(h); - DType dist_x = w - x0, dist_y = h - y0; - DType q00 = (1 - dist_x) * (1 - dist_y); - DType q01 = (1 - dist_x) * dist_y; - DType q10 = dist_x * (1 - dist_y); - DType q11 = dist_x * dist_y; - index_t bottom_index_base = c * height * width; - offset_bottom_data_diff[bottom_index_base + y0 * width + x0] += q00 * diff_val; - offset_bottom_data_diff[bottom_index_base + y1 * width + x0] += q01 * diff_val; - offset_bottom_data_diff[bottom_index_base + y0 * width + x1] += q10 * diff_val; - offset_bottom_data_diff[bottom_index_base + y1 * width + x1] += q11 * diff_val; - - if (no_trans) { - continue; - } - DType U00 = offset_bottom_data[bottom_index_base + y0 * width + x0]; - DType U01 = offset_bottom_data[bottom_index_base + y1 * width + x0]; - DType U10 = offset_bottom_data[bottom_index_base + y0 * width + x1]; - DType U11 = offset_bottom_data[bottom_index_base + y1 * width + x1]; - DType diff_x = U11 * dist_y + U10 * (1 - dist_y) - U01 * dist_y - U00 * (1 - dist_y); - diff_x *= trans_std * diff_val * roi_width; - DType diff_y = U11 * dist_x + U01 * (1 - dist_x) - U10 * dist_x - U00 * (1 - dist_x); - diff_y *= trans_std * diff_val * roi_height; - - index_t offset_trans_diff = (((n * num_classes + class_id) * 2) - * part_size + part_h) * part_size + part_w; - bottom_trans_diff[offset_trans_diff] += diff_x; - bottom_trans_diff[offset_trans_diff + part_size * part_size] += diff_y; +} + +template +inline void DeformablePSROIPoolForward(const Tensor& out, + const Tensor& data, + const Tensor& bbox, + const Tensor& trans, + const Tensor& top_count, + const bool no_trans, + const float spatial_scale, + const index_t output_dim, + const index_t group_size, + const index_t pooled_size, + const index_t part_size, + const index_t sample_per_part, + const float trans_std) { + const DType* bottom_data = data.dptr_; + const DType* bottom_rois = bbox.dptr_; + const DType* bottom_trans = no_trans ? nullptr : trans.dptr_; + DType* top_data = out.dptr_; + DType* top_count_data = top_count.dptr_; + const index_t count = out.shape_.Size(); + const index_t channels = data.size(1); + const index_t height = data.size(2); + const index_t width = data.size(3); + const index_t pooled_height = pooled_size; + const index_t pooled_width = pooled_size; + const index_t num_classes = no_trans ? 1 : trans.size(1) / 2; + const index_t channels_each_class = no_trans ? output_dim : output_dim / num_classes; + DeformablePSROIPoolForwardCPU(count, + bottom_data, + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + bottom_rois, + bottom_trans, + no_trans, + trans_std, + sample_per_part, + output_dim, + group_size, + part_size, + num_classes, + channels_each_class, + top_data, + top_count_data); + + return; +} + +template +inline void DeformablePSROIPoolBackwardAccCPU(const index_t count, + const DType* top_diff, + const DType* top_count, + const index_t num_rois, + const DType spatial_scale, + const index_t channels, + const index_t height, + const index_t width, + const index_t pooled_height, + const index_t pooled_width, + const index_t output_dim, + DType* bottom_data_diff, + DType* bottom_trans_diff, + const DType* bottom_data, + const DType* bottom_rois, + const DType* bottom_trans, + const bool no_trans, + const DType trans_std, + const index_t sample_per_part, + const index_t group_size, + const index_t part_size, + const index_t num_classes, + const index_t channels_each_class) { + for (index_t index = 0; index < count; index++) { + // The output is in order (n, ctop, ph, pw) + index_t pw = index % pooled_width; + index_t ph = (index / pooled_width) % pooled_height; + index_t ctop = (index / pooled_width / pooled_height) % output_dim; + index_t n = index / pooled_width / pooled_height / output_dim; + + // [start, end) interval for spatial sampling + const DType* offset_bottom_rois = bottom_rois + n * 5; + index_t roi_batch_ind = offset_bottom_rois[0]; + DType roi_start_w = static_cast(round(offset_bottom_rois[1])) * spatial_scale - 0.5; + DType roi_start_h = static_cast(round(offset_bottom_rois[2])) * spatial_scale - 0.5; + DType roi_end_w = static_cast(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5; + DType roi_end_h = static_cast(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5; + + // Force too small ROIs to be 1x1 + DType roi_width = max(roi_end_w - roi_start_w, static_cast(0.1)); // avoid 0 + DType roi_height = max(roi_end_h - roi_start_h, static_cast(0.1)); + + // Compute w and h at bottom + DType bin_size_h = roi_height / static_cast(pooled_height); + DType bin_size_w = roi_width / static_cast(pooled_width); + + DType sub_bin_size_h = bin_size_h / static_cast(sample_per_part); + DType sub_bin_size_w = bin_size_w / static_cast(sample_per_part); + + index_t part_h = floor(static_cast(ph) / pooled_height * part_size); + index_t part_w = floor(static_cast(pw) / pooled_width * part_size); + index_t class_id = ctop / channels_each_class; + DType trans_x = + no_trans + ? static_cast(0) + : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + + part_w] * + trans_std; + DType trans_y = + no_trans ? static_cast(0) + : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * + part_size + + part_w] * + trans_std; + + DType wstart = static_cast(pw) * bin_size_w + roi_start_w; + wstart += trans_x * roi_width; + DType hstart = static_cast(ph) * bin_size_h + roi_start_h; + hstart += trans_y * roi_height; + + if (top_count[index] <= 0) { + continue; + } + DType diff_val = top_diff[index] / top_count[index]; + const DType* offset_bottom_data = bottom_data + roi_batch_ind * channels * height * width; + DType* offset_bottom_data_diff = bottom_data_diff + roi_batch_ind * channels * height * width; + index_t gw = floor(static_cast(pw) * group_size / pooled_width); + index_t gh = floor(static_cast(ph) * group_size / pooled_height); + gw = min(max(gw, static_cast(0)), group_size - 1); + gh = min(max(gh, static_cast(0)), group_size - 1); + + for (index_t ih = 0; ih < sample_per_part; ih++) { + for (index_t iw = 0; iw < sample_per_part; iw++) { + DType w = wstart + iw * sub_bin_size_w; + DType h = hstart + ih * sub_bin_size_h; + // bilinear interpolation + if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5) { + continue; + } + w = min(max(w, static_cast(0)), static_cast(width - 1)); + h = min(max(h, static_cast(0)), static_cast(height - 1)); + index_t c = (ctop * group_size + gh) * group_size + gw; + // backward on feature + index_t x0 = floor(w); + index_t x1 = ceil(w); + index_t y0 = floor(h); + index_t y1 = ceil(h); + DType dist_x = w - x0, dist_y = h - y0; + DType q00 = (1 - dist_x) * (1 - dist_y); + DType q01 = (1 - dist_x) * dist_y; + DType q10 = dist_x * (1 - dist_y); + DType q11 = dist_x * dist_y; + index_t bottom_index_base = c * height * width; + offset_bottom_data_diff[bottom_index_base + y0 * width + x0] += q00 * diff_val; + offset_bottom_data_diff[bottom_index_base + y1 * width + x0] += q01 * diff_val; + offset_bottom_data_diff[bottom_index_base + y0 * width + x1] += q10 * diff_val; + offset_bottom_data_diff[bottom_index_base + y1 * width + x1] += q11 * diff_val; + + if (no_trans) { + continue; } + DType U00 = offset_bottom_data[bottom_index_base + y0 * width + x0]; + DType U01 = offset_bottom_data[bottom_index_base + y1 * width + x0]; + DType U10 = offset_bottom_data[bottom_index_base + y0 * width + x1]; + DType U11 = offset_bottom_data[bottom_index_base + y1 * width + x1]; + DType diff_x = U11 * dist_y + U10 * (1 - dist_y) - U01 * dist_y - U00 * (1 - dist_y); + diff_x *= trans_std * diff_val * roi_width; + DType diff_y = U11 * dist_x + U01 * (1 - dist_x) - U10 * dist_x - U00 * (1 - dist_x); + diff_y *= trans_std * diff_val * roi_height; + + index_t offset_trans_diff = + (((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w; + bottom_trans_diff[offset_trans_diff] += diff_x; + bottom_trans_diff[offset_trans_diff + part_size * part_size] += diff_y; } } } - - template - inline void DeformablePSROIPoolBackwardAcc(const Tensor &in_grad, - const Tensor &trans_grad, - const Tensor &out_grad, - const Tensor &data, - const Tensor &bbox, - const Tensor &trans, - const Tensor &top_count, - const bool no_trans, const float spatial_scale, - const index_t output_dim, const index_t group_size, - const index_t pooled_size, const index_t part_size, - const index_t sample_per_part, const float trans_std) { - const DType *top_diff = out_grad.dptr_; - const DType *bottom_data = data.dptr_; - const DType *bottom_rois = bbox.dptr_; - const DType *bottom_trans = no_trans ? nullptr : trans.dptr_; - DType *bottom_data_diff = in_grad.dptr_; - DType *bottom_trans_diff = no_trans ? nullptr : trans_grad.dptr_; - const DType *top_count_data = top_count.dptr_; - const index_t count = out_grad.shape_.Size(); - const index_t num_rois = bbox.size(0); - const index_t channels = in_grad.size(1); - const index_t height = in_grad.size(2); - const index_t width = in_grad.size(3); - const index_t pooled_height = pooled_size; - const index_t pooled_width = pooled_size; - const index_t num_classes = no_trans ? 1 : trans_grad.size(1) / 2; - const index_t channels_each_class = no_trans ? output_dim : output_dim / num_classes; - DeformablePSROIPoolBackwardAccCPU(count, top_diff, top_count_data, num_rois, - spatial_scale, channels, height, width, - pooled_height, pooled_width, output_dim, - bottom_data_diff, bottom_trans_diff, - bottom_data, bottom_rois, bottom_trans, - no_trans, trans_std, sample_per_part, - group_size, part_size, num_classes, - channels_each_class); - - return; - } +} + +template +inline void DeformablePSROIPoolBackwardAcc(const Tensor& in_grad, + const Tensor& trans_grad, + const Tensor& out_grad, + const Tensor& data, + const Tensor& bbox, + const Tensor& trans, + const Tensor& top_count, + const bool no_trans, + const float spatial_scale, + const index_t output_dim, + const index_t group_size, + const index_t pooled_size, + const index_t part_size, + const index_t sample_per_part, + const float trans_std) { + const DType* top_diff = out_grad.dptr_; + const DType* bottom_data = data.dptr_; + const DType* bottom_rois = bbox.dptr_; + const DType* bottom_trans = no_trans ? nullptr : trans.dptr_; + DType* bottom_data_diff = in_grad.dptr_; + DType* bottom_trans_diff = no_trans ? nullptr : trans_grad.dptr_; + const DType* top_count_data = top_count.dptr_; + const index_t count = out_grad.shape_.Size(); + const index_t num_rois = bbox.size(0); + const index_t channels = in_grad.size(1); + const index_t height = in_grad.size(2); + const index_t width = in_grad.size(3); + const index_t pooled_height = pooled_size; + const index_t pooled_width = pooled_size; + const index_t num_classes = no_trans ? 1 : trans_grad.size(1) / 2; + const index_t channels_each_class = no_trans ? output_dim : output_dim / num_classes; + DeformablePSROIPoolBackwardAccCPU(count, + top_diff, + top_count_data, + num_rois, + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + output_dim, + bottom_data_diff, + bottom_trans_diff, + bottom_data, + bottom_rois, + bottom_trans, + no_trans, + trans_std, + sample_per_part, + group_size, + part_size, + num_classes, + channels_each_class); + + return; +} } // namespace mshadow namespace mxnet { namespace op { - template<> - Operator *CreateOp(DeformablePSROIPoolingParam param, int dtype) { - Operator* op = nullptr; - MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { - op = new DeformablePSROIPoolingOp(param); - }); - return op; - } - - Operator *DeformablePSROIPoolingProp::CreateOperatorEx(Context ctx, - mxnet::ShapeVector *in_shape, - std::vector *in_type) const { - mxnet::ShapeVector out_shape, aux_shape; - std::vector out_type, aux_type; - CHECK(InferType(in_type, &out_type, &aux_type)); - CHECK(InferShape(in_shape, &out_shape, &aux_shape)); - DO_BIND_DISPATCH(CreateOp, param_, in_type->at(0)); - } - - DMLC_REGISTER_PARAMETER(DeformablePSROIPoolingParam); - - MXNET_REGISTER_OP_PROPERTY(_contrib_DeformablePSROIPooling, DeformablePSROIPoolingProp) - .describe("Performs deformable position-sensitive region-of-interest pooling on inputs.\n" - "The DeformablePSROIPooling operation is described in https://arxiv.org/abs/1703.06211 ." - "batch_size will change to the number of region bounding boxes after DeformablePSROIPooling") +template <> +Operator* CreateOp(DeformablePSROIPoolingParam param, int dtype) { + Operator* op = nullptr; + MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { op = new DeformablePSROIPoolingOp(param); }); + return op; +} + +Operator* DeformablePSROIPoolingProp::CreateOperatorEx(Context ctx, + mxnet::ShapeVector* in_shape, + std::vector* in_type) const { + mxnet::ShapeVector out_shape, aux_shape; + std::vector out_type, aux_type; + CHECK(InferType(in_type, &out_type, &aux_type)); + CHECK(InferShape(in_shape, &out_shape, &aux_shape)); + DO_BIND_DISPATCH(CreateOp, param_, in_type->at(0)); +} + +DMLC_REGISTER_PARAMETER(DeformablePSROIPoolingParam); + +MXNET_REGISTER_OP_PROPERTY(_contrib_DeformablePSROIPooling, DeformablePSROIPoolingProp) + .describe( + "Performs deformable position-sensitive region-of-interest pooling on inputs.\n" + "The DeformablePSROIPooling operation is described in https://arxiv.org/abs/1703.06211 ." + "batch_size will change to the number of region bounding boxes after " + "DeformablePSROIPooling") .add_argument("data", "Symbol", "Input data to the pooling operator, a 4D Feature maps") - .add_argument("rois", "Symbol", "Bounding box coordinates, a 2D array of " - "[[batch_index, x1, y1, x2, y2]]. (x1, y1) and (x2, y2) are top left and down right corners " - "of designated region of interest. batch_index indicates the index of corresponding image " - "in the input data") + .add_argument( + "rois", + "Symbol", + "Bounding box coordinates, a 2D array of " + "[[batch_index, x1, y1, x2, y2]]. (x1, y1) and (x2, y2) are top left and down right " + "corners " + "of designated region of interest. batch_index indicates the index of corresponding image " + "in the input data") .add_argument("trans", "Symbol", "transition parameter") .add_arguments(DeformablePSROIPoolingParam::__FIELDS__()); } // namespace op diff --git a/src/operator/contrib/deformable_psroi_pooling.cu b/src/operator/contrib/deformable_psroi_pooling.cu index 2206b5aa67b3..b1888ad6098f 100644 --- a/src/operator/contrib/deformable_psroi_pooling.cu +++ b/src/operator/contrib/deformable_psroi_pooling.cu @@ -23,7 +23,7 @@ * \file deformable_psroi_pooling.cu * \brief * \author Yi Li, Guodong Zhang, Jifeng Dai -*/ + */ #include "./deformable_psroi_pooling-inl.h" #include #include @@ -32,381 +32,458 @@ #include "../../common/cuda/utils.h" #include "../mxnet_op.h" -#define DeformablePSROIPOOLING_CUDA_CHECK(condition) \ - /* Code block avoids redefinition of cudaError_t error */ \ - do { \ - cudaError_t error = condition; \ +#define DeformablePSROIPOOLING_CUDA_CHECK(condition) \ + /* Code block avoids redefinition of cudaError_t error */ \ + do { \ + cudaError_t error = condition; \ CHECK_EQ(error, cudaSuccess) << " " << cudaGetErrorString(error); \ } while (0) namespace mshadow { namespace cuda { - template - __device__ DType bilinear_interp(const DType* data, - const DType x, const DType y, - const index_t width, const index_t height) { - index_t x1 = floor(x); - index_t x2 = ceil(x); - index_t y1 = floor(y); - index_t y2 = ceil(y); - DType dist_x = static_cast(x - x1); - DType dist_y = static_cast(y - y1); - DType value11 = data[y1 * width + x1]; - DType value12 = data[y2 * width + x1]; - DType value21 = data[y1 * width + x2]; - DType value22 = data[y2 * width + x2]; - DType value = (1 - dist_x) * (1 - dist_y) * value11 + (1 - dist_x) * dist_y * value12 + - dist_x * (1 - dist_y) * value21 + dist_x * dist_y * value22; - return value; - } - - template - __global__ void DeformablePSROIPoolForwardKernel(const index_t count, - const DType* bottom_data, - const DType spatial_scale, - const index_t channels, - const index_t height, const index_t width, - const index_t pooled_height, - const index_t pooled_width, - const DType* bottom_rois, - const DType* bottom_trans, - const bool no_trans, const DType trans_std, - const index_t sample_per_part, - const index_t output_dim, - const index_t group_size, - const index_t part_size, - const index_t num_classes, - const index_t channels_each_class, - DType* top_data, DType* top_count) { - CUDA_KERNEL_LOOP(index, count) { - // The output is in order (n, ctop, ph, pw) - index_t pw = index % pooled_width; - index_t ph = (index / pooled_width) % pooled_height; - index_t ctop = (index / pooled_width / pooled_height) % output_dim; - index_t n = index / pooled_width / pooled_height / output_dim; - - // [start, end) interval for spatial sampling - const DType* offset_bottom_rois = bottom_rois + n * 5; - index_t roi_batch_ind = offset_bottom_rois[0]; - DType roi_start_w = static_cast(round(offset_bottom_rois[1])) * spatial_scale - 0.5; - DType roi_start_h = static_cast(round(offset_bottom_rois[2])) * spatial_scale - 0.5; - DType roi_end_w = static_cast(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5; - DType roi_end_h = static_cast(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5; - - // Force too small ROIs to be 1x1 - DType roi_width = max(roi_end_w - roi_start_w, 0.1); // avoid 0 - DType roi_height = max(roi_end_h - roi_start_h, 0.1); - - // Compute w and h at bottom - DType bin_size_h = roi_height / static_cast(pooled_height); - DType bin_size_w = roi_width / static_cast(pooled_width); - - DType sub_bin_size_h = bin_size_h / static_cast(sample_per_part); - DType sub_bin_size_w = bin_size_w / static_cast(sample_per_part); - - index_t part_h = floor(static_cast(ph) / pooled_height * part_size); - index_t part_w = floor(static_cast(pw) / pooled_width * part_size); - index_t class_id = ctop / channels_each_class; - DType trans_x = no_trans ? static_cast(0) : - bottom_trans[(((n * num_classes + class_id) * 2) - * part_size + part_h) - * part_size + part_w] * trans_std; - DType trans_y = no_trans ? static_cast(0) : - bottom_trans[(((n * num_classes + class_id) * 2 + 1) - * part_size + part_h) - * part_size + part_w] * trans_std; - - DType wstart = static_cast(pw) * bin_size_w + roi_start_w; - wstart += trans_x * roi_width; - DType hstart = static_cast(ph) * bin_size_h + roi_start_h; - hstart += trans_y * roi_height; - - DType sum = 0; - index_t count = 0; - index_t gw = floor(static_cast(pw) * group_size / pooled_width); - index_t gh = floor(static_cast(ph) * group_size / pooled_height); - gw = min(max(gw, static_cast(0)), group_size - 1); - gh = min(max(gh, static_cast(0)), group_size - 1); - - const DType* offset_bottom_data = bottom_data + (roi_batch_ind * channels) * height * width; - for (index_t ih = 0; ih < sample_per_part; ih++) { - for (index_t iw = 0; iw < sample_per_part; iw++) { - DType w = wstart + iw * sub_bin_size_w; - DType h = hstart + ih * sub_bin_size_h; - // bilinear interpolation - if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5) { - continue; - } - w = min(max(w, 0.), width - 1.); - h = min(max(h, 0.), height - 1.); - index_t c = (ctop * group_size + gh) * group_size + gw; - DType val = bilinear_interp(offset_bottom_data + c * height * width, - w, h, width, height); - sum += val; - count++; +template +__device__ DType bilinear_interp(const DType* data, + const DType x, + const DType y, + const index_t width, + const index_t height) { + index_t x1 = floor(x); + index_t x2 = ceil(x); + index_t y1 = floor(y); + index_t y2 = ceil(y); + DType dist_x = static_cast(x - x1); + DType dist_y = static_cast(y - y1); + DType value11 = data[y1 * width + x1]; + DType value12 = data[y2 * width + x1]; + DType value21 = data[y1 * width + x2]; + DType value22 = data[y2 * width + x2]; + DType value = (1 - dist_x) * (1 - dist_y) * value11 + (1 - dist_x) * dist_y * value12 + + dist_x * (1 - dist_y) * value21 + dist_x * dist_y * value22; + return value; +} + +template +__global__ void DeformablePSROIPoolForwardKernel(const index_t count, + const DType* bottom_data, + const DType spatial_scale, + const index_t channels, + const index_t height, + const index_t width, + const index_t pooled_height, + const index_t pooled_width, + const DType* bottom_rois, + const DType* bottom_trans, + const bool no_trans, + const DType trans_std, + const index_t sample_per_part, + const index_t output_dim, + const index_t group_size, + const index_t part_size, + const index_t num_classes, + const index_t channels_each_class, + DType* top_data, + DType* top_count) { + CUDA_KERNEL_LOOP(index, count) { + // The output is in order (n, ctop, ph, pw) + index_t pw = index % pooled_width; + index_t ph = (index / pooled_width) % pooled_height; + index_t ctop = (index / pooled_width / pooled_height) % output_dim; + index_t n = index / pooled_width / pooled_height / output_dim; + + // [start, end) interval for spatial sampling + const DType* offset_bottom_rois = bottom_rois + n * 5; + index_t roi_batch_ind = offset_bottom_rois[0]; + DType roi_start_w = static_cast(round(offset_bottom_rois[1])) * spatial_scale - 0.5; + DType roi_start_h = static_cast(round(offset_bottom_rois[2])) * spatial_scale - 0.5; + DType roi_end_w = static_cast(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5; + DType roi_end_h = static_cast(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5; + + // Force too small ROIs to be 1x1 + DType roi_width = max(roi_end_w - roi_start_w, 0.1); // avoid 0 + DType roi_height = max(roi_end_h - roi_start_h, 0.1); + + // Compute w and h at bottom + DType bin_size_h = roi_height / static_cast(pooled_height); + DType bin_size_w = roi_width / static_cast(pooled_width); + + DType sub_bin_size_h = bin_size_h / static_cast(sample_per_part); + DType sub_bin_size_w = bin_size_w / static_cast(sample_per_part); + + index_t part_h = floor(static_cast(ph) / pooled_height * part_size); + index_t part_w = floor(static_cast(pw) / pooled_width * part_size); + index_t class_id = ctop / channels_each_class; + DType trans_x = + no_trans + ? static_cast(0) + : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + + part_w] * + trans_std; + DType trans_y = + no_trans ? static_cast(0) + : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * + part_size + + part_w] * + trans_std; + + DType wstart = static_cast(pw) * bin_size_w + roi_start_w; + wstart += trans_x * roi_width; + DType hstart = static_cast(ph) * bin_size_h + roi_start_h; + hstart += trans_y * roi_height; + + DType sum = 0; + index_t count = 0; + index_t gw = floor(static_cast(pw) * group_size / pooled_width); + index_t gh = floor(static_cast(ph) * group_size / pooled_height); + gw = min(max(gw, static_cast(0)), group_size - 1); + gh = min(max(gh, static_cast(0)), group_size - 1); + + const DType* offset_bottom_data = bottom_data + (roi_batch_ind * channels) * height * width; + for (index_t ih = 0; ih < sample_per_part; ih++) { + for (index_t iw = 0; iw < sample_per_part; iw++) { + DType w = wstart + iw * sub_bin_size_w; + DType h = hstart + ih * sub_bin_size_h; + // bilinear interpolation + if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5) { + continue; } + w = min(max(w, 0.), width - 1.); + h = min(max(h, 0.), height - 1.); + index_t c = (ctop * group_size + gh) * group_size + gw; + DType val = bilinear_interp(offset_bottom_data + c * height * width, w, h, width, height); + sum += val; + count++; } - top_data[index] = count == 0 ? static_cast(0) : sum / count; - top_count[index] = count; } + top_data[index] = count == 0 ? static_cast(0) : sum / count; + top_count[index] = count; } - - template - inline void DeformablePSROIPoolForward(const Tensor &out, - const Tensor &data, - const Tensor &bbox, - const Tensor &trans, - const Tensor &top_count, - const bool no_trans, const float spatial_scale, - const index_t output_dim, const index_t group_size, - const index_t pooled_size, const index_t part_size, - const index_t sample_per_part, const float trans_std) { - const DType *bottom_data = data.dptr_; - const DType *bottom_rois = bbox.dptr_; - const DType *bottom_trans = no_trans ? nullptr : trans.dptr_; - DType *top_data = out.dptr_; - DType *top_count_data = top_count.dptr_; - const index_t count = out.shape_.Size(); - const index_t channels = data.size(1); - const index_t height = data.size(2); - const index_t width = data.size(3); - const index_t pooled_height = pooled_size; - const index_t pooled_width = pooled_size; - const index_t num_classes = no_trans ? 1 : trans.size(1) / 2; - const index_t channels_each_class = no_trans ? output_dim : output_dim / num_classes; - - cudaStream_t stream = Stream::GetStream(out.stream_); - DeformablePSROIPoolForwardKernel<<< - mxnet::op::mxnet_op::cuda_get_num_blocks(count), kBaseThreadNum, - 0, stream>>>(count, bottom_data, spatial_scale, channels, height, width, - pooled_height, pooled_width, bottom_rois, bottom_trans, - no_trans, trans_std, sample_per_part, output_dim, - group_size, part_size, num_classes, - channels_each_class, top_data, top_count_data); - DeformablePSROIPOOLING_CUDA_CHECK(cudaGetLastError()); - } - - - template - __global__ void DeformablePSROIPoolBackwardAccKernel(const index_t count, - const DType* top_diff, - const DType* top_count, - const index_t num_rois, - const DType spatial_scale, - const index_t channels, - const index_t height, - const index_t width, - const index_t pooled_height, - const index_t pooled_width, - const index_t output_dim, - DType* bottom_data_diff, - DType* bottom_trans_diff, - const DType* bottom_data, - const DType* bottom_rois, - const DType* bottom_trans, - const bool no_trans, - const DType trans_std, - const index_t sample_per_part, - const index_t group_size, - const index_t part_size, - const index_t num_classes, - const index_t channels_each_class) { - CUDA_KERNEL_LOOP(index, count) { - // The output is in order (n, ctop, ph, pw) - index_t pw = index % pooled_width; - index_t ph = (index / pooled_width) % pooled_height; - index_t ctop = (index / pooled_width / pooled_height) % output_dim; - index_t n = index / pooled_width / pooled_height / output_dim; - - // [start, end) interval for spatial sampling - const DType* offset_bottom_rois = bottom_rois + n * 5; - index_t roi_batch_ind = offset_bottom_rois[0]; - DType roi_start_w = static_cast(round(offset_bottom_rois[1])) * spatial_scale - 0.5; - DType roi_start_h = static_cast(round(offset_bottom_rois[2])) * spatial_scale - 0.5; - DType roi_end_w = static_cast(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5; - DType roi_end_h = static_cast(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5; - - // Force too small ROIs to be 1x1 - DType roi_width = max(roi_end_w - roi_start_w, 0.1); // avoid 0 - DType roi_height = max(roi_end_h - roi_start_h, 0.1); - - // Compute w and h at bottom - DType bin_size_h = roi_height / static_cast(pooled_height); - DType bin_size_w = roi_width / static_cast(pooled_width); - - DType sub_bin_size_h = bin_size_h / static_cast(sample_per_part); - DType sub_bin_size_w = bin_size_w / static_cast(sample_per_part); - - index_t part_h = floor(static_cast(ph) / pooled_height * part_size); - index_t part_w = floor(static_cast(pw) / pooled_width * part_size); - index_t class_id = ctop / channels_each_class; - DType trans_x = no_trans ? static_cast(0) : - bottom_trans[(((n * num_classes + class_id) * 2) - * part_size + part_h) - * part_size + part_w] * trans_std; - DType trans_y = no_trans ? static_cast(0) : - bottom_trans[(((n * num_classes + class_id) * 2 + 1) - * part_size + part_h) - * part_size + part_w] * trans_std; - - DType wstart = static_cast(pw) * bin_size_w + roi_start_w; - wstart += trans_x * roi_width; - DType hstart = static_cast(ph) * bin_size_h + roi_start_h; - hstart += trans_y * roi_height; - - if (top_count[index] <= 0) { - continue; - } - DType diff_val = top_diff[index] / top_count[index]; - const DType* offset_bottom_data = bottom_data + roi_batch_ind * channels * height * width; - DType* offset_bottom_data_diff = bottom_data_diff + roi_batch_ind * channels * height * width; - index_t gw = floor(static_cast(pw) * group_size / pooled_width); - index_t gh = floor(static_cast(ph) * group_size / pooled_height); - gw = min(max(gw, static_cast(0)), group_size - 1); - gh = min(max(gh, static_cast(0)), group_size - 1); - - for (index_t ih = 0; ih < sample_per_part; ih++) { - for (index_t iw = 0; iw < sample_per_part; iw++) { - DType w = wstart + iw * sub_bin_size_w; - DType h = hstart + ih * sub_bin_size_h; - // bilinear interpolation - if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5) { - continue; - } - w = min(max(w, 0.), width - 1.); - h = min(max(h, 0.), height - 1.); - index_t c = (ctop * group_size + gh) * group_size + gw; - // backward on feature - index_t x0 = floor(w); - index_t x1 = ceil(w); - index_t y0 = floor(h); - index_t y1 = ceil(h); - DType dist_x = w - x0, dist_y = h - y0; - DType q00 = (1 - dist_x) * (1 - dist_y); - DType q01 = (1 - dist_x) * dist_y; - DType q10 = dist_x * (1 - dist_y); - DType q11 = dist_x * dist_y; - index_t bottom_index_base = c * height * width; - atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x0, q00 * diff_val); - atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x0, q01 * diff_val); - atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x1, q10 * diff_val); - atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x1, q11 * diff_val); - - if (no_trans) { - continue; - } - DType U00 = offset_bottom_data[bottom_index_base + y0 * width + x0]; - DType U01 = offset_bottom_data[bottom_index_base + y1 * width + x0]; - DType U10 = offset_bottom_data[bottom_index_base + y0 * width + x1]; - DType U11 = offset_bottom_data[bottom_index_base + y1 * width + x1]; - DType diff_x = U11 * dist_y + U10 * (1 - dist_y) - U01 * dist_y - U00 * (1 - dist_y); - diff_x *= trans_std * diff_val * roi_width; - DType diff_y = U11 * dist_x + U01 * (1 - dist_x) - U10 * dist_x - U00 * (1 - dist_x); - diff_y *= trans_std * diff_val * roi_height; - - atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2) - * part_size + part_h) - * part_size + part_w, diff_x); - atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2 + 1) - * part_size + part_h) - * part_size + part_w, diff_y); +} + +template +inline void DeformablePSROIPoolForward(const Tensor& out, + const Tensor& data, + const Tensor& bbox, + const Tensor& trans, + const Tensor& top_count, + const bool no_trans, + const float spatial_scale, + const index_t output_dim, + const index_t group_size, + const index_t pooled_size, + const index_t part_size, + const index_t sample_per_part, + const float trans_std) { + const DType* bottom_data = data.dptr_; + const DType* bottom_rois = bbox.dptr_; + const DType* bottom_trans = no_trans ? nullptr : trans.dptr_; + DType* top_data = out.dptr_; + DType* top_count_data = top_count.dptr_; + const index_t count = out.shape_.Size(); + const index_t channels = data.size(1); + const index_t height = data.size(2); + const index_t width = data.size(3); + const index_t pooled_height = pooled_size; + const index_t pooled_width = pooled_size; + const index_t num_classes = no_trans ? 1 : trans.size(1) / 2; + const index_t channels_each_class = no_trans ? output_dim : output_dim / num_classes; + + cudaStream_t stream = Stream::GetStream(out.stream_); + DeformablePSROIPoolForwardKernel + <<>>( + count, + bottom_data, + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + bottom_rois, + bottom_trans, + no_trans, + trans_std, + sample_per_part, + output_dim, + group_size, + part_size, + num_classes, + channels_each_class, + top_data, + top_count_data); + DeformablePSROIPOOLING_CUDA_CHECK(cudaGetLastError()); +} + +template +__global__ void DeformablePSROIPoolBackwardAccKernel(const index_t count, + const DType* top_diff, + const DType* top_count, + const index_t num_rois, + const DType spatial_scale, + const index_t channels, + const index_t height, + const index_t width, + const index_t pooled_height, + const index_t pooled_width, + const index_t output_dim, + DType* bottom_data_diff, + DType* bottom_trans_diff, + const DType* bottom_data, + const DType* bottom_rois, + const DType* bottom_trans, + const bool no_trans, + const DType trans_std, + const index_t sample_per_part, + const index_t group_size, + const index_t part_size, + const index_t num_classes, + const index_t channels_each_class) { + CUDA_KERNEL_LOOP(index, count) { + // The output is in order (n, ctop, ph, pw) + index_t pw = index % pooled_width; + index_t ph = (index / pooled_width) % pooled_height; + index_t ctop = (index / pooled_width / pooled_height) % output_dim; + index_t n = index / pooled_width / pooled_height / output_dim; + + // [start, end) interval for spatial sampling + const DType* offset_bottom_rois = bottom_rois + n * 5; + index_t roi_batch_ind = offset_bottom_rois[0]; + DType roi_start_w = static_cast(round(offset_bottom_rois[1])) * spatial_scale - 0.5; + DType roi_start_h = static_cast(round(offset_bottom_rois[2])) * spatial_scale - 0.5; + DType roi_end_w = static_cast(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5; + DType roi_end_h = static_cast(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5; + + // Force too small ROIs to be 1x1 + DType roi_width = max(roi_end_w - roi_start_w, 0.1); // avoid 0 + DType roi_height = max(roi_end_h - roi_start_h, 0.1); + + // Compute w and h at bottom + DType bin_size_h = roi_height / static_cast(pooled_height); + DType bin_size_w = roi_width / static_cast(pooled_width); + + DType sub_bin_size_h = bin_size_h / static_cast(sample_per_part); + DType sub_bin_size_w = bin_size_w / static_cast(sample_per_part); + + index_t part_h = floor(static_cast(ph) / pooled_height * part_size); + index_t part_w = floor(static_cast(pw) / pooled_width * part_size); + index_t class_id = ctop / channels_each_class; + DType trans_x = + no_trans + ? static_cast(0) + : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + + part_w] * + trans_std; + DType trans_y = + no_trans ? static_cast(0) + : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * + part_size + + part_w] * + trans_std; + + DType wstart = static_cast(pw) * bin_size_w + roi_start_w; + wstart += trans_x * roi_width; + DType hstart = static_cast(ph) * bin_size_h + roi_start_h; + hstart += trans_y * roi_height; + + if (top_count[index] <= 0) { + continue; + } + DType diff_val = top_diff[index] / top_count[index]; + const DType* offset_bottom_data = bottom_data + roi_batch_ind * channels * height * width; + DType* offset_bottom_data_diff = bottom_data_diff + roi_batch_ind * channels * height * width; + index_t gw = floor(static_cast(pw) * group_size / pooled_width); + index_t gh = floor(static_cast(ph) * group_size / pooled_height); + gw = min(max(gw, static_cast(0)), group_size - 1); + gh = min(max(gh, static_cast(0)), group_size - 1); + + for (index_t ih = 0; ih < sample_per_part; ih++) { + for (index_t iw = 0; iw < sample_per_part; iw++) { + DType w = wstart + iw * sub_bin_size_w; + DType h = hstart + ih * sub_bin_size_h; + // bilinear interpolation + if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5) { + continue; } + w = min(max(w, 0.), width - 1.); + h = min(max(h, 0.), height - 1.); + index_t c = (ctop * group_size + gh) * group_size + gw; + // backward on feature + index_t x0 = floor(w); + index_t x1 = ceil(w); + index_t y0 = floor(h); + index_t y1 = ceil(h); + DType dist_x = w - x0, dist_y = h - y0; + DType q00 = (1 - dist_x) * (1 - dist_y); + DType q01 = (1 - dist_x) * dist_y; + DType q10 = dist_x * (1 - dist_y); + DType q11 = dist_x * dist_y; + index_t bottom_index_base = c * height * width; + atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x0, q00 * diff_val); + atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x0, q01 * diff_val); + atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x1, q10 * diff_val); + atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x1, q11 * diff_val); + + if (no_trans) { + continue; + } + DType U00 = offset_bottom_data[bottom_index_base + y0 * width + x0]; + DType U01 = offset_bottom_data[bottom_index_base + y1 * width + x0]; + DType U10 = offset_bottom_data[bottom_index_base + y0 * width + x1]; + DType U11 = offset_bottom_data[bottom_index_base + y1 * width + x1]; + DType diff_x = U11 * dist_y + U10 * (1 - dist_y) - U01 * dist_y - U00 * (1 - dist_y); + diff_x *= trans_std * diff_val * roi_width; + DType diff_y = U11 * dist_x + U01 * (1 - dist_x) - U10 * dist_x - U00 * (1 - dist_x); + diff_y *= trans_std * diff_val * roi_height; + + atomicAdd(bottom_trans_diff + + (((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + + part_w, + diff_x); + atomicAdd(bottom_trans_diff + + (((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + + part_w, + diff_y); } } } - - - template - inline void DeformablePSROIPoolBackwardAcc(const Tensor &in_grad, - const Tensor &trans_grad, - const Tensor &out_grad, - const Tensor &data, - const Tensor &bbox, - const Tensor &trans, - const Tensor &top_count, - const bool no_trans, const float spatial_scale, - const index_t output_dim, const index_t group_size, - const index_t pooled_size, const index_t part_size, - const index_t sample_per_part, const float trans_std) { - const DType *top_diff = out_grad.dptr_; - const DType *bottom_data = data.dptr_; - const DType *bottom_rois = bbox.dptr_; - const DType *bottom_trans = no_trans ? nullptr : trans.dptr_; - DType *bottom_data_diff = in_grad.dptr_; - DType *bottom_trans_diff = no_trans ? nullptr : trans_grad.dptr_; - const DType *top_count_data = top_count.dptr_; - const index_t count = out_grad.shape_.Size(); - const index_t num_rois = bbox.size(0); - const index_t channels = in_grad.size(1); - const index_t height = in_grad.size(2); - const index_t width = in_grad.size(3); - const index_t pooled_height = pooled_size; - const index_t pooled_width = pooled_size; - const index_t num_classes = no_trans ? 1 : trans_grad.size(1) / 2; - const index_t channels_each_class = no_trans ? output_dim : output_dim / num_classes; - - cudaStream_t stream = Stream::GetStream(in_grad.stream_); - DeformablePSROIPoolBackwardAccKernel<<< - mxnet::op::mxnet_op::cuda_get_num_blocks(count), kBaseThreadNum, - 0, stream >>>(count, top_diff, top_count_data, num_rois, spatial_scale, - channels, height, width, pooled_height, pooled_width, - output_dim, bottom_data_diff, bottom_trans_diff, - bottom_data, bottom_rois, bottom_trans, - no_trans, trans_std, sample_per_part, group_size, - part_size, num_classes, channels_each_class); - DeformablePSROIPOOLING_CUDA_CHECK(cudaGetLastError()); - } +} + +template +inline void DeformablePSROIPoolBackwardAcc(const Tensor& in_grad, + const Tensor& trans_grad, + const Tensor& out_grad, + const Tensor& data, + const Tensor& bbox, + const Tensor& trans, + const Tensor& top_count, + const bool no_trans, + const float spatial_scale, + const index_t output_dim, + const index_t group_size, + const index_t pooled_size, + const index_t part_size, + const index_t sample_per_part, + const float trans_std) { + const DType* top_diff = out_grad.dptr_; + const DType* bottom_data = data.dptr_; + const DType* bottom_rois = bbox.dptr_; + const DType* bottom_trans = no_trans ? nullptr : trans.dptr_; + DType* bottom_data_diff = in_grad.dptr_; + DType* bottom_trans_diff = no_trans ? nullptr : trans_grad.dptr_; + const DType* top_count_data = top_count.dptr_; + const index_t count = out_grad.shape_.Size(); + const index_t num_rois = bbox.size(0); + const index_t channels = in_grad.size(1); + const index_t height = in_grad.size(2); + const index_t width = in_grad.size(3); + const index_t pooled_height = pooled_size; + const index_t pooled_width = pooled_size; + const index_t num_classes = no_trans ? 1 : trans_grad.size(1) / 2; + const index_t channels_each_class = no_trans ? output_dim : output_dim / num_classes; + + cudaStream_t stream = Stream::GetStream(in_grad.stream_); + DeformablePSROIPoolBackwardAccKernel + <<>>( + count, + top_diff, + top_count_data, + num_rois, + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + output_dim, + bottom_data_diff, + bottom_trans_diff, + bottom_data, + bottom_rois, + bottom_trans, + no_trans, + trans_std, + sample_per_part, + group_size, + part_size, + num_classes, + channels_each_class); + DeformablePSROIPOOLING_CUDA_CHECK(cudaGetLastError()); +} } // namespace cuda - template - inline void DeformablePSROIPoolForward(const Tensor &out, - const Tensor &data, - const Tensor &bbox, - const Tensor &trans, - const Tensor &top_count, - const bool no_trans, const float spatial_scale, - const index_t output_dim, const index_t group_size, - const index_t pooled_size, const index_t part_size, - const index_t sample_per_part, const float trans_std) { - cuda::DeformablePSROIPoolForward(out, data, bbox, trans, top_count, - no_trans, spatial_scale, output_dim, - group_size, pooled_size, part_size, - sample_per_part, trans_std); - } - - template - inline void DeformablePSROIPoolBackwardAcc(const Tensor &in_grad, - const Tensor &trans_grad, - const Tensor &out_grad, - const Tensor &data, - const Tensor &bbox, - const Tensor &trans, - const Tensor &top_count, - const bool no_trans, const float spatial_scale, - const index_t output_dim, const index_t group_size, - const index_t pooled_size, const index_t part_size, - const index_t sample_per_part, const float trans_std) { - cuda::DeformablePSROIPoolBackwardAcc(in_grad, trans_grad, out_grad, data, bbox, - trans, top_count, no_trans, spatial_scale, - output_dim, group_size, pooled_size, - part_size, sample_per_part, trans_std); - } +template +inline void DeformablePSROIPoolForward(const Tensor& out, + const Tensor& data, + const Tensor& bbox, + const Tensor& trans, + const Tensor& top_count, + const bool no_trans, + const float spatial_scale, + const index_t output_dim, + const index_t group_size, + const index_t pooled_size, + const index_t part_size, + const index_t sample_per_part, + const float trans_std) { + cuda::DeformablePSROIPoolForward(out, + data, + bbox, + trans, + top_count, + no_trans, + spatial_scale, + output_dim, + group_size, + pooled_size, + part_size, + sample_per_part, + trans_std); +} + +template +inline void DeformablePSROIPoolBackwardAcc(const Tensor& in_grad, + const Tensor& trans_grad, + const Tensor& out_grad, + const Tensor& data, + const Tensor& bbox, + const Tensor& trans, + const Tensor& top_count, + const bool no_trans, + const float spatial_scale, + const index_t output_dim, + const index_t group_size, + const index_t pooled_size, + const index_t part_size, + const index_t sample_per_part, + const float trans_std) { + cuda::DeformablePSROIPoolBackwardAcc(in_grad, + trans_grad, + out_grad, + data, + bbox, + trans, + top_count, + no_trans, + spatial_scale, + output_dim, + group_size, + pooled_size, + part_size, + sample_per_part, + trans_std); +} } // namespace mshadow - namespace mxnet { namespace op { - template<> - Operator* CreateOp(DeformablePSROIPoolingParam param, int dtype) { - Operator* op = nullptr; - MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { - op = new DeformablePSROIPoolingOp(param); - }); - return op; - } +template <> +Operator* CreateOp(DeformablePSROIPoolingParam param, int dtype) { + Operator* op = nullptr; + MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { op = new DeformablePSROIPoolingOp(param); }); + return op; +} } // namespace op } // namespace mxnet diff --git a/src/operator/contrib/dgl_graph-inl.h b/src/operator/contrib/dgl_graph-inl.h index f31071b896e3..918c8c88f1e0 100644 --- a/src/operator/contrib/dgl_graph-inl.h +++ b/src/operator/contrib/dgl_graph-inl.h @@ -37,7 +37,7 @@ namespace mxnet { namespace op { -template +template void DGLAdjacencyForwardEx(const nnvm::NodeAttrs& attrs, const OpContext& ctx, const std::vector& inputs, @@ -49,14 +49,14 @@ void DGLAdjacencyForwardEx(const nnvm::NodeAttrs& attrs, CHECK_EQ(inputs[0].storage_type(), kCSRStorage); CHECK_EQ(outputs[0].storage_type(), kCSRStorage); CHECK_EQ(req[0], kWriteTo); - const TBlob &in_idx = inputs[0].aux_data(csr::kIdx); - const TBlob &in_indptr = inputs[0].aux_data(csr::kIndPtr); + const TBlob& in_idx = inputs[0].aux_data(csr::kIdx); + const TBlob& in_indptr = inputs[0].aux_data(csr::kIndPtr); outputs[0].CheckAndAllocData(in_idx.shape_); outputs[0].CheckAndAllocAuxData(csr::kIdx, in_idx.shape_); outputs[0].CheckAndAllocAuxData(csr::kIndPtr, in_indptr.shape_); - mshadow::Stream *s = ctx.get_stream(); + mshadow::Stream* s = ctx.get_stream(); Fill(s, outputs[0].data(), req[0], 1.0); mxnet_op::copy(s, outputs[0].aux_data(csr::kIdx), in_idx); mxnet_op::copy(s, outputs[0].aux_data(csr::kIndPtr), in_indptr); diff --git a/src/operator/contrib/dgl_graph.cc b/src/operator/contrib/dgl_graph.cc index c8e27f38d1a6..5b53843e61b3 100644 --- a/src/operator/contrib/dgl_graph.cc +++ b/src/operator/contrib/dgl_graph.cc @@ -46,19 +46,19 @@ typedef int64_t dgl_id_t; class ArrayHeap { public: explicit ArrayHeap(const std::vector& prob, unsigned int seed) { - generator_ = std::mt19937(seed); + generator_ = std::mt19937(seed); distribution_ = std::uniform_real_distribution(0.0, 1.0); - vec_size_ = prob.size(); - bit_len_ = ceil(log2(vec_size_)); - limit_ = 1 << bit_len_; + vec_size_ = prob.size(); + bit_len_ = ceil(log2(vec_size_)); + limit_ = 1 << bit_len_; // allocate twice the size heap_.resize(limit_ << 1, 0); // allocate the leaves - for (int i = limit_; i < vec_size_+limit_; ++i) { - heap_[i] = prob[i-limit_]; + for (int i = limit_; i < vec_size_ + limit_; ++i) { + heap_[i] = prob[i - limit_]; } // iterate up the tree (this is O(m)) - for (int i = bit_len_-1; i >= 0; --i) { + for (int i = bit_len_ - 1; i >= 0; --i) { for (int j = (1 << i); j < (1 << (i + 1)); ++j) { heap_[j] = heap_[j << 1] + heap_[(j << 1) + 1]; } @@ -71,7 +71,7 @@ class ArrayHeap { */ void Delete(size_t index) { size_t i = index + limit_; - float w = heap_[i]; + float w = heap_[i]; for (int j = bit_len_; j >= 0; --j) { heap_[i] -= w; i = i >> 1; @@ -94,7 +94,7 @@ class ArrayHeap { */ size_t Sample() { float xi = heap_[1] * distribution_(generator_); - int i = 1; + int i = 1; while (i < limit_) { i = i << 1; if (xi >= heap_[i]) { @@ -131,17 +131,10 @@ struct NeighborSampleParam : public dmlc::Parameter { dgl_id_t num_neighbor; dgl_id_t max_num_vertices; DMLC_DECLARE_PARAMETER(NeighborSampleParam) { - DMLC_DECLARE_FIELD(num_args).set_lower_bound(2) - .describe("Number of input NDArray."); - DMLC_DECLARE_FIELD(num_hops) - .set_default(1) - .describe("Number of hops."); - DMLC_DECLARE_FIELD(num_neighbor) - .set_default(2) - .describe("Number of neighbor."); - DMLC_DECLARE_FIELD(max_num_vertices) - .set_default(100) - .describe("Max number of vertices."); + DMLC_DECLARE_FIELD(num_args).set_lower_bound(2).describe("Number of input NDArray."); + DMLC_DECLARE_FIELD(num_hops).set_default(1).describe("Number of hops."); + DMLC_DECLARE_FIELD(num_neighbor).set_default(2).describe("Number of neighbor."); + DMLC_DECLARE_FIELD(max_num_vertices).set_default(100).describe("Max number of vertices."); } }; @@ -153,8 +146,8 @@ DMLC_REGISTER_PARAMETER(NeighborSampleParam); static bool CSRNeighborUniformSampleStorageType(const nnvm::NodeAttrs& attrs, const int dev_mask, DispatchMode* dispatch_mode, - std::vector *in_attrs, - std::vector *out_attrs) { + std::vector* in_attrs, + std::vector* out_attrs) { const NeighborSampleParam& params = nnvm::get(attrs.parsed); size_t num_subgraphs = params.num_args - 1; @@ -181,7 +174,7 @@ static bool CSRNeighborUniformSampleStorageType(const nnvm::NodeAttrs& attrs, } // sub_layer for (size_t i = 0; i < num_subgraphs; i++) { - if (!type_assign(&(*out_attrs)[i + 2*num_subgraphs], mxnet::kDefaultStorage)) { + if (!type_assign(&(*out_attrs)[i + 2 * num_subgraphs], mxnet::kDefaultStorage)) { success = false; } } @@ -197,10 +190,9 @@ static bool CSRNeighborUniformSampleStorageType(const nnvm::NodeAttrs& attrs, static bool CSRNeighborNonUniformSampleStorageType(const nnvm::NodeAttrs& attrs, const int dev_mask, DispatchMode* dispatch_mode, - std::vector *in_attrs, - std::vector *out_attrs) { - const NeighborSampleParam& params = - nnvm::get(attrs.parsed); + std::vector* in_attrs, + std::vector* out_attrs) { + const NeighborSampleParam& params = nnvm::get(attrs.parsed); size_t num_subgraphs = params.num_args - 2; CHECK_EQ(out_attrs->size(), 4 * num_subgraphs); @@ -229,13 +221,13 @@ static bool CSRNeighborNonUniformSampleStorageType(const nnvm::NodeAttrs& attrs, } // sub_probability for (size_t i = 0; i < num_subgraphs; i++) { - if (!type_assign(&(*out_attrs)[i + 2*num_subgraphs], mxnet::kDefaultStorage)) { + if (!type_assign(&(*out_attrs)[i + 2 * num_subgraphs], mxnet::kDefaultStorage)) { success = false; } } // sub_layer for (size_t i = 0; i < num_subgraphs; i++) { - if (!type_assign(&(*out_attrs)[i + 3*num_subgraphs], mxnet::kDefaultStorage)) { + if (!type_assign(&(*out_attrs)[i + 3 * num_subgraphs], mxnet::kDefaultStorage)) { success = false; } } @@ -249,10 +241,9 @@ static bool CSRNeighborNonUniformSampleStorageType(const nnvm::NodeAttrs& attrs, * Check uniform Shape */ static bool CSRNeighborUniformSampleShape(const nnvm::NodeAttrs& attrs, - mxnet::ShapeVector *in_attrs, - mxnet::ShapeVector *out_attrs) { - const NeighborSampleParam& params = - nnvm::get(attrs.parsed); + mxnet::ShapeVector* in_attrs, + mxnet::ShapeVector* out_attrs) { + const NeighborSampleParam& params = nnvm::get(attrs.parsed); size_t num_subgraphs = params.num_args - 1; CHECK_EQ(out_attrs->size(), 3 * num_subgraphs); @@ -287,7 +278,7 @@ static bool CSRNeighborUniformSampleShape(const nnvm::NodeAttrs& attrs, mxnet::TShape out_layer_shape(1, -1); out_layer_shape[0] = params.max_num_vertices; for (size_t i = 0; i < num_subgraphs; i++) { - SHAPE_ASSIGN_CHECK(*out_attrs, i + 2*num_subgraphs, out_layer_shape); + SHAPE_ASSIGN_CHECK(*out_attrs, i + 2 * num_subgraphs, out_layer_shape); success = success && !mxnet::op::shape_is_none(out_attrs->at(i + 2 * num_subgraphs)); } @@ -298,10 +289,9 @@ static bool CSRNeighborUniformSampleShape(const nnvm::NodeAttrs& attrs, * Check non-uniform Shape */ static bool CSRNeighborNonUniformSampleShape(const nnvm::NodeAttrs& attrs, - mxnet::ShapeVector *in_attrs, - mxnet::ShapeVector *out_attrs) { - const NeighborSampleParam& params = - nnvm::get(attrs.parsed); + mxnet::ShapeVector* in_attrs, + mxnet::ShapeVector* out_attrs) { + const NeighborSampleParam& params = nnvm::get(attrs.parsed); size_t num_subgraphs = params.num_args - 2; CHECK_EQ(out_attrs->size(), 4 * num_subgraphs); @@ -339,14 +329,14 @@ static bool CSRNeighborNonUniformSampleShape(const nnvm::NodeAttrs& attrs, mxnet::TShape out_prob_shape(1, -1); out_prob_shape[0] = params.max_num_vertices; for (size_t i = 0; i < num_subgraphs; i++) { - SHAPE_ASSIGN_CHECK(*out_attrs, i + 2*num_subgraphs, out_prob_shape); + SHAPE_ASSIGN_CHECK(*out_attrs, i + 2 * num_subgraphs, out_prob_shape); success = success && !mxnet::op::shape_is_none(out_attrs->at(i + 2 * num_subgraphs)); } // sub_layer mxnet::TShape out_layer_shape(1, -1); out_layer_shape[0] = params.max_num_vertices; for (size_t i = 0; i < num_subgraphs; i++) { - SHAPE_ASSIGN_CHECK(*out_attrs, i + 3*num_subgraphs, out_prob_shape); + SHAPE_ASSIGN_CHECK(*out_attrs, i + 3 * num_subgraphs, out_prob_shape); success = success && !mxnet::op::shape_is_none(out_attrs->at(i + 3 * num_subgraphs)); } @@ -357,10 +347,9 @@ static bool CSRNeighborNonUniformSampleShape(const nnvm::NodeAttrs& attrs, * Check uniform Type */ static bool CSRNeighborUniformSampleType(const nnvm::NodeAttrs& attrs, - std::vector *in_attrs, - std::vector *out_attrs) { - const NeighborSampleParam& params = - nnvm::get(attrs.parsed); + std::vector* in_attrs, + std::vector* out_attrs) { + const NeighborSampleParam& params = nnvm::get(attrs.parsed); size_t num_subgraphs = params.num_args - 1; CHECK_EQ(out_attrs->size(), 3 * num_subgraphs); @@ -369,11 +358,9 @@ static bool CSRNeighborUniformSampleType(const nnvm::NodeAttrs& attrs, for (size_t i = 0; i < num_subgraphs; i++) { TYPE_ASSIGN_CHECK(*out_attrs, i, in_attrs->at(1)); TYPE_ASSIGN_CHECK(*out_attrs, i + num_subgraphs, in_attrs->at(0)); - TYPE_ASSIGN_CHECK(*out_attrs, i + 2*num_subgraphs, in_attrs->at(1)); - success = success && - out_attrs->at(i) != -1 && - out_attrs->at(i + num_subgraphs) != -1 && - out_attrs->at(i + 2*num_subgraphs) != -1; + TYPE_ASSIGN_CHECK(*out_attrs, i + 2 * num_subgraphs, in_attrs->at(1)); + success = success && out_attrs->at(i) != -1 && out_attrs->at(i + num_subgraphs) != -1 && + out_attrs->at(i + 2 * num_subgraphs) != -1; } return success; @@ -383,10 +370,9 @@ static bool CSRNeighborUniformSampleType(const nnvm::NodeAttrs& attrs, * Check non-uniform Type */ static bool CSRNeighborNonUniformSampleType(const nnvm::NodeAttrs& attrs, - std::vector *in_attrs, - std::vector *out_attrs) { - const NeighborSampleParam& params = - nnvm::get(attrs.parsed); + std::vector* in_attrs, + std::vector* out_attrs) { + const NeighborSampleParam& params = nnvm::get(attrs.parsed); size_t num_subgraphs = params.num_args - 2; CHECK_EQ(out_attrs->size(), 4 * num_subgraphs); @@ -395,22 +381,17 @@ static bool CSRNeighborNonUniformSampleType(const nnvm::NodeAttrs& attrs, for (size_t i = 0; i < num_subgraphs; i++) { TYPE_ASSIGN_CHECK(*out_attrs, i, in_attrs->at(2)); TYPE_ASSIGN_CHECK(*out_attrs, i + num_subgraphs, in_attrs->at(0)); - TYPE_ASSIGN_CHECK(*out_attrs, i + 2*num_subgraphs, in_attrs->at(1)); - TYPE_ASSIGN_CHECK(*out_attrs, i + 3*num_subgraphs, in_attrs->at(2)); - success = success && - out_attrs->at(i) != -1 && - out_attrs->at(i + num_subgraphs) != -1 && - out_attrs->at(i + 2*num_subgraphs) != -1 && - out_attrs->at(i + 3*num_subgraphs) != -1; + TYPE_ASSIGN_CHECK(*out_attrs, i + 2 * num_subgraphs, in_attrs->at(1)); + TYPE_ASSIGN_CHECK(*out_attrs, i + 3 * num_subgraphs, in_attrs->at(2)); + success = success && out_attrs->at(i) != -1 && out_attrs->at(i + num_subgraphs) != -1 && + out_attrs->at(i + 2 * num_subgraphs) != -1 && + out_attrs->at(i + 3 * num_subgraphs) != -1; } return success; } -static void RandomSample(size_t set_size, - size_t num, - std::vector* out, - unsigned int seed) { +static void RandomSample(size_t set_size, size_t num, std::vector* out, unsigned int seed) { std::mt19937 generator(seed); std::unordered_set sampled_idxs; std::uniform_int_distribution distribution(0, set_size - 1); @@ -423,11 +404,9 @@ static void RandomSample(size_t set_size, } } -static void NegateSet(const std::vector &idxs, - size_t set_size, - std::vector* out) { +static void NegateSet(const std::vector& idxs, size_t set_size, std::vector* out) { // idxs must have been sorted. - auto it = idxs.begin(); + auto it = idxs.begin(); size_t i = 0; CHECK_GT(set_size, idxs.back()); for (; i < set_size && it != idxs.end(); i++) { @@ -469,8 +448,7 @@ static void GetUniformSample(const dgl_id_t* val_list, } else { std::vector negate; negate.reserve(ver_len - max_num_neighbor); - RandomSample(ver_len, ver_len - max_num_neighbor, - &negate, seed); + RandomSample(ver_len, ver_len - max_num_neighbor, &negate, seed); std::sort(negate.begin(), negate.end()); NegateSet(negate, ver_len, &sorted_idxs); } @@ -515,8 +493,8 @@ static void GetNonUniformSample(const float* probability, out_ver->resize(max_num_neighbor); out_edge->resize(max_num_neighbor); for (size_t i = 0; i < max_num_neighbor; ++i) { - size_t idx = sp_index[i]; - out_ver->at(i) = col_list[idx]; + size_t idx = sp_index[i]; + out_ver->at(i) = col_list[idx]; out_edge->at(i) = val_list[idx]; } sort(out_ver->begin(), out_ver->end()); @@ -529,20 +507,19 @@ static void GetNonUniformSample(const float* probability, struct neigh_list { std::vector neighs; std::vector edges; - neigh_list(std::vector _neighs, - std::vector _edges) - : neighs(std::move(_neighs)), edges(std::move(_edges)) {} + neigh_list(std::vector _neighs, std::vector _edges) + : neighs(std::move(_neighs)), edges(std::move(_edges)) {} }; /* * Sample sub-graph from csr graph */ -static void SampleSubgraph(const NDArray &csr, - const NDArray &seed_arr, - const NDArray &sampled_ids, - const NDArray &sub_csr, +static void SampleSubgraph(const NDArray& csr, + const NDArray& seed_arr, + const NDArray& sampled_ids, + const NDArray& sub_csr, float* sub_prob, - const NDArray &sub_layer, + const NDArray& sub_layer, const float* probability, int num_hops, size_t num_neighbor, @@ -553,10 +530,10 @@ static void SampleSubgraph(const NDArray &csr, const dgl_id_t* val_list = csr.data().dptr(); const dgl_id_t* col_list = csr.aux_data(csr::kIdx).dptr(); - const dgl_id_t* indptr = csr.aux_data(csr::kIndPtr).dptr(); - const dgl_id_t* seed = seed_arr.data().dptr(); - dgl_id_t* out = sampled_ids.data().dptr(); - dgl_id_t* out_layer = sub_layer.data().dptr(); + const dgl_id_t* indptr = csr.aux_data(csr::kIndPtr).dptr(); + const dgl_id_t* seed = seed_arr.data().dptr(); + dgl_id_t* out = sampled_ids.data().dptr(); + dgl_id_t* out_layer = sub_layer.data().dptr(); // BFS traverse the graph and sample vertices // @@ -584,9 +561,8 @@ static void SampleSubgraph(const NDArray &csr, // A vertex in the vector only needs to be accessed once. If there is a vertex behind idx // isn't in the last level, we will sample its neighbors. If not, the while loop terminates. size_t idx = 0; - while (idx < sub_vers.size() && - sub_ver_mp.size() < max_num_vertices) { - dgl_id_t dst_id = sub_vers[idx].first; + while (idx < sub_vers.size() && sub_ver_mp.size() < max_num_vertices) { + dgl_id_t dst_id = sub_vers[idx].first; int cur_node_level = sub_vers[idx].second; idx++; // If the node is in the last level, we don't need to sample neighbors @@ -596,7 +572,7 @@ static void SampleSubgraph(const NDArray &csr, tmp_sampled_src_list.clear(); tmp_sampled_edge_list.clear(); - dgl_id_t ver_len = *(indptr+dst_id+1) - *(indptr+dst_id); + dgl_id_t ver_len = *(indptr + dst_id + 1) - *(indptr + dst_id); if (probability == nullptr) { // uniform-sample GetUniformSample(val_list + *(indptr + dst_id), col_list + *(indptr + dst_id), @@ -607,13 +583,13 @@ static void SampleSubgraph(const NDArray &csr, random_seed); } else { // non-uniform-sample GetNonUniformSample(probability, - val_list + *(indptr + dst_id), - col_list + *(indptr + dst_id), - ver_len, - num_neighbor, - &tmp_sampled_src_list, - &tmp_sampled_edge_list, - random_seed); + val_list + *(indptr + dst_id), + col_list + *(indptr + dst_id), + ver_len, + num_neighbor, + &tmp_sampled_src_list, + &tmp_sampled_edge_list, + random_seed); } CHECK_EQ(tmp_sampled_src_list.size(), tmp_sampled_edge_list.size()); size_t pos = neighbor_list.size(); @@ -621,15 +597,15 @@ static void SampleSubgraph(const NDArray &csr, // First we push the size of neighbor vector neighbor_list.push_back(tmp_sampled_edge_list.size()); // Then push the vertices - for (dgl_id_t & i : tmp_sampled_src_list) { + for (dgl_id_t& i : tmp_sampled_src_list) { neighbor_list.push_back(i); } // Finally we push the edge list - for (dgl_id_t & i : tmp_sampled_edge_list) { + for (dgl_id_t& i : tmp_sampled_edge_list) { neighbor_list.push_back(i); } num_edges += tmp_sampled_src_list.size(); - for (dgl_id_t & i : tmp_sampled_src_list) { + for (dgl_id_t& i : tmp_sampled_src_list) { // If we have sampled the max number of vertices, we have to stop. if (sub_ver_mp.size() >= max_num_vertices) break; @@ -646,8 +622,8 @@ static void SampleSubgraph(const NDArray &csr, for (; idx < sub_vers.size(); idx++) { if (sub_vers[idx].second < num_hops) { LOG(WARNING) - << "The sampling is truncated because we have reached the max number of vertices\n" - << "Please use a smaller number of seeds or a small neighborhood"; + << "The sampling is truncated because we have reached the max number of vertices\n" + << "Please use a smaller number of seeds or a small neighborhood"; break; } } @@ -655,12 +631,13 @@ static void SampleSubgraph(const NDArray &csr, // Copy sub_ver_mp to output[0] // Copy layer size_t num_vertices = sub_ver_mp.size(); - std::sort(sub_vers.begin(), sub_vers.end(), - [](const std::pair &a1, const std::pair &a2) { - return a1.first < a2.first; - }); + std::sort(sub_vers.begin(), + sub_vers.end(), + [](const std::pair& a1, const std::pair& a2) { + return a1.first < a2.first; + }); for (size_t i = 0; i < sub_vers.size(); i++) { - out[i] = sub_vers[i].first; + out[i] = sub_vers[i].first; out_layer[i] = sub_vers[i].second; } // The last element stores the actual @@ -671,29 +648,30 @@ static void SampleSubgraph(const NDArray &csr, if (sub_prob != nullptr) { for (size_t i = 0; i < sub_ver_mp.size(); ++i) { dgl_id_t idx = out[i]; - sub_prob[i] = probability[idx]; + sub_prob[i] = probability[idx]; } } // Construct sub_csr_graph mxnet::TShape shape_1(1, -1); mxnet::TShape shape_2(1, -1); shape_1[0] = num_edges; - shape_2[0] = max_num_vertices+1; + shape_2[0] = max_num_vertices + 1; sub_csr.CheckAndAllocData(shape_1); sub_csr.CheckAndAllocAuxData(csr::kIdx, shape_1); sub_csr.CheckAndAllocAuxData(csr::kIndPtr, shape_2); - dgl_id_t* val_list_out = sub_csr.data().dptr(); - dgl_id_t* col_list_out = sub_csr.aux_data(1).dptr(); - dgl_id_t* indptr_out = sub_csr.aux_data(0).dptr(); - indptr_out[0] = 0; + dgl_id_t* val_list_out = sub_csr.data().dptr(); + dgl_id_t* col_list_out = sub_csr.aux_data(1).dptr(); + dgl_id_t* indptr_out = sub_csr.aux_data(0).dptr(); + indptr_out[0] = 0; size_t collected_nedges = 0; // Both the out array and neigh_pos are sorted. By scanning the two arrays, we can see // which vertices have neighbors and which don't. - std::sort(neigh_pos.begin(), neigh_pos.end(), - [](const std::pair &a1, const std::pair &a2) { - return a1.first < a2.first; - }); + std::sort(neigh_pos.begin(), + neigh_pos.end(), + [](const std::pair& a1, const std::pair& a2) { + return a1.first < a2.first; + }); size_t idx_with_neigh = 0; for (size_t i = 0; i < num_vertices; i++) { dgl_id_t dst_id = *(out + i); @@ -706,19 +684,16 @@ static void SampleSubgraph(const NDArray &csr, edge_size = neighbor_list[pos]; CHECK_LE(pos + edge_size * 2 + 1, neighbor_list.size()); - std::copy_n(neighbor_list.begin() + pos + 1, - edge_size, - col_list_out + collected_nedges); - std::copy_n(neighbor_list.begin() + pos + edge_size + 1, - edge_size, - val_list_out + collected_nedges); + std::copy_n(neighbor_list.begin() + pos + 1, edge_size, col_list_out + collected_nedges); + std::copy_n( + neighbor_list.begin() + pos + edge_size + 1, edge_size, val_list_out + collected_nedges); collected_nedges += edge_size; idx_with_neigh++; } - indptr_out[i+1] = indptr_out[i] + edge_size; + indptr_out[i + 1] = indptr_out[i] + edge_size; } - for (size_t i = num_vertices+1; i <= max_num_vertices; ++i) { - indptr_out[i] = indptr_out[i-1]; + for (size_t i = num_vertices + 1; i <= max_num_vertices; ++i) { + indptr_out[i] = indptr_out[i - 1]; } } @@ -726,28 +701,28 @@ static void SampleSubgraph(const NDArray &csr, * Operator: contrib_csr_neighbor_uniform_sample */ static void CSRNeighborUniformSampleComputeExCPU(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { const NeighborSampleParam& params = nnvm::get(attrs.parsed); int num_subgraphs = inputs.size() - 1; CHECK_EQ(outputs.size(), 3 * num_subgraphs); - mshadow::Stream *s = ctx.get_stream(); - mshadow::Random *prnd = ctx.requested[0].get_random(s); - unsigned int seed = prnd->GetRandInt(); + mshadow::Stream* s = ctx.get_stream(); + mshadow::Random* prnd = ctx.requested[0].get_random(s); + unsigned int seed = prnd->GetRandInt(); #pragma omp parallel for for (int i = 0; i < num_subgraphs; i++) { - SampleSubgraph(inputs[0], // graph_csr - inputs[i + 1], // seed vector - outputs[i], // sample_id - outputs[i + 1*num_subgraphs], // sub_csr - nullptr, // sample_id_probability - outputs[i + 2*num_subgraphs], // sample_id_layer - nullptr, // probability + SampleSubgraph(inputs[0], // graph_csr + inputs[i + 1], // seed vector + outputs[i], // sample_id + outputs[i + 1 * num_subgraphs], // sub_csr + nullptr, // sample_id_probability + outputs[i + 2 * num_subgraphs], // sample_id_layer + nullptr, // probability params.num_hops, params.num_neighbor, params.max_num_vertices, @@ -760,7 +735,7 @@ static void CSRNeighborUniformSampleComputeExCPU(const nnvm::NodeAttrs& attrs, } NNVM_REGISTER_OP(_contrib_dgl_csr_neighbor_uniform_sample) -.describe(R"code(This operator samples sub-graphs from a csr graph via an + .describe(R"code(This operator samples sub-graphs from a csr graph via an uniform probability. The operator is designed for DGL. The operator outputs three sets of NDArrays to represent the sampled results @@ -800,38 +775,37 @@ of max_num_vertices, and the valid number of vertices is the same as the ones in )code" ADD_FILELINE) -.set_attr_parser(ParamParser) -.set_num_inputs([](const NodeAttrs& attrs) { - const NeighborSampleParam& params = - nnvm::get(attrs.parsed); - return params.num_args; -}) -.set_num_outputs([](const NodeAttrs& attrs) { - const NeighborSampleParam& params = - nnvm::get(attrs.parsed); - size_t num_subgraphs = params.num_args - 1; - return num_subgraphs * 3; -}) -.set_attr("FInferStorageType", CSRNeighborUniformSampleStorageType) -.set_attr("FInferShape", CSRNeighborUniformSampleShape) -.set_attr("FInferType", CSRNeighborUniformSampleType) -.set_attr("FComputeEx", CSRNeighborUniformSampleComputeExCPU) -.set_attr("FResourceRequest", [](const NodeAttrs& attrs) { - return std::vector{ResourceRequest::kRandom}; -}) -.add_argument("csr_matrix", "NDArray-or-Symbol", "csr matrix") -.add_argument("seed_arrays", "NDArray-or-Symbol[]", "seed vertices") -.set_attr("key_var_num_args", "num_args") -.add_arguments(NeighborSampleParam::__FIELDS__()); + .set_attr_parser(ParamParser) + .set_num_inputs([](const NodeAttrs& attrs) { + const NeighborSampleParam& params = nnvm::get(attrs.parsed); + return params.num_args; + }) + .set_num_outputs([](const NodeAttrs& attrs) { + const NeighborSampleParam& params = nnvm::get(attrs.parsed); + size_t num_subgraphs = params.num_args - 1; + return num_subgraphs * 3; + }) + .set_attr("FInferStorageType", CSRNeighborUniformSampleStorageType) + .set_attr("FInferShape", CSRNeighborUniformSampleShape) + .set_attr("FInferType", CSRNeighborUniformSampleType) + .set_attr("FComputeEx", CSRNeighborUniformSampleComputeExCPU) + .set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kRandom}; + }) + .add_argument("csr_matrix", "NDArray-or-Symbol", "csr matrix") + .add_argument("seed_arrays", "NDArray-or-Symbol[]", "seed vertices") + .set_attr("key_var_num_args", "num_args") + .add_arguments(NeighborSampleParam::__FIELDS__()); /* * Operator: contrib_csr_neighbor_non_uniform_sample */ static void CSRNeighborNonUniformSampleComputeExCPU(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { const NeighborSampleParam& params = nnvm::get(attrs.parsed); int num_subgraphs = inputs.size() - 2; @@ -839,19 +813,19 @@ static void CSRNeighborNonUniformSampleComputeExCPU(const nnvm::NodeAttrs& attrs const float* probability = inputs[1].data().dptr(); - mshadow::Stream *s = ctx.get_stream(); - mshadow::Random *prnd = ctx.requested[0].get_random(s); - unsigned int seed = prnd->GetRandInt(); + mshadow::Stream* s = ctx.get_stream(); + mshadow::Random* prnd = ctx.requested[0].get_random(s); + unsigned int seed = prnd->GetRandInt(); #pragma omp parallel for for (int i = 0; i < num_subgraphs; i++) { - float* sub_prob = outputs[i+2*num_subgraphs].data().dptr(); - SampleSubgraph(inputs[0], // graph_csr - inputs[i + 2], // seed vector - outputs[i], // sample_id - outputs[i + 1*num_subgraphs], // sub_csr - sub_prob, // sample_id_probability - outputs[i + 3*num_subgraphs], // sample_id_layer + float* sub_prob = outputs[i + 2 * num_subgraphs].data().dptr(); + SampleSubgraph(inputs[0], // graph_csr + inputs[i + 2], // seed vector + outputs[i], // sample_id + outputs[i + 1 * num_subgraphs], // sub_csr + sub_prob, // sample_id_probability + outputs[i + 3 * num_subgraphs], // sample_id_layer probability, params.num_hops, params.num_neighbor, @@ -865,7 +839,7 @@ static void CSRNeighborNonUniformSampleComputeExCPU(const nnvm::NodeAttrs& attrs } NNVM_REGISTER_OP(_contrib_dgl_csr_neighbor_non_uniform_sample) -.describe(R"code(This operator samples sub-graph from a csr graph via an + .describe(R"code(This operator samples sub-graph from a csr graph via an non-uniform probability. The operator is designed for DGL. The operator outputs four sets of NDArrays to represent the sampled results @@ -910,30 +884,29 @@ of max_num_vertices, and the valid number of vertices is the same as the ones in )code" ADD_FILELINE) -.set_attr_parser(ParamParser) -.set_num_inputs([](const NodeAttrs& attrs) { - const NeighborSampleParam& params = - nnvm::get(attrs.parsed); - return params.num_args; -}) -.set_num_outputs([](const NodeAttrs& attrs) { - const NeighborSampleParam& params = - nnvm::get(attrs.parsed); - size_t num_subgraphs = params.num_args - 2; - return num_subgraphs * 4; -}) -.set_attr("FInferStorageType", CSRNeighborNonUniformSampleStorageType) -.set_attr("FInferShape", CSRNeighborNonUniformSampleShape) -.set_attr("FInferType", CSRNeighborNonUniformSampleType) -.set_attr("FComputeEx", CSRNeighborNonUniformSampleComputeExCPU) -.set_attr("FResourceRequest", [](const NodeAttrs& attrs) { - return std::vector{ResourceRequest::kRandom}; -}) -.add_argument("csr_matrix", "NDArray-or-Symbol", "csr matrix") -.add_argument("probability", "NDArray-or-Symbol", "probability vector") -.add_argument("seed_arrays", "NDArray-or-Symbol[]", "seed vertices") -.set_attr("key_var_num_args", "num_args") -.add_arguments(NeighborSampleParam::__FIELDS__()); + .set_attr_parser(ParamParser) + .set_num_inputs([](const NodeAttrs& attrs) { + const NeighborSampleParam& params = nnvm::get(attrs.parsed); + return params.num_args; + }) + .set_num_outputs([](const NodeAttrs& attrs) { + const NeighborSampleParam& params = nnvm::get(attrs.parsed); + size_t num_subgraphs = params.num_args - 2; + return num_subgraphs * 4; + }) + .set_attr("FInferStorageType", CSRNeighborNonUniformSampleStorageType) + .set_attr("FInferShape", CSRNeighborNonUniformSampleShape) + .set_attr("FInferType", CSRNeighborNonUniformSampleType) + .set_attr("FComputeEx", CSRNeighborNonUniformSampleComputeExCPU) + .set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kRandom}; + }) + .add_argument("csr_matrix", "NDArray-or-Symbol", "csr matrix") + .add_argument("probability", "NDArray-or-Symbol", "probability vector") + .add_argument("seed_arrays", "NDArray-or-Symbol[]", "seed vertices") + .set_attr("key_var_num_args", "num_args") + .add_arguments(NeighborSampleParam::__FIELDS__()); ///////////////////////// Create induced subgraph /////////////////////////// @@ -941,10 +914,10 @@ struct DGLSubgraphParam : public dmlc::Parameter { int num_args; bool return_mapping; DMLC_DECLARE_PARAMETER(DGLSubgraphParam) { - DMLC_DECLARE_FIELD(num_args).set_lower_bound(2) - .describe("Number of input arguments, including all symbol inputs."); + DMLC_DECLARE_FIELD(num_args).set_lower_bound(2).describe( + "Number of input arguments, including all symbol inputs."); DMLC_DECLARE_FIELD(return_mapping) - .describe("Return mapping of vid and eid between the subgraph and the parent graph."); + .describe("Return mapping of vid and eid between the subgraph and the parent graph."); } }; // struct DGLSubgraphParam @@ -953,24 +926,24 @@ DMLC_REGISTER_PARAMETER(DGLSubgraphParam); static bool DGLSubgraphStorageType(const nnvm::NodeAttrs& attrs, const int dev_mask, DispatchMode* dispatch_mode, - std::vector *in_attrs, - std::vector *out_attrs) { + std::vector* in_attrs, + std::vector* out_attrs) { CHECK_EQ(in_attrs->at(0), kCSRStorage); for (size_t i = 1; i < in_attrs->size(); i++) CHECK_EQ(in_attrs->at(i), kDefaultStorage); - bool success = true; + bool success = true; *dispatch_mode = DispatchMode::kFComputeEx; - for (int & out_attr : *out_attrs) { + for (int& out_attr : *out_attrs) { if (!type_assign(&out_attr, mxnet::kCSRStorage)) - success = false; + success = false; } return success; } static bool DGLSubgraphShape(const nnvm::NodeAttrs& attrs, - mxnet::ShapeVector *in_attrs, - mxnet::ShapeVector *out_attrs) { + mxnet::ShapeVector* in_attrs, + mxnet::ShapeVector* out_attrs) { const DGLSubgraphParam& params = nnvm::get(attrs.parsed); CHECK_EQ(in_attrs->at(0).ndim(), 2U); for (size_t i = 1; i < in_attrs->size(); i++) @@ -979,28 +952,28 @@ static bool DGLSubgraphShape(const nnvm::NodeAttrs& attrs, size_t num_g = params.num_args - 1; for (size_t i = 0; i < num_g; i++) { mxnet::TShape gshape(2, -1); - gshape[0] = in_attrs->at(i + 1)[0]; - gshape[1] = in_attrs->at(i + 1)[0]; + gshape[0] = in_attrs->at(i + 1)[0]; + gshape[1] = in_attrs->at(i + 1)[0]; out_attrs->at(i) = gshape; } for (size_t i = num_g; i < out_attrs->size(); i++) { mxnet::TShape gshape(2, -1); - gshape[0] = in_attrs->at(i - num_g + 1)[0]; - gshape[1] = in_attrs->at(i - num_g + 1)[0]; + gshape[0] = in_attrs->at(i - num_g + 1)[0]; + gshape[1] = in_attrs->at(i - num_g + 1)[0]; out_attrs->at(i) = gshape; } return true; } static bool DGLSubgraphType(const nnvm::NodeAttrs& attrs, - std::vector *in_attrs, - std::vector *out_attrs) { + std::vector* in_attrs, + std::vector* out_attrs) { const DGLSubgraphParam& params = nnvm::get(attrs.parsed); - size_t num_g = params.num_args - 1; + size_t num_g = params.num_args - 1; for (size_t i = 0; i < num_g; i++) { CHECK_EQ(in_attrs->at(i + 1), mshadow::kInt64); } - for (int & out_attr : *out_attrs) { + for (int& out_attr : *out_attrs) { out_attr = in_attrs->at(0); } return true; @@ -1014,8 +987,9 @@ class Bitmap { size_t hash(dgl_id_t id) const { return id & mask; } + public: - Bitmap(const dgl_id_t *vid_data, int64_t len): map(size) { + Bitmap(const dgl_id_t* vid_data, int64_t len) : map(size) { for (int64_t i = 0; i < len; ++i) { map[hash(vid_data[i])] = true; } @@ -1034,27 +1008,30 @@ class HashTableChecker { Bitmap map; public: - HashTableChecker(const dgl_id_t *vid_data, int64_t len): map(vid_data, len) { + HashTableChecker(const dgl_id_t* vid_data, int64_t len) : map(vid_data, len) { oldv2newv.reserve(len); for (int64_t i = 0; i < len; ++i) { oldv2newv[vid_data[i]] = i; } } - void CollectOnRow(const dgl_id_t col_idx[], const dgl_id_t eids[], size_t row_len, - std::vector *new_col_idx, - std::vector *orig_eids) { + void CollectOnRow(const dgl_id_t col_idx[], + const dgl_id_t eids[], + size_t row_len, + std::vector* new_col_idx, + std::vector* orig_eids) { // TODO(zhengda) I need to make sure the column index in each row is sorted. for (size_t j = 0; j < row_len; ++j) { const dgl_id_t oldsucc = col_idx[j]; - const dgl_id_t eid = eids[j]; + const dgl_id_t eid = eids[j]; Collect(oldsucc, eid, new_col_idx, orig_eids); } } - void Collect(const dgl_id_t old_id, const dgl_id_t old_eid, - std::vector *col_idx, - std::vector *orig_eids) { + void Collect(const dgl_id_t old_id, + const dgl_id_t old_eid, + std::vector* col_idx, + std::vector* orig_eids) { if (!map.test(old_id)) return; @@ -1068,12 +1045,14 @@ class HashTableChecker { } }; -static void GetSubgraph(const NDArray &csr_arr, const NDArray &varr, - const NDArray &sub_csr, const NDArray *old_eids) { - const TBlob &data = varr.data(); - int64_t num_vertices = csr_arr.shape()[0]; - const size_t len = varr.shape()[0]; - const dgl_id_t *vid_data = data.dptr(); +static void GetSubgraph(const NDArray& csr_arr, + const NDArray& varr, + const NDArray& sub_csr, + const NDArray* old_eids) { + const TBlob& data = varr.data(); + int64_t num_vertices = csr_arr.shape()[0]; + const size_t len = varr.shape()[0]; + const dgl_id_t* vid_data = data.dptr(); HashTableChecker def_check(vid_data, len); // check if varr is sorted. CHECK(std::is_sorted(vid_data, vid_data + len)) << "The input vertex list has to be sorted"; @@ -1084,17 +1063,20 @@ static void GetSubgraph(const NDArray &csr_arr, const NDArray &varr, std::vector orig_eids; col_idx.reserve(len * 50); orig_eids.reserve(len * 50); - const dgl_id_t *eids = csr_arr.data().dptr(); - const dgl_id_t *indptr = csr_arr.aux_data(csr::kIndPtr).dptr(); - const dgl_id_t *indices = csr_arr.aux_data(csr::kIdx).dptr(); + const dgl_id_t* eids = csr_arr.data().dptr(); + const dgl_id_t* indptr = csr_arr.aux_data(csr::kIndPtr).dptr(); + const dgl_id_t* indices = csr_arr.aux_data(csr::kIdx).dptr(); for (size_t i = 0; i < len; ++i) { const dgl_id_t oldvid = vid_data[i]; - CHECK_LT(oldvid, num_vertices) << "Vertex Id " << oldvid << " isn't in a graph of " - << num_vertices << " vertices"; + CHECK_LT(oldvid, num_vertices) + << "Vertex Id " << oldvid << " isn't in a graph of " << num_vertices << " vertices"; size_t row_start = indptr[oldvid]; - size_t row_len = indptr[oldvid + 1] - indptr[oldvid]; - def_check.CollectOnRow(indices + row_start, eids + row_start, row_len, - &col_idx, old_eids == nullptr ? nullptr : &orig_eids); + size_t row_len = indptr[oldvid + 1] - indptr[oldvid]; + def_check.CollectOnRow(indices + row_start, + eids + row_start, + row_len, + &col_idx, + old_eids == nullptr ? nullptr : &orig_eids); row_idx[i + 1] = col_idx.size(); } @@ -1108,11 +1090,11 @@ static void GetSubgraph(const NDArray &csr_arr, const NDArray &varr, sub_csr.CheckAndAllocData(nz_shape); sub_csr.CheckAndAllocAuxData(csr::kIdx, nz_shape); sub_csr.CheckAndAllocAuxData(csr::kIndPtr, indptr_shape); - dgl_id_t *indices_out = sub_csr.aux_data(csr::kIdx).dptr(); - dgl_id_t *indptr_out = sub_csr.aux_data(csr::kIndPtr).dptr(); + dgl_id_t* indices_out = sub_csr.aux_data(csr::kIdx).dptr(); + dgl_id_t* indptr_out = sub_csr.aux_data(csr::kIndPtr).dptr(); std::copy(col_idx.begin(), col_idx.end(), indices_out); std::copy(row_idx.begin(), row_idx.end(), indptr_out); - dgl_id_t *sub_eids = sub_csr.data().dptr(); + dgl_id_t* sub_eids = sub_csr.data().dptr(); for (int64_t i = 0; i < nz_shape[0]; i++) sub_eids[i] = i; @@ -1121,9 +1103,9 @@ static void GetSubgraph(const NDArray &csr_arr, const NDArray &varr, old_eids->CheckAndAllocData(nz_shape); old_eids->CheckAndAllocAuxData(csr::kIdx, nz_shape); old_eids->CheckAndAllocAuxData(csr::kIndPtr, indptr_shape); - dgl_id_t *indices_out = old_eids->aux_data(csr::kIdx).dptr(); - dgl_id_t *indptr_out = old_eids->aux_data(csr::kIndPtr).dptr(); - dgl_id_t *sub_eids = old_eids->data().dptr(); + dgl_id_t* indices_out = old_eids->aux_data(csr::kIdx).dptr(); + dgl_id_t* indptr_out = old_eids->aux_data(csr::kIndPtr).dptr(); + dgl_id_t* sub_eids = old_eids->data().dptr(); std::copy(col_idx.begin(), col_idx.end(), indices_out); std::copy(row_idx.begin(), row_idx.end(), indptr_out); std::copy(orig_eids.begin(), orig_eids.end(), sub_eids); @@ -1136,16 +1118,16 @@ static void DGLSubgraphComputeExCPU(const nnvm::NodeAttrs& attrs, const std::vector& req, const std::vector& outputs) { const DGLSubgraphParam& params = nnvm::get(attrs.parsed); - int num_g = params.num_args - 1; + int num_g = params.num_args - 1; #pragma omp parallel for for (int i = 0; i < num_g; i++) { - const NDArray *old_eids = params.return_mapping ? &outputs[i + num_g] : nullptr; + const NDArray* old_eids = params.return_mapping ? &outputs[i + num_g] : nullptr; GetSubgraph(inputs[0], inputs[i + 1], outputs[i], old_eids); } } NNVM_REGISTER_OP(_contrib_dgl_subgraph) -.describe(R"code(This operator constructs an induced subgraph for + .describe(R"code(This operator constructs an induced subgraph for a given set of vertices from a graph. The operator accepts multiple sets of vertices as input. For each set of vertices, it returns a pair of CSR matrices if return_mapping is True: the first matrix contains edges @@ -1170,38 +1152,40 @@ edge Ids. [0, 5, 0]] )code" ADD_FILELINE) -.set_attr_parser(ParamParser) -.set_num_inputs([](const NodeAttrs& attrs) { - const DGLSubgraphParam& params = nnvm::get(attrs.parsed); - return params.num_args; -}) -.set_num_outputs([](const NodeAttrs& attrs) { - const DGLSubgraphParam& params = nnvm::get(attrs.parsed); - int num_varray = params.num_args - 1; - if (params.return_mapping) - return num_varray * 2; - else - return num_varray; -}) -.set_attr("FListInputNames", - [](const NodeAttrs& attrs) { - const DGLSubgraphParam& params = nnvm::get(attrs.parsed); - std::vector names; - names.reserve(params.num_args); - names.emplace_back("graph"); - for (int i = 1; i < params.num_args; ++i) - names.push_back("varray" + std::to_string(i - 1)); - return names; -}) -.set_attr("FInferStorageType", DGLSubgraphStorageType) -.set_attr("FInferShape", DGLSubgraphShape) -.set_attr("FInferType", DGLSubgraphType) -.set_attr("FComputeEx", DGLSubgraphComputeExCPU) -.set_attr("key_var_num_args", "num_args") -.add_argument("graph", "NDArray-or-Symbol", "Input graph where we sample vertices.") -.add_argument("data", "NDArray-or-Symbol[]", - "The input arrays that include data arrays and states.") -.add_arguments(DGLSubgraphParam::__FIELDS__()); + .set_attr_parser(ParamParser) + .set_num_inputs([](const NodeAttrs& attrs) { + const DGLSubgraphParam& params = nnvm::get(attrs.parsed); + return params.num_args; + }) + .set_num_outputs([](const NodeAttrs& attrs) { + const DGLSubgraphParam& params = nnvm::get(attrs.parsed); + int num_varray = params.num_args - 1; + if (params.return_mapping) + return num_varray * 2; + else + return num_varray; + }) + .set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + const DGLSubgraphParam& params = + nnvm::get(attrs.parsed); + std::vector names; + names.reserve(params.num_args); + names.emplace_back("graph"); + for (int i = 1; i < params.num_args; ++i) + names.push_back("varray" + std::to_string(i - 1)); + return names; + }) + .set_attr("FInferStorageType", DGLSubgraphStorageType) + .set_attr("FInferShape", DGLSubgraphShape) + .set_attr("FInferType", DGLSubgraphType) + .set_attr("FComputeEx", DGLSubgraphComputeExCPU) + .set_attr("key_var_num_args", "num_args") + .add_argument("graph", "NDArray-or-Symbol", "Input graph where we sample vertices.") + .add_argument("data", + "NDArray-or-Symbol[]", + "The input arrays that include data arrays and states.") + .add_arguments(DGLSubgraphParam::__FIELDS__()); ///////////////////////// Edge Id /////////////////////////// @@ -1238,13 +1222,13 @@ inline bool EdgeIDStorageType(const nnvm::NodeAttrs& attrs, std::vector* out_attrs) { CHECK_EQ(in_attrs->size(), 3U) << "Only works for 2d arrays"; CHECK_EQ(out_attrs->size(), 1U); - int& in_stype = in_attrs->at(0); - int& out_stype = out_attrs->at(0); + int& in_stype = in_attrs->at(0); + int& out_stype = out_attrs->at(0); bool dispatched = false; if (!dispatched && in_stype == kCSRStorage) { // csr -> dns - dispatched = storage_type_assign(&out_stype, kDefaultStorage, - dispatch_mode, DispatchMode::kFComputeEx); + dispatched = + storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode, DispatchMode::kFComputeEx); } if (!dispatched) { LOG(ERROR) << "Cannot dispatch edge_id storage type, only works for csr matrices"; @@ -1253,14 +1237,19 @@ inline bool EdgeIDStorageType(const nnvm::NodeAttrs& attrs, } struct edge_id_csr_forward { - template - MSHADOW_XINLINE static void Map(int i, DType* out_data, const DType* in_data, - const IType* in_indices, const IType* in_indptr, - const CType* u, const CType* v) { + template + MSHADOW_XINLINE static void Map(int i, + DType* out_data, + const DType* in_data, + const IType* in_indices, + const IType* in_indptr, + const CType* u, + const CType* v) { const int64_t target_row_id = static_cast(u[i]); - const IType target_col_id = static_cast(v[i]); - auto ptr = std::find(in_indices + in_indptr[target_row_id], - in_indices + in_indptr[target_row_id + 1], target_col_id); + const IType target_col_id = static_cast(v[i]); + auto ptr = std::find(in_indices + in_indptr[target_row_id], + in_indices + in_indptr[target_row_id + 1], + target_col_id); if (ptr == in_indices + in_indptr[target_row_id + 1]) { // does not exist in the range out_data[i] = DType(-1); @@ -1270,7 +1259,7 @@ struct edge_id_csr_forward { } }; -template +template void EdgeIDForwardCsrImpl(const OpContext& ctx, const std::vector& inputs, const OpReqType req, @@ -1278,40 +1267,45 @@ void EdgeIDForwardCsrImpl(const OpContext& ctx, using namespace mshadow; using namespace mxnet_op; using namespace csr; - if (req == kNullOp) return; + if (req == kNullOp) + return; CHECK_EQ(inputs.size(), 3U); CHECK_EQ(req, kWriteTo) << "EdgeID with CSR only supports kWriteTo"; - Stream *s = ctx.get_stream(); - const NDArray& u = inputs[1]; + Stream* s = ctx.get_stream(); + const NDArray& u = inputs[1]; const dim_t out_elems = u.shape().Size(); if (!inputs[0].storage_initialized()) { MSHADOW_TYPE_SWITCH(output.dtype(), DType, { Kernel, xpu>::Launch( - s, out_elems, output.data().dptr(), DType(-1)); + s, out_elems, output.data().dptr(), DType(-1)); }); return; } - const NDArray& data = inputs[0]; - const TBlob& in_data = data.data(); + const NDArray& data = inputs[0]; + const TBlob& in_data = data.data(); const TBlob& in_indices = data.aux_data(kIdx); - const TBlob& in_indptr = data.aux_data(kIndPtr); - const NDArray& v = inputs[2]; + const TBlob& in_indptr = data.aux_data(kIndPtr); + const NDArray& v = inputs[2]; CHECK_EQ(data.aux_type(kIdx), data.aux_type(kIndPtr)) - << "The dtypes of indices and indptr don't match"; + << "The dtypes of indices and indptr don't match"; MSHADOW_TYPE_SWITCH(data.dtype(), DType, { MSHADOW_IDX_TYPE_SWITCH(data.aux_type(kIdx), IType, { MSHADOW_TYPE_SWITCH(u.dtype(), CType, { - Kernel::Launch( - s, out_elems, output.data().dptr(), in_data.dptr(), - in_indices.dptr(), in_indptr.dptr(), - u.data().dptr(), v.data().dptr()); + Kernel::Launch(s, + out_elems, + output.data().dptr(), + in_data.dptr(), + in_indices.dptr(), + in_indptr.dptr(), + u.data().dptr(), + v.data().dptr()); }); }); }); } -template +template void EdgeIDForwardEx(const nnvm::NodeAttrs& attrs, const OpContext& ctx, const std::vector& inputs, @@ -1320,7 +1314,7 @@ void EdgeIDForwardEx(const nnvm::NodeAttrs& attrs, CHECK_EQ(inputs.size(), 3U); CHECK_EQ(outputs.size(), 1U); CHECK_EQ(req.size(), 1U); - const auto in_stype = inputs[0].storage_type(); + const auto in_stype = inputs[0].storage_type(); const auto out_stype = outputs[0].storage_type(); if (in_stype == kCSRStorage && out_stype == kDefaultStorage) { EdgeIDForwardCsrImpl(ctx, inputs, req[0], outputs[0]); @@ -1330,7 +1324,7 @@ void EdgeIDForwardEx(const nnvm::NodeAttrs& attrs, } NNVM_REGISTER_OP(_contrib_edge_id) -.describe(R"code(This operator implements the edge_id function for a graph + .describe(R"code(This operator implements the edge_id function for a graph stored in a CSR matrix (the value of the CSR stores the edge Id of the graph). output[i] = input[u[i], v[i]] if there is an edge between u[i] and v[i]], otherwise output[i] will be -1. Both u and v should be 1D vectors. @@ -1351,19 +1345,19 @@ The storage type of ``edge_id`` output depends on storage types of inputs - default and rsp inputs are not supported )code" ADD_FILELINE) -.set_num_inputs(3) -.set_num_outputs(1) -.set_attr("FListInputNames", - [](const NodeAttrs& attrs) { - return std::vector{"data", "u", "v"}; - }) -.set_attr("FInferShape", EdgeIDShape) -.set_attr("FInferType", EdgeIDType) -.set_attr("FInferStorageType", EdgeIDStorageType) -.set_attr("FComputeEx", EdgeIDForwardEx) -.add_argument("data", "NDArray-or-Symbol", "Input ndarray") -.add_argument("u", "NDArray-or-Symbol", "u ndarray") -.add_argument("v", "NDArray-or-Symbol", "v ndarray"); + .set_num_inputs(3) + .set_num_outputs(1) + .set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"data", "u", "v"}; + }) + .set_attr("FInferShape", EdgeIDShape) + .set_attr("FInferType", EdgeIDType) + .set_attr("FInferStorageType", EdgeIDStorageType) + .set_attr("FComputeEx", EdgeIDForwardEx) + .add_argument("data", "NDArray-or-Symbol", "Input ndarray") + .add_argument("u", "NDArray-or-Symbol", "u ndarray") + .add_argument("v", "NDArray-or-Symbol", "v ndarray"); ///////////////////////// DGL Adjacency /////////////////////////// @@ -1396,17 +1390,17 @@ inline bool DGLAdjacencyStorageType(const nnvm::NodeAttrs& attrs, CHECK_EQ(in_attrs->size(), 1U) << "Only works for 2d arrays"; CHECK_EQ(out_attrs->size(), 1U); int& out_stype = out_attrs->at(0); - bool dispatched = storage_type_assign(&out_stype, kCSRStorage, - dispatch_mode, DispatchMode::kFComputeEx); + bool dispatched = + storage_type_assign(&out_stype, kCSRStorage, dispatch_mode, DispatchMode::kFComputeEx); if (!dispatched) { LOG(ERROR) << "Cannot dispatch output storage type: " << common::stype_string(out_stype) - << ". dgl_adjacency only works for csr matrices"; + << ". dgl_adjacency only works for csr matrices"; } return dispatched; } NNVM_REGISTER_OP(_contrib_dgl_adjacency) -.describe(R"code(This operator converts a CSR matrix whose values are edge Ids + .describe(R"code(This operator converts a CSR matrix whose values are edge Ids to an adjacency matrix whose values are ones. The output CSR matrix always has the data value of float32. @@ -1423,17 +1417,17 @@ the data value of float32. [ 0, 0, 1 ]] )code" ADD_FILELINE) -.set_num_inputs(1) -.set_num_outputs(1) -.set_attr("FListInputNames", - [](const NodeAttrs& attrs) { - return std::vector{"data"}; - }) -.set_attr("FInferShape", DGLAdjacencyShape) -.set_attr("FInferType", DGLAdjacencyType) -.set_attr("FInferStorageType", DGLAdjacencyStorageType) -.set_attr("FComputeEx", DGLAdjacencyForwardEx) -.add_argument("data", "NDArray-or-Symbol", "Input ndarray"); + .set_num_inputs(1) + .set_num_outputs(1) + .set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"data"}; + }) + .set_attr("FInferShape", DGLAdjacencyShape) + .set_attr("FInferType", DGLAdjacencyType) + .set_attr("FInferStorageType", DGLAdjacencyStorageType) + .set_attr("FComputeEx", DGLAdjacencyForwardEx) + .add_argument("data", "NDArray-or-Symbol", "Input ndarray"); ///////////////////////// Compact subgraphs /////////////////////////// @@ -1442,30 +1436,30 @@ struct SubgraphCompactParam : public dmlc::Parameter { bool return_mapping; mxnet::Tuple graph_sizes; DMLC_DECLARE_PARAMETER(SubgraphCompactParam) { - DMLC_DECLARE_FIELD(num_args).set_lower_bound(2) - .describe("Number of input arguments."); + DMLC_DECLARE_FIELD(num_args).set_lower_bound(2).describe("Number of input arguments."); DMLC_DECLARE_FIELD(return_mapping) - .describe("Return mapping of vid and eid between the subgraph and the parent graph."); - DMLC_DECLARE_FIELD(graph_sizes) - .describe("the number of vertices in each graph."); + .describe("Return mapping of vid and eid between the subgraph and the parent graph."); + DMLC_DECLARE_FIELD(graph_sizes).describe("the number of vertices in each graph."); } }; // struct SubgraphCompactParam DMLC_REGISTER_PARAMETER(SubgraphCompactParam); -static inline size_t get_num_graphs(const SubgraphCompactParam ¶ms) { +static inline size_t get_num_graphs(const SubgraphCompactParam& params) { // Each CSR needs a 1D array to store the original vertex Id for each row. return params.num_args / 2; } -static void CompactSubgraph(const NDArray &csr, const NDArray &vids, - const NDArray &out_csr, size_t graph_size) { - TBlob in_idx_data = csr.aux_data(csr::kIdx); - TBlob in_ptr_data = csr.aux_data(csr::kIndPtr); - const dgl_id_t *indices_in = in_idx_data.dptr(); - const dgl_id_t *indptr_in = in_ptr_data.dptr(); - const dgl_id_t *row_ids = vids.data().dptr(); - size_t num_elems = csr.aux_data(csr::kIdx).shape_.Size(); +static void CompactSubgraph(const NDArray& csr, + const NDArray& vids, + const NDArray& out_csr, + size_t graph_size) { + TBlob in_idx_data = csr.aux_data(csr::kIdx); + TBlob in_ptr_data = csr.aux_data(csr::kIndPtr); + const dgl_id_t* indices_in = in_idx_data.dptr(); + const dgl_id_t* indptr_in = in_ptr_data.dptr(); + const dgl_id_t* row_ids = vids.data().dptr(); + size_t num_elems = csr.aux_data(csr::kIdx).shape_.Size(); // The last element in vids is the actual number of vertices in the subgraph. CHECK_EQ(vids.shape()[0], in_ptr_data.shape_[0]); CHECK_EQ(static_cast(row_ids[vids.shape()[0] - 1]), graph_size); @@ -1489,16 +1483,16 @@ static void CompactSubgraph(const NDArray &csr, const NDArray &vids, out_csr.CheckAndAllocAuxData(csr::kIdx, nz_shape); out_csr.CheckAndAllocAuxData(csr::kIndPtr, indptr_shape); - dgl_id_t *indices_out = out_csr.aux_data(csr::kIdx).dptr(); - dgl_id_t *indptr_out = out_csr.aux_data(csr::kIndPtr).dptr(); - dgl_id_t *sub_eids = out_csr.data().dptr(); + dgl_id_t* indices_out = out_csr.aux_data(csr::kIdx).dptr(); + dgl_id_t* indptr_out = out_csr.aux_data(csr::kIndPtr).dptr(); + dgl_id_t* sub_eids = out_csr.data().dptr(); std::copy(indptr_in, indptr_in + indptr_shape[0], indptr_out); for (int64_t i = 0; i < nz_shape[0]; i++) { dgl_id_t old_id = indices_in[i]; - auto it = id_map.find(old_id); + auto it = id_map.find(old_id); CHECK(it != id_map.end()); indices_out[i] = it->second; - sub_eids[i] = i; + sub_eids[i] = i; } } @@ -1508,7 +1502,7 @@ static void SubgraphCompactComputeExCPU(const nnvm::NodeAttrs& attrs, const std::vector& req, const std::vector& outputs) { const SubgraphCompactParam& params = nnvm::get(attrs.parsed); - int num_g = get_num_graphs(params); + int num_g = get_num_graphs(params); #pragma omp parallel for for (int i = 0; i < num_g; i++) { CompactSubgraph(inputs[i], inputs[i + num_g], outputs[i], params.graph_sizes[i]); @@ -1518,10 +1512,10 @@ static void SubgraphCompactComputeExCPU(const nnvm::NodeAttrs& attrs, static bool SubgraphCompactStorageType(const nnvm::NodeAttrs& attrs, const int dev_mask, DispatchMode* dispatch_mode, - std::vector *in_attrs, - std::vector *out_attrs) { + std::vector* in_attrs, + std::vector* out_attrs) { const SubgraphCompactParam& params = nnvm::get(attrs.parsed); - size_t num_g = get_num_graphs(params); + size_t num_g = get_num_graphs(params); CHECK_EQ(num_g * 2, in_attrs->size()); // These are the input subgraphs. for (size_t i = 0; i < num_g; i++) @@ -1530,9 +1524,9 @@ static bool SubgraphCompactStorageType(const nnvm::NodeAttrs& attrs, for (size_t i = 0; i < num_g; i++) CHECK_EQ(in_attrs->at(i + num_g), kDefaultStorage); - bool success = true; + bool success = true; *dispatch_mode = DispatchMode::kFComputeEx; - for (int & out_attr : *out_attrs) { + for (int& out_attr : *out_attrs) { if (!type_assign(&out_attr, mxnet::kCSRStorage)) success = false; } @@ -1540,10 +1534,10 @@ static bool SubgraphCompactStorageType(const nnvm::NodeAttrs& attrs, } static bool SubgraphCompactShape(const nnvm::NodeAttrs& attrs, - mxnet::ShapeVector *in_attrs, - mxnet::ShapeVector *out_attrs) { + mxnet::ShapeVector* in_attrs, + mxnet::ShapeVector* out_attrs) { const SubgraphCompactParam& params = nnvm::get(attrs.parsed); - size_t num_g = get_num_graphs(params); + size_t num_g = get_num_graphs(params); CHECK_EQ(num_g * 2, in_attrs->size()); // These are the input subgraphs. for (size_t i = 0; i < num_g; i++) { @@ -1559,8 +1553,8 @@ static bool SubgraphCompactShape(const nnvm::NodeAttrs& attrs, for (size_t i = 0; i < num_g; i++) { mxnet::TShape gshape(2, -1); - gshape[0] = params.graph_sizes[i]; - gshape[1] = params.graph_sizes[i]; + gshape[0] = params.graph_sizes[i]; + gshape[1] = params.graph_sizes[i]; out_attrs->at(i) = gshape; if (params.return_mapping) out_attrs->at(i + num_g) = gshape; @@ -1569,19 +1563,19 @@ static bool SubgraphCompactShape(const nnvm::NodeAttrs& attrs, } static bool SubgraphCompactType(const nnvm::NodeAttrs& attrs, - std::vector *in_attrs, - std::vector *out_attrs) { - for (int & in_attr : *in_attrs) { + std::vector* in_attrs, + std::vector* out_attrs) { + for (int& in_attr : *in_attrs) { CHECK_EQ(in_attr, mshadow::kInt64); } - for (int & out_attr : *out_attrs) { + for (int& out_attr : *out_attrs) { out_attr = mshadow::kInt64; } return true; } NNVM_REGISTER_OP(_contrib_dgl_graph_compact) -.describe(R"code(This operator compacts a CSR matrix generated by + .describe(R"code(This operator compacts a CSR matrix generated by dgl_csr_neighbor_uniform_sample and dgl_csr_neighbor_non_uniform_sample. The CSR matrices generated by these two operators may have many empty rows at the end and many empty columns. This operator removes these @@ -1612,38 +1606,39 @@ empty rows and empty columns. [8, 9, 0, 0, 0]]) )code" ADD_FILELINE) -.set_attr_parser(ParamParser) -.set_num_inputs([](const NodeAttrs& attrs) { - const SubgraphCompactParam& params = nnvm::get(attrs.parsed); - return params.num_args; -}) -.set_num_outputs([](const NodeAttrs& attrs) { - const SubgraphCompactParam& params = nnvm::get(attrs.parsed); - int num_varray = get_num_graphs(params); - if (params.return_mapping) - return num_varray * 2; - else - return num_varray; -}) -.set_attr("FListInputNames", - [](const NodeAttrs& attrs) { - const SubgraphCompactParam& params = nnvm::get(attrs.parsed); - std::vector names; - names.reserve(params.num_args); - size_t num_graphs = get_num_graphs(params); - for (size_t i = 0; i < num_graphs; i++) - names.push_back("graph" + std::to_string(i)); - for (size_t i = 0; i < num_graphs; ++i) - names.push_back("varray" + std::to_string(i)); - return names; -}) -.set_attr("FInferStorageType", SubgraphCompactStorageType) -.set_attr("FInferShape", SubgraphCompactShape) -.set_attr("FInferType", SubgraphCompactType) -.set_attr("FComputeEx", SubgraphCompactComputeExCPU) -.set_attr("key_var_num_args", "num_args") -.add_argument("graph_data", "NDArray-or-Symbol[]", "Input graphs and input vertex Ids.") -.add_arguments(SubgraphCompactParam::__FIELDS__()); + .set_attr_parser(ParamParser) + .set_num_inputs([](const NodeAttrs& attrs) { + const SubgraphCompactParam& params = nnvm::get(attrs.parsed); + return params.num_args; + }) + .set_num_outputs([](const NodeAttrs& attrs) { + const SubgraphCompactParam& params = nnvm::get(attrs.parsed); + int num_varray = get_num_graphs(params); + if (params.return_mapping) + return num_varray * 2; + else + return num_varray; + }) + .set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + const SubgraphCompactParam& params = + nnvm::get(attrs.parsed); + std::vector names; + names.reserve(params.num_args); + size_t num_graphs = get_num_graphs(params); + for (size_t i = 0; i < num_graphs; i++) + names.push_back("graph" + std::to_string(i)); + for (size_t i = 0; i < num_graphs; ++i) + names.push_back("varray" + std::to_string(i)); + return names; + }) + .set_attr("FInferStorageType", SubgraphCompactStorageType) + .set_attr("FInferShape", SubgraphCompactShape) + .set_attr("FInferType", SubgraphCompactType) + .set_attr("FComputeEx", SubgraphCompactComputeExCPU) + .set_attr("key_var_num_args", "num_args") + .add_argument("graph_data", "NDArray-or-Symbol[]", "Input graphs and input vertex Ids.") + .add_arguments(SubgraphCompactParam::__FIELDS__()); } // namespace op } // namespace mxnet diff --git a/src/operator/contrib/dgl_graph.cu b/src/operator/contrib/dgl_graph.cu index 336e9d4f8026..68c1d8b76537 100644 --- a/src/operator/contrib/dgl_graph.cu +++ b/src/operator/contrib/dgl_graph.cu @@ -23,7 +23,7 @@ namespace mxnet { namespace op { NNVM_REGISTER_OP(_contrib_dgl_adjacency) -.set_attr("FComputeEx", DGLAdjacencyForwardEx); + .set_attr("FComputeEx", DGLAdjacencyForwardEx); } // namespace op } // namespace mxnet diff --git a/src/operator/contrib/dynamic_shape_ops-inl.h b/src/operator/contrib/dynamic_shape_ops-inl.h index 1d1aff8a80a6..91e8f1e47f6e 100644 --- a/src/operator/contrib/dynamic_shape_ops-inl.h +++ b/src/operator/contrib/dynamic_shape_ops-inl.h @@ -19,7 +19,7 @@ /*! * Copyright (c) 2018 by Contributors * \file dynamic_shape_ops-inl.h -*/ + */ #ifndef MXNET_OPERATOR_CONTRIB_DYNAMIC_SHAPE_OPS_INL_H_ #define MXNET_OPERATOR_CONTRIB_DYNAMIC_SHAPE_OPS_INL_H_ @@ -32,17 +32,17 @@ namespace mxnet { namespace op { -template +template inline void DynamicReshapeForward(const nnvm::NodeAttrs& attrs, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { CHECK_EQ(inputs.size(), 2U); CHECK_EQ(outputs.size(), 1U); - const NDArray &out = outputs[0]; - const NDArray &idx = inputs[1]; - size_t idx_size = idx.shape()[0]; + const NDArray& out = outputs[0]; + const NDArray& idx = inputs[1]; + size_t idx_size = idx.shape()[0]; mxnet::TShape shape_value = mxnet::TShape(idx_size, 0); std::vector shapev(idx_size, 0); @@ -55,32 +55,33 @@ inline void DynamicReshapeForward(const nnvm::NodeAttrs& attrs, } }); shape_value = InferReshapeShape(mxnet::Tuple(shapev), inputs[0].shape(), false); - const_cast(out).Init(shape_value); - mshadow::Stream *s = ctx.get_stream(); + const_cast(out).Init(shape_value); + mshadow::Stream* s = ctx.get_stream(); MSHADOW_TYPE_SWITCH(out.dtype(), DType, { MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { mxnet_op::Kernel, xpu>::Launch( - s, inputs[0].data().Size(), out.data().dptr(), - inputs[0].data().dptr()); + s, inputs[0].data().Size(), out.data().dptr(), inputs[0].data().dptr()); }); }); } -template +template inline void DynamicReshapeBackward(const nnvm::NodeAttrs& attrs, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { CHECK_EQ(inputs.size(), 1U); CHECK_EQ(outputs.size(), 2U); - mshadow::Stream *s = ctx.get_stream(); + mshadow::Stream* s = ctx.get_stream(); MSHADOW_TYPE_SWITCH(outputs[0].dtype(), DType, { MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { mxnet_op::Kernel, xpu>::Launch( - s, inputs[0].data().Size(), outputs[0].data().dptr(), + s, + inputs[0].data().Size(), + outputs[0].data().dptr(), inputs[0].data().dptr()); }); }); diff --git a/src/operator/contrib/dynamic_shape_ops.cc b/src/operator/contrib/dynamic_shape_ops.cc index 1b5274cea4fe..8610d345de6d 100644 --- a/src/operator/contrib/dynamic_shape_ops.cc +++ b/src/operator/contrib/dynamic_shape_ops.cc @@ -19,7 +19,7 @@ /*! * Copyright (c) 2018 by Contributors * \file dynamic_shape_ops.cc -*/ + */ #include "./dynamic_shape_ops-inl.h" #include "../tensor/elemwise_binary_op.h" @@ -29,8 +29,8 @@ namespace mxnet { namespace op { inline bool DynamicReshapeType(const nnvm::NodeAttrs& attrs, - std::vector *in_attrs, - std::vector *out_attrs) { + std::vector* in_attrs, + std::vector* out_attrs) { CHECK_EQ(in_attrs->size(), 2U); CHECK_EQ(out_attrs->size(), 1U); TYPE_ASSIGN_CHECK(*out_attrs, 0, (*in_attrs)[0]); @@ -41,8 +41,8 @@ inline bool DynamicReshapeType(const nnvm::NodeAttrs& attrs, bool DynamicReshapeStorageType(const nnvm::NodeAttrs& attrs, const int dev_mask, DispatchMode* dispatch_mode, - std::vector *in_attrs, - std::vector *out_attrs) { + std::vector* in_attrs, + std::vector* out_attrs) { CHECK_EQ(in_attrs->size(), 2); CHECK_EQ(out_attrs->size(), 1); for (size_t i = 0; i < in_attrs->size(); ++i) { @@ -58,8 +58,8 @@ bool DynamicReshapeStorageType(const nnvm::NodeAttrs& attrs, bool DynamicReshapeBackwardStorageType(const nnvm::NodeAttrs& attrs, const int dev_mask, DispatchMode* dispatch_mode, - std::vector *in_attrs, - std::vector *out_attrs) { + std::vector* in_attrs, + std::vector* out_attrs) { CHECK_EQ(in_attrs->size(), 1); CHECK_EQ(out_attrs->size(), 2); for (size_t i = 0; i < in_attrs->size(); ++i) { @@ -73,7 +73,7 @@ bool DynamicReshapeBackwardStorageType(const nnvm::NodeAttrs& attrs, } NNVM_REGISTER_OP(_contrib_dynamic_reshape) -.describe(R"code( + .describe(R"code( Experimental support for reshape operator with dynamic shape. Accepts 2 inputs - data and shape. @@ -119,26 +119,26 @@ Example:: // out will be of shape (2,75) )code" ADD_FILELINE) -.set_num_inputs(2) -.set_num_outputs(1) -.set_attr("FListInputNames", - [](const NodeAttrs& attrs) { - return std::vector{"data", "shape"}; - }) -.set_attr("FInferType", DynamicReshapeType) -.set_attr("FInferStorageType", DynamicReshapeStorageType) -.set_attr("FComputeEx", DynamicReshapeForward) -.set_attr("FGradient", ElemwiseGradUseNone{"_backward_contrib_dynamic_reshape"}) -.add_argument("data", "NDArray-or-Symbol", "Data") -.add_argument("shape", "NDArray-or-Symbol", "Shape"); - + .set_num_inputs(2) + .set_num_outputs(1) + .set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"data", "shape"}; + }) + .set_attr("FInferType", DynamicReshapeType) + .set_attr("FInferStorageType", DynamicReshapeStorageType) + .set_attr("FComputeEx", DynamicReshapeForward) + .set_attr("FGradient", + ElemwiseGradUseNone{"_backward_contrib_dynamic_reshape"}) + .add_argument("data", "NDArray-or-Symbol", "Data") + .add_argument("shape", "NDArray-or-Symbol", "Shape"); NNVM_REGISTER_OP(_backward_contrib_dynamic_reshape) -.set_num_inputs(1) -.set_num_outputs(2) -.set_attr("TIsBackward", true) -.set_attr("FInferStorageType", DynamicReshapeBackwardStorageType) -.set_attr("FComputeEx", DynamicReshapeBackward); + .set_num_inputs(1) + .set_num_outputs(2) + .set_attr("TIsBackward", true) + .set_attr("FInferStorageType", DynamicReshapeBackwardStorageType) + .set_attr("FComputeEx", DynamicReshapeBackward); } // namespace op } // namespace mxnet diff --git a/src/operator/contrib/erfinv-inl.h b/src/operator/contrib/erfinv-inl.h index 728a11918bdd..e0bd2ba08e58 100644 --- a/src/operator/contrib/erfinv-inl.h +++ b/src/operator/contrib/erfinv-inl.h @@ -52,7 +52,6 @@ namespace mxnet { namespace op { namespace mshadow_op { - /* * Evaluate polynomial * @@ -96,13 +95,13 @@ namespace mshadow_op { */ MSHADOW_XINLINE static double polevl(double x, const double coef[], int N) { - const double *p; + const double* p; double ans; int i; - p = coef; + p = coef; ans = *p++; - i = N; + i = N; do { ans = ans * x + *p++; @@ -112,13 +111,13 @@ MSHADOW_XINLINE static double polevl(double x, const double coef[], int N) { } MSHADOW_XINLINE static double p1evl(double x, const double coef[], int N) { - const double *p; + const double* p; double ans; int i; - p = coef; + p = coef; ans = x + *p++; - i = N - 1; + i = N - 1; do { ans = ans * x + *p++; @@ -127,7 +126,6 @@ MSHADOW_XINLINE static double p1evl(double x, const double coef[], int N) { return (ans); } - /* Inverse of Normal distribution function * * SYNOPSIS: @@ -172,97 +170,97 @@ MSHADOW_XINLINE static double ndtri(double y0) { /* approximation for 0 <= |y - 0.5| <= 3/8 */ double P0[5] = { - -5.99633501014107895267E1, - 9.80010754185999661536E1, - -5.66762857469070293439E1, - 1.39312609387279679503E1, - -1.23916583867381258016E0, + -5.99633501014107895267E1, + 9.80010754185999661536E1, + -5.66762857469070293439E1, + 1.39312609387279679503E1, + -1.23916583867381258016E0, }; double Q0[8] = { - /* 1.00000000000000000000E0, */ - 1.95448858338141759834E0, - 4.67627912898881538453E0, - 8.63602421390890590575E1, - -2.25462687854119370527E2, - 2.00260212380060660359E2, - -8.20372256168333339912E1, - 1.59056225126211695515E1, - -1.18331621121330003142E0, + /* 1.00000000000000000000E0, */ + 1.95448858338141759834E0, + 4.67627912898881538453E0, + 8.63602421390890590575E1, + -2.25462687854119370527E2, + 2.00260212380060660359E2, + -8.20372256168333339912E1, + 1.59056225126211695515E1, + -1.18331621121330003142E0, }; /* Approximation for interval z = sqrt(-2 log y ) between 2 and 8 * i.e., y between exp(-2) = .135 and exp(-32) = 1.27e-14. */ double P1[9] = { - 4.05544892305962419923E0, - 3.15251094599893866154E1, - 5.71628192246421288162E1, - 4.40805073893200834700E1, - 1.46849561928858024014E1, - 2.18663306850790267539E0, - -1.40256079171354495875E-1, - -3.50424626827848203418E-2, - -8.57456785154685413611E-4, + 4.05544892305962419923E0, + 3.15251094599893866154E1, + 5.71628192246421288162E1, + 4.40805073893200834700E1, + 1.46849561928858024014E1, + 2.18663306850790267539E0, + -1.40256079171354495875E-1, + -3.50424626827848203418E-2, + -8.57456785154685413611E-4, }; double Q1[8] = { - /* 1.00000000000000000000E0, */ - 1.57799883256466749731E1, - 4.53907635128879210584E1, - 4.13172038254672030440E1, - 1.50425385692907503408E1, - 2.50464946208309415979E0, - -1.42182922854787788574E-1, - -3.80806407691578277194E-2, - -9.33259480895457427372E-4, + /* 1.00000000000000000000E0, */ + 1.57799883256466749731E1, + 4.53907635128879210584E1, + 4.13172038254672030440E1, + 1.50425385692907503408E1, + 2.50464946208309415979E0, + -1.42182922854787788574E-1, + -3.80806407691578277194E-2, + -9.33259480895457427372E-4, }; /* Approximation for interval z = sqrt(-2 log y ) between 8 and 64 * i.e., y between exp(-32) = 1.27e-14 and exp(-2048) = 3.67e-890. */ double P2[9] = { - 3.23774891776946035970E0, - 6.91522889068984211695E0, - 3.93881025292474443415E0, - 1.33303460815807542389E0, - 2.01485389549179081538E-1, - 1.23716634817820021358E-2, - 3.01581553508235416007E-4, - 2.65806974686737550832E-6, - 6.23974539184983293730E-9, + 3.23774891776946035970E0, + 6.91522889068984211695E0, + 3.93881025292474443415E0, + 1.33303460815807542389E0, + 2.01485389549179081538E-1, + 1.23716634817820021358E-2, + 3.01581553508235416007E-4, + 2.65806974686737550832E-6, + 6.23974539184983293730E-9, }; double Q2[8] = { - /* 1.00000000000000000000E0, */ - 6.02427039364742014255E0, - 3.67983563856160859403E0, - 1.37702099489081330271E0, - 2.16236993594496635890E-1, - 1.34204006088543189037E-2, - 3.28014464682127739104E-4, - 2.89247864745380683936E-6, - 6.79019408009981274425E-9, + /* 1.00000000000000000000E0, */ + 6.02427039364742014255E0, + 3.67983563856160859403E0, + 1.37702099489081330271E0, + 2.16236993594496635890E-1, + 1.34204006088543189037E-2, + 3.28014464682127739104E-4, + 2.89247864745380683936E-6, + 6.79019408009981274425E-9, }; double x, y, z, y2, x0, x1; bool code = true; - y = y0; - if (y > (1.0 - 0.13533528323661269189)) { /* 0.135... = exp(-2) */ - y = 1.0 - y; + y = y0; + if (y > (1.0 - 0.13533528323661269189)) { /* 0.135... = exp(-2) */ + y = 1.0 - y; code = false; } if (y > 0.13533528323661269189) { - y = y - 0.5; + y = y - 0.5; y2 = y * y; - x = y + y * (y2 * polevl(y2, P0, 4) / p1evl(y2, Q0, 8)); - x = x * s2pi; + x = y + y * (y2 * polevl(y2, P0, 4) / p1evl(y2, Q0, 8)); + x = x * s2pi; return (x); } - x = sqrt(-2.0 * log(y)); + x = sqrt(-2.0 * log(y)); x0 = x - log(x) / x; z = 1.0 / x; - if (x < 8.0) { /* y > exp(-32) = 1.2664165549e-14 */ + if (x < 8.0) { /* y > exp(-32) = 1.2664165549e-14 */ x1 = z * polevl(z, P1, 8) / p1evl(z, Q1, 8); } else { x1 = z * polevl(z, P2, 8) / p1evl(z, Q2, 8); @@ -275,10 +273,9 @@ MSHADOW_XINLINE static double ndtri(double y0) { return (x); } - /*! \brief inverse of the error function */ struct erfinv : public mxnet_op::tunable { - template + template MSHADOW_XINLINE static DType Map(DType v) { /* Inverse of the error function. * Computes the inverse of the error function on the restricted domain @@ -289,7 +286,7 @@ struct erfinv : public mxnet_op::tunable { const double domain_ub = 1; const double thresh = 1e-7; - double y = static_cast(v); + double y = static_cast(v); /* * For small arguments, use the Taylor expansion @@ -302,7 +299,7 @@ struct erfinv : public mxnet_op::tunable { } if ((domain_lb < y) && (y < domain_ub)) { - return DType(ndtri(0.5 * (y+1)) * M_SQRT1_2); + return DType(ndtri(0.5 * (y + 1)) * M_SQRT1_2); } else if (y == domain_lb || y == domain_ub) { return DType(std::copysign(1.0, y) * std::numeric_limits::infinity()); } else { diff --git a/src/operator/contrib/fft-inl.h b/src/operator/contrib/fft-inl.h index 7db8b26a79e2..530412c48016 100644 --- a/src/operator/contrib/fft-inl.h +++ b/src/operator/contrib/fft-inl.h @@ -22,7 +22,7 @@ * \file fft-inl.h * \brief * \author Chen Zhu -*/ + */ #ifndef MXNET_OPERATOR_CONTRIB_FFT_INL_H_ #define MXNET_OPERATOR_CONTRIB_FFT_INL_H_ #include @@ -42,34 +42,35 @@ namespace mxnet { namespace op { namespace fft { -enum fftOpInputs {kData}; -enum fftOpOutputs {kOutComplex}; // seperate the image and real parts at the moment -enum fftOpResource {kTempSpace}; // might be requiered as we need to pad the real matrices -} +enum fftOpInputs { kData }; +enum fftOpOutputs { kOutComplex }; // seperate the image and real parts at the moment +enum fftOpResource { kTempSpace }; // might be requiered as we need to pad the real matrices +} // namespace fft struct FFTParam : public dmlc::Parameter { int compute_size; // the maximum size of sub-batch to be forwarded through FFT in one time DMLC_DECLARE_PARAMETER(FFTParam) { - DMLC_DECLARE_FIELD(compute_size).set_default(128) - .describe("Maximum size of sub-batch to be forwarded at one time"); + DMLC_DECLARE_FIELD(compute_size) + .set_default(128) + .describe("Maximum size of sub-batch to be forwarded at one time"); } }; #if MXNET_USE_CUDA -template +template class FFTOp : public Operator { public: explicit FFTOp(FFTParam p) { this->param_ = p; - init_cufft_ = false; - dim_ = 0; + init_cufft_ = false; + dim_ = 0; } - virtual void Forward(const OpContext &ctx, - const std::vector &in_data, - const std::vector &req, - const std::vector &out_data, - const std::vector &aux_args) { + virtual void Forward(const OpContext& ctx, + const std::vector& in_data, + const std::vector& req, + const std::vector& out_data, + const std::vector& aux_args) { using namespace mshadow; using namespace mshadow::expr; CHECK_EQ(in_data.size(), 1); @@ -77,10 +78,10 @@ class FFTOp : public Operator { // the last dimention should be the dimension of fft vector if (!init_cufft_) { - n_ffts = in_data[fft::kData].shape_.ProdShape(0, in_data[fft::kData].ndim()-1); - dim_ = in_data[fft::kData].shape_[in_data[fft::kData].ndim()-1]; + n_ffts = in_data[fft::kData].shape_.ProdShape(0, in_data[fft::kData].ndim() - 1); + dim_ = in_data[fft::kData].shape_[in_data[fft::kData].ndim() - 1]; - stride_ = param_.compute_size*dim_; + stride_ = param_.compute_size * dim_; init_cufft_ = true; @@ -88,79 +89,75 @@ class FFTOp : public Operator { num_compute = n_ffts / param_.compute_size; } - - Stream *s = ctx.get_stream(); + Stream* s = ctx.get_stream(); // const mxnet::TShape& oshape = out_data[fft::kOutComplex].shape_; - Tensor data = in_data[fft::kData].get_with_shape( - Shape2(n_ffts, dim_), s); - Tensor out = out_data[fft::kOutComplex].get_with_shape( - Shape2(n_ffts, dim_*2), s); + Tensor data = + in_data[fft::kData].get_with_shape(Shape2(n_ffts, dim_), s); + Tensor out = + out_data[fft::kOutComplex].get_with_shape(Shape2(n_ffts, dim_ * 2), s); // need temp space to pad the data into complex numbers due to cufft interface - Tensor workspace = - ctx.requested[fft::kTempSpace].get_space_typed( - Shape1(param_.compute_size*dim_*2), s); - Tensor complex_data = Tensor(workspace.dptr_, - Shape2(param_.compute_size, dim_*2), s); + Tensor workspace = ctx.requested[fft::kTempSpace].get_space_typed( + Shape1(param_.compute_size * dim_ * 2), s); + Tensor complex_data = + Tensor(workspace.dptr_, Shape2(param_.compute_size, dim_ * 2), s); // start fft cufftHandle plan; cufftPlanMany(&plan, 1, &dim_, nullptr, 0, 0, nullptr, 0, 0, CUFFT_C2C, param_.compute_size); - for (size_t idx=0; idx < num_compute; ++idx) { - complex_data = complex_pad_imag(data.Slice(idx*param_.compute_size, - idx*param_.compute_size+param_.compute_size)); + for (size_t idx = 0; idx < num_compute; ++idx) { + complex_data = complex_pad_imag( + data.Slice(idx * param_.compute_size, idx * param_.compute_size + param_.compute_size)); - cufftComplex* in_tmp = const_cast( - reinterpret_cast(complex_data.dptr_)); - cufftComplex* out_tmp = reinterpret_cast(out.dptr_ + 2*idx*stride_); + cufftComplex* in_tmp = + const_cast(reinterpret_cast(complex_data.dptr_)); + cufftComplex* out_tmp = reinterpret_cast(out.dptr_ + 2 * idx * stride_); CHECK_EQ(cufftExecC2C(plan, in_tmp, out_tmp, CUFFT_FORWARD), CUFFT_SUCCESS); } cufftDestroy(plan); // handle the remaining samples - size_t remain_num = n_ffts - param_.compute_size*num_compute; + size_t remain_num = n_ffts - param_.compute_size * num_compute; if (remain_num > 0) { cufftHandle plan_remain; - cufftPlanMany(&plan_remain, 1, &dim_, nullptr, 0, 0, nullptr, 0, 0, - CUFFT_C2C, remain_num); + cufftPlanMany(&plan_remain, 1, &dim_, nullptr, 0, 0, nullptr, 0, 0, CUFFT_C2C, remain_num); - complex_data = Tensor(workspace.dptr_, - Shape2(remain_num, dim_*2), s); - complex_data = complex_pad_imag(data.Slice( - num_compute*param_.compute_size, num_compute*param_.compute_size+remain_num)); + complex_data = Tensor(workspace.dptr_, Shape2(remain_num, dim_ * 2), s); + complex_data = complex_pad_imag(data.Slice(num_compute * param_.compute_size, + num_compute * param_.compute_size + remain_num)); - cufftComplex* in_tmp = const_cast( - reinterpret_cast(complex_data.dptr_)); - cufftComplex* out_tmp = reinterpret_cast(out.dptr_ + 2*num_compute*stride_); + cufftComplex* in_tmp = + const_cast(reinterpret_cast(complex_data.dptr_)); + cufftComplex* out_tmp = + reinterpret_cast(out.dptr_ + 2 * num_compute * stride_); CHECK_EQ(cufftExecC2C(plan_remain, in_tmp, out_tmp, CUFFT_FORWARD), CUFFT_SUCCESS); cufftDestroy(plan_remain); } } - virtual void Backward(const OpContext &ctx, - const std::vector &out_grad, - const std::vector &in_data, - const std::vector &out_data, - const std::vector &req, - const std::vector &in_grad, - const std::vector &aux_args) { + virtual void Backward(const OpContext& ctx, + const std::vector& out_grad, + const std::vector& in_data, + const std::vector& out_data, + const std::vector& req, + const std::vector& in_grad, + const std::vector& aux_args) { using namespace mshadow; using namespace mshadow::expr; CHECK_EQ(out_grad.size(), 1); CHECK(in_data.size() == 1 && in_grad.size() == 1); CHECK_EQ(req.size(), 1); - Stream *s = ctx.get_stream(); + Stream* s = ctx.get_stream(); - Tensor gdata = in_grad[fft::kData].get_with_shape( - Shape2(n_ffts, dim_), s); - Tensor grad = out_grad[fft::kOutComplex].get_with_shape( - Shape2(n_ffts, dim_*2), s); + Tensor gdata = + in_grad[fft::kData].get_with_shape(Shape2(n_ffts, dim_), s); + Tensor grad = + out_grad[fft::kOutComplex].get_with_shape(Shape2(n_ffts, dim_ * 2), s); // need temp space to pad the data into complex numbers due to cufft interface - Tensor workspace = - ctx.requested[fft::kTempSpace].get_space_typed( - Shape1(param_.compute_size*dim_*2), s); - Tensor complex_data = Tensor(workspace.dptr_, - Shape2(param_.compute_size, dim_*2), s); + Tensor workspace = ctx.requested[fft::kTempSpace].get_space_typed( + Shape1(param_.compute_size * dim_ * 2), s); + Tensor complex_data = + Tensor(workspace.dptr_, Shape2(param_.compute_size, dim_ * 2), s); // by default, we think forward is firstly conducted // In this solution, out_grad must comes from a fft of real signal, @@ -170,32 +167,32 @@ class FFTOp : public Operator { cufftPlanMany(&plan, 1, &dim_, nullptr, 0, 0, nullptr, 0, 0, CUFFT_C2C, param_.compute_size); for (size_t idx = 0; idx < num_compute; ++idx) { cufftComplex* in_tmp = const_cast( - reinterpret_cast(grad.dptr_ + 2*idx*stride_)); + reinterpret_cast(grad.dptr_ + 2 * idx * stride_)); cufftComplex* out_tmp = reinterpret_cast(complex_data.dptr_); CHECK_EQ(cufftExecC2C(plan, in_tmp, out_tmp, CUFFT_INVERSE), CUFFT_SUCCESS); - Assign(gdata.Slice(idx*param_.compute_size, (idx+1)*param_.compute_size), - req[fft::kData], complex_toreal(complex_data)); + Assign(gdata.Slice(idx * param_.compute_size, (idx + 1) * param_.compute_size), + req[fft::kData], + complex_toreal(complex_data)); } cufftDestroy(plan); // handle the remaining samples - size_t remain_num = n_ffts - param_.compute_size*num_compute; + size_t remain_num = n_ffts - param_.compute_size * num_compute; if (remain_num > 0) { cufftHandle plan_remain; - cufftPlanMany(&plan_remain, 1, &dim_, nullptr, 0, 0, nullptr, 0, 0, - CUFFT_C2C, remain_num); - complex_data = Tensor(workspace.dptr_, - Shape2(remain_num, dim_*2), s); + cufftPlanMany(&plan_remain, 1, &dim_, nullptr, 0, 0, nullptr, 0, 0, CUFFT_C2C, remain_num); + complex_data = Tensor(workspace.dptr_, Shape2(remain_num, dim_ * 2), s); cufftComplex* in_tmp = const_cast( - reinterpret_cast(grad.dptr_ + 2*num_compute*stride_)); + reinterpret_cast(grad.dptr_ + 2 * num_compute * stride_)); cufftComplex* out_tmp = reinterpret_cast(complex_data.dptr_); CHECK_EQ(cufftExecC2C(plan_remain, in_tmp, out_tmp, CUFFT_INVERSE), CUFFT_SUCCESS); - Assign(gdata.Slice(param_.compute_size*num_compute, - param_.compute_size*num_compute+remain_num), - req[fft::kData], complex_toreal(complex_data)); + Assign(gdata.Slice(param_.compute_size * num_compute, + param_.compute_size * num_compute + remain_num), + req[fft::kData], + complex_toreal(complex_data)); cufftDestroy(plan_remain); } // for bp, we should not divide it @@ -208,11 +205,11 @@ class FFTOp : public Operator { int dim_, stride_, n_ffts; size_t num_compute; bool init_cufft_; -}; // class FFTOp +}; // class FFTOp #endif // MXNET_USE_CUDA // Declare Factory Function, used for dispatch specialization -template +template Operator* CreateOp(FFTParam param, int dtype); #if DMLC_USE_CXX11 @@ -230,27 +227,28 @@ class FFTProp : public OperatorProperty { return param_.__DICT__(); } - bool InferShape(mxnet::ShapeVector *in_shape, - mxnet::ShapeVector *out_shape, - mxnet::ShapeVector *aux_shape) const override { + bool InferShape(mxnet::ShapeVector* in_shape, + mxnet::ShapeVector* out_shape, + mxnet::ShapeVector* aux_shape) const override { using namespace mshadow; - CHECK_EQ(in_shape->size(), 1) <<"Input:[data]"; - const mxnet::TShape &dshape = (*in_shape)[fft::kData]; + CHECK_EQ(in_shape->size(), 1) << "Input:[data]"; + const mxnet::TShape& dshape = (*in_shape)[fft::kData]; // require data to be known - if (mxnet::op::shape_is_none(dshape)) return false; + if (mxnet::op::shape_is_none(dshape)) + return false; out_shape->clear(); if (dshape.ndim() == 4) { - out_shape->push_back(Shape4(dshape[0], dshape[1], dshape[2], dshape[3]*2)); + out_shape->push_back(Shape4(dshape[0], dshape[1], dshape[2], dshape[3] * 2)); } else if (dshape.ndim() == 2) { - out_shape->push_back(Shape2(dshape[0], dshape[1]*2)); + out_shape->push_back(Shape2(dshape[0], dshape[1] * 2)); } return true; } - bool InferType(std::vector *in_type, - std::vector *out_type, - std::vector *aux_type) const override { + bool InferType(std::vector* in_type, + std::vector* out_type, + std::vector* aux_type) const override { CHECK_GE(in_type->size(), 1); int dtype = (*in_type)[0]; CHECK_NE(dtype, -1) << "First input must have specified type"; @@ -268,7 +266,7 @@ class FFTProp : public OperatorProperty { OperatorProperty* Copy() const override { FFTProp* fft_sym = new FFTProp(); - fft_sym->param_ = this->param_; + fft_sym->param_ = this->param_; return fft_sym; } @@ -277,28 +275,25 @@ class FFTProp : public OperatorProperty { } // declare dependency and inplace optimization options - std::vector DeclareBackwardDependency( - const std::vector &out_grad, - const std::vector &in_data, - const std::vector &out_data) const override { + std::vector DeclareBackwardDependency(const std::vector& out_grad, + const std::vector& in_data, + const std::vector& out_data) const override { return {out_grad[fft::kOutComplex], in_data[fft::kData]}; } - std::vector ForwardResource( - const mxnet::ShapeVector &in_shape) const override { + std::vector ForwardResource(const mxnet::ShapeVector& in_shape) const override { return {ResourceRequest::kTempSpace}; } - std::vector BackwardResource( - const mxnet::ShapeVector &in_shape) const override { + std::vector BackwardResource(const mxnet::ShapeVector& in_shape) const override { return {ResourceRequest::kTempSpace}; } std::vector > BackwardInplaceOption( - const std::vector &out_grad, - const std::vector &in_data, - const std::vector &out_data, - const std::vector &in_grad) const override { + const std::vector& out_grad, + const std::vector& in_data, + const std::vector& out_data, + const std::vector& in_grad) const override { return {{in_data[fft::kData], in_grad[fft::kData]}}; } @@ -307,8 +302,9 @@ class FFTProp : public OperatorProperty { return nullptr; } - Operator* CreateOperatorEx(Context ctx, mxnet::ShapeVector *in_shape, - std::vector *in_type) const override; + Operator* CreateOperatorEx(Context ctx, + mxnet::ShapeVector* in_shape, + std::vector* in_type) const override; private: FFTParam param_; diff --git a/src/operator/contrib/fft.cc b/src/operator/contrib/fft.cc index 1262835cbb58..ec51c54476c4 100644 --- a/src/operator/contrib/fft.cc +++ b/src/operator/contrib/fft.cc @@ -22,26 +22,27 @@ * \file fft-inl.h * \brief * \author Chen Zhu -*/ + */ #include "./fft-inl.h" namespace mxnet { namespace op { -template<> -Operator *CreateOp(FFTParam param, int dtype) { +template <> +Operator* CreateOp(FFTParam param, int dtype) { LOG(FATAL) << "fft is only available for GPU."; return nullptr; } -Operator *FFTProp::CreateOperatorEx(Context ctx, mxnet::ShapeVector *in_shape, - std::vector *in_type) const { +Operator* FFTProp::CreateOperatorEx(Context ctx, + mxnet::ShapeVector* in_shape, + std::vector* in_type) const { DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0]); } DMLC_REGISTER_PARAMETER(FFTParam); MXNET_REGISTER_OP_PROPERTY(_contrib_fft, FFTProp) -.describe(R"code(Apply 1D FFT to input" + .describe(R"code(Apply 1D FFT to input" .. note:: `fft` is only available on GPU. @@ -54,7 +55,7 @@ Example:: out = mx.contrib.ndarray.fft(data = mx.nd.array(data,ctx = mx.gpu(0))) )code" ADD_FILELINE) -.add_argument("data", "NDArray-or-Symbol", "Input data to the FFTOp.") -.add_arguments(FFTParam::__FIELDS__()); + .add_argument("data", "NDArray-or-Symbol", "Input data to the FFTOp.") + .add_arguments(FFTParam::__FIELDS__()); } // namespace op } // namespace mxnet diff --git a/src/operator/contrib/fft.cu b/src/operator/contrib/fft.cu index 8cc56ade83dd..8b00b6c967dc 100644 --- a/src/operator/contrib/fft.cu +++ b/src/operator/contrib/fft.cu @@ -28,12 +28,10 @@ namespace mxnet { namespace op { -template<> +template <> Operator* CreateOp(FFTParam param, int dtype) { - Operator *op = nullptr; - MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { - op = new FFTOp(param); - }) + Operator* op = nullptr; + MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { op = new FFTOp(param); }) return op; } diff --git a/src/operator/contrib/gradient_multiplier_op.cc b/src/operator/contrib/gradient_multiplier_op.cc index f1664a3eaac4..da46a25b3db8 100644 --- a/src/operator/contrib/gradient_multiplier_op.cc +++ b/src/operator/contrib/gradient_multiplier_op.cc @@ -22,7 +22,7 @@ * \file gradient_multiplier_op.cc * \brief * \author Istvan Fehervari -*/ + */ #include "../tensor/elemwise_unary_op.h" #include "../tensor/elemwise_binary_scalar_op.h" @@ -32,22 +32,22 @@ namespace op { static bool BinaryScalarStorageType(const nnvm::NodeAttrs& attrs, const int dev_mask, DispatchMode* dispatch_mode, - std::vector *in_attrs, - std::vector *out_attrs) { + std::vector* in_attrs, + std::vector* out_attrs) { CHECK_EQ(in_attrs->size(), 1); CHECK_EQ(out_attrs->size(), 1); const auto in_stype = in_attrs->at(0); - auto &out_stype = out_attrs->at(0); - bool dispatched = false; + auto& out_stype = out_attrs->at(0); + bool dispatched = false; if (!dispatched && (in_stype == kDefaultStorage)) { // dense -> dense - dispatched = storage_type_assign(&out_stype, kDefaultStorage, - dispatch_mode, DispatchMode::kFCompute); + dispatched = + storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode, DispatchMode::kFCompute); } if (!dispatched && in_stype == kRowSparseStorage) { // row sparse -> row sparse - dispatched = storage_type_assign(&out_stype, kRowSparseStorage, - dispatch_mode, DispatchMode::kFComputeEx); + dispatched = storage_type_assign( + &out_stype, kRowSparseStorage, dispatch_mode, DispatchMode::kFComputeEx); // FComputeEx can handle dns output on cpu, too if (dev_mask == cpu::kDevMask && out_stype == kDefaultStorage) { DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx); @@ -56,8 +56,8 @@ static bool BinaryScalarStorageType(const nnvm::NodeAttrs& attrs, } if (!dispatched && in_stype == kCSRStorage) { // csr -> csr - dispatched = storage_type_assign(&out_stype, kCSRStorage, - dispatch_mode, DispatchMode::kFComputeEx); + dispatched = + storage_type_assign(&out_stype, kCSRStorage, dispatch_mode, DispatchMode::kFComputeEx); // FComputeEx can handle dns output on cpu, too if (dev_mask == cpu::kDevMask && out_stype == kDefaultStorage) { DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx); @@ -71,27 +71,28 @@ static bool BinaryScalarStorageType(const nnvm::NodeAttrs& attrs, } MXNET_OPERATOR_REGISTER_UNARY(_contrib_gradientmultiplier) -.describe(R"code(This operator implements the gradient multiplier function. + .describe(R"code(This operator implements the gradient multiplier function. In forward pass it acts as an identity transform. During backpropagation it multiplies the gradient from the subsequent level by a scalar factor lambda and passes it to the preceding layer. )code" ADD_FILELINE) -.set_attr_parser(ParamParser) -.set_attr("FInferStorageType", ElemwiseStorageType<1, 1, false, true, true>) -.set_attr("FCompute", UnaryOp::IdentityCompute) -.set_attr("FComputeEx", UnaryOp::IdentityComputeEx) -.set_attr("FGradient", ElemwiseGradUseNone{"_contrib_backward_gradientmultiplier"}) -.set_attr("FInplaceIdentity", - [](const NodeAttrs& attrs){ - return std::vector{true}; - }) -.add_arguments(NumpyBinaryScalarParam::__FIELDS__()); + .set_attr_parser(ParamParser) + .set_attr("FInferStorageType", ElemwiseStorageType<1, 1, false, true, true>) + .set_attr("FCompute", UnaryOp::IdentityCompute) + .set_attr("FComputeEx", UnaryOp::IdentityComputeEx) + .set_attr("FGradient", + ElemwiseGradUseNone{"_contrib_backward_gradientmultiplier"}) + .set_attr("FInplaceIdentity", + [](const NodeAttrs& attrs) { + return std::vector{true}; + }) + .add_arguments(NumpyBinaryScalarParam::__FIELDS__()); MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_contrib_backward_gradientmultiplier) -.set_attr("TIsBackward", true) -.set_attr("FInferStorageType", BinaryScalarStorageType) -.set_attr("FCompute", BinaryScalarOp::Compute) -.set_attr("FComputeEx", BinaryScalarOp::ComputeEx); + .set_attr("TIsBackward", true) + .set_attr("FInferStorageType", BinaryScalarStorageType) + .set_attr("FCompute", BinaryScalarOp::Compute) + .set_attr("FComputeEx", BinaryScalarOp::ComputeEx); } // namespace op } // namespace mxnet diff --git a/src/operator/contrib/gradient_multiplier_op.cu b/src/operator/contrib/gradient_multiplier_op.cu index f519f0db5f49..43d1566f36a4 100644 --- a/src/operator/contrib/gradient_multiplier_op.cu +++ b/src/operator/contrib/gradient_multiplier_op.cu @@ -22,7 +22,7 @@ * \file gradient_multiplier_op.cu * \brief * \author Istvan Fehervari -*/ + */ #include "../tensor/elemwise_unary_op.h" #include "../tensor/elemwise_binary_scalar_op.h" @@ -30,12 +30,12 @@ namespace mxnet { namespace op { NNVM_REGISTER_OP(_contrib_gradientmultiplier) -.set_attr("FComputeEx", UnaryOp::IdentityComputeEx) -.set_attr("FCompute", UnaryOp::IdentityCompute); + .set_attr("FComputeEx", UnaryOp::IdentityComputeEx) + .set_attr("FCompute", UnaryOp::IdentityCompute); NNVM_REGISTER_OP(_contrib_backward_gradientmultiplier) -.set_attr("FCompute", BinaryScalarRTCCompute{"mul"}) -.set_attr("(FComputeEx", BinaryScalarRTCCompute{"mul"}); + .set_attr("FCompute", BinaryScalarRTCCompute{"mul"}) + .set_attr("(FComputeEx", BinaryScalarRTCCompute{"mul"}); } // namespace op } // namespace mxnet diff --git a/src/operator/contrib/hawkes_ll-inl.h b/src/operator/contrib/hawkes_ll-inl.h index d5e90ad6545d..7be08afe48bd 100644 --- a/src/operator/contrib/hawkes_ll-inl.h +++ b/src/operator/contrib/hawkes_ll-inl.h @@ -37,18 +37,26 @@ namespace mxnet { namespace op { namespace hawkesll { - enum HawkesLLOpInputs {kMu, kAlpha, kBeta, kState, kIATimes, kMarks, - kValidLength, kMaxTime}; - enum HawkesLLGradInputs {kOutGradLL, kOutGradStates, kGradMu, kGradAlpha, - kGradBeta, kGradState, kGradIATimes, kGradMarks, - kGradValidLength, kGradMaxTime}; - enum HawkesLLOpOutputs {kOutLL, kOutStates}; - enum HawkesLLOpResource {kTempSpace}; +enum HawkesLLOpInputs { kMu, kAlpha, kBeta, kState, kIATimes, kMarks, kValidLength, kMaxTime }; +enum HawkesLLGradInputs { + kOutGradLL, + kOutGradStates, + kGradMu, + kGradAlpha, + kGradBeta, + kGradState, + kGradIATimes, + kGradMarks, + kGradValidLength, + kGradMaxTime +}; +enum HawkesLLOpOutputs { kOutLL, kOutStates }; +enum HawkesLLOpResource { kTempSpace }; } // namespace hawkesll inline bool HawkesLLOpType(const nnvm::NodeAttrs& attrs, - std::vector* in_attrs, - std::vector* out_attrs) { + std::vector* in_attrs, + std::vector* out_attrs) { // check dimensions of the type vectors CHECK_EQ(in_attrs->size(), 8U); CHECK_EQ(out_attrs->size(), 2U); @@ -67,8 +75,8 @@ inline bool HawkesLLOpType(const nnvm::NodeAttrs& attrs, } inline bool HawkesLLOpShape(const nnvm::NodeAttrs& attrs, - std::vector* in_attrs, - std::vector* out_attrs) { + std::vector* in_attrs, + std::vector* out_attrs) { using namespace mshadow; int N, T, K; @@ -76,18 +84,18 @@ inline bool HawkesLLOpShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(out_attrs->size(), 2U); // check ndims - CHECK_EQ(in_attrs->at(hawkesll::kMu).ndim(), 2); // mu (N, K) - CHECK_EQ(in_attrs->at(hawkesll::kAlpha).ndim(), 1); // branching ratio (K,) - CHECK_EQ(in_attrs->at(hawkesll::kBeta).ndim(), 1); // decay exponent (K,) - CHECK_EQ(in_attrs->at(hawkesll::kState).ndim(), 2); // Hawkes states (N, K) - CHECK_EQ(in_attrs->at(hawkesll::kIATimes).ndim(), 2); // i.a. times (N, T) - CHECK_EQ(in_attrs->at(hawkesll::kMarks).ndim(), 2); // marks (N, T) + CHECK_EQ(in_attrs->at(hawkesll::kMu).ndim(), 2); // mu (N, K) + CHECK_EQ(in_attrs->at(hawkesll::kAlpha).ndim(), 1); // branching ratio (K,) + CHECK_EQ(in_attrs->at(hawkesll::kBeta).ndim(), 1); // decay exponent (K,) + CHECK_EQ(in_attrs->at(hawkesll::kState).ndim(), 2); // Hawkes states (N, K) + CHECK_EQ(in_attrs->at(hawkesll::kIATimes).ndim(), 2); // i.a. times (N, T) + CHECK_EQ(in_attrs->at(hawkesll::kMarks).ndim(), 2); // marks (N, T) CHECK_EQ(in_attrs->at(hawkesll::kValidLength).ndim(), 1); // valid len (N,) - CHECK_EQ(in_attrs->at(hawkesll::kMaxTime).ndim(), 1); // max_time (N,) + CHECK_EQ(in_attrs->at(hawkesll::kMaxTime).ndim(), 1); // max_time (N,) N = in_attrs->at(hawkesll::kIATimes)[0]; // number of samples in batch T = in_attrs->at(hawkesll::kIATimes)[1]; // time length - K = in_attrs->at(hawkesll::kMu)[1]; // number of marks + K = in_attrs->at(hawkesll::kMu)[1]; // number of marks // check inputs consistent CHECK_EQ(in_attrs->at(hawkesll::kMu)[0], N); @@ -106,12 +114,12 @@ inline bool HawkesLLOpShape(const nnvm::NodeAttrs& attrs, SHAPE_ASSIGN_CHECK(*out_attrs, hawkesll::kOutStates, Shape2(N, K)) return out_attrs->at(hawkesll::kOutLL).ndim() != 0U && - out_attrs->at(hawkesll::kOutStates).Size() != 0U; + out_attrs->at(hawkesll::kOutStates).Size() != 0U; } -template +template struct hawkesll_forward { - template + template MSHADOW_XINLINE static void Map(int i, DType* out_loglike, DType* out_state, @@ -125,27 +133,26 @@ struct hawkesll_forward { DType* max_time, int K, int T, - DType* temp_register - ) { - int32_t ci; // current mark + DType* temp_register) { + int32_t ci; // current mark DType ll = 0; // log likelihood - DType t = 0; // current time + DType t = 0; // current time DType d, ed, lda, comp; - DType *last_ = &temp_register[i * K]; + DType* last_ = &temp_register[i * K]; - const DType *mu_ = &mu[i * K]; - const DType *lag_ = &lags[i * T]; - const int32_t *mark_ = &marks[i * T]; - DType *state_ = &out_state[i * K]; + const DType* mu_ = &mu[i * K]; + const DType* lag_ = &lags[i * T]; + const int32_t* mark_ = &marks[i * T]; + DType* state_ = &out_state[i * K]; // iterate over points in sequence for (index_t j = 0; j < valid_length[i]; ++j) { ci = mark_[j]; t += lag_[j]; - d = t - last_[ci]; + d = t - last_[ci]; ed = expf(-beta[ci] * d); - lda = mu_[ci] + alpha[ci] * beta[ci] * state_[ci] * ed; + lda = mu_[ci] + alpha[ci] * beta[ci] * state_[ci] * ed; comp = mu_[ci] * d + alpha[ci] * state_[ci] * (1 - ed); ll += logf(lda) - comp; @@ -159,9 +166,9 @@ struct hawkesll_forward { } }; -template +template struct hawkesll_forward_compensator { - template + template MSHADOW_XINLINE static void Map(int i, DType* rem_comp, DType* out_state, @@ -170,65 +177,59 @@ struct hawkesll_forward_compensator { const DType* beta, const DType* max_time, const int K, - const DType* last_buffer - ) { + const DType* last_buffer) { DType d, ed; int m = i % K; // mark int j = i / K; // particle // take care of the remaining compensators and state update - d = max_time[j] - last_buffer[i]; + d = max_time[j] - last_buffer[i]; ed = expf(-beta[m] * d); // return the remaining compensator - KERNEL_ASSIGN(rem_comp[i], req, - mu[i] * d + alpha[m] * out_state[i] * (1 - ed)) + KERNEL_ASSIGN(rem_comp[i], req, mu[i] * d + alpha[m] * out_state[i] * (1 - ed)) // update the state KERNEL_ASSIGN(out_state[i], req, ed * out_state[i]) } }; -template +template void HawkesLLForward(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { using namespace mshadow; using namespace mxnet_op; - Stream *s = ctx.get_stream(); + Stream* s = ctx.get_stream(); CHECK_EQ(inputs.size(), 8U); CHECK_EQ(outputs.size(), 2U); const TBlob& out_loglike = outputs[hawkesll::kOutLL]; - const TBlob& out_state = outputs[hawkesll::kOutStates]; + const TBlob& out_state = outputs[hawkesll::kOutStates]; int K = inputs[hawkesll::kMu].shape_[1]; int N = inputs[hawkesll::kIATimes].shape_[0]; int T = inputs[hawkesll::kIATimes].shape_[1]; MSHADOW_TYPE_SWITCH(out_loglike.type_flag_, DType, { - Tensor temp_space = ctx.requested[hawkesll::kTempSpace] - .get_space_typed( - Shape2(2*N, K), - s); + Tensor temp_space = + ctx.requested[hawkesll::kTempSpace].get_space_typed(Shape2(2 * N, K), s); Tensor last_buffer = Tensor(&temp_space.dptr_[0], Shape2(N, K), s); Tensor rem_comp = - Tensor(&temp_space.dptr_[N*K], Shape2(N, K), s); + Tensor(&temp_space.dptr_[N * K], Shape2(N, K), s); - Tensor out_loglike_ts = - out_loglike.get_with_shape(Shape1(N), s); + Tensor out_loglike_ts = out_loglike.get_with_shape(Shape1(N), s); last_buffer = DType(0.0); - rem_comp = DType(0.0); + rem_comp = DType(0.0); - Tensor out_state_ts = - out_state.get_with_shape(Shape2(N, K), s); + Tensor out_state_ts = out_state.get_with_shape(Shape2(N, K), s); Tensor in_state_ts = inputs[hawkesll::kState].get_with_shape(Shape2(N, K), s); @@ -236,99 +237,94 @@ void HawkesLLForward(const nnvm::NodeAttrs& attrs, MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { Kernel, xpu>::Launch( - s, N, - out_loglike.dptr(), - out_state.dptr(), - inputs[hawkesll::kMu].dptr(), // mu - inputs[hawkesll::kAlpha].dptr(), // alpha - inputs[hawkesll::kBeta].dptr(), // beta - inputs[hawkesll::kState].dptr(), // states - inputs[hawkesll::kIATimes].dptr(), // interarrival times - inputs[hawkesll::kMarks].dptr(), // marks - inputs[hawkesll::kValidLength].dptr(), // valid_length - inputs[hawkesll::kMaxTime].dptr(), // max_time - K, - T, - last_buffer.dptr_); + s, + N, + out_loglike.dptr(), + out_state.dptr(), + inputs[hawkesll::kMu].dptr(), // mu + inputs[hawkesll::kAlpha].dptr(), // alpha + inputs[hawkesll::kBeta].dptr(), // beta + inputs[hawkesll::kState].dptr(), // states + inputs[hawkesll::kIATimes].dptr(), // interarrival times + inputs[hawkesll::kMarks].dptr(), // marks + inputs[hawkesll::kValidLength].dptr(), // valid_length + inputs[hawkesll::kMaxTime].dptr(), // max_time + K, + T, + last_buffer.dptr_); }); // in parallel, we take care of the remaining compensators MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { Kernel, xpu>::Launch( - s, N * K, - rem_comp.dptr_, - out_state.dptr(), - inputs[hawkesll::kMu].dptr(), // mu - inputs[hawkesll::kAlpha].dptr(), // alpha - inputs[hawkesll::kBeta].dptr(), // beta - inputs[hawkesll::kMaxTime].dptr(), // max_time - K, - last_buffer.dptr_); + s, + N * K, + rem_comp.dptr_, + out_state.dptr(), + inputs[hawkesll::kMu].dptr(), // mu + inputs[hawkesll::kAlpha].dptr(), // alpha + inputs[hawkesll::kBeta].dptr(), // beta + inputs[hawkesll::kMaxTime].dptr(), // max_time + K, + last_buffer.dptr_); }); out_loglike_ts -= mshadow::expr::sumall_except_dim<0>(rem_comp); }) } -template +template struct hawkesll_backward { - template + template MSHADOW_XINLINE static void Map(int i, // indexes the sample (particle) DType* mu_gbfr, DType* alpha_gbfr, - DType* beta_gbfr, // (N, K) - const DType* mu, // (N, K) - const DType* alpha, // (K,) - const DType* beta, // (K,) - const DType* lags, // (N, T) - const int32_t* marks, // (N, T) + DType* beta_gbfr, // (N, K) + const DType* mu, // (N, K) + const DType* alpha, // (K,) + const DType* beta, // (K,) + const DType* lags, // (N, T) + const int32_t* marks, // (N, T) const DType* valid_length, // (N,) - const DType* max_time, // (N,) + const DType* max_time, // (N,) const int K, const int T, DType* last_buffer, DType* phi_buffer, - DType* phig_buffer - ) { + DType* phig_buffer) { int32_t ci; - int32_t part_ix_K = i*K, part_ix_T = i*T; + int32_t part_ix_K = i * K, part_ix_T = i * T; - DType t = 0, d, lda, ed; - DType* last_ = &last_buffer[part_ix_K]; - DType* state_ = &phi_buffer[part_ix_K]; + DType t = 0, d, lda, ed; + DType* last_ = &last_buffer[part_ix_K]; + DType* state_ = &phi_buffer[part_ix_K]; DType* dstate_ = &phig_buffer[part_ix_K]; - DType* mug_ = &mu_gbfr[part_ix_K]; + DType* mug_ = &mu_gbfr[part_ix_K]; DType* alphag_ = &alpha_gbfr[part_ix_K]; - DType* betag_ = &beta_gbfr[part_ix_K]; + DType* betag_ = &beta_gbfr[part_ix_K]; - const DType* lag_ = &lags[part_ix_T]; + const DType* lag_ = &lags[part_ix_T]; const int32_t* mark_ = &marks[part_ix_T]; // iterate over points - for (index_t j = 0; j < valid_length[i]; ++j){ + for (index_t j = 0; j < valid_length[i]; ++j) { ci = mark_[j]; t += lag_[j]; - d = t - last_[ci]; + d = t - last_[ci]; ed = expf(-beta[ci] * d); lda = mu[part_ix_K + ci] + alpha[ci] * beta[ci] * state_[ci] * ed; KERNEL_ASSIGN(mug_[ci], req, mug_[ci] + (1 / lda) - d) - KERNEL_ASSIGN(alphag_[ci], req, - ( - alphag_[ci] - + (beta[ci] * state_[ci] * ed) / lda - - state_[ci] * (1 - ed) - ) - ) - KERNEL_ASSIGN(betag_[ci], req, - betag_[ci] - + alpha[ci] * ed - * (state_[ci] * (1 - beta[ci] * d) + beta[ci] * dstate_[ci]) - / lda - - alpha[ci] - * (dstate_[ci] * (1 - ed) + state_[ci] * d * ed) - ) + KERNEL_ASSIGN(alphag_[ci], + req, + (alphag_[ci] + (beta[ci] * state_[ci] * ed) / lda - state_[ci] * (1 - ed))) + KERNEL_ASSIGN( + betag_[ci], + req, + betag_[ci] + + alpha[ci] * ed * (state_[ci] * (1 - beta[ci] * d) + beta[ci] * dstate_[ci]) / lda - + alpha[ci] * (dstate_[ci] * (1 - ed) + state_[ci] * d * ed)) KERNEL_ASSIGN(dstate_[ci], req, ed * (-d * state_[ci] + dstate_[ci])) KERNEL_ASSIGN(state_[ci], req, 1 + (state_[ci] * ed)) @@ -338,47 +334,39 @@ struct hawkesll_backward { } }; - -template +template struct hawkesll_backward_compensator { - template + template MSHADOW_XINLINE static void Map(int i, DType* mu_gbfr, DType* alpha_gbfr, - DType* beta_gbfr, // (N, K) - DType* out_grad, // read this (N,) - const DType* mu, // (N, K) - const DType* alpha, // (K,) - const DType* beta, // (K,) + DType* beta_gbfr, // (N, K) + DType* out_grad, // read this (N,) + const DType* mu, // (N, K) + const DType* alpha, // (K,) + const DType* beta, // (K,) const DType* max_time, // (N,) const int K, DType* last_buffer, DType* phi_buffer, - DType* phig_buffer - ) { + DType* phig_buffer) { DType d, ed; - int m = i % K; // mark - int j = i / K; // particle - int32_t part_ix_K = j*K; - DType* mug_ = &mu_gbfr[part_ix_K]; - DType* alphag_ = &alpha_gbfr[part_ix_K]; - DType* betag_ = &beta_gbfr[part_ix_K]; + int m = i % K; // mark + int j = i / K; // particle + int32_t part_ix_K = j * K; + DType* mug_ = &mu_gbfr[part_ix_K]; + DType* alphag_ = &alpha_gbfr[part_ix_K]; + DType* betag_ = &beta_gbfr[part_ix_K]; // take care of the remaining compensators and state update - d = max_time[j] - last_buffer[i]; + d = max_time[j] - last_buffer[i]; ed = expf(-beta[m] * d); // take care of the gradients of the remaining compensator KERNEL_ASSIGN(mug_[m], req, mug_[m] - d) - KERNEL_ASSIGN(alphag_[m], req, - alphag_[m] - phi_buffer[i] * (1 - ed) - ) - KERNEL_ASSIGN(betag_[m], req, - betag_[m] - alpha[m] * ( - phig_buffer[i] * (1 - ed) - + phi_buffer[i] * d * ed - ) - ) + KERNEL_ASSIGN(alphag_[m], req, alphag_[m] - phi_buffer[i] * (1 - ed)) + KERNEL_ASSIGN( + betag_[m], req, betag_[m] - alpha[m] * (phig_buffer[i] * (1 - ed) + phi_buffer[i] * d * ed)) // // correct the gradients with respect to output gradients KERNEL_ASSIGN(mug_[m], req, out_grad[j] * mug_[m]) @@ -387,23 +375,23 @@ struct hawkesll_backward_compensator { } }; -template +template void HawkesLLBackward(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { CHECK_EQ(inputs.size(), 10U); CHECK_EQ(outputs.size(), 8U); CHECK_EQ(req.size(), 8U); - mshadow::Stream *s = ctx.get_stream(); + mshadow::Stream* s = ctx.get_stream(); int K = inputs[hawkesll::kGradMu].shape_[1]; // mu data int N = inputs[hawkesll::kGradIATimes].shape_[0]; int T = inputs[hawkesll::kGradIATimes].shape_[1]; - CHECK_EQ(inputs[hawkesll::kOutGradLL].shape_[0], N); // grad of out 0 (LL) + CHECK_EQ(inputs[hawkesll::kOutGradLL].shape_[0], N); // grad of out 0 (LL) CHECK_EQ(inputs[hawkesll::kOutGradStates].shape_[0], N); // grad out 1-states CHECK_EQ(inputs[hawkesll::kOutGradStates].shape_[1], K); @@ -420,29 +408,24 @@ void HawkesLLBackward(const nnvm::NodeAttrs& attrs, MSHADOW_TYPE_SWITCH(out_grad.type_flag_, DType, { // allocate gradient buffers Tensor bfr = - ctx.requested[hawkesll::kTempSpace] - .get_space_typed(Shape2(6*N, K), s); + ctx.requested[hawkesll::kTempSpace].get_space_typed(Shape2(6 * N, K), s); - Tensor alpha_gbfr = - Tensor(&bfr.dptr_[N*K], Shape2(N, K), s); - Tensor beta_gbfr = - Tensor(&bfr.dptr_[2*N*K], Shape2(N, K), s); + Tensor alpha_gbfr = Tensor(&bfr.dptr_[N * K], Shape2(N, K), s); + Tensor beta_gbfr = Tensor(&bfr.dptr_[2 * N * K], Shape2(N, K), s); Tensor last_buffer = - Tensor(&bfr.dptr_[3*N*K], Shape2(N, K), s); + Tensor(&bfr.dptr_[3 * N * K], Shape2(N, K), s); Tensor phig_buffer = - Tensor(&bfr.dptr_[4*N*K], Shape2(N, K), s); + Tensor(&bfr.dptr_[4 * N * K], Shape2(N, K), s); Tensor phi_buffer = - Tensor(&bfr.dptr_[5*N*K], Shape2(N, K), s); + Tensor(&bfr.dptr_[5 * N * K], Shape2(N, K), s); - alpha_gbfr = DType(0.0); - beta_gbfr = DType(0.0); + alpha_gbfr = DType(0.0); + beta_gbfr = DType(0.0); last_buffer = DType(0.0); phig_buffer = DType(0.0); - mshadow::Copy(phi_buffer, - inputs[hawkesll::kGradState] - .get_with_shape(Shape2(N, K), s), - s); + mshadow::Copy( + phi_buffer, inputs[hawkesll::kGradState].get_with_shape(Shape2(N, K), s), s); // get the gradient to be output Tensor in_grad_mu = @@ -456,47 +439,47 @@ void HawkesLLBackward(const nnvm::NodeAttrs& attrs, MXNET_ASSIGN_REQ_SWITCH(req[hawkesll::kMu], req_type, { Kernel, xpu>::Launch( - s, - N, - in_grad_mu.dptr_, alpha_gbfr.dptr_, beta_gbfr.dptr_, // gradients - inputs[hawkesll::kGradMu].dptr(), // mu data - inputs[hawkesll::kGradAlpha].dptr(), // alpha data - inputs[hawkesll::kGradBeta].dptr(), // beta data - inputs[hawkesll::kGradIATimes].dptr(), // lags data - inputs[hawkesll::kGradMarks].dptr(), // marks data - inputs[hawkesll::kGradValidLength].dptr(), // valid_length data - inputs[hawkesll::kGradMaxTime].dptr(), // max_time data - K, - T, - last_buffer.dptr_, // buffer to keep timestamp of last item - phi_buffer.dptr_, // "states" - phig_buffer.dptr_); // derivatives of "states" + s, + N, + in_grad_mu.dptr_, + alpha_gbfr.dptr_, + beta_gbfr.dptr_, // gradients + inputs[hawkesll::kGradMu].dptr(), // mu data + inputs[hawkesll::kGradAlpha].dptr(), // alpha data + inputs[hawkesll::kGradBeta].dptr(), // beta data + inputs[hawkesll::kGradIATimes].dptr(), // lags data + inputs[hawkesll::kGradMarks].dptr(), // marks data + inputs[hawkesll::kGradValidLength].dptr(), // valid_length data + inputs[hawkesll::kGradMaxTime].dptr(), // max_time data + K, + T, + last_buffer.dptr_, // buffer to keep timestamp of last item + phi_buffer.dptr_, // "states" + phig_buffer.dptr_); // derivatives of "states" }); MXNET_ASSIGN_REQ_SWITCH(req[hawkesll::kMu], req_type, { Kernel, xpu>::Launch( - s, - N * K, - in_grad_mu.dptr_, alpha_gbfr.dptr_, beta_gbfr.dptr_, // gradients - out_grad.dptr(), - inputs[hawkesll::kGradMu].dptr(), // mu data - inputs[hawkesll::kGradAlpha].dptr(), // alpha data - inputs[hawkesll::kGradBeta].dptr(), // beta data - inputs[hawkesll::kGradMaxTime].dptr(), // max_time data - K, - last_buffer.dptr_, // buffer to keep timestamp of last item - phi_buffer.dptr_, // "states" - phig_buffer.dptr_); // derivatives of "states" + s, + N * K, + in_grad_mu.dptr_, + alpha_gbfr.dptr_, + beta_gbfr.dptr_, // gradients + out_grad.dptr(), + inputs[hawkesll::kGradMu].dptr(), // mu data + inputs[hawkesll::kGradAlpha].dptr(), // alpha data + inputs[hawkesll::kGradBeta].dptr(), // beta data + inputs[hawkesll::kGradMaxTime].dptr(), // max_time data + K, + last_buffer.dptr_, // buffer to keep timestamp of last item + phi_buffer.dptr_, // "states" + phig_buffer.dptr_); // derivatives of "states" }); // reduce the gradients - Assign(in_grad_alpha, req[hawkesll::kAlpha], - mshadow::expr::sumall_except_dim<1>(alpha_gbfr) - ) + Assign(in_grad_alpha, req[hawkesll::kAlpha], mshadow::expr::sumall_except_dim<1>(alpha_gbfr)) - Assign(in_grad_beta, req[hawkesll::kBeta], - mshadow::expr::sumall_except_dim<1>(beta_gbfr) - ) + Assign(in_grad_beta, req[hawkesll::kBeta], mshadow::expr::sumall_except_dim<1>(beta_gbfr)) }) } diff --git a/src/operator/contrib/hawkes_ll.cc b/src/operator/contrib/hawkes_ll.cc index 1e2fff5c9871..34d4c9e4f501 100644 --- a/src/operator/contrib/hawkes_ll.cc +++ b/src/operator/contrib/hawkes_ll.cc @@ -84,58 +84,40 @@ Example:: )code" ADD_FILELINE) .set_num_inputs(8) .set_num_outputs(2) - .set_attr("FListInputNames", + .set_attr( + "FListInputNames", [](const NodeAttrs& attrs) { - return std::vector{ - "lda", "alpha", "beta", "state", "lags", - "marks", "valid_length", "max_time" - }; - }) + return std::vector{ + "lda", "alpha", "beta", "state", "lags", "marks", "valid_length", "max_time"}; + }) .set_attr("FListOutputNames", - [](const NodeAttrs& attrs) { - return std::vector{"output", "out_state"}; - }) + [](const NodeAttrs& attrs) { + return std::vector{"output", "out_state"}; + }) .set_attr("FInferShape", HawkesLLOpShape) .set_attr("FInferType", HawkesLLOpType) .set_attr("FCompute", HawkesLLForward) - .set_attr( - "FGradient", ElemwiseGradUseIn{"_contrib_backward_hawkesll"} - ) - .set_attr("FResourceRequest", [](const NodeAttrs& n) { - return std::vector{ResourceRequest::Type::kTempSpace}; - }) + .set_attr("FGradient", ElemwiseGradUseIn{"_contrib_backward_hawkesll"}) + .set_attr("FResourceRequest", + [](const NodeAttrs& n) { + return std::vector{ + ResourceRequest::Type::kTempSpace}; + }) .set_attr("THasDeterministicOutput", true) - .add_argument( - "lda", "NDArray-or-Symbol", - "Shape (N, K) The intensity for each of the K processes, for each sample" - ) - .add_argument( - "alpha", "NDArray-or-Symbol", - "Shape (K,) The infectivity factor (branching ratio) for each process" - ) - .add_argument( - "beta", "NDArray-or-Symbol", - "Shape (K,) The decay parameter for each process" - ) - .add_argument( - "state", "NDArray-or-Symbol", - "Shape (N, K) the Hawkes state for each process" - ) - .add_argument( - "lags", "NDArray-or-Symbol", - "Shape (N, T) the interarrival times" - ) - .add_argument( - "marks", "NDArray-or-Symbol", - "Shape (N, T) the marks (process ids)" - ) - .add_argument( - "valid_length", "NDArray-or-Symbol", - "The number of valid points in the process" - ) - .add_argument( - "max_time", "NDArray-or-Symbol", - "the length of the interval where the processes were sampled"); + .add_argument("lda", + "NDArray-or-Symbol", + "Shape (N, K) The intensity for each of the K processes, for each sample") + .add_argument("alpha", + "NDArray-or-Symbol", + "Shape (K,) The infectivity factor (branching ratio) for each process") + .add_argument("beta", "NDArray-or-Symbol", "Shape (K,) The decay parameter for each process") + .add_argument("state", "NDArray-or-Symbol", "Shape (N, K) the Hawkes state for each process") + .add_argument("lags", "NDArray-or-Symbol", "Shape (N, T) the interarrival times") + .add_argument("marks", "NDArray-or-Symbol", "Shape (N, T) the marks (process ids)") + .add_argument("valid_length", "NDArray-or-Symbol", "The number of valid points in the process") + .add_argument("max_time", + "NDArray-or-Symbol", + "the length of the interval where the processes were sampled"); NNVM_REGISTER_OP(_contrib_backward_hawkesll) .set_num_inputs(10) @@ -143,7 +125,7 @@ NNVM_REGISTER_OP(_contrib_backward_hawkesll) .set_attr("TIsBackward", true) .set_attr("FCompute", HawkesLLBackward) .set_attr("FResourceRequest", [](const NodeAttrs& n) { - return std::vector{ResourceRequest::Type::kTempSpace}; + return std::vector{ResourceRequest::Type::kTempSpace}; }); } // namespace op } // namespace mxnet diff --git a/src/operator/contrib/hawkes_ll.cu b/src/operator/contrib/hawkes_ll.cu old mode 100755 new mode 100644 index d35d7d0b0c08..9e1e6af655f9 --- a/src/operator/contrib/hawkes_ll.cu +++ b/src/operator/contrib/hawkes_ll.cu @@ -28,11 +28,10 @@ namespace mxnet { namespace op { -NNVM_REGISTER_OP(_contrib_hawkesll) -.set_attr("FCompute", HawkesLLForward); +NNVM_REGISTER_OP(_contrib_hawkesll).set_attr("FCompute", HawkesLLForward); NNVM_REGISTER_OP(_contrib_backward_hawkesll) -.set_attr("FCompute", HawkesLLBackward); + .set_attr("FCompute", HawkesLLBackward); } // namespace op } // namespace mxnet diff --git a/src/operator/contrib/index_array-inl.h b/src/operator/contrib/index_array-inl.h index e280d7661b7c..c47f976cd923 100644 --- a/src/operator/contrib/index_array-inl.h +++ b/src/operator/contrib/index_array-inl.h @@ -29,17 +29,14 @@ namespace mxnet { namespace op { namespace index_array_enum { -enum IndexArrayOpInputs {kIn}; -enum IndexArrayOpOutputs {kOut}; -enum IndexArrayOpResource {kTempSpace}; +enum IndexArrayOpInputs { kIn }; +enum IndexArrayOpOutputs { kOut }; +enum IndexArrayOpResource { kTempSpace }; } // namespace index_array_enum -template +template struct IndexArrayKernel { - MSHADOW_XINLINE static void Map(int i, - int64_t* out_data, - const int n, - const int64_t* workspace) { + MSHADOW_XINLINE static void Map(int i, int64_t* out_data, const int n, const int64_t* workspace) { for (ptrdiff_t j = 0; j < n; j++) { int64_t upper = workspace[ptrdiff_t(2) * j]; int64_t lower = workspace[ptrdiff_t(2) * j + ptrdiff_t(1)]; @@ -48,12 +45,9 @@ struct IndexArrayKernel { } }; -template +template struct IndexArrayDefaultKernel { - MSHADOW_XINLINE static void Map(int i, - int64_t* out_data, - const int ndim, - const dim_t* shape) { + MSHADOW_XINLINE static void Map(int i, int64_t* out_data, const int ndim, const dim_t* shape) { int64_t index = i; for (ptrdiff_t j = ndim - 1; j >= 0; j--) { KERNEL_ASSIGN(out_data[ptrdiff_t(i) * ptrdiff_t(ndim) + j], req, index % shape[j]); @@ -62,7 +56,7 @@ struct IndexArrayDefaultKernel { } }; -inline std::vector IndexArrayComputeIndexProducts(const TShape &inshape) { +inline std::vector IndexArrayComputeIndexProducts(const TShape& inshape) { const int ndim = inshape.ndim(); std::vector index_products(static_cast(ndim + 1)); @@ -76,15 +70,15 @@ inline std::vector IndexArrayComputeIndexProducts(const TShape &inshape return index_products; } -inline void IndexArrayBuildSelectedAxesWorkspace(const mxnet::Tuple &axes, - const std::vector &index_products, +inline void IndexArrayBuildSelectedAxesWorkspace(const mxnet::Tuple& axes, + const std::vector& index_products, int64_t* workspace, const int ndim) { for (int i = 0; i < axes.ndim(); i++) { // Make sure that the axis is between 0 and ndim. const int axis = ((axes[i] % ndim) + ndim) % ndim; - workspace[ptrdiff_t(2) * ptrdiff_t(i)] = index_products[axis]; + workspace[ptrdiff_t(2) * ptrdiff_t(i)] = index_products[axis]; workspace[ptrdiff_t(2) * ptrdiff_t(i) + ptrdiff_t(1)] = index_products[axis + 1]; } } @@ -92,8 +86,9 @@ inline void IndexArrayBuildSelectedAxesWorkspace(const mxnet::Tuple &axes, struct IndexArrayParam : public dmlc::Parameter { dmlc::optional> axes; DMLC_DECLARE_PARAMETER(IndexArrayParam) { - DMLC_DECLARE_FIELD(axes).set_default(dmlc::optional>()) - .describe("The axes to include in the index array. Supports negative values."); + DMLC_DECLARE_FIELD(axes) + .set_default(dmlc::optional>()) + .describe("The axes to include in the index array. Supports negative values."); } }; // struct IndexArrayParam diff --git a/src/operator/contrib/index_array.cc b/src/operator/contrib/index_array.cc index ef4f030863f2..82cd468f81e3 100644 --- a/src/operator/contrib/index_array.cc +++ b/src/operator/contrib/index_array.cc @@ -19,34 +19,33 @@ #include #include "./index_array-inl.h" - namespace mxnet { namespace op { -void IndexArrayForwardCPU(const nnvm::NodeAttrs &attrs, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { +void IndexArrayForwardCPU(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { using namespace mshadow; CHECK_EQ(inputs.size(), 1U); CHECK_EQ(outputs.size(), 1U); CHECK_EQ(req.size(), 1U); - const TBlob& in_data = inputs[0]; + const TBlob& in_data = inputs[0]; const TBlob& out_data = outputs[0]; const IndexArrayParam& param = nnvm::get(attrs.parsed); const TShape inshape = in_data.shape_; - const int ndim = inshape.ndim(); + const int ndim = inshape.ndim(); - Stream *stream = ctx.get_stream(); + Stream* stream = ctx.get_stream(); using namespace mxnet_op; if (param.axes.has_value()) { const mxnet::Tuple& axes = param.axes.value(); - const int naxes = axes.ndim(); + const int naxes = axes.ndim(); std::vector index_products = IndexArrayComputeIndexProducts(inshape); @@ -56,13 +55,13 @@ void IndexArrayForwardCPU(const nnvm::NodeAttrs &attrs, IndexArrayBuildSelectedAxesWorkspace(axes, index_products, workspace.dptr_, ndim); MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { - Kernel, cpu>::Launch(stream, in_data.Size(), - out_data.dptr(), naxes, workspace.dptr_); + Kernel, cpu>::Launch( + stream, in_data.Size(), out_data.dptr(), naxes, workspace.dptr_); }); } else { MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { - Kernel, cpu>::Launch(stream, in_data.Size(), - out_data.dptr(), ndim, inshape.data()); + Kernel, cpu>::Launch( + stream, in_data.Size(), out_data.dptr(), ndim, inshape.data()); }); } } @@ -70,7 +69,7 @@ void IndexArrayForwardCPU(const nnvm::NodeAttrs &attrs, DMLC_REGISTER_PARAMETER(IndexArrayParam); NNVM_REGISTER_OP(_contrib_index_array) -.describe(R"code(Returns an array of indexes of the input array. + .describe(R"code(Returns an array of indexes of the input array. For an input array with shape :math:`(d_1, d_2, ..., d_n)`, `index_array` returns a :math:`(d_1, d_2, ..., d_n, n)` array `idx`, where @@ -116,58 +115,62 @@ Examples:: [1 2]]]] )code" ADD_FILELINE) -.set_num_inputs(1) -.set_num_outputs(1) -.set_attr("FListInputNames", - [](const NodeAttrs &attrs) { - return std::vector{ "data" }; - }) -.set_attr("FListOutputNames", - [](const NodeAttrs &attrs) { - return std::vector{ "output" }; + .set_num_inputs(1) + .set_num_outputs(1) + .set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"data"}; + }) + .set_attr("FListOutputNames", + [](const NodeAttrs& attrs) { + return std::vector{"output"}; + }) + .set_attr_parser(ParamParser) + .set_attr("FInferShape", + [](const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector* in_shape, + mxnet::ShapeVector* out_shape) { + const IndexArrayParam& param = + nnvm::get(attrs.parsed); + CHECK_EQ(in_shape->size(), 1U); + CHECK_EQ(out_shape->size(), 1U); + const mxnet::TShape& inshape = + (*in_shape)[index_array_enum::kIn]; + if (!mxnet::ndim_is_known(inshape)) + return false; + + mxnet::TShape oshape = mxnet::TShape(inshape.ndim() + 1, 0); + + for (int i = 0; i < inshape.ndim(); i++) { + oshape[i] = inshape[i]; + } + if (param.axes.has_value()) { + oshape[inshape.ndim()] = param.axes.value().ndim(); + } else { + oshape[inshape.ndim()] = inshape.ndim(); + } + + SHAPE_ASSIGN_CHECK(*out_shape, 0, oshape); + return shape_is_known(oshape); }) -.set_attr_parser(ParamParser) -.set_attr("FInferShape", [](const nnvm::NodeAttrs &attrs, - mxnet::ShapeVector *in_shape, - mxnet::ShapeVector *out_shape) { - const IndexArrayParam ¶m = nnvm::get(attrs.parsed); - CHECK_EQ(in_shape->size(), 1U); - CHECK_EQ(out_shape->size(), 1U); - const mxnet::TShape &inshape = (*in_shape)[index_array_enum::kIn]; - if (!mxnet::ndim_is_known(inshape)) return false; - - mxnet::TShape oshape = mxnet::TShape(inshape.ndim() + 1, 0); - - for (int i = 0; i < inshape.ndim(); i++) { - oshape[i] = inshape[i]; - } - if (param.axes.has_value()) { - oshape[inshape.ndim()] = param.axes.value().ndim(); - } else { - oshape[inshape.ndim()] = inshape.ndim(); - } - - SHAPE_ASSIGN_CHECK(*out_shape, 0, oshape); - return shape_is_known(oshape); -}) -.set_attr("FInferType", [](const nnvm::NodeAttrs &attrs, - std::vector *in_attrs, - std::vector *out_attrs) { - CHECK_EQ(in_attrs->size(), 1U); - CHECK_EQ(out_attrs->size(), 1U); - TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kInt64); - return out_attrs->at(0) != -1; -}) -.set_attr("FCompute", IndexArrayForwardCPU) -.set_attr("FGradient", MakeZeroGradNodes) -.set_attr("FResourceRequest", [](const NodeAttrs& n) { - return std::vector{ResourceRequest::kTempSpace}; -}) -.set_attr("THasDeterministicOutput", true) -.add_argument("data", "NDArray-or-Symbol", "Input data") -.add_arguments(IndexArrayParam::__FIELDS__()); - + .set_attr("FInferType", + [](const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kInt64); + return out_attrs->at(0) != -1; + }) + .set_attr("FCompute", IndexArrayForwardCPU) + .set_attr("FGradient", MakeZeroGradNodes) + .set_attr("FResourceRequest", + [](const NodeAttrs& n) { + return std::vector{ResourceRequest::kTempSpace}; + }) + .set_attr("THasDeterministicOutput", true) + .add_argument("data", "NDArray-or-Symbol", "Input data") + .add_arguments(IndexArrayParam::__FIELDS__()); } // namespace op } // namespace mxnet - diff --git a/src/operator/contrib/index_array.cu b/src/operator/contrib/index_array.cu index dae61ca71b6d..482cbf6b8150 100644 --- a/src/operator/contrib/index_array.cu +++ b/src/operator/contrib/index_array.cu @@ -24,31 +24,31 @@ namespace op { using namespace mshadow::cuda; -void IndexArrayForwardGPU(const nnvm::NodeAttrs &attrs, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { +void IndexArrayForwardGPU(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { using namespace mshadow; CHECK_EQ(inputs.size(), 1U); CHECK_EQ(outputs.size(), 1U); CHECK_EQ(req.size(), 1U); - const TBlob& in_data = inputs[0]; + const TBlob& in_data = inputs[0]; const TBlob& out_data = outputs[0]; const IndexArrayParam& param = nnvm::get(attrs.parsed); const TShape inshape = in_data.shape_; - const int ndim = inshape.ndim(); + const int ndim = inshape.ndim(); - Stream *s = ctx.get_stream(); + Stream* s = ctx.get_stream(); cudaStream_t stream = Stream::GetStream(s); using namespace mxnet_op; if (param.axes.has_value()) { const mxnet::Tuple& axes = param.axes.value(); - const int naxes = axes.ndim(); + const int naxes = axes.ndim(); std::vector index_products = IndexArrayComputeIndexProducts(inshape); @@ -58,29 +58,31 @@ void IndexArrayForwardGPU(const nnvm::NodeAttrs &attrs, Tensor workspace = ctx.requested[0].get_space_typed(Shape1(2 * naxes), s); - CUDA_CALL(cudaMemcpyAsync(workspace.dptr_, cpu_workspace.data(), sizeof(int64_t) * (2 * naxes), - cudaMemcpyHostToDevice, stream)); + CUDA_CALL(cudaMemcpyAsync(workspace.dptr_, + cpu_workspace.data(), + sizeof(int64_t) * (2 * naxes), + cudaMemcpyHostToDevice, + stream)); MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { - Kernel, gpu>::Launch(s, in_data.Size(), - out_data.dptr(), naxes, workspace.dptr_); + Kernel, gpu>::Launch( + s, in_data.Size(), out_data.dptr(), naxes, workspace.dptr_); }); } else { Tensor workspace = ctx.requested[0].get_space_typed(Shape1(ndim), s); - CUDA_CALL(cudaMemcpyAsync(workspace.dptr_, inshape.data(), sizeof(dim_t) * ndim, - cudaMemcpyHostToDevice, stream)); + CUDA_CALL(cudaMemcpyAsync( + workspace.dptr_, inshape.data(), sizeof(dim_t) * ndim, cudaMemcpyHostToDevice, stream)); MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { - Kernel, gpu>::Launch(s, in_data.Size(), - out_data.dptr(), ndim, workspace.dptr_); + Kernel, gpu>::Launch( + s, in_data.Size(), out_data.dptr(), ndim, workspace.dptr_); }); } } -NNVM_REGISTER_OP(_contrib_index_array) -.set_attr("FCompute", IndexArrayForwardGPU); +NNVM_REGISTER_OP(_contrib_index_array).set_attr("FCompute", IndexArrayForwardGPU); } // namespace op } // namespace mxnet diff --git a/src/operator/contrib/index_copy-inl.h b/src/operator/contrib/index_copy-inl.h index 35bfcd0e77b6..21418b324be7 100644 --- a/src/operator/contrib/index_copy-inl.h +++ b/src/operator/contrib/index_copy-inl.h @@ -37,14 +37,14 @@ namespace mxnet { namespace op { -template +template void IndexCopyForward(const nnvm::NodeAttrs& attrs, const OpContext& ctx, const std::vector& inputs, const std::vector& req, const std::vector& outputs); -template +template void IndexCopyBackward(const nnvm::NodeAttrs& attrs, const OpContext& ctx, const std::vector& inputs, @@ -52,8 +52,8 @@ void IndexCopyBackward(const nnvm::NodeAttrs& attrs, const std::vector& outputs); inline bool IndexCopyShape(const nnvm::NodeAttrs& attrs, - mxnet::ShapeVector *in_attrs, - mxnet::ShapeVector *out_attrs) { + mxnet::ShapeVector* in_attrs, + mxnet::ShapeVector* out_attrs) { // inputs[0]: original tensor // inputs[1]: index vector // inputs[2]: copied tensor diff --git a/src/operator/contrib/index_copy.cc b/src/operator/contrib/index_copy.cc index 2543f2deca6d..79cb3005f538 100644 --- a/src/operator/contrib/index_copy.cc +++ b/src/operator/contrib/index_copy.cc @@ -27,13 +27,13 @@ namespace mxnet { namespace op { struct index_copy_fwd_cpu { - template + template static void Map(index_t i, const DType* new_tensor, const IType* idx, DType* out_tensor, int dim_size) { - DType* out_ptr = out_tensor + static_cast(idx[i]) * dim_size; + DType* out_ptr = out_tensor + static_cast(idx[i]) * dim_size; const DType* new_ptr = new_tensor + i * dim_size; #pragma GCC diagnostic push #if __GNUC__ >= 8 @@ -44,7 +44,7 @@ struct index_copy_fwd_cpu { } }; -template<> +template <> void IndexCopyForward(const nnvm::NodeAttrs& attrs, const OpContext& ctx, const std::vector& inputs, @@ -56,27 +56,31 @@ void IndexCopyForward(const nnvm::NodeAttrs& attrs, CHECK_EQ(outputs.size(), 1U); CHECK_EQ(req.size(), 1U); CHECK(req[0] != kAddTo); - if (req[0] == kNullOp) return; - mshadow::Stream *s = ctx.get_stream(); - const TBlob& out = outputs[0]; + if (req[0] == kNullOp) + return; + mshadow::Stream* s = ctx.get_stream(); + const TBlob& out = outputs[0]; const TBlob& original_tensor = inputs[0]; - const TBlob& idx_vector = inputs[1]; - const TBlob& copied_tensor = inputs[2]; - int dim_size = inputs[2].Size() / inputs[1].Size(); + const TBlob& idx_vector = inputs[1]; + const TBlob& copied_tensor = inputs[2]; + int dim_size = inputs[2].Size() / inputs[1].Size(); // copy original tensor to output copy(s, out, original_tensor); // index copy MSHADOW_TYPE_SWITCH(out.type_flag_, DType, { MSHADOW_TYPE_SWITCH(idx_vector.type_flag_, IType, { - Kernel::Launch( - s, idx_vector.Size(), copied_tensor.dptr(), - idx_vector.dptr(), out.dptr(), dim_size); + Kernel::Launch(s, + idx_vector.Size(), + copied_tensor.dptr(), + idx_vector.dptr(), + out.dptr(), + dim_size); }); }); } struct index_copy_bwd_cpu { - template + template static void Map(int i, const DType* out_tensor_grad, DType* orig_tensor_grad, @@ -86,9 +90,9 @@ struct index_copy_bwd_cpu { int idx_size, OpReqType orig_req, OpReqType new_req) { - const int index = idx[i]; - DType* new_ptr = new_tensor_grad + i * dim_size; - DType* orig_ptr = orig_tensor_grad + index * dim_size; + const int index = idx[i]; + DType* new_ptr = new_tensor_grad + i * dim_size; + DType* orig_ptr = orig_tensor_grad + index * dim_size; const DType* src_ptr = out_tensor_grad + index * dim_size; for (int iter = 0; iter < dim_size; ++iter) { KERNEL_ASSIGN(new_ptr[iter], new_req, src_ptr[iter]); @@ -110,7 +114,7 @@ struct index_copy_bwd_cpu { } }; -template<> +template <> void IndexCopyBackward(const nnvm::NodeAttrs& attrs, const OpContext& ctx, const std::vector& inputs, @@ -120,15 +124,15 @@ void IndexCopyBackward(const nnvm::NodeAttrs& attrs, using namespace mxnet_op; CHECK_EQ(inputs.size(), 4U); CHECK_EQ(outputs.size(), 3U); - Stream *s = ctx.get_stream(); - const TBlob& out_grad = inputs[0]; - const TBlob& index = inputs[2]; + Stream* s = ctx.get_stream(); + const TBlob& out_grad = inputs[0]; + const TBlob& index = inputs[2]; const TBlob& in_grad_1 = outputs[0]; const TBlob& in_grad_2 = outputs[2]; - int dim_size = inputs[3].Size() / inputs[2].Size(); - int index_size = inputs[2].Size(); - OpReqType orig_req = req[0]; - OpReqType new_req = req[2]; + int dim_size = inputs[3].Size() / inputs[2].Size(); + int index_size = inputs[2].Size(); + OpReqType orig_req = req[0]; + OpReqType new_req = req[2]; // index_copy_backward MSHADOW_TYPE_SWITCH(out_grad.type_flag_, DType, { MSHADOW_TYPE_SWITCH(index.type_flag_, IType, { @@ -141,20 +145,29 @@ void IndexCopyBackward(const nnvm::NodeAttrs& attrs, break; case kAddTo: Kernel, cpu>::Launch( - s, out_grad.Size(), in_grad_1.dptr(), - out_grad.dptr(), in_grad_1.dptr()); + s, + out_grad.Size(), + in_grad_1.dptr(), + out_grad.dptr(), + in_grad_1.dptr()); } - Kernel::Launch( - s, index_size, out_grad.dptr(), - in_grad_1.dptr(), in_grad_2.dptr(), - index.dptr(), dim_size, index_size, orig_req, new_req); + Kernel::Launch(s, + index_size, + out_grad.dptr(), + in_grad_1.dptr(), + in_grad_2.dptr(), + index.dptr(), + dim_size, + index_size, + orig_req, + new_req); }); }); } static bool IndexCopyType(const nnvm::NodeAttrs& attrs, - std::vector *in_attrs, - std::vector *out_attrs) { + std::vector* in_attrs, + std::vector* out_attrs) { CHECK_EQ(in_attrs->size(), 3U); CHECK_EQ(out_attrs->size(), 1U); TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); @@ -163,7 +176,7 @@ static bool IndexCopyType(const nnvm::NodeAttrs& attrs, } NNVM_REGISTER_OP(_contrib_index_copy) -.describe(R"code(Copies the elements of a `new_tensor` into the `old_tensor`. + .describe(R"code(Copies the elements of a `new_tensor` into the `old_tensor`. This operator copies the elements by selecting the indices in the order given in `index`. The output will be a new tensor containing the rest elements of old tensor and @@ -191,25 +204,26 @@ Examples:: )code" ADD_FILELINE) -.set_num_inputs(3) -.set_num_outputs(1) -.set_attr("FInferShape", IndexCopyShape) -.set_attr("FInferType", IndexCopyType) -.set_attr("FGradient", ElemwiseGradUseIn{"_contrib_backward_index_copy"}) -.set_attr("FCompute", IndexCopyForward) -.set_attr("FListInputNames", - [](const NodeAttrs& attrs) { - return std::vector{"old_tensor", "index_vector", "new_tensor"}; - }) -.add_argument("old_tensor", "NDArray-or-Symbol", "Old tensor") -.add_argument("index_vector", "NDArray-or-Symbol", "Index vector") -.add_argument("new_tensor", "NDArray-or-Symbol", "New tensor to be copied"); + .set_num_inputs(3) + .set_num_outputs(1) + .set_attr("FInferShape", IndexCopyShape) + .set_attr("FInferType", IndexCopyType) + .set_attr("FGradient", ElemwiseGradUseIn{"_contrib_backward_index_copy"}) + .set_attr("FCompute", IndexCopyForward) + .set_attr( + "FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"old_tensor", "index_vector", "new_tensor"}; + }) + .add_argument("old_tensor", "NDArray-or-Symbol", "Old tensor") + .add_argument("index_vector", "NDArray-or-Symbol", "Index vector") + .add_argument("new_tensor", "NDArray-or-Symbol", "New tensor to be copied"); NNVM_REGISTER_OP(_contrib_backward_index_copy) -.set_num_inputs(4) -.set_num_outputs(3) -.set_attr("TIsBackward", true) -.set_attr("FCompute", IndexCopyBackward); + .set_num_inputs(4) + .set_num_outputs(3) + .set_attr("TIsBackward", true) + .set_attr("FCompute", IndexCopyBackward); } // namespace op } // namespace mxnet diff --git a/src/operator/contrib/index_copy.cu b/src/operator/contrib/index_copy.cu index 53f2600aba06..131c389a8cf5 100644 --- a/src/operator/contrib/index_copy.cu +++ b/src/operator/contrib/index_copy.cu @@ -27,18 +27,18 @@ namespace mxnet { namespace op { struct index_copy_fwd_gpu { - template + template MSHADOW_XINLINE static void Map(int i, const DType* new_tensor, const IType* idx, DType* out_tensor, int dim_size) { - int index = static_cast(idx[i / dim_size]); + int index = static_cast(idx[i / dim_size]); out_tensor[index * dim_size + i % dim_size] = new_tensor[i]; } }; -template<> +template <> void IndexCopyForward(const nnvm::NodeAttrs& attrs, const OpContext& ctx, const std::vector& inputs, @@ -50,27 +50,31 @@ void IndexCopyForward(const nnvm::NodeAttrs& attrs, CHECK_EQ(outputs.size(), 1U); CHECK_EQ(req.size(), 1U); CHECK(req[0] != kAddTo); - if (req[0] == kNullOp) return; - mshadow::Stream *s = ctx.get_stream(); - const TBlob& out = outputs[0]; + if (req[0] == kNullOp) + return; + mshadow::Stream* s = ctx.get_stream(); + const TBlob& out = outputs[0]; const TBlob& original_tensor = inputs[0]; - const TBlob& idx_vector = inputs[1]; - const TBlob& copied_tensor = inputs[2]; - int dim_size = inputs[2].Size() / inputs[1].Size(); + const TBlob& idx_vector = inputs[1]; + const TBlob& copied_tensor = inputs[2]; + int dim_size = inputs[2].Size() / inputs[1].Size(); // copy original tensor to output copy(s, out, original_tensor); // index copy MSHADOW_TYPE_SWITCH(out.type_flag_, DType, { MSHADOW_TYPE_SWITCH(idx_vector.type_flag_, IType, { - Kernel::Launch( - s, copied_tensor.Size(), copied_tensor.dptr(), - idx_vector.dptr(), out.dptr(), dim_size); + Kernel::Launch(s, + copied_tensor.Size(), + copied_tensor.dptr(), + idx_vector.dptr(), + out.dptr(), + dim_size); }); }); } struct index_copy_bwd_gpu { - template + template MSHADOW_XINLINE static void Map(int i, const DType* out_grad, DType* orig_grad, @@ -92,7 +96,7 @@ struct index_copy_bwd_gpu { } }; -template<> +template <> void IndexCopyBackward(const nnvm::NodeAttrs& attrs, const OpContext& ctx, const std::vector& inputs, @@ -102,15 +106,15 @@ void IndexCopyBackward(const nnvm::NodeAttrs& attrs, using namespace mxnet_op; CHECK_EQ(inputs.size(), 4U); CHECK_EQ(outputs.size(), 3U); - Stream *s = ctx.get_stream(); - const TBlob& out_grad = inputs[0]; - const TBlob& index = inputs[2]; + Stream* s = ctx.get_stream(); + const TBlob& out_grad = inputs[0]; + const TBlob& index = inputs[2]; const TBlob& in_grad_1 = outputs[0]; const TBlob& in_grad_2 = outputs[2]; - int dim_size = inputs[3].Size() / inputs[2].Size(); - int index_size = inputs[2].Size(); - OpReqType orig_req = req[0]; - OpReqType new_req = req[2]; + int dim_size = inputs[3].Size() / inputs[2].Size(); + int index_size = inputs[2].Size(); + OpReqType orig_req = req[0]; + OpReqType new_req = req[2]; // index_copy_backward MSHADOW_TYPE_SWITCH(out_grad.type_flag_, DType, { MSHADOW_TYPE_SWITCH(index.type_flag_, IType, { @@ -123,22 +127,30 @@ void IndexCopyBackward(const nnvm::NodeAttrs& attrs, break; case kAddTo: Kernel, gpu>::Launch( - s, out_grad.Size(), in_grad_1.dptr(), - out_grad.dptr(), in_grad_1.dptr()); + s, + out_grad.Size(), + in_grad_1.dptr(), + out_grad.dptr(), + in_grad_1.dptr()); } - Kernel::Launch( - s, in_grad_2.Size(), out_grad.dptr(), - in_grad_1.dptr(), in_grad_2.dptr(), - index.dptr(), dim_size, index_size, orig_req, new_req); + Kernel::Launch(s, + in_grad_2.Size(), + out_grad.dptr(), + in_grad_1.dptr(), + in_grad_2.dptr(), + index.dptr(), + dim_size, + index_size, + orig_req, + new_req); }); }); } -NNVM_REGISTER_OP(_contrib_index_copy) -.set_attr("FCompute", IndexCopyForward); +NNVM_REGISTER_OP(_contrib_index_copy).set_attr("FCompute", IndexCopyForward); NNVM_REGISTER_OP(_contrib_backward_index_copy) -.set_attr("FCompute", IndexCopyBackward); + .set_attr("FCompute", IndexCopyBackward); } // namespace op } // namespace mxnet diff --git a/src/operator/contrib/intgemm/intgemm_fully_connected_op.cc b/src/operator/contrib/intgemm/intgemm_fully_connected_op.cc index 216f5ce47ecc..610f0f257507 100644 --- a/src/operator/contrib/intgemm/intgemm_fully_connected_op.cc +++ b/src/operator/contrib/intgemm/intgemm_fully_connected_op.cc @@ -42,18 +42,18 @@ struct IntgemmFullyConnectedParam : public dmlc::Parameter