Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

DataParallelExecutorGroup: layout handling for symbols #6736

@leezu

Description

@leezu

When merge_multi_context=True in a call to get_outputs of a DataParallelExecutorGroup, _merge_multi_context will try to concatenate the outputs from the different devices along the major axis.

The major axis is computed based on [DataDesc.get_batch_axis(self.symbol[name].attr('__layout__')) for name in self.output_names] in the initializer of DataParallelExecutorGroup.

What is the recommended way to set the attr('__layout__') of a symbol? Simply pass attr={'__layout__': layout} when constructing the symbol? Can the attribute be set automatically during module binding?

Setting attr('__layout__') is necessary, as it is otherwise None, leading to _merge_multi_context trying to concatenate along dim=0 which will fail if the batch size is not divisible by the number of devices and the symbol outputs a shape (1, batch_size_per_device, X).

I.e. in case of 3 devices and batch size 128, concatenating (1, 43), (1, 43), (1, 42) along dim=0 will fail.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions