Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
110 changes: 51 additions & 59 deletions src/api/_api_internal/_api_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,80 +35,72 @@
namespace mxnet {

MXNET_REGISTER_GLOBAL("_Integer")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
if (args[0].type_code() == kDLInt) {
*ret = Integer(args[0].operator int64_t());
} else {
LOG(FATAL) << "only accept int";
}
});
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
if (args[0].type_code() == kDLInt) {
*ret = Integer(args[0].operator int64_t());
} else {
LOG(FATAL) << "only accept int";
}
});

MXNET_REGISTER_GLOBAL("_Float")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
if (args[0].type_code() == kDLFloat) {
*ret = Float(args[0].operator double());
} else {
LOG(FATAL) << "only accept float";
}
MXNET_REGISTER_GLOBAL("_Float").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
if (args[0].type_code() == kDLFloat) {
*ret = Float(args[0].operator double());
} else {
LOG(FATAL) << "only accept float";
}
});

MXNET_REGISTER_GLOBAL("_ADT")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
std::vector<ObjectRef> data;
for (int i = 0; i < args.size(); ++i) {
if (args[i].type_code() == kNDArrayHandle) {
mxnet::NDArray *array = args[i].operator mxnet::NDArray*();
ObjectRef input = NDArrayHandle(array);
data.push_back(input);
} else if (args[i].type_code() != kNull) {
ObjectRef input = String::CanConvertFrom(args[i]) ? args[i].operator String()
: args[i].operator ObjectRef();
data.push_back(input);
} else {
data.emplace_back(nullptr);
}
MXNET_REGISTER_GLOBAL("_ADT").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
std::vector<ObjectRef> data;
for (int i = 0; i < args.size(); ++i) {
if (args[i].type_code() == kNDArrayHandle) {
mxnet::NDArray* array = args[i].operator mxnet::NDArray*();
ObjectRef input = NDArrayHandle(array);
data.push_back(input);
} else if (args[i].type_code() != kNull) {
ObjectRef input = String::CanConvertFrom(args[i]) ? args[i].operator String()
: args[i].operator ObjectRef();
data.push_back(input);
} else {
data.emplace_back(nullptr);
}
*ret = ADT(0, data.begin(), data.end());
}
*ret = ADT(0, data.begin(), data.end());
});

MXNET_REGISTER_GLOBAL("_Map")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
CHECK_EQ(args.size() % 2, 0);
std::unordered_map<ObjectRef, ObjectRef, ObjectHash, ObjectEqual> data;
for (int i = 0; i < args.num_args; i += 2) {
ObjectRef k =
String::CanConvertFrom(args[i]) ? args[i].operator String()
: args[i].operator ObjectRef();
ObjectRef v;
if (args[i + 1].type_code() == kNDArrayHandle) {
mxnet::NDArray *array = args[i + 1].operator mxnet::NDArray*();
v = NDArrayHandle(array);
} else {
v = args[i + 1];
}
data.emplace(std::move(k), std::move(v));
MXNET_REGISTER_GLOBAL("_Map").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
CHECK_EQ(args.size() % 2, 0);
std::unordered_map<ObjectRef, ObjectRef, ObjectHash, ObjectEqual> data;
for (int i = 0; i < args.num_args; i += 2) {
ObjectRef k =
String::CanConvertFrom(args[i]) ? args[i].operator String() : args[i].operator ObjectRef();
ObjectRef v;
if (args[i + 1].type_code() == kNDArrayHandle) {
mxnet::NDArray* array = args[i + 1].operator mxnet::NDArray*();
v = NDArrayHandle(array);
} else {
v = args[i + 1];
}
*ret = Map<ObjectRef, ObjectRef>(data);
data.emplace(std::move(k), std::move(v));
}
*ret = Map<ObjectRef, ObjectRef>(data);
});

MXNET_REGISTER_GLOBAL("_String")
.set_body([] (runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
MXNET_REGISTER_GLOBAL("_String").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
std::string str = args[0].operator std::string();
*ret = String(std::move(str));
*ret = String(std::move(str));
});

MXNET_REGISTER_GLOBAL("_echo")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
MXNET_REGISTER_GLOBAL("_echo").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
*ret = args[0];
});

MXNET_REGISTER_API("_nop")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
});
MXNET_REGISTER_API("_nop").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {});

} // namespace mxnet
201 changes: 100 additions & 101 deletions src/api/cached_op_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,119 +30,118 @@
namespace mxnet {

MXNET_REGISTER_GLOBAL("cached_op.invoke")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
CachedOpPtr op_shared = *static_cast<CachedOpPtr*>(args[0].value().v_handle);
// CachedOp* points to CachedOpThreadSafe object if CreateCachedOpEX
// was called with thread_safe=true
CachedOp* op = dynamic_cast<CachedOp*>(op_shared.get());
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
CachedOpPtr op_shared = *static_cast<CachedOpPtr*>(args[0].value().v_handle);
// CachedOp* points to CachedOpThreadSafe object if CreateCachedOpEX
// was called with thread_safe=true
CachedOp* op = dynamic_cast<CachedOp*>(op_shared.get());

int num_inputs = args[1];
int args_size = args.size();
std::vector<NDArray*> ndinputs;
ndinputs.reserve(num_inputs);
for (int i = 2; i < num_inputs + 2; ++i) {
ndinputs.push_back(static_cast<mxnet::NDArray*>(args[i]));
}
int num_inputs = args[1];
int args_size = args.size();
std::vector<NDArray*> ndinputs;
ndinputs.reserve(num_inputs);
for (int i = 2; i < num_inputs + 2; ++i) {
ndinputs.push_back(static_cast<mxnet::NDArray*>(args[i]));
}

std::vector<NDArray*> ndoutputs;
ndoutputs.reserve(op->num_outputs());
if (args[num_inputs + 4].type_code() == kNull) {
for (int i = 0; i < op->num_outputs(); ++i) ndoutputs.push_back(new NDArray());
} else {
int array_size = args_size - num_inputs - 4;
CHECK_EQ(array_size, op->num_outputs())
<< "CachedOp expects " << op->num_outputs() << " outputs, but "
<< array_size << " was given.";
for (int i = num_inputs + 4; i < array_size; ++i) {
ndoutputs.push_back(args[i].operator mxnet::NDArray*());
}
}
std::vector<NDArray*> ndoutputs;
ndoutputs.reserve(op->num_outputs());
if (args[num_inputs + 4].type_code() == kNull) {
for (int i = 0; i < op->num_outputs(); ++i)
ndoutputs.push_back(new NDArray());
} else {
int array_size = args_size - num_inputs - 4;
CHECK_EQ(array_size, op->num_outputs()) << "CachedOp expects " << op->num_outputs()
<< " outputs, but " << array_size << " was given.";
for (int i = num_inputs + 4; i < array_size; ++i) {
ndoutputs.push_back(args[i].operator mxnet::NDArray*());
}
}

int default_dev_type;
int default_dev_id;
if (args[num_inputs + 2].type_code() != kNull) {
default_dev_type = args[num_inputs + 2];
default_dev_id = args[num_inputs + 3];
} else {
const Context &ctx = ndinputs[0]->ctx();
default_dev_type = ctx.dev_type;
default_dev_id = ctx.dev_id;
}
int default_dev_type;
int default_dev_id;
if (args[num_inputs + 2].type_code() != kNull) {
default_dev_type = args[num_inputs + 2];
default_dev_id = args[num_inputs + 3];
} else {
const Context& ctx = ndinputs[0]->ctx();
default_dev_type = ctx.dev_type;
default_dev_id = ctx.dev_id;
}

// construct default context
Context ctx = Context::Create(static_cast<Context::DeviceType>(default_dev_type),
default_dev_id);
op->Forward(op_shared, ndinputs, ndoutputs, ctx);
// construct default context
Context ctx =
Context::Create(static_cast<Context::DeviceType>(default_dev_type), default_dev_id);
op->Forward(op_shared, ndinputs, ndoutputs, ctx);

if (op->num_outputs() == 1) {
*ret = ndoutputs[0];
} else {
std::vector<ObjectRef> outputs;
outputs.reserve(op->num_outputs());
for (int i = 0; i < op->num_outputs(); ++i) {
ObjectRef out = NDArrayHandle(ndoutputs[i]);
outputs.push_back(out);
delete ndoutputs[i];
}
*ret = runtime::ADT(0, outputs.begin(), outputs.end());
}
});
if (op->num_outputs() == 1) {
*ret = ndoutputs[0];
} else {
std::vector<ObjectRef> outputs;
outputs.reserve(op->num_outputs());
for (int i = 0; i < op->num_outputs(); ++i) {
ObjectRef out = NDArrayHandle(ndoutputs[i]);
outputs.push_back(out);
delete ndoutputs[i];
}
*ret = runtime::ADT(0, outputs.begin(), outputs.end());
}
});

MXNET_REGISTER_GLOBAL("cached_op.create")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
nnvm::Symbol* sym = static_cast<nnvm::Symbol*>(args[0].value().v_handle);
Object* flags_ptr = static_cast<Object*>(args[1].value().v_handle);
auto* n = static_cast<const runtime::MapObj*>(flags_ptr);
int num_flags = static_cast<int>(n->size());
bool thread_safe = args[2];
std::vector<std::pair<std::string, std::string> > flags;
flags.reserve(num_flags);
for (const auto& kv : *n) {
flags.emplace_back(std::string(runtime::Downcast<runtime::String>(kv.first)),
std::string(runtime::Downcast<runtime::String>(kv.second)));
}
mxnet::CachedOpPtr *out;
if (!thread_safe) {
out = new CachedOpPtr(new CachedOp(*sym, flags));
} else {
out = new CachedOpPtr(new CachedOpThreadSafe(*sym, flags));
}
*ret = static_cast<void*>(out);
});
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
nnvm::Symbol* sym = static_cast<nnvm::Symbol*>(args[0].value().v_handle);
Object* flags_ptr = static_cast<Object*>(args[1].value().v_handle);
auto* n = static_cast<const runtime::MapObj*>(flags_ptr);
int num_flags = static_cast<int>(n->size());
bool thread_safe = args[2];
std::vector<std::pair<std::string, std::string> > flags;
flags.reserve(num_flags);
for (const auto& kv : *n) {
flags.emplace_back(std::string(runtime::Downcast<runtime::String>(kv.first)),
std::string(runtime::Downcast<runtime::String>(kv.second)));
}
mxnet::CachedOpPtr* out;
if (!thread_safe) {
out = new CachedOpPtr(new CachedOp(*sym, flags));
} else {
out = new CachedOpPtr(new CachedOpThreadSafe(*sym, flags));
}
*ret = static_cast<void*>(out);
});

MXNET_REGISTER_GLOBAL("cached_op.free")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
CachedOpPtr* g = static_cast<CachedOpPtr*>(args[0].value().v_handle);
delete g;
});
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
CachedOpPtr* g = static_cast<CachedOpPtr*>(args[0].value().v_handle);
delete g;
});

MXNET_REGISTER_GLOBAL("cached_op.get_optimized_symbol")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
auto s = new nnvm::Symbol();
CachedOpPtr op = *static_cast<CachedOpPtr*>(args[0].value().v_handle);
*s = op->GetOptimizedSymbol();
*ret = static_cast<void*>(static_cast<SymbolHandle>(s));
});
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
auto s = new nnvm::Symbol();
CachedOpPtr op = *static_cast<CachedOpPtr*>(args[0].value().v_handle);
*s = op->GetOptimizedSymbol();
*ret = static_cast<void*>(static_cast<SymbolHandle>(s));
});

MXNET_REGISTER_GLOBAL("cached_op.register_op_hook")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
CachedOpHandle handle = static_cast<CachedOpHandle>(args[0].value().v_handle);
CachedOpMonitorCallback callback = reinterpret_cast<CachedOpMonitorCallback>(
reinterpret_cast<void (*)(const char *, const char *, void *)>(args[1].value().v_handle));
bool monitor_all = args[2];
CachedOpMonitorCallback callback_temp = nullptr;
std::function<void(const char *, const char *, void*)> clbk;
if (callback) {
callback_temp = callback;
clbk = [callback_temp](const char *name, const char *opr_name,
void *handle) {
callback_temp(name, opr_name, handle);
};
} else {
clbk = nullptr;
}
CachedOpPtr op = *static_cast<CachedOpPtr *>(handle);
op->RegisterOpHook(clbk, monitor_all);
});
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
CachedOpHandle handle = static_cast<CachedOpHandle>(args[0].value().v_handle);
CachedOpMonitorCallback callback = reinterpret_cast<CachedOpMonitorCallback>(
reinterpret_cast<void (*)(const char*, const char*, void*)>(args[1].value().v_handle));
bool monitor_all = args[2];
CachedOpMonitorCallback callback_temp = nullptr;
std::function<void(const char*, const char*, void*)> clbk;
if (callback) {
callback_temp = callback;
clbk = [callback_temp](const char* name, const char* opr_name, void* handle) {
callback_temp(name, opr_name, handle);
};
} else {
clbk = nullptr;
}
CachedOpPtr op = *static_cast<CachedOpPtr*>(handle);
op->RegisterOpHook(clbk, monitor_all);
});

} // namespace mxnet
17 changes: 8 additions & 9 deletions src/api/operator/numpy/linalg/np_det.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
Expand All @@ -27,17 +27,16 @@

namespace mxnet {

MXNET_REGISTER_API("_npi.det")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
MXNET_REGISTER_API("_npi.det").set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_det");
nnvm::NodeAttrs attrs;
attrs.op = op;
int num_inputs = 1;
int num_outputs = 0;
attrs.op = op;
int num_inputs = 1;
int num_outputs = 0;
NDArray* inputs[] = {args[0].operator mxnet::NDArray*()};
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr);
*ret = NDArrayHandle(ndoutputs[0]);
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr);
*ret = NDArrayHandle(ndoutputs[0]);
});

} // namespace mxnet
Loading