diff --git a/tests/test_booster/test_plugin/test_torch_fsdp_plugin.py b/tests/test_booster/test_plugin/test_torch_fsdp_plugin.py index 3f65e48ac2c9..12562095c153 100644 --- a/tests/test_booster/test_plugin/test_torch_fsdp_plugin.py +++ b/tests/test_booster/test_plugin/test_torch_fsdp_plugin.py @@ -46,7 +46,10 @@ def run_fn(model_fn, data_gen_fn, output_transform_fn): def check_torch_fsdp_plugin(): for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items(): - if 'diffusers' in name: + if any(element in name for element in [ + 'diffusers', 'deepfm_sparsearch', 'dlrm_interactionarch', 'torchvision_googlenet', + 'torchvision_inception_v3' + ]): continue run_fn(model_fn, data_gen_fn, output_transform_fn) torch.cuda.empty_cache() @@ -58,12 +61,6 @@ def run_dist(rank, world_size, port): check_torch_fsdp_plugin() -# FIXME: this test is not working - - -@pytest.mark.skip( - "ValueError: expected to be in states [, ] but current state is TrainingState_.IDLE" -) @pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason="requires torch1.12 or higher") @rerun_if_address_is_in_use() def test_torch_fsdp_plugin():