When merge_multi_context=True in a call to get_outputs of a DataParallelExecutorGroup, _merge_multi_context will try to concatenate the outputs from the different devices along the major axis.
The major axis is computed based on [DataDesc.get_batch_axis(self.symbol[name].attr('__layout__')) for name in self.output_names] in the initializer of DataParallelExecutorGroup.
What is the recommended way to set the attr('__layout__') of a symbol? Simply pass attr={'__layout__': layout} when constructing the symbol? Can the attribute be set automatically during module binding?
Setting attr('__layout__') is necessary, as it is otherwise None, leading to _merge_multi_context trying to concatenate along dim=0 which will fail if the batch size is not divisible by the number of devices and the symbol outputs a shape (1, batch_size_per_device, X).
I.e. in case of 3 devices and batch size 128, concatenating (1, 43), (1, 43), (1, 42) along dim=0 will fail.
When
merge_multi_context=Truein a call toget_outputsof aDataParallelExecutorGroup,_merge_multi_contextwill try to concatenate the outputs from the different devices along the major axis.The major axis is computed based on
[DataDesc.get_batch_axis(self.symbol[name].attr('__layout__')) for name in self.output_names]in the initializer ofDataParallelExecutorGroup.What is the recommended way to set the
attr('__layout__')of a symbol? Simply passattr={'__layout__': layout}when constructing the symbol? Can the attribute be set automatically during module binding?Setting
attr('__layout__')is necessary, as it is otherwiseNone, leading to_merge_multi_contexttrying to concatenate alongdim=0which will fail if the batch size is not divisible by the number of devices and the symbol outputs a shape (1, batch_size_per_device, X).I.e. in case of 3 devices and batch size 128, concatenating
(1, 43), (1, 43), (1, 42)alongdim=0will fail.