|
30 | 30 | namespace mxnet { |
31 | 31 |
|
32 | 32 | MXNET_REGISTER_GLOBAL("cached_op.invoke") |
33 | | -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { |
34 | | - CachedOpPtr op_shared = *static_cast<CachedOpPtr*>(args[0].value().v_handle); |
35 | | - // CachedOp* points to CachedOpThreadSafe object if CreateCachedOpEX |
36 | | - // was called with thread_safe=true |
37 | | - CachedOp* op = dynamic_cast<CachedOp*>(op_shared.get()); |
| 33 | + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { |
| 34 | + CachedOpPtr op_shared = *static_cast<CachedOpPtr*>(args[0].value().v_handle); |
| 35 | + // CachedOp* points to CachedOpThreadSafe object if CreateCachedOpEX |
| 36 | + // was called with thread_safe=true |
| 37 | + CachedOp* op = dynamic_cast<CachedOp*>(op_shared.get()); |
38 | 38 |
|
39 | | - int num_inputs = args[1]; |
40 | | - int args_size = args.size(); |
41 | | - std::vector<NDArray*> ndinputs; |
42 | | - ndinputs.reserve(num_inputs); |
43 | | - for (int i = 2; i < num_inputs + 2; ++i) { |
44 | | - ndinputs.push_back(static_cast<mxnet::NDArray*>(args[i])); |
45 | | - } |
| 39 | + int num_inputs = args[1]; |
| 40 | + int args_size = args.size(); |
| 41 | + std::vector<NDArray*> ndinputs; |
| 42 | + ndinputs.reserve(num_inputs); |
| 43 | + for (int i = 2; i < num_inputs + 2; ++i) { |
| 44 | + ndinputs.push_back(static_cast<mxnet::NDArray*>(args[i])); |
| 45 | + } |
46 | 46 |
|
47 | | - std::vector<NDArray*> ndoutputs; |
48 | | - ndoutputs.reserve(op->num_outputs()); |
49 | | - if (args[num_inputs + 4].type_code() == kNull) { |
50 | | - for (int i = 0; i < op->num_outputs(); ++i) ndoutputs.push_back(new NDArray()); |
51 | | - } else { |
52 | | - int array_size = args_size - num_inputs - 4; |
53 | | - CHECK_EQ(array_size, op->num_outputs()) |
54 | | - << "CachedOp expects " << op->num_outputs() << " outputs, but " |
55 | | - << array_size << " was given."; |
56 | | - for (int i = num_inputs + 4; i < array_size; ++i) { |
57 | | - ndoutputs.push_back(args[i].operator mxnet::NDArray*()); |
58 | | - } |
59 | | - } |
| 47 | + std::vector<NDArray*> ndoutputs; |
| 48 | + ndoutputs.reserve(op->num_outputs()); |
| 49 | + if (args[num_inputs + 4].type_code() == kNull) { |
| 50 | + for (int i = 0; i < op->num_outputs(); ++i) |
| 51 | + ndoutputs.push_back(new NDArray()); |
| 52 | + } else { |
| 53 | + int array_size = args_size - num_inputs - 4; |
| 54 | + CHECK_EQ(array_size, op->num_outputs()) << "CachedOp expects " << op->num_outputs() |
| 55 | + << " outputs, but " << array_size << " was given."; |
| 56 | + for (int i = num_inputs + 4; i < array_size; ++i) { |
| 57 | + ndoutputs.push_back(args[i].operator mxnet::NDArray*()); |
| 58 | + } |
| 59 | + } |
60 | 60 |
|
61 | | - int default_dev_type; |
62 | | - int default_dev_id; |
63 | | - if (args[num_inputs + 2].type_code() != kNull) { |
64 | | - default_dev_type = args[num_inputs + 2]; |
65 | | - default_dev_id = args[num_inputs + 3]; |
66 | | - } else { |
67 | | - const Context &ctx = ndinputs[0]->ctx(); |
68 | | - default_dev_type = ctx.dev_type; |
69 | | - default_dev_id = ctx.dev_id; |
70 | | - } |
| 61 | + int default_dev_type; |
| 62 | + int default_dev_id; |
| 63 | + if (args[num_inputs + 2].type_code() != kNull) { |
| 64 | + default_dev_type = args[num_inputs + 2]; |
| 65 | + default_dev_id = args[num_inputs + 3]; |
| 66 | + } else { |
| 67 | + const Context& ctx = ndinputs[0]->ctx(); |
| 68 | + default_dev_type = ctx.dev_type; |
| 69 | + default_dev_id = ctx.dev_id; |
| 70 | + } |
71 | 71 |
|
72 | | - // construct default context |
73 | | - Context ctx = Context::Create(static_cast<Context::DeviceType>(default_dev_type), |
74 | | - default_dev_id); |
75 | | - op->Forward(op_shared, ndinputs, ndoutputs, ctx); |
| 72 | + // construct default context |
| 73 | + Context ctx = |
| 74 | + Context::Create(static_cast<Context::DeviceType>(default_dev_type), default_dev_id); |
| 75 | + op->Forward(op_shared, ndinputs, ndoutputs, ctx); |
76 | 76 |
|
77 | | - if (op->num_outputs() == 1) { |
78 | | - *ret = ndoutputs[0]; |
79 | | - } else { |
80 | | - std::vector<ObjectRef> outputs; |
81 | | - outputs.reserve(op->num_outputs()); |
82 | | - for (int i = 0; i < op->num_outputs(); ++i) { |
83 | | - ObjectRef out = NDArrayHandle(ndoutputs[i]); |
84 | | - outputs.push_back(out); |
85 | | - delete ndoutputs[i]; |
86 | | - } |
87 | | - *ret = runtime::ADT(0, outputs.begin(), outputs.end()); |
88 | | - } |
89 | | -}); |
| 77 | + if (op->num_outputs() == 1) { |
| 78 | + *ret = ndoutputs[0]; |
| 79 | + } else { |
| 80 | + std::vector<ObjectRef> outputs; |
| 81 | + outputs.reserve(op->num_outputs()); |
| 82 | + for (int i = 0; i < op->num_outputs(); ++i) { |
| 83 | + ObjectRef out = NDArrayHandle(ndoutputs[i]); |
| 84 | + outputs.push_back(out); |
| 85 | + delete ndoutputs[i]; |
| 86 | + } |
| 87 | + *ret = runtime::ADT(0, outputs.begin(), outputs.end()); |
| 88 | + } |
| 89 | + }); |
90 | 90 |
|
91 | 91 | MXNET_REGISTER_GLOBAL("cached_op.create") |
92 | | -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { |
93 | | - nnvm::Symbol* sym = static_cast<nnvm::Symbol*>(args[0].value().v_handle); |
94 | | - Object* flags_ptr = static_cast<Object*>(args[1].value().v_handle); |
95 | | - auto* n = static_cast<const runtime::MapObj*>(flags_ptr); |
96 | | - int num_flags = static_cast<int>(n->size()); |
97 | | - bool thread_safe = args[2]; |
98 | | - std::vector<std::pair<std::string, std::string> > flags; |
99 | | - flags.reserve(num_flags); |
100 | | - for (const auto& kv : *n) { |
101 | | - flags.emplace_back(std::string(runtime::Downcast<runtime::String>(kv.first)), |
102 | | - std::string(runtime::Downcast<runtime::String>(kv.second))); |
103 | | - } |
104 | | - mxnet::CachedOpPtr *out; |
105 | | - if (!thread_safe) { |
106 | | - out = new CachedOpPtr(new CachedOp(*sym, flags)); |
107 | | - } else { |
108 | | - out = new CachedOpPtr(new CachedOpThreadSafe(*sym, flags)); |
109 | | - } |
110 | | - *ret = static_cast<void*>(out); |
111 | | -}); |
| 92 | + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { |
| 93 | + nnvm::Symbol* sym = static_cast<nnvm::Symbol*>(args[0].value().v_handle); |
| 94 | + Object* flags_ptr = static_cast<Object*>(args[1].value().v_handle); |
| 95 | + auto* n = static_cast<const runtime::MapObj*>(flags_ptr); |
| 96 | + int num_flags = static_cast<int>(n->size()); |
| 97 | + bool thread_safe = args[2]; |
| 98 | + std::vector<std::pair<std::string, std::string> > flags; |
| 99 | + flags.reserve(num_flags); |
| 100 | + for (const auto& kv : *n) { |
| 101 | + flags.emplace_back(std::string(runtime::Downcast<runtime::String>(kv.first)), |
| 102 | + std::string(runtime::Downcast<runtime::String>(kv.second))); |
| 103 | + } |
| 104 | + mxnet::CachedOpPtr* out; |
| 105 | + if (!thread_safe) { |
| 106 | + out = new CachedOpPtr(new CachedOp(*sym, flags)); |
| 107 | + } else { |
| 108 | + out = new CachedOpPtr(new CachedOpThreadSafe(*sym, flags)); |
| 109 | + } |
| 110 | + *ret = static_cast<void*>(out); |
| 111 | + }); |
112 | 112 |
|
113 | 113 | MXNET_REGISTER_GLOBAL("cached_op.free") |
114 | | -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { |
115 | | - CachedOpPtr* g = static_cast<CachedOpPtr*>(args[0].value().v_handle); |
116 | | - delete g; |
117 | | -}); |
| 114 | + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { |
| 115 | + CachedOpPtr* g = static_cast<CachedOpPtr*>(args[0].value().v_handle); |
| 116 | + delete g; |
| 117 | + }); |
118 | 118 |
|
119 | 119 | MXNET_REGISTER_GLOBAL("cached_op.get_optimized_symbol") |
120 | | -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { |
121 | | - auto s = new nnvm::Symbol(); |
122 | | - CachedOpPtr op = *static_cast<CachedOpPtr*>(args[0].value().v_handle); |
123 | | - *s = op->GetOptimizedSymbol(); |
124 | | - *ret = static_cast<void*>(static_cast<SymbolHandle>(s)); |
125 | | -}); |
| 120 | + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { |
| 121 | + auto s = new nnvm::Symbol(); |
| 122 | + CachedOpPtr op = *static_cast<CachedOpPtr*>(args[0].value().v_handle); |
| 123 | + *s = op->GetOptimizedSymbol(); |
| 124 | + *ret = static_cast<void*>(static_cast<SymbolHandle>(s)); |
| 125 | + }); |
126 | 126 |
|
127 | 127 | MXNET_REGISTER_GLOBAL("cached_op.register_op_hook") |
128 | | -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { |
129 | | - CachedOpHandle handle = static_cast<CachedOpHandle>(args[0].value().v_handle); |
130 | | - CachedOpMonitorCallback callback = reinterpret_cast<CachedOpMonitorCallback>( |
131 | | - reinterpret_cast<void (*)(const char *, const char *, void *)>(args[1].value().v_handle)); |
132 | | - bool monitor_all = args[2]; |
133 | | - CachedOpMonitorCallback callback_temp = nullptr; |
134 | | - std::function<void(const char *, const char *, void*)> clbk; |
135 | | - if (callback) { |
136 | | - callback_temp = callback; |
137 | | - clbk = [callback_temp](const char *name, const char *opr_name, |
138 | | - void *handle) { |
139 | | - callback_temp(name, opr_name, handle); |
140 | | - }; |
141 | | - } else { |
142 | | - clbk = nullptr; |
143 | | - } |
144 | | - CachedOpPtr op = *static_cast<CachedOpPtr *>(handle); |
145 | | - op->RegisterOpHook(clbk, monitor_all); |
146 | | -}); |
| 128 | + .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { |
| 129 | + CachedOpHandle handle = static_cast<CachedOpHandle>(args[0].value().v_handle); |
| 130 | + CachedOpMonitorCallback callback = reinterpret_cast<CachedOpMonitorCallback>( |
| 131 | + reinterpret_cast<void (*)(const char*, const char*, void*)>(args[1].value().v_handle)); |
| 132 | + bool monitor_all = args[2]; |
| 133 | + CachedOpMonitorCallback callback_temp = nullptr; |
| 134 | + std::function<void(const char*, const char*, void*)> clbk; |
| 135 | + if (callback) { |
| 136 | + callback_temp = callback; |
| 137 | + clbk = [callback_temp](const char* name, const char* opr_name, void* handle) { |
| 138 | + callback_temp(name, opr_name, handle); |
| 139 | + }; |
| 140 | + } else { |
| 141 | + clbk = nullptr; |
| 142 | + } |
| 143 | + CachedOpPtr op = *static_cast<CachedOpPtr*>(handle); |
| 144 | + op->RegisterOpHook(clbk, monitor_all); |
| 145 | + }); |
147 | 146 |
|
148 | 147 | } // namespace mxnet |
0 commit comments