Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit ce56127

Browse files
committed
Initial commit: only src/ - clang-formater
1 parent 68dc5a9 commit ce56127

File tree

941 files changed

+84080
-75352
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

941 files changed

+84080
-75352
lines changed

src/api/_api_internal/_api_internal.cc

Lines changed: 51 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -35,80 +35,72 @@
3535
namespace mxnet {
3636

3737
MXNET_REGISTER_GLOBAL("_Integer")
38-
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
39-
using namespace runtime;
40-
if (args[0].type_code() == kDLInt) {
41-
*ret = Integer(args[0].operator int64_t());
42-
} else {
43-
LOG(FATAL) << "only accept int";
44-
}
45-
});
38+
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
39+
using namespace runtime;
40+
if (args[0].type_code() == kDLInt) {
41+
*ret = Integer(args[0].operator int64_t());
42+
} else {
43+
LOG(FATAL) << "only accept int";
44+
}
45+
});
4646

47-
MXNET_REGISTER_GLOBAL("_Float")
48-
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
49-
using namespace runtime;
50-
if (args[0].type_code() == kDLFloat) {
51-
*ret = Float(args[0].operator double());
52-
} else {
53-
LOG(FATAL) << "only accept float";
54-
}
47+
MXNET_REGISTER_GLOBAL("_Float").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
48+
using namespace runtime;
49+
if (args[0].type_code() == kDLFloat) {
50+
*ret = Float(args[0].operator double());
51+
} else {
52+
LOG(FATAL) << "only accept float";
53+
}
5554
});
5655

57-
MXNET_REGISTER_GLOBAL("_ADT")
58-
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
59-
using namespace runtime;
60-
std::vector<ObjectRef> data;
61-
for (int i = 0; i < args.size(); ++i) {
62-
if (args[i].type_code() == kNDArrayHandle) {
63-
mxnet::NDArray *array = args[i].operator mxnet::NDArray*();
64-
ObjectRef input = NDArrayHandle(array);
65-
data.push_back(input);
66-
} else if (args[i].type_code() != kNull) {
67-
ObjectRef input = String::CanConvertFrom(args[i]) ? args[i].operator String()
68-
: args[i].operator ObjectRef();
69-
data.push_back(input);
70-
} else {
71-
data.emplace_back(nullptr);
72-
}
56+
MXNET_REGISTER_GLOBAL("_ADT").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
57+
using namespace runtime;
58+
std::vector<ObjectRef> data;
59+
for (int i = 0; i < args.size(); ++i) {
60+
if (args[i].type_code() == kNDArrayHandle) {
61+
mxnet::NDArray* array = args[i].operator mxnet::NDArray*();
62+
ObjectRef input = NDArrayHandle(array);
63+
data.push_back(input);
64+
} else if (args[i].type_code() != kNull) {
65+
ObjectRef input = String::CanConvertFrom(args[i]) ? args[i].operator String()
66+
: args[i].operator ObjectRef();
67+
data.push_back(input);
68+
} else {
69+
data.emplace_back(nullptr);
7370
}
74-
*ret = ADT(0, data.begin(), data.end());
71+
}
72+
*ret = ADT(0, data.begin(), data.end());
7573
});
7674

77-
MXNET_REGISTER_GLOBAL("_Map")
78-
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
79-
using namespace runtime;
80-
CHECK_EQ(args.size() % 2, 0);
81-
std::unordered_map<ObjectRef, ObjectRef, ObjectHash, ObjectEqual> data;
82-
for (int i = 0; i < args.num_args; i += 2) {
83-
ObjectRef k =
84-
String::CanConvertFrom(args[i]) ? args[i].operator String()
85-
: args[i].operator ObjectRef();
86-
ObjectRef v;
87-
if (args[i + 1].type_code() == kNDArrayHandle) {
88-
mxnet::NDArray *array = args[i + 1].operator mxnet::NDArray*();
89-
v = NDArrayHandle(array);
90-
} else {
91-
v = args[i + 1];
92-
}
93-
data.emplace(std::move(k), std::move(v));
75+
MXNET_REGISTER_GLOBAL("_Map").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
76+
using namespace runtime;
77+
CHECK_EQ(args.size() % 2, 0);
78+
std::unordered_map<ObjectRef, ObjectRef, ObjectHash, ObjectEqual> data;
79+
for (int i = 0; i < args.num_args; i += 2) {
80+
ObjectRef k =
81+
String::CanConvertFrom(args[i]) ? args[i].operator String() : args[i].operator ObjectRef();
82+
ObjectRef v;
83+
if (args[i + 1].type_code() == kNDArrayHandle) {
84+
mxnet::NDArray* array = args[i + 1].operator mxnet::NDArray*();
85+
v = NDArrayHandle(array);
86+
} else {
87+
v = args[i + 1];
9488
}
95-
*ret = Map<ObjectRef, ObjectRef>(data);
89+
data.emplace(std::move(k), std::move(v));
90+
}
91+
*ret = Map<ObjectRef, ObjectRef>(data);
9692
});
9793

