Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/kit/model_zoo/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from dataclasses import dataclass
from typing import Callable

__all__ = ['ModelZooRegistry', 'ModelAttributem', 'model_zoo']
__all__ = ['ModelZooRegistry', 'ModelAttribute', 'model_zoo']


@dataclass
Expand Down Expand Up @@ -37,7 +37,7 @@ def register(self,
>>> model_zoo = ModelZooRegistry()
>>> model_zoo.register('resnet18', resnet18, resnet18_data_gen)
>>> # Run the model
>>> data = resnresnet18_data_gen() # do not input any argument
>>> data = resnet18_data_gen() # do not input any argument
>>> model = resnet18() # do not input any argument
>>> out = model(**data)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def check_bn_module_handler(rank, world_size, port):
# the index of bn node in computation graph
node_index = 1
# the total number of bn strategies without sync bn mode
# TODO: add sync bn stategies after related passes ready
# TODO: add sync bn strategies after related passes ready
strategy_number = 4
numerical_test_for_node_strategy(model=model,
device_mesh=device_mesh,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,14 @@ def test_output_handler(output_option):
output_strategies_vector = StrategiesVector(output_node)

# build handler
otuput_handler = OutputHandler(node=output_node,
output_handler = OutputHandler(node=output_node,
device_mesh=device_mesh,
strategies_vector=output_strategies_vector,
output_option=output_option)

otuput_handler.register_strategy(compute_resharding_cost=False)
output_handler.register_strategy(compute_resharding_cost=False)
# check operation data mapping
mapping = otuput_handler.get_operation_data_mapping()
mapping = output_handler.get_operation_data_mapping()

for name, op_data in mapping.items():
op_data: OperationData
Expand All @@ -59,7 +59,7 @@ def test_output_handler(output_option):

assert mapping['output'].name == "output"
assert mapping['output'].type == OperationDataType.OUTPUT
strategy_name_list = [val.name for val in otuput_handler.strategies_vector]
strategy_name_list = [val.name for val in output_handler.strategies_vector]
if output_option == 'distributed':
assert "Distributed Output" in strategy_name_list
else:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_tensor/test_dtensor/test_layout_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def check_layout_converting(rank, world_size, port):
assert comm_action_sequence[2].shard_dim == 0
assert comm_action_sequence[2].logical_process_axis == 1

# checkout chached_spec_pairs_transform_path
# checkout cached_spec_pairs_transform_path
assert layout_converter.cached_solution[('[R, S01, R]', '[S01, R, R]')][0] == transform_path
assert layout_converter.cached_solution[('[R, S01, R]', '[S01, R, R]')][1] == comm_action_sequence

Expand Down