@@ -41,8 +41,10 @@ void GradientCompression::SetParams(const std::vector<std::pair<std::string, std
4141 & kwargs) {
4242 GradientCompressionParam params;
4343 params.InitAllowUnknown (kwargs);
44- CHECK_GT (params.threshold , 0 ) << " threshold must be greater than 0" ;
45- if (params.type == " 2bit" ) {
44+ if (params.type == " 1bit" ) {
45+ SetOneBitCompression (params.threshold );
46+ } else if (params.type == " 2bit" ) {
47+ CHECK_GT (params.threshold , 0 ) << " threshold must be greater than 0 for two bit compression" ;
4648 SetTwoBitCompression (params.threshold );
4749 } else {
4850 LOG (FATAL) << " Unknown type for gradient compression " << params.type ;
@@ -57,6 +59,11 @@ std::string GradientCompression::get_type_str() {
5759 return std::to_string (static_cast <int >(type_));
5860}
5961
62+ void GradientCompression::SetOneBitCompression (const float threshold) {
63+ type_ = CompressionType::kOneBit ;
64+ threshold_ = threshold;
65+ }
66+
6067void GradientCompression::SetTwoBitCompression (const float threshold) {
6168 type_ = CompressionType::kTwoBit ;
6269 threshold_ = threshold;
@@ -83,7 +90,9 @@ void GradientCompression::DecodeParams(const std::string &s) {
8390}
8491
8592int GradientCompression::GetCompressionFactor () {
86- if (type_ == CompressionType::kTwoBit ) {
93+ if (type_ == CompressionType::kOneBit ) {
94+ return 32 ;
95+ } else if (type_ == CompressionType::kTwoBit ) {
8796 return 16 ;
8897 } else {
8998 LOG (FATAL) << " Unsupported compression type: " << get_type_str ();
@@ -106,16 +115,34 @@ void GradientCompression::Quantize(const mxnet::NDArray &from, mxnet::NDArray *t
106115 const int a = from.ctx ().dev_mask ();
107116 const int b = to->ctx ().dev_mask ();
108117 const float threshold = threshold_;
109- if (type_ == CompressionType::kTwoBit ) {
110- if (a == mshadow::cpu::kDevMask && b == mshadow::cpu::kDevMask ) {
118+ if (a == mshadow::cpu::kDevMask && b == mshadow::cpu::kDevMask ) {
119+ if (type_ == CompressionType::kOneBit ) {
120+ mxnet::Engine::Get ()->PushSync ([from, to, residual, threshold](mxnet::RunContext ctx) {
121+ std::vector<mxnet::TBlob> inputs = {from.data (), residual->data (), to->data ()};
122+ Quantize1BitImpl (ctx.get_stream <mshadow::cpu>(), inputs, threshold);
123+ }, from.ctx (), {from.var ()}, {to->var (), residual->var ()},
124+ mxnet::FnProperty::kNormal , priority, " QuantizeCPU" );
125+ } else if (type_ == CompressionType::kTwoBit ) {
111126 mxnet::Engine::Get ()->PushSync ([from, to, residual, threshold](mxnet::RunContext ctx) {
112127 std::vector<mxnet::TBlob> inputs = {from.data (), residual->data (), to->data ()};
113128 Quantize2BitImpl (ctx.get_stream <mshadow::cpu>(), inputs, threshold);
114129 }, from.ctx (), {from.var ()}, {to->var (), residual->var ()},
115130 mxnet::FnProperty::kNormal , priority, " QuantizeCPU" );
116131 } else {
132+ LOG (FATAL) << " Unsupported quantization of type " << get_type_str ();
133+ }
134+ } else {
135+ if (a == mshadow::gpu::kDevMask && b == mshadow::gpu::kDevMask ) {
117136#if MXNET_USE_CUDA
118- if (a == mshadow::gpu::kDevMask && b == mshadow::gpu::kDevMask ) {
137+ if (type_ == CompressionType::kOneBit ) {
138+ mxnet::Engine::Get ()->PushSync ([from, to, residual, threshold](mxnet::RunContext ctx) {
139+ std::vector<mxnet::TBlob> inputs = {from.data (), residual->data (), to->data ()};
140+ Quantize1BitImpl (ctx.get_stream <mshadow::gpu>(), inputs, threshold);
141+ // Wait GPU kernel to complete
142+ ctx.get_stream <mshadow::gpu>()->Wait ();
143+ }, from.ctx (), {from.var ()}, {to->var (), residual->var ()},
144+ mxnet::FnProperty::kNormal , priority, " QuantizeGPU" );
145+ } else if (type_ == CompressionType::kTwoBit ) {
119146 mxnet::Engine::Get ()->PushSync ([from, to, residual, threshold](mxnet::RunContext ctx) {
120147 std::vector<mxnet::TBlob> inputs = {from.data (), residual->data (), to->data ()};
121148 Quantize2BitImpl (ctx.get_stream <mshadow::gpu>(), inputs, threshold);
@@ -124,14 +151,14 @@ void GradientCompression::Quantize(const mxnet::NDArray &from, mxnet::NDArray *t
124151 }, from.ctx (), {from.var ()}, {to->var (), residual->var ()},
125152 mxnet::FnProperty::kNormal , priority, " QuantizeGPU" );
126153 } else {
127- LOG (FATAL) << " unknown device mask " ;
154+ LOG (FATAL) << " Unsupported quantization of type " << get_type_str () ;
128155 }
129156#else
130157 LOG (FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
131158#endif
159+ } else {
160+ LOG (FATAL) << " Unknown device mask, from device mask " << a << " to device mask " << b;
132161 }
133- } else {
134- LOG (FATAL) << " Unsupported quantization of type " << get_type_str ();
135162 }
136163}
137164
@@ -142,35 +169,52 @@ void GradientCompression::Dequantize(const mxnet::NDArray &from, mxnet::NDArray
142169 const int a = from.ctx ().dev_mask ();
143170 const int b = to->ctx ().dev_mask ();
144171 const float threshold = threshold_;
145- if (type_ == CompressionType::kTwoBit ) {
146- if (a == mshadow::cpu::kDevMask && b == mshadow::cpu::kDevMask ) {
172+ if (a == mshadow::cpu::kDevMask && b == mshadow::cpu::kDevMask ) {
173+ if (type_ == CompressionType::kOneBit ) {
174+ mxnet::Engine::Get ()->PushSync ([from, to, threshold](mxnet::RunContext ctx) {
175+ std::vector<mxnet::TBlob> inputs = {from.data (), to->data ()};
176+ Dequantize1BitImpl (ctx.get_stream <mshadow::cpu>(), inputs, threshold);
177+ }, from.ctx (), {from.var ()}, {to->var ()},
178+ mxnet::FnProperty::kNormal , priority, " DequantizeCPU" );
179+ } else if (type_ == CompressionType::kTwoBit ) {
147180 mxnet::Engine::Get ()->PushSync ([from, to, threshold](mxnet::RunContext ctx) {
148181 std::vector<mxnet::TBlob> inputs = {from.data (), to->data ()};
149182 Dequantize2BitImpl (ctx.get_stream <mshadow::cpu>(), inputs, threshold);
150183 }, from.ctx (), {from.var ()}, {to->var ()},
151184 mxnet::FnProperty::kNormal , priority, " DequantizeCPU" );
152185 } else {
186+ LOG (FATAL) << " Unsupported dequantization of type " << get_type_str ();
187+ }
188+ } else {
189+ if (a == mshadow::gpu::kDevMask && b == mshadow::gpu::kDevMask ) {
153190#if MXNET_USE_CUDA
154- if (a == mshadow::gpu:: kDevMask && b == mshadow::gpu:: kDevMask ) {
191+ if (type_ == CompressionType:: kOneBit ) {
155192 mxnet::Engine::Get ()->PushSync ([from, to, threshold](mxnet::RunContext ctx) {
156193 std::vector<mxnet::TBlob> inputs = {from.data (), to->data ()};
157- Dequantize2BitImpl (ctx.get_stream <mshadow::gpu>(), inputs, threshold);
194+ Dequantize1BitImpl (ctx.get_stream <mshadow::gpu>(), inputs, threshold);
158195 // Wait GPU kernel to complete
159196 ctx.get_stream <mshadow::gpu>()->Wait ();
160197 }, from.ctx (), {from.var ()}, {to->var ()},
161198 mxnet::FnProperty::kNormal , priority, " DequantizeGPU" );
199+ } else if (type_ == CompressionType::kTwoBit ) {
200+ mxnet::Engine::Get ()->PushSync ([from, to, threshold](mxnet::RunContext ctx) {
201+ std::vector<mxnet::TBlob> inputs = {from.data (), to->data ()};
202+ Dequantize2BitImpl (ctx.get_stream <mshadow::gpu>(), inputs, threshold);
203+ // Wait GPU kernel to completes
204+ ctx.get_stream <mshadow::gpu>()->Wait ();
205+ }, from.ctx (), {from.var ()}, {to->var ()},
206+ mxnet::FnProperty::kNormal , priority, " DequantizeGPU" );
162207 } else {
163- LOG (FATAL) << " unknown device mask " ;
208+ LOG (FATAL) << " Unsupported dequantization of type " << get_type_str () ;
164209 }
165210#else
166- LOG (FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
211+ LOG (FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
167212#endif
213+ } else {
214+ LOG (FATAL) << " Unknown device mask, from device mask " << a << " to device mask " << b;
168215 }
169- } else {
170- LOG (FATAL) << " Unsupported dequantization of type " << get_type_str ();
171216 }
172217}
173-
174218} // namespace kvstore
175219} // namespace mxnet
176220
0 commit comments