98-
MXNET_REGISTER_GLOBAL("_String")
99-
.set_body([] (runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
94+
MXNET_REGISTER_GLOBAL("_String").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
10095
using namespace runtime;
10196
std::string str = args[0].operator std::string();
102-
*ret = String(std::move(str));
97+
*ret = String(std::move(str));
10398
});
10499

105-
MXNET_REGISTER_GLOBAL("_echo")
106-
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
100+
MXNET_REGISTER_GLOBAL("_echo").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
107101
*ret = args[0];
108102
});
109103

110-
MXNET_REGISTER_API("_nop")
111-
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
112-
});
104+
MXNET_REGISTER_API("_nop").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {});
113105

114106
} // namespace mxnet

src/api/cached_op_api.cc

Lines changed: 100 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -30,119 +30,118 @@
3030
namespace mxnet {
3131

3232
MXNET_REGISTER_GLOBAL("cached_op.invoke")
33-
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
34-
CachedOpPtr op_shared = *static_cast<CachedOpPtr*>(args[0].value().v_handle);
35-
// CachedOp* points to CachedOpThreadSafe object if CreateCachedOpEX
36-
// was called with thread_safe=true
37-
CachedOp* op = dynamic_cast<CachedOp*>(op_shared.get());
33+
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
34+
CachedOpPtr op_shared = *static_cast<CachedOpPtr*>(args[0].value().v_handle);
35+
// CachedOp* points to CachedOpThreadSafe object if CreateCachedOpEX
36+
// was called with thread_safe=true
37+
CachedOp* op = dynamic_cast<CachedOp*>(op_shared.get());
3838

39-
int num_inputs = args[1];
40-
int args_size = args.size();
41-
std::vector<NDArray*> ndinputs;
42-
ndinputs.reserve(num_inputs);
43-
for (int i = 2; i < num_inputs + 2; ++i) {
44-
ndinputs.push_back(static_cast<mxnet::NDArray*>(args[i]));
45-
}
39+
int num_inputs = args[1];
40+
int args_size = args.size();
41+
std::vector<NDArray*> ndinputs;
42+
ndinputs.reserve(num_inputs);
43+
for (int i = 2; i < num_inputs + 2; ++i) {
44+
ndinputs.push_back(static_cast<mxnet::NDArray*>(args[i]));
45+
}
4646

47-
std::vector<NDArray*> ndoutputs;
48-
ndoutputs.reserve(op->num_outputs());
49-
if (args[num_inputs + 4].type_code() == kNull) {
50-
for (int i = 0; i < op->num_outputs(); ++i) ndoutputs.push_back(new NDArray());
51-
} else {
52-
int array_size = args_size - num_inputs - 4;
53-
CHECK_EQ(array_size, op->num_outputs())
54-
<< "CachedOp expects " << op->num_outputs() << " outputs, but "
55-
<< array_size << " was given.";
56-
for (int i = num_inputs + 4; i < array_size; ++i) {
57-
ndoutputs.push_back(args[i].operator mxnet::NDArray*());
58-
}
59-
}
47+
std::vector<NDArray*> ndoutputs;
48+
ndoutputs.reserve(op->num_outputs());
49+
if (args[num_inputs + 4].type_code() == kNull) {
50+
for (int i = 0; i < op->num_outputs(); ++i)
51+
ndoutputs.push_back(new NDArray());
52+
} else {
53+
int array_size = args_size - num_inputs - 4;
54+
CHECK_EQ(array_size, op->num_outputs()) << "CachedOp expects " << op->num_outputs()
55+
<< " outputs, but " << array_size << " was given.";
56+
for (int i = num_inputs + 4; i < array_size; ++i) {
57+
ndoutputs.push_back(args[i].operator mxnet::NDArray*());
58+
}
59+
}
6060

61-
int default_dev_type;
62-
int default_dev_id;
63-
if (args[num_inputs + 2].type_code() != kNull) {
64-
default_dev_type = args[num_inputs + 2];
65-
default_dev_id = args[num_inputs + 3];
66-
} else {
67-
const Context &ctx = ndinputs[0]->ctx();
68-
default_dev_type = ctx.dev_type;
69-
default_dev_id = ctx.dev_id;
70-
}
61+
int default_dev_type;
62+
int default_dev_id;
63+
if (args[num_inputs + 2].type_code() != kNull) {
64+
default_dev_type = args[num_inputs + 2];
65+
default_dev_id = args[num_inputs + 3];
66+
} else {
67+
const Context& ctx = ndinputs[0]->ctx();
68+
default_dev_type = ctx.dev_type;
69+
default_dev_id = ctx.dev_id;
70+
}
7171

72-
// construct default context
73-
Context ctx = Context::Create(static_cast<Context::DeviceType>(default_dev_type),
74-
default_dev_id);
75-
op->Forward(op_shared, ndinputs, ndoutputs, ctx);
72+
// construct default context
73+
Context ctx =
74+
Context::Create(static_cast<Context::DeviceType>(default_dev_type), default_dev_id);
75+
op->Forward(op_shared, ndinputs, ndoutputs, ctx);
7676

77-
if (op->num_outputs() == 1) {
78-
*ret = ndoutputs[0];
79-
} else {
80-
std::vector<ObjectRef> outputs;
81-
outputs.reserve(op->num_outputs());
82-
for (int i = 0; i < op->num_outputs(); ++i) {
83-
ObjectRef out = NDArrayHandle(ndoutputs[i]);
84-
outputs.push_back(out);
85-
delete ndoutputs[i];
86-
}
87-
*ret = runtime::ADT(0, outputs.begin(), outputs.end());
88-
}
89-
});
77+
if (op->num_outputs() == 1) {
78+
*ret = ndoutputs[0];
79+
} else {
80+
std::vector<ObjectRef> outputs;
81+
outputs.reserve(op->num_outputs());
82+
for (int i = 0; i < op->num_outputs(); ++i) {
83+
ObjectRef out = NDArrayHandle(ndoutputs[i]);
84+
outputs.push_back(out);
85+
delete ndoutputs[i];
86+
}
87+
*ret = runtime::ADT(0, outputs.begin(), outputs.end());
88+
}
89+
});
9090

9191
MXNET_REGISTER_GLOBAL("cached_op.create")
92-
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
93-
nnvm::Symbol* sym = static_cast<nnvm::Symbol*>(args[0].value().v_handle);
94-
Object* flags_ptr = static_cast<Object*>(args[1].value().v_handle);
95-
auto* n = static_cast<const runtime::MapObj*>(flags_ptr);
96-
int num_flags = static_cast<int>(n->size());
97-
bool thread_safe = args[2];
98-
std::vector<std::pair<std::string, std::string> > flags;
99-
flags.reserve(num_flags);
100-
for (const auto& kv : *n) {
101-
flags.emplace_back(std::string(runtime::Downcast<runtime::String>(kv.first)),
102-
std::string(runtime::Downcast<runtime::String>(kv.second)));
103-
}
104-
mxnet::CachedOpPtr *out;
105-
if (!thread_safe) {
106-
out = new CachedOpPtr(new CachedOp(*sym, flags));
107-
} else {
108-
out = new CachedOpPtr(new CachedOpThreadSafe(*sym, flags));
109-
}
110-
*ret = static_cast<void*>(out);
111-
});
92+
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
93+
nnvm::Symbol* sym = static_cast<nnvm::Symbol*>(args[0].value().v_handle);
94+
Object* flags_ptr = static_cast<Object*>(args[1].value().v_handle);
95+
auto* n = static_cast<const runtime::MapObj*>(flags_ptr);
96+
int num_flags = static_cast<int>(n->size());
97+
bool thread_safe = args[2];
98+
std::vector<std::pair<std::string, std::string> > flags;
99+
flags.reserve(num_flags);
100+
for (const auto& kv : *n) {
101+
flags.emplace_back(std::string(runtime::Downcast<runtime::String>(kv.first)),
102+
std::string(runtime::Downcast<runtime::String>(kv.second)));
103+
}
104+
mxnet::CachedOpPtr* out;
105+
if (!thread_safe) {
106+
out = new CachedOpPtr(new CachedOp(*sym, flags));
107+
} else {
108+
out = new CachedOpPtr(new CachedOpThreadSafe(*sym, flags));
109+
}
110+
*ret = static_cast<void*>(out);
111+
});
112112

113113
MXNET_REGISTER_GLOBAL("cached_op.free")
114-
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
115-
CachedOpPtr* g = static_cast<CachedOpPtr*>(args[0].value().v_handle);
116-
delete g;
117-
});
114+
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
115+
CachedOpPtr* g = static_cast<CachedOpPtr*>(args[0].value().v_handle);
116+
delete g;
117+
});
118118

119119
MXNET_REGISTER_GLOBAL("cached_op.get_optimized_symbol")
120-
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
121-
auto s = new nnvm::Symbol();
122-
CachedOpPtr op = *static_cast<CachedOpPtr*>(args[0].value().v_handle);
123-
*s = op->GetOptimizedSymbol();
124-
*ret = static_cast<void*>(static_cast<SymbolHandle>(s));
125-
});
120+
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
121+
auto s = new nnvm::Symbol();
122+
CachedOpPtr op = *static_cast<CachedOpPtr*>(args[0].value().v_handle);
123+
*s = op->GetOptimizedSymbol();
124+
*ret = static_cast<void*>(static_cast<SymbolHandle>(s));
125+
});
126126

127127
MXNET_REGISTER_GLOBAL("cached_op.register_op_hook")
128-
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
129-
CachedOpHandle handle = static_cast<CachedOpHandle>(args[0].value().v_handle);
130-
CachedOpMonitorCallback callback = reinterpret_cast<CachedOpMonitorCallback>(
131-
reinterpret_cast<void (*)(const char *, const char *, void *)>(args[1].value().v_handle));
132-
bool monitor_all = args[2];
133-
CachedOpMonitorCallback callback_temp = nullptr;
134-
std::function<void(const char *, const char *, void*)> clbk;
135-
if (callback) {
136-
callback_temp = callback;
137-
clbk = [callback_temp](const char *name, const char *opr_name,
138-
void *handle) {
139-
callback_temp(name, opr_name, handle);
140-
};
141-
} else {
142-
clbk = nullptr;
143-
}
144-
CachedOpPtr op = *static_cast<CachedOpPtr *>(handle);
145-
op->RegisterOpHook(clbk, monitor_all);
146-
});
128+
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
129+
CachedOpHandle handle = static_cast<CachedOpHandle>(args[0].value().v_handle);
130+
CachedOpMonitorCallback callback = reinterpret_cast<CachedOpMonitorCallback>(
131+
reinterpret_cast<void (*)(const char*, const char*, void*)>(args[1].value().v_handle));
132+
bool monitor_all = args[2];
133+
CachedOpMonitorCallback callback_temp = nullptr;
134+
std::function<void(const char*, const char*, void*)> clbk;
135+
if (callback) {
136+
callback_temp = callback;
137+
clbk = [callback_temp](const char* name, const char* opr_name, void* handle) {
138+
callback_temp(name, opr_name, handle);
139+
};
140+
} else {
141+
clbk = nullptr;
142+
}
143+
CachedOpPtr op = *static_cast<CachedOpPtr*>(handle);
144+
op->RegisterOpHook(clbk, monitor_all);
145+
});
147146

148147
} // namespace mxnet

src/api/operator/numpy/linalg/np_det.cc

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
* to you under the Apache License, Version 2.0 (the
77
* "License"); you may not use this file except in compliance
88
* with the License. You may obtain a copy of the License at
9-
*
9+
*
1010
* http://www.apache.org/licenses/LICENSE-2.0
11-
*
11+
*
1212
* Unless required by applicable law or agreed to in writing,
1313
* software distributed under the License is distributed on an
1414
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -27,17 +27,16 @@
2727

2828
namespace mxnet {
2929

30-
MXNET_REGISTER_API("_npi.det")
31-
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
30+
MXNET_REGISTER_API("_npi.det").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
3231
using namespace runtime;
3332
const nnvm::Op* op = Op::Get("_npi_det");
3433
nnvm::NodeAttrs attrs;
35-
attrs.op = op;
36-
int num_inputs = 1;
37-
int num_outputs = 0;
34+
attrs.op = op;
35+
int num_inputs = 1;
36+
int num_outputs = 0;
3837
NDArray* inputs[] = {args[0].operator mxnet::NDArray*()};
39-
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr);
40-
*ret = NDArrayHandle(ndoutputs[0]);
38+
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr);
39+
*ret = NDArrayHandle(ndoutputs[0]);
4140
});
4241

4342
} // namespace mxnet

0 commit comments

Comments
 (0)