From f7b6317ec884785daa9f4feb2a4bc55b4aab3322 Mon Sep 17 00:00:00 2001 From: FrankLeeeee Date: Mon, 20 Mar 2023 15:11:17 +0800 Subject: [PATCH 1/4] [test] fixed torchrec registration in model zoo --- tests/kit/model_zoo/torchrec/torchrec.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/tests/kit/model_zoo/torchrec/torchrec.py b/tests/kit/model_zoo/torchrec/torchrec.py index 03d95a06a89b..169f37bf492d 100644 --- a/tests/kit/model_zoo/torchrec/torchrec.py +++ b/tests/kit/model_zoo/torchrec/torchrec.py @@ -42,7 +42,22 @@ def get_ebc(): # EmbeddingBagCollection eb1_config = EmbeddingBagConfig(name="t1", embedding_dim=SHAPE, num_embeddings=SHAPE, feature_names=["f1"]) eb2_config = EmbeddingBagConfig(name="t2", embedding_dim=SHAPE, num_embeddings=SHAPE, feature_names=["f2"]) - return EmbeddingBagCollection(tables=[eb1_config, eb2_config]) + return EmbeddingBagCollection(tables=[eb1_config, eb2_config], device=torch.device('cpu')) + + +def sparse_arch_model_fn(): + ebc = get_ebc() + return deepfm.SparseArch(ebc) + + +def simple_deep_fmnn_model_fn(): + ebc = get_ebc() + return deepfm.SimpleDeepFMNN(SHAPE, ebc, SHAPE, SHAPE) + + +def dlrm_model_fn(): + ebc = get_ebc() + return dlrm.DLRM(ebc, SHAPE, [SHAPE, SHAPE], [5, 1]) model_zoo.register(name='deepfm_densearch', @@ -61,17 +76,17 @@ def get_ebc(): output_transform_fn=output_transform_fn) model_zoo.register(name='deepfm_simpledeepfmnn', - model_fn=partial(deepfm.SimpleDeepFMNN, SHAPE, get_ebc(), SHAPE, SHAPE), + model_fn=simple_deep_fmnn_model_fn, data_gen_fn=simple_dfm_data_gen_fn, output_transform_fn=output_transform_fn) model_zoo.register(name='deepfm_sparsearch', - model_fn=partial(deepfm.SparseArch, get_ebc()), + model_fn=sparse_arch_model_fn, data_gen_fn=sparse_arch_data_gen_fn, output_transform_fn=output_transform_fn) model_zoo.register(name='dlrm', - model_fn=partial(dlrm.DLRM, get_ebc(), SHAPE, [SHAPE, SHAPE], [5, 1]), + model_fn=dlrm_model_fn, data_gen_fn=simple_dfm_data_gen_fn, output_transform_fn=output_transform_fn) From 4db8784387e8c3a379b45234741f6ca7f557a56e Mon Sep 17 00:00:00 2001 From: FrankLeeeee Date: Mon, 20 Mar 2023 15:14:18 +0800 Subject: [PATCH 2/4] polish code --- .../test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py | 1 - tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py | 1 - 2 files changed, 2 deletions(-) diff --git a/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py b/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py index a30139f26d29..a4e847dbcfcd 100644 --- a/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py +++ b/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py @@ -47,7 +47,6 @@ def trace_and_compare(model_cls, data, output_transform_fn, meta_args=None): ), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' -@pytest.mark.skip('unknown error') def test_torchrec_deepfm_models(): deepfm_models = model_zoo.get_sub_registry('deepfm') torch.backends.cudnn.deterministic = True diff --git a/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py b/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py index 27a88291397e..810be41d0ccc 100644 --- a/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py +++ b/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py @@ -47,7 +47,6 @@ def trace_and_compare(model_cls, data, output_transform_fn, meta_args=None): ), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' -@pytest.mark.skip('unknown error') def test_torchrec_dlrm_models(): torch.backends.cudnn.deterministic = True dlrm_models = model_zoo.get_sub_registry('dlrm') From 9281356f936e15289ec34e04fd28ec2ec1088ef3 Mon Sep 17 00:00:00 2001 From: FrankLeeeee Date: Mon, 20 Mar 2023 15:48:33 +0800 Subject: [PATCH 3/4] polish code --- tests/kit/model_zoo/torchrec/torchrec.py | 49 ++++++++++++++--- tests/test_temp/test_dlrm_model.py | 70 ++++++++++++++++++++++++ tests/test_temp/test_fp16_torch.py | 29 ++++++++++ 3 files changed, 139 insertions(+), 9 deletions(-) create mode 100644 tests/test_temp/test_dlrm_model.py create mode 100644 tests/test_temp/test_fp16_torch.py diff --git a/tests/kit/model_zoo/torchrec/torchrec.py b/tests/kit/model_zoo/torchrec/torchrec.py index 169f37bf492d..dda563155fca 100644 --- a/tests/kit/model_zoo/torchrec/torchrec.py +++ b/tests/kit/model_zoo/torchrec/torchrec.py @@ -11,21 +11,47 @@ BATCH = 2 SHAPE = 10 -# KeyedTensor -KT = KeyedTensor(keys=["f1", "f2"], length_per_key=[SHAPE, SHAPE], values=torch.rand((BATCH, 2 * SHAPE))) + + +def gen_kt(): + KT = KeyedTensor(keys=["f1", "f2"], length_per_key=[SHAPE, SHAPE], values=torch.rand((BATCH, 2 * SHAPE))) + return KT + # KeyedJaggedTensor -KJT = KeyedJaggedTensor.from_offsets_sync(keys=["f1", "f2"], - values=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]), - offsets=torch.tensor([0, 2, 4, 6, 8])) +def gen_kjt(): + KJT = KeyedJaggedTensor.from_offsets_sync(keys=["f1", "f2"], + values=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]), + offsets=torch.tensor([0, 2, 4, 6, 8])) + return KJT + data_gen_fn = lambda: dict(features=torch.rand((BATCH, SHAPE))) -interaction_arch_data_gen_fn = lambda: dict(dense_features=torch.rand((BATCH, SHAPE)), sparse_features=KT) -simple_dfm_data_gen_fn = lambda: dict(dense_features=torch.rand((BATCH, SHAPE)), sparse_features=KJT) +def interaction_arch_data_gen_fn(): + KT = gen_kt() + return dict(dense_features=torch.rand((BATCH, SHAPE)), sparse_features=KT) + -sparse_arch_data_gen_fn = lambda: dict(features=KJT) +def simple_dfm_data_gen_fn(): + KJT = gen_kjt() + return dict(dense_features=torch.rand((BATCH, SHAPE)), sparse_features=KJT) + + +def sparse_arch_data_gen_fn(): + KJT = gen_kjt() + return dict(features=KJT) + + +def output_transform_fn(x): + if isinstance(x, KeyedTensor): + output = dict() + for key in x.keys(): + output[key] = x[key] + return output + else: + return dict(output=x) def output_transform_fn(x): @@ -60,6 +86,11 @@ def dlrm_model_fn(): return dlrm.DLRM(ebc, SHAPE, [SHAPE, SHAPE], [5, 1]) +def dlrm_sparsearch_model_fn(): + ebc = get_ebc() + return dlrm.SparseArch(ebc) + + model_zoo.register(name='deepfm_densearch', model_fn=partial(deepfm.DenseArch, SHAPE, SHAPE, SHAPE), data_gen_fn=data_gen_fn, @@ -106,6 +137,6 @@ def dlrm_model_fn(): output_transform_fn=output_transform_fn) model_zoo.register(name='dlrm_sparsearch', - model_fn=partial(dlrm.SparseArch, get_ebc()), + model_fn=dlrm_sparsearch_model_fn, data_gen_fn=sparse_arch_data_gen_fn, output_transform_fn=output_transform_fn) diff --git a/tests/test_temp/test_dlrm_model.py b/tests/test_temp/test_dlrm_model.py new file mode 100644 index 000000000000..810be41d0ccc --- /dev/null +++ b/tests/test_temp/test_dlrm_model.py @@ -0,0 +1,70 @@ +import pytest +import torch + +from colossalai.fx import symbolic_trace +from tests.kit.model_zoo import model_zoo + +BATCH = 2 +SHAPE = 10 + + +def trace_and_compare(model_cls, data, output_transform_fn, meta_args=None): + # trace + model = model_cls() + + # convert to eval for inference + # it is important to set it to eval mode before tracing + # without this statement, the torch.nn.functional.batch_norm will always be in training mode + model.eval() + + gm = symbolic_trace(model, meta_args=meta_args) + gm.eval() + # run forward + with torch.no_grad(): + fx_out = gm(**data) + non_fx_out = model(**data) + + # compare output + transformed_fx_out = output_transform_fn(fx_out) + transformed_non_fx_out = output_transform_fn(non_fx_out) + + assert len(transformed_fx_out) == len(transformed_non_fx_out) + if torch.is_tensor(fx_out): + assert torch.allclose( + fx_out, non_fx_out), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' + else: + assert torch.allclose( + fx_out.values(), + non_fx_out.values()), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' + for key in transformed_fx_out.keys(): + fx_output_val = transformed_fx_out[key] + non_fx_output_val = transformed_non_fx_out[key] + if torch.is_tensor(fx_output_val): + assert torch.allclose(fx_output_val, non_fx_output_val, atol=1e-5), \ + f'{model.__class__.__name__} has inconsistent outputs, {fx_output_val} vs {non_fx_output_val}' + else: + assert torch.allclose(fx_output_val.values(), non_fx_output_val.values() + ), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' + + +def test_torchrec_dlrm_models(): + torch.backends.cudnn.deterministic = True + dlrm_models = model_zoo.get_sub_registry('dlrm') + + for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in dlrm_models.items(): + data = data_gen_fn() + + # dlrm_interactionarch is not supported + if name == 'dlrm_interactionarch': + continue + + if attribute is not None and attribute.has_control_flow: + meta_args = {k: v.to('meta') for k, v in data.items()} + else: + meta_args = None + + trace_and_compare(model_fn, data, output_transform_fn, meta_args) + + +if __name__ == "__main__": + test_torchrec_dlrm_models() diff --git a/tests/test_temp/test_fp16_torch.py b/tests/test_temp/test_fp16_torch.py new file mode 100644 index 000000000000..98d00cd2caca --- /dev/null +++ b/tests/test_temp/test_fp16_torch.py @@ -0,0 +1,29 @@ +import torch +from torch.optim import Adam + +from colossalai.booster.mixed_precision import FP16TorchMixedPrecision +from tests.kit.model_zoo import model_zoo + + +def test_torch_amp(): + for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items(): + # dlrm_interactionarch has not parameters, so skip + if name == 'dlrm_interactionarch': + continue + + model = model_fn().cuda() + optimizer = Adam(model.parameters(), lr=1e-3) + criterion = lambda x: x.mean() + data = data_gen_fn() + data = { + k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items() + } + mixed_precision = FP16TorchMixedPrecision() + model, optimizer, criterion = mixed_precision.configure(model, optimizer, criterion) + output = model(**data) + output = output_transform_fn(output) + output_key = list(output.keys())[0] + loss = criterion(output[output_key]) + optimizer.backward(loss) + optimizer.clip_grad_by_norm(1.0) + optimizer.step() From 5a34a82fa5f08e99b6d621fbbe1cf40962800440 Mon Sep 17 00:00:00 2001 From: FrankLeeeee Date: Mon, 20 Mar 2023 15:50:41 +0800 Subject: [PATCH 4/4] polish code --- tests/test_temp/test_dlrm_model.py | 70 ------------------------------ tests/test_temp/test_fp16_torch.py | 29 ------------- 2 files changed, 99 deletions(-) delete mode 100644 tests/test_temp/test_dlrm_model.py delete mode 100644 tests/test_temp/test_fp16_torch.py diff --git a/tests/test_temp/test_dlrm_model.py b/tests/test_temp/test_dlrm_model.py deleted file mode 100644 index 810be41d0ccc..000000000000 --- a/tests/test_temp/test_dlrm_model.py +++ /dev/null @@ -1,70 +0,0 @@ -import pytest -import torch - -from colossalai.fx import symbolic_trace -from tests.kit.model_zoo import model_zoo - -BATCH = 2 -SHAPE = 10 - - -def trace_and_compare(model_cls, data, output_transform_fn, meta_args=None): - # trace - model = model_cls() - - # convert to eval for inference - # it is important to set it to eval mode before tracing - # without this statement, the torch.nn.functional.batch_norm will always be in training mode - model.eval() - - gm = symbolic_trace(model, meta_args=meta_args) - gm.eval() - # run forward - with torch.no_grad(): - fx_out = gm(**data) - non_fx_out = model(**data) - - # compare output - transformed_fx_out = output_transform_fn(fx_out) - transformed_non_fx_out = output_transform_fn(non_fx_out) - - assert len(transformed_fx_out) == len(transformed_non_fx_out) - if torch.is_tensor(fx_out): - assert torch.allclose( - fx_out, non_fx_out), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' - else: - assert torch.allclose( - fx_out.values(), - non_fx_out.values()), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' - for key in transformed_fx_out.keys(): - fx_output_val = transformed_fx_out[key] - non_fx_output_val = transformed_non_fx_out[key] - if torch.is_tensor(fx_output_val): - assert torch.allclose(fx_output_val, non_fx_output_val, atol=1e-5), \ - f'{model.__class__.__name__} has inconsistent outputs, {fx_output_val} vs {non_fx_output_val}' - else: - assert torch.allclose(fx_output_val.values(), non_fx_output_val.values() - ), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' - - -def test_torchrec_dlrm_models(): - torch.backends.cudnn.deterministic = True - dlrm_models = model_zoo.get_sub_registry('dlrm') - - for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in dlrm_models.items(): - data = data_gen_fn() - - # dlrm_interactionarch is not supported - if name == 'dlrm_interactionarch': - continue - - if attribute is not None and attribute.has_control_flow: - meta_args = {k: v.to('meta') for k, v in data.items()} - else: - meta_args = None - - trace_and_compare(model_fn, data, output_transform_fn, meta_args) - - -if __name__ == "__main__": - test_torchrec_dlrm_models() diff --git a/tests/test_temp/test_fp16_torch.py b/tests/test_temp/test_fp16_torch.py deleted file mode 100644 index 98d00cd2caca..000000000000 --- a/tests/test_temp/test_fp16_torch.py +++ /dev/null @@ -1,29 +0,0 @@ -import torch -from torch.optim import Adam - -from colossalai.booster.mixed_precision import FP16TorchMixedPrecision -from tests.kit.model_zoo import model_zoo - - -def test_torch_amp(): - for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items(): - # dlrm_interactionarch has not parameters, so skip - if name == 'dlrm_interactionarch': - continue - - model = model_fn().cuda() - optimizer = Adam(model.parameters(), lr=1e-3) - criterion = lambda x: x.mean() - data = data_gen_fn() - data = { - k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items() - } - mixed_precision = FP16TorchMixedPrecision() - model, optimizer, criterion = mixed_precision.configure(model, optimizer, criterion) - output = model(**data) - output = output_transform_fn(output) - output_key = list(output.keys())[0] - loss = criterion(output[output_key]) - optimizer.backward(loss) - optimizer.clip_grad_by_norm(1.0) - optimizer.step()