From 45bf0b4e6aac82e23a427b4eb9292c8d8c6455a6 Mon Sep 17 00:00:00 2001 From: Mannat Singh Date: Fri, 17 Apr 2020 20:14:34 -0700 Subject: [PATCH] Remove redundant passing of head.unique_id (#484) Summary: Pull Request resolved: https://github.com/facebookresearch/ClassyVision/pull/484 When setting heads, we do the following currently - ``` model.set_heads({"block3-2": {head.unique_id: head}}) ``` This is a redundant structure, since the unique_id is already available in `head`. I was writing a tutorial and I want to show a cleaner API there - ``` model.set_heads({"block3-2": [head]}) ``` Differential Revision: D21096621 fbshipit-source-id: 4dfc87e172390c43e9cd7fce98ad987099542ca2 --- classy_vision/models/__init__.py | 4 ++-- classy_vision/models/classy_model.py | 19 ++++++++++++------- classy_vision/tasks/fine_tuning_task.py | 4 ++-- .../manual/models_classy_vision_model_test.py | 4 ++-- test/models_classy_block_test.py | 17 +++++++---------- test/models_classy_model_test.py | 2 +- tutorials/fine_tuning.ipynb | 4 ++-- 7 files changed, 28 insertions(+), 26 deletions(-) diff --git a/classy_vision/models/__init__.py b/classy_vision/models/__init__.py index 97c841fded..ecaa728cc7 100644 --- a/classy_vision/models/__init__.py +++ b/classy_vision/models/__init__.py @@ -69,7 +69,7 @@ def build_model(config): assert config["name"] in MODEL_REGISTRY, "unknown model" model = MODEL_REGISTRY[config["name"]].from_config(config) if "heads" in config: - heads = defaultdict(dict) + heads = defaultdict(list) for head_config in config["heads"]: assert "fork_block" in head_config, "Expect fork_block in config" fork_block = head_config["fork_block"] @@ -77,7 +77,7 @@ def build_model(config): del updated_config["fork_block"] head = build_head(updated_config) - heads[fork_block][head.unique_id] = head + heads[fork_block].append(head) model.set_heads(heads) return model diff --git a/classy_vision/models/classy_model.py b/classy_vision/models/classy_model.py index de925cfd59..1cc65839d1 100644 --- a/classy_vision/models/classy_model.py +++ b/classy_vision/models/classy_model.py @@ -234,7 +234,7 @@ def get_classy_state(self, deep_copy=False): head_state_dict = {} for block, heads in attached_heads.items(): head_state_dict[block] = { - head_name: head.state_dict() for head_name, head in heads.items() + head.unique_id: head.state_dict() for head in heads } model_state_dict = { "model": {"trunk": trunk_state_dict, "heads": head_state_dict} @@ -342,7 +342,7 @@ def _make_module_attachable(self, module, module_name): found = found or found_in_child return found - def set_heads(self, heads: Dict[str, Dict[str, ClassyHead]]): + def set_heads(self, heads: Dict[str, List[ClassyHead]]): """Attach all the heads to corresponding blocks. A head is expected to be a ClassyHead object. For more @@ -350,7 +350,7 @@ def set_heads(self, heads: Dict[str, Dict[str, ClassyHead]]): Args: heads (Dict): a mapping between attachable block name - and a dictionary of heads attached to that block. For + and a list of heads attached to that block. For example, if you have two different teams that want to attach two different heads for downstream classifiers to the 15th block, then they would use: @@ -358,7 +358,7 @@ def set_heads(self, heads: Dict[str, Dict[str, ClassyHead]]): .. code-block:: python heads = {"block15": - {"team1": classifier_head1, "team2": classifier_head2} + [classifier_head1, classifier_head2] } """ self.clear_heads() @@ -367,11 +367,13 @@ def set_heads(self, heads: Dict[str, Dict[str, ClassyHead]]): for block_name, block_heads in heads.items(): if not self._make_module_attachable(self, block_name): raise KeyError(f"{block_name} not found in the model") - for head in block_heads.values(): + for head in block_heads: if head.unique_id in head_ids: raise ValueError("head id {} already exists".format(head.unique_id)) head_ids.add(head.unique_id) - self._heads[block_name] = nn.ModuleDict(block_heads) + self._heads[block_name] = nn.ModuleDict( + {head.unique_id: head for head in block_heads} + ) def get_heads(self): """Returns the heads on the model @@ -381,7 +383,10 @@ def get_heads(self): attached to that block. """ - return {block_name: dict(heads) for block_name, heads in self._heads.items()} + return { + block_name: list(heads.values()) + for block_name, heads in self._heads.items() + } @property def head_outputs(self): diff --git a/classy_vision/tasks/fine_tuning_task.py b/classy_vision/tasks/fine_tuning_task.py index 5da2b382ce..3cac44d50d 100644 --- a/classy_vision/tasks/fine_tuning_task.py +++ b/classy_vision/tasks/fine_tuning_task.py @@ -61,7 +61,7 @@ def _set_model_train_mode(self): # convert all the sub-modules to the eval mode, except the heads self.base_model.eval() for heads in self.base_model.get_heads().values(): - for h in heads.values(): + for h in heads: h.train(phase["train"]) else: self.base_model.train(phase["train"]) @@ -91,7 +91,7 @@ def prepare( for param in self.base_model.parameters(): param.requires_grad = False for heads in self.base_model.get_heads().values(): - for h in heads.values(): + for h in heads: for param in h.parameters(): param.requires_grad = True # re-create ddp model diff --git a/test/manual/models_classy_vision_model_test.py b/test/manual/models_classy_vision_model_test.py index effbf3dc9d..ea89912607 100644 --- a/test/manual/models_classy_vision_model_test.py +++ b/test/manual/models_classy_vision_model_test.py @@ -94,10 +94,10 @@ def test_get_set_head_states(self): model = build_model(config["model"]) trunk_state = model.get_classy_state() - heads = defaultdict(dict) + heads = defaultdict(list) for head_config in head_configs: head = build_head(head_config) - heads[head_config["fork_block"]][head.unique_id] = head + heads[head_config["fork_block"]].append(head) model.set_heads(heads) model_state = model.get_classy_state() diff --git a/test/models_classy_block_test.py b/test/models_classy_block_test.py index 7040def911..a187466f25 100644 --- a/test/models_classy_block_test.py +++ b/test/models_classy_block_test.py @@ -55,7 +55,7 @@ def test_head_execution(self): self.DummyTestModel.wrapper_cls = wrapper_class model = self.DummyTestModel() head = self.DummyTestHead() - model.set_heads({"dummy_block2": {head.unique_id: head}}) + model.set_heads({"dummy_block2": [head]}) input = torch.randn(1, 2) output = model(input) head_output = model.execute_heads() @@ -66,7 +66,7 @@ def test_head_execution(self): self.DummyTestModel.wrapper_cls = ClassyModelHeadExecutorWrapper model = self.DummyTestModel() head = self.DummyTestHead() - model.set_heads({"dummy_block2": {head.unique_id: head}}) + model.set_heads({"dummy_block2": [head]}) input = torch.randn(1, 2) output = model(input) head_output = model.execute_heads() @@ -79,10 +79,7 @@ def test_duplicated_head_ids(self): model = self.DummyTestModel() head1 = self.DummyTestHead() head2 = self.DummyTestHead() - heads = { - "dummy_block": {head1.unique_id: head1}, - "dummy_block2": {head2.unique_id: head2}, - } + heads = {"dummy_block": [head1], "dummy_block2": [head2]} with self.assertRaises(ValueError): model.set_heads(heads) @@ -92,13 +89,13 @@ def test_duplicated_head_ids(self): def test_duplicated_block_names(self): model = self.DummyTestModelDuplicatedBlockNames() head = self.DummyTestHead() - heads = {"dummy_block2": {head.unique_id: head}} + heads = {"dummy_block2": [head]} with self.assertRaises(Exception): # there are two modules with the name "dummy_block2" # which is not supported model.set_heads(heads) # can still attach to a module with a unique id - heads = {"features": {head.unique_id: head}} + heads = {"features": [head]} model.set_heads(heads) def test_set_heads(self): @@ -107,7 +104,7 @@ def test_set_heads(self): self.assertEqual( len(model.get_heads()), 0, "heads should be empty before set_heads" ) - model.set_heads({"dummy_block2": {head.unique_id: head}}) + model.set_heads({"dummy_block2": [head]}) input = torch.randn(1, 2) model(input) head_outputs = model.execute_heads() @@ -119,4 +116,4 @@ def test_set_heads(self): # try a non-existing module with self.assertRaises(Exception): - model.set_heads({"unknown_block": {head.unique_id: head}}) + model.set_heads({"unknown_block": [head]}) diff --git a/test/models_classy_model_test.py b/test/models_classy_model_test.py index 40f0d79541..3f305dbcdb 100644 --- a/test/models_classy_model_test.py +++ b/test/models_classy_model_test.py @@ -250,6 +250,6 @@ def test_heads(self): head = FullyConnectedHead( unique_id="default", in_plane=2048, num_classes=num_classes ) - classy_model.set_heads({"layer4": {head.unique_id: head}}) + classy_model.set_heads({"layer4": [head]}) input = torch.ones((1, 3, 224, 224)) self.assertEqual(classy_model(input).shape, (1, num_classes)) diff --git a/tutorials/fine_tuning.ipynb b/tutorials/fine_tuning.ipynb index 6a2f85e0b7..0ab3a450a9 100644 --- a/tutorials/fine_tuning.ipynb +++ b/tutorials/fine_tuning.ipynb @@ -149,7 +149,7 @@ }, "outputs": [], "source": [ - "model.set_heads({\"block3-2\": {head.unique_id: head}})" + "model.set_heads({\"block3-2\": [head]})" ] }, { @@ -443,7 +443,7 @@ }, "outputs": [], "source": [ - "model.set_heads({\"block3-2\": {head.unique_id: head}})" + "model.set_heads({\"block3-2\": [head]})" ] }, {