Skip to content

[fx] Add unit test and fix bugs for transform_mlp_pass#1299

Merged
FrankLeeeee merged 3 commits intohpcaitech:mainfrom
Itok2000u:feature/transform_mlp_pass
Jul 15, 2022
Merged

[fx] Add unit test and fix bugs for transform_mlp_pass#1299
FrankLeeeee merged 3 commits intohpcaitech:mainfrom
Itok2000u:feature/transform_mlp_pass

Conversation

@Itok2000u
Copy link
Copy Markdown
Contributor

Still needs to handle special cases.

Comment thread colossalai/fx/passes/shard_1d_pass.py Outdated
from torch.fx.passes.split_module import split_module

import colossalai
# from colossalai.tensor import ColoTensor, TensorSpec, distspec, ProcessGroup, ComputeSpec, ComputePattern
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this line.

Comment thread colossalai/fx/passes/shard_1d_pass.py Outdated
Comment on lines +24 to +37
#TODO: This func temporarily works with no materialization
# Append a Tensor spec to target_module.weight.shard
# Convert to ColoTensor: colo_tensor = ColoTensor.from_torch_tensor(tensor, spec)
assert isinstance(weight, torch.Tensor), \
f'The type of the input tensor should be torch.nn.parameter' \
f'Your Input tensor is {type(weight)}'
# assert isinstance(weight, torch.nn.parameter.Parameter), \
# f'The type of the input tensor should be torch.nn.parameter' \
# f'Your Input tensor is {type(weight)}'

# FIXME() I initialized a PG for this tensor. Only has TP comm group.
# we only consider the TP-only caes.
world_size = torch.distributed.get_world_size()
pg = ProcessGroup(tp_degree=world_size)
# world_size = torch.distributed.get_world_size()
# pg = ProcessGroup(tp_degree=world_size)

spec = ColoTensorSpec(pg, ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
# As you has constructed a Spec, why not directly convert the tensor to ColoTensor.
setattr(weight, "fx_attr", spec)
# spec = TensorSpec(distspec.shard(pg, [dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
# # As you has constructed a Spec, why not directly convert the tensor to ColoTensor.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove unnecessary comments and code.

Comment thread colossalai/fx/passes/shard_1d_pass.py Outdated
setattr(weight, "fx_attr", (dim, "SHARD", "TP", "col_normal"))
else:
setattr(weight, "fx_attr", (dim, "SHARD", "TP", "col_needs_many_outputs"))
# weight.data = ColoTensor(data=weight.data, spec=spec)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this line.

@FrankLeeeee FrankLeeeee merged commit ca2d3f2 into hpcaitech:main Jul 15, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants