From aefea27619b43cbb3e518ea972009e525676dc8d Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Sun, 23 Oct 2022 21:19:44 +0000 Subject: [PATCH 1/2] v1 --- src/transformers/modeling_utils.py | 37 ++++ src/transformers/models/lilt/modeling_lilt.py | 1 + .../models/roberta/modeling_roberta.py | 1 + .../xlm_roberta/modeling_xlm_roberta.py | 1 + tests/models/lilt/test_modeling_lilt.py | 180 ++++++++++++++++- tests/models/roberta/test_modeling_roberta.py | 181 +++++++++++++++++- tests/models/xlm/test_modeling_xlm.py | 180 ++++++++++++++++- 7 files changed, 578 insertions(+), 3 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 58ef08bde419..802c09be3d13 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -532,6 +532,38 @@ def _move_model_to_meta(model, loaded_state_dict_keys, start_prefix): setattr(submodule, param_name, new_val) +def _load_unloaded_params( + model, + dtype, + load_in_8bit=False, +): + r""" + This function loads un-loaded params that can occur in some corner cases. For some models, certain keys (such as + `lm_head.xx.weight` are in the `_keys_to_ignore_on_load_missing` and `_keys_to_ignore_on_save`. This leads to a bug + when loading a saved model with `device_map=auto` since these weights will not be in the `state_dict`- so not + loaded at all by `set_module_tensor_to_device`. Therefore this function aims to load these missing params, by + randomly initializing them and put them on `CPU`. + + Args: + `model`, (`torch.nn.Module`): + Input model. + `dtype`, (`torch.dtype`): + Torch dtype of the input model. + `load_in_8bit`, (`bool`, *optional*): + Whether the model has to be loaded in 8-bit or not. + """ + if load_in_8bit: + from .utils.bitsandbytes import set_module_8bit_tensor_to_device + + for param_name, param in model.named_parameters(): + if param.device.type == "meta": + if not load_in_8bit: + set_module_tensor_to_device(model, param_name, "cpu", torch.empty(*param.size(), dtype=dtype)) + else: + set_module_8bit_tensor_to_device(model, param_name, "cpu", torch.empty(*param.size(), dtype=dtype)) + return model + + def _load_state_dict_into_meta_model( model, state_dict, @@ -2560,6 +2592,11 @@ def _find_mismatched_keys( load_in_8bit=load_in_8bit, ) error_msgs += new_error_msgs + + # At this point all the modules should have a device set, except for the modules that are + # in `_keys_to_ignore_on_load_missing` and `_keys_to_ignore_on_save` - let's randomly initialize them! + if hasattr(model, "_keys_to_ignore_on_load_missing") and hasattr(model, "_keys_to_ignore_on_save"): + model = _load_unloaded_params(model, dtype, load_in_8bit) else: error_msgs += _load_state_dict_into_model(model_to_load, state_dict, start_prefix) diff --git a/src/transformers/models/lilt/modeling_lilt.py b/src/transformers/models/lilt/modeling_lilt.py index c78490f4b439..664192ef7ff8 100644 --- a/src/transformers/models/lilt/modeling_lilt.py +++ b/src/transformers/models/lilt/modeling_lilt.py @@ -588,6 +588,7 @@ class LiltPreTrainedModel(PreTrainedModel): config_class = LiltConfig base_model_prefix = "lilt" supports_gradient_checkpointing = True + _no_split_modules = ["LiltLayer", "LiltLayoutEmbeddings"] # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights def _init_weights(self, module): diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index d27fc5fa3425..cf543afc85b8 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -584,6 +584,7 @@ class RobertaPreTrainedModel(PreTrainedModel): config_class = RobertaConfig base_model_prefix = "roberta" supports_gradient_checkpointing = True + _no_split_modules = ["RobertaLayer"] # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights def _init_weights(self, module): diff --git a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py index 9e5bdfe4fcdf..1e7d40e6b41a 100644 --- a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py @@ -586,6 +586,7 @@ class XLMRobertaPreTrainedModel(PreTrainedModel): config_class = XLMRobertaConfig base_model_prefix = "roberta" supports_gradient_checkpointing = True + _no_split_modules = ["XLMRobertaLayer"] # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights def _init_weights(self, module): diff --git a/tests/models/lilt/test_modeling_lilt.py b/tests/models/lilt/test_modeling_lilt.py index 718d2bd287fb..96f05a6c2138 100644 --- a/tests/models/lilt/test_modeling_lilt.py +++ b/tests/models/lilt/test_modeling_lilt.py @@ -14,10 +14,19 @@ # limitations under the License. +import tempfile import unittest from transformers import LiltConfig, is_torch_available -from transformers.testing_utils import require_torch, slow, torch_device +from transformers.testing_utils import ( + is_accelerate_available, + require_accelerate, + require_torch, + require_torch_gpu, + require_torch_multi_gpu, + slow, + torch_device, +) from ...generation.test_generation_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester @@ -35,6 +44,9 @@ ) from transformers.models.lilt.modeling_lilt import LILT_PRETRAINED_MODEL_ARCHIVE_LIST +if is_accelerate_available(): + from accelerate.utils import compute_module_sizes + class LiltModelTester: def __init__( @@ -264,6 +276,172 @@ def test_model_from_pretrained(self): model = LiltModel.from_pretrained(model_name) self.assertIsNotNone(model) + @require_accelerate + @require_torch_multi_gpu + def test_model_parallelism(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + if model_class._no_split_modules is None: + continue + + inputs_dict_model_class = self._prepare_for_class(inputs_dict, model_class) + model = model_class(config).eval() + model = model.to(torch_device) + + ignore_lm_head = ( + hasattr(model, "lm_head") + and hasattr(model, "_keys_to_ignore_on_load_missing") + and hasattr(model, "_keys_to_ignore_on_save") + ) + if ignore_lm_head: + ignore_lm_head = ( + ignore_lm_head + if any("lm_head" in keys for keys in model._keys_to_ignore_on_load_missing) + else not ignore_lm_head + ) + inputs_dict_model_class["output_attentions"] = True + + torch.manual_seed(0) + base_output = model(**inputs_dict_model_class) + + model_size = compute_module_sizes(model)[""] + # We test several splits of sizes to make sure it works. + max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents] + with tempfile.TemporaryDirectory() as tmp_dir: + model.cpu().save_pretrained(tmp_dir) + + for max_size in max_gpu_sizes: + max_memory = {0: max_size, 1: model_size * 2, "cpu": model_size * 2} + + new_model = model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory) + + # Making sure part of the model will actually end up offloaded + self.assertSetEqual(set(new_model.hf_device_map.values()), {0, 1}) + + self.check_device_map_is_respected(new_model, new_model.hf_device_map) + + torch.manual_seed(0) + new_output = new_model(**inputs_dict_model_class) + + # Since the `lm_head.decoder.weight` is in `_keys_to_ignore_on_save` we don't check + # the final logits but only the intermediate attentions for all layers + if ignore_lm_head: + for i in range(len(base_output["attentions"])): + self.assertTrue(torch.allclose(base_output["attentions"][i], new_output["attentions"][i])) + else: + self.assertTrue(torch.allclose(base_output[0], new_output[0])) + + @require_accelerate + @require_torch_gpu + def test_disk_offload(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + if model_class._no_split_modules is None: + continue + + inputs_dict_model_class = self._prepare_for_class(inputs_dict, model_class) + model = model_class(config).eval() + model = model.to(torch_device) + torch.manual_seed(0) + + ignore_lm_head = ( + hasattr(model, "lm_head") + and hasattr(model, "_keys_to_ignore_on_load_missing") + and hasattr(model, "_keys_to_ignore_on_save") + ) + if ignore_lm_head: + ignore_lm_head = ( + ignore_lm_head + if any("lm_head" in keys for keys in model._keys_to_ignore_on_load_missing) + else not ignore_lm_head + ) + inputs_dict_model_class["output_attentions"] = True + + base_output = model(**inputs_dict_model_class) + + model_size = compute_module_sizes(model)[""] + max_size = int(self.model_split_percents[0] * model_size) + with tempfile.TemporaryDirectory() as tmp_dir: + model.cpu().save_pretrained(tmp_dir) + + max_memory = {0: max_size, "cpu": max_size} + with self.assertRaises(ValueError): + # This errors out cause it's missing an offload folder + new_model = model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory) + + new_model = model_class.from_pretrained( + tmp_dir, device_map="auto", max_memory=max_memory, offload_folder=tmp_dir + ) + + self.check_device_map_is_respected(new_model, new_model.hf_device_map) + torch.manual_seed(0) + new_output = new_model(**inputs_dict_model_class) + + # Since the `lm_head.decoder.weight` is in `_keys_to_ignore_on_save` we don't check + # the final logits but only the intermediate attentions for all layers + if ignore_lm_head: + for i in range(len(base_output["attentions"])): + self.assertTrue(torch.allclose(base_output["attentions"][i], new_output["attentions"][i])) + else: + self.assertTrue(torch.allclose(base_output[0], new_output[0])) + + @require_accelerate + @require_torch_gpu + def test_cpu_offload(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + if model_class._no_split_modules is None: + continue + + inputs_dict_model_class = self._prepare_for_class(inputs_dict, model_class) + + model = model_class(config).eval() + model = model.to(torch_device) + + ignore_lm_head = ( + hasattr(model, "lm_head") + and hasattr(model, "_keys_to_ignore_on_load_missing") + and hasattr(model, "_keys_to_ignore_on_save") + ) + if ignore_lm_head: + ignore_lm_head = ( + ignore_lm_head + if any("lm_head" in keys for keys in model._keys_to_ignore_on_load_missing) + else not ignore_lm_head + ) + inputs_dict_model_class["output_attentions"] = True + + torch.manual_seed(0) + base_output = model(**inputs_dict_model_class) + + model_size = compute_module_sizes(model)[""] + # We test several splits of sizes to make sure it works. + max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents] + with tempfile.TemporaryDirectory() as tmp_dir: + model.cpu().save_pretrained(tmp_dir) + + for max_size in max_gpu_sizes: + max_memory = {0: max_size, "cpu": model_size * 2} + new_model = model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory) + # Making sure part of the model will actually end up offloaded + self.assertSetEqual(set(new_model.hf_device_map.values()), {0, "cpu"}) + + self.check_device_map_is_respected(new_model, new_model.hf_device_map) + + torch.manual_seed(0) + new_output = new_model(**inputs_dict_model_class) + + # Since the `lm_head.decoder.weight` is in `_keys_to_ignore_on_save` we don't check + # the final logits but only the intermediate attentions for all layers + if ignore_lm_head: + for i in range(len(base_output["attentions"])): + self.assertTrue(torch.allclose(base_output["attentions"][i], new_output["attentions"][i])) + else: + self.assertTrue(torch.allclose(base_output[0], new_output[0])) + @require_torch @slow diff --git a/tests/models/roberta/test_modeling_roberta.py b/tests/models/roberta/test_modeling_roberta.py index 7163a357021e..aa70eaa51bd9 100644 --- a/tests/models/roberta/test_modeling_roberta.py +++ b/tests/models/roberta/test_modeling_roberta.py @@ -14,11 +14,21 @@ # limitations under the License. +import tempfile import unittest from copy import deepcopy from transformers import RobertaConfig, is_torch_available -from transformers.testing_utils import TestCasePlus, require_torch, slow, torch_device +from transformers.testing_utils import ( + TestCasePlus, + is_accelerate_available, + require_accelerate, + require_torch, + require_torch_gpu, + require_torch_multi_gpu, + slow, + torch_device, +) from ...generation.test_generation_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester @@ -43,6 +53,9 @@ create_position_ids_from_input_ids, ) +if is_accelerate_available(): + from accelerate.utils import compute_module_sizes + ROBERTA_TINY = "sshleifer/tiny-distilroberta-base" @@ -483,6 +496,172 @@ def test_create_position_ids_from_inputs_embeds(self): self.assertEqual(position_ids.shape, expected_positions.shape) self.assertTrue(torch.all(torch.eq(position_ids, expected_positions))) + @require_accelerate + @require_torch_multi_gpu + def test_model_parallelism(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + if model_class._no_split_modules is None: + continue + + inputs_dict_model_class = self._prepare_for_class(inputs_dict, model_class) + model = model_class(config).eval() + model = model.to(torch_device) + + ignore_lm_head = ( + hasattr(model, "lm_head") + and hasattr(model, "_keys_to_ignore_on_load_missing") + and hasattr(model, "_keys_to_ignore_on_save") + ) + if ignore_lm_head: + ignore_lm_head = ( + ignore_lm_head + if any("lm_head" in keys for keys in model._keys_to_ignore_on_load_missing) + else not ignore_lm_head + ) + inputs_dict_model_class["output_attentions"] = True + + torch.manual_seed(0) + base_output = model(**inputs_dict_model_class) + + model_size = compute_module_sizes(model)[""] + # We test several splits of sizes to make sure it works. + max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents] + with tempfile.TemporaryDirectory() as tmp_dir: + model.cpu().save_pretrained(tmp_dir) + + for max_size in max_gpu_sizes: + max_memory = {0: max_size, 1: model_size * 2, "cpu": model_size * 2} + + new_model = model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory) + + # Making sure part of the model will actually end up offloaded + self.assertSetEqual(set(new_model.hf_device_map.values()), {0, 1}) + + self.check_device_map_is_respected(new_model, new_model.hf_device_map) + + torch.manual_seed(0) + new_output = new_model(**inputs_dict_model_class) + + # Since the `lm_head.decoder.weight` is in `_keys_to_ignore_on_save` we don't check + # the final logits but only the intermediate attentions for all layers + if ignore_lm_head: + for i in range(len(base_output["attentions"])): + self.assertTrue(torch.allclose(base_output["attentions"][i], new_output["attentions"][i])) + else: + self.assertTrue(torch.allclose(base_output[0], new_output[0])) + + @require_accelerate + @require_torch_gpu + def test_disk_offload(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + if model_class._no_split_modules is None: + continue + + inputs_dict_model_class = self._prepare_for_class(inputs_dict, model_class) + model = model_class(config).eval() + model = model.to(torch_device) + torch.manual_seed(0) + + ignore_lm_head = ( + hasattr(model, "lm_head") + and hasattr(model, "_keys_to_ignore_on_load_missing") + and hasattr(model, "_keys_to_ignore_on_save") + ) + if ignore_lm_head: + ignore_lm_head = ( + ignore_lm_head + if any("lm_head" in keys for keys in model._keys_to_ignore_on_load_missing) + else not ignore_lm_head + ) + inputs_dict_model_class["output_attentions"] = True + + base_output = model(**inputs_dict_model_class) + + model_size = compute_module_sizes(model)[""] + max_size = int(self.model_split_percents[0] * model_size) + with tempfile.TemporaryDirectory() as tmp_dir: + model.cpu().save_pretrained(tmp_dir) + + max_memory = {0: max_size, "cpu": max_size} + with self.assertRaises(ValueError): + # This errors out cause it's missing an offload folder + new_model = model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory) + + new_model = model_class.from_pretrained( + tmp_dir, device_map="auto", max_memory=max_memory, offload_folder=tmp_dir + ) + + self.check_device_map_is_respected(new_model, new_model.hf_device_map) + torch.manual_seed(0) + new_output = new_model(**inputs_dict_model_class) + + # Since the `lm_head.decoder.weight` is in `_keys_to_ignore_on_save` we don't check + # the final logits but only the intermediate attentions for all layers + if ignore_lm_head: + for i in range(len(base_output["attentions"])): + self.assertTrue(torch.allclose(base_output["attentions"][i], new_output["attentions"][i])) + else: + self.assertTrue(torch.allclose(base_output[0], new_output[0])) + + @require_accelerate + @require_torch_gpu + def test_cpu_offload(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + if model_class._no_split_modules is None: + continue + + inputs_dict_model_class = self._prepare_for_class(inputs_dict, model_class) + + model = model_class(config).eval() + model = model.to(torch_device) + + ignore_lm_head = ( + hasattr(model, "lm_head") + and hasattr(model, "_keys_to_ignore_on_load_missing") + and hasattr(model, "_keys_to_ignore_on_save") + ) + if ignore_lm_head: + ignore_lm_head = ( + ignore_lm_head + if any("lm_head" in keys for keys in model._keys_to_ignore_on_load_missing) + else not ignore_lm_head + ) + inputs_dict_model_class["output_attentions"] = True + + torch.manual_seed(0) + base_output = model(**inputs_dict_model_class) + + model_size = compute_module_sizes(model)[""] + # We test several splits of sizes to make sure it works. + max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents] + with tempfile.TemporaryDirectory() as tmp_dir: + model.cpu().save_pretrained(tmp_dir) + + for max_size in max_gpu_sizes: + max_memory = {0: max_size, "cpu": model_size * 2} + new_model = model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory) + # Making sure part of the model will actually end up offloaded + self.assertSetEqual(set(new_model.hf_device_map.values()), {0, "cpu"}) + + self.check_device_map_is_respected(new_model, new_model.hf_device_map) + + torch.manual_seed(0) + new_output = new_model(**inputs_dict_model_class) + + # Since the `lm_head.decoder.weight` is in `_keys_to_ignore_on_save` we don't check + # the final logits but only the intermediate attentions for all layers + if ignore_lm_head: + for i in range(len(base_output["attentions"])): + self.assertTrue(torch.allclose(base_output["attentions"][i], new_output["attentions"][i])) + else: + self.assertTrue(torch.allclose(base_output[0], new_output[0])) + @require_torch class RobertaModelIntegrationTest(TestCasePlus): diff --git a/tests/models/xlm/test_modeling_xlm.py b/tests/models/xlm/test_modeling_xlm.py index 8f56ed8472ea..dc32b9697b45 100644 --- a/tests/models/xlm/test_modeling_xlm.py +++ b/tests/models/xlm/test_modeling_xlm.py @@ -13,10 +13,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +import tempfile import unittest from transformers import XLMConfig, is_torch_available -from transformers.testing_utils import require_torch, slow, torch_device +from transformers.testing_utils import ( + is_accelerate_available, + require_accelerate, + require_torch, + require_torch_gpu, + require_torch_multi_gpu, + slow, + torch_device, +) from ...generation.test_generation_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester @@ -37,6 +46,9 @@ ) from transformers.models.xlm.modeling_xlm import XLM_PRETRAINED_MODEL_ARCHIVE_LIST +if is_accelerate_available(): + from accelerate.utils import compute_module_sizes + class XLMModelTester: def __init__( @@ -453,6 +465,172 @@ def test_model_from_pretrained(self): model = XLMModel.from_pretrained(model_name) self.assertIsNotNone(model) + @require_accelerate + @require_torch_multi_gpu + def test_model_parallelism(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + if model_class._no_split_modules is None: + continue + + inputs_dict_model_class = self._prepare_for_class(inputs_dict, model_class) + model = model_class(config).eval() + model = model.to(torch_device) + + ignore_lm_head = ( + hasattr(model, "lm_head") + and hasattr(model, "_keys_to_ignore_on_load_missing") + and hasattr(model, "_keys_to_ignore_on_save") + ) + if ignore_lm_head: + ignore_lm_head = ( + ignore_lm_head + if any("lm_head" in keys for keys in model._keys_to_ignore_on_load_missing) + else not ignore_lm_head + ) + inputs_dict_model_class["output_attentions"] = True + + torch.manual_seed(0) + base_output = model(**inputs_dict_model_class) + + model_size = compute_module_sizes(model)[""] + # We test several splits of sizes to make sure it works. + max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents] + with tempfile.TemporaryDirectory() as tmp_dir: + model.cpu().save_pretrained(tmp_dir) + + for max_size in max_gpu_sizes: + max_memory = {0: max_size, 1: model_size * 2, "cpu": model_size * 2} + + new_model = model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory) + + # Making sure part of the model will actually end up offloaded + self.assertSetEqual(set(new_model.hf_device_map.values()), {0, 1}) + + self.check_device_map_is_respected(new_model, new_model.hf_device_map) + + torch.manual_seed(0) + new_output = new_model(**inputs_dict_model_class) + + # Since the `lm_head.decoder.weight` is in `_keys_to_ignore_on_save` we don't check + # the final logits but only the intermediate attentions for all layers + if ignore_lm_head: + for i in range(len(base_output["attentions"])): + self.assertTrue(torch.allclose(base_output["attentions"][i], new_output["attentions"][i])) + else: + self.assertTrue(torch.allclose(base_output[0], new_output[0])) + + @require_accelerate + @require_torch_gpu + def test_disk_offload(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + if model_class._no_split_modules is None: + continue + + inputs_dict_model_class = self._prepare_for_class(inputs_dict, model_class) + model = model_class(config).eval() + model = model.to(torch_device) + torch.manual_seed(0) + + ignore_lm_head = ( + hasattr(model, "lm_head") + and hasattr(model, "_keys_to_ignore_on_load_missing") + and hasattr(model, "_keys_to_ignore_on_save") + ) + if ignore_lm_head: + ignore_lm_head = ( + ignore_lm_head + if any("lm_head" in keys for keys in model._keys_to_ignore_on_load_missing) + else not ignore_lm_head + ) + inputs_dict_model_class["output_attentions"] = True + + base_output = model(**inputs_dict_model_class) + + model_size = compute_module_sizes(model)[""] + max_size = int(self.model_split_percents[0] * model_size) + with tempfile.TemporaryDirectory() as tmp_dir: + model.cpu().save_pretrained(tmp_dir) + + max_memory = {0: max_size, "cpu": max_size} + with self.assertRaises(ValueError): + # This errors out cause it's missing an offload folder + new_model = model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory) + + new_model = model_class.from_pretrained( + tmp_dir, device_map="auto", max_memory=max_memory, offload_folder=tmp_dir + ) + + self.check_device_map_is_respected(new_model, new_model.hf_device_map) + torch.manual_seed(0) + new_output = new_model(**inputs_dict_model_class) + + # Since the `lm_head.decoder.weight` is in `_keys_to_ignore_on_save` we don't check + # the final logits but only the intermediate attentions for all layers + if ignore_lm_head: + for i in range(len(base_output["attentions"])): + self.assertTrue(torch.allclose(base_output["attentions"][i], new_output["attentions"][i])) + else: + self.assertTrue(torch.allclose(base_output[0], new_output[0])) + + @require_accelerate + @require_torch_gpu + def test_cpu_offload(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + if model_class._no_split_modules is None: + continue + + inputs_dict_model_class = self._prepare_for_class(inputs_dict, model_class) + + model = model_class(config).eval() + model = model.to(torch_device) + + ignore_lm_head = ( + hasattr(model, "lm_head") + and hasattr(model, "_keys_to_ignore_on_load_missing") + and hasattr(model, "_keys_to_ignore_on_save") + ) + if ignore_lm_head: + ignore_lm_head = ( + ignore_lm_head + if any("lm_head" in keys for keys in model._keys_to_ignore_on_load_missing) + else not ignore_lm_head + ) + inputs_dict_model_class["output_attentions"] = True + + torch.manual_seed(0) + base_output = model(**inputs_dict_model_class) + + model_size = compute_module_sizes(model)[""] + # We test several splits of sizes to make sure it works. + max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents] + with tempfile.TemporaryDirectory() as tmp_dir: + model.cpu().save_pretrained(tmp_dir) + + for max_size in max_gpu_sizes: + max_memory = {0: max_size, "cpu": model_size * 2} + new_model = model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory) + # Making sure part of the model will actually end up offloaded + self.assertSetEqual(set(new_model.hf_device_map.values()), {0, "cpu"}) + + self.check_device_map_is_respected(new_model, new_model.hf_device_map) + + torch.manual_seed(0) + new_output = new_model(**inputs_dict_model_class) + + # Since the `lm_head.decoder.weight` is in `_keys_to_ignore_on_save` we don't check + # the final logits but only the intermediate attentions for all layers + if ignore_lm_head: + for i in range(len(base_output["attentions"])): + self.assertTrue(torch.allclose(base_output["attentions"][i], new_output["attentions"][i])) + else: + self.assertTrue(torch.allclose(base_output[0], new_output[0])) + @require_torch class XLMModelLanguageGenerationTest(unittest.TestCase): From caf88fe337a5940c746d10e57e1de59d9c618bf7 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 24 Oct 2022 17:57:03 +0000 Subject: [PATCH 2/2] add more conditions --- src/transformers/modeling_utils.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 802c09be3d13..aa86a4e867f3 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -552,11 +552,14 @@ def _load_unloaded_params( `load_in_8bit`, (`bool`, *optional*): Whether the model has to be loaded in 8-bit or not. """ + if model._keys_to_ignore_on_save is None: + return model + if load_in_8bit: from .utils.bitsandbytes import set_module_8bit_tensor_to_device for param_name, param in model.named_parameters(): - if param.device.type == "meta": + if param.device.type == "meta" and param_name in model._keys_to_ignore_on_save: if not load_in_8bit: set_module_tensor_to_device(model, param_name, "cpu", torch.empty(*param.size(), dtype=dtype)) else: @@ -2593,10 +2596,6 @@ def _find_mismatched_keys( ) error_msgs += new_error_msgs - # At this point all the modules should have a device set, except for the modules that are - # in `_keys_to_ignore_on_load_missing` and `_keys_to_ignore_on_save` - let's randomly initialize them! - if hasattr(model, "_keys_to_ignore_on_load_missing") and hasattr(model, "_keys_to_ignore_on_save"): - model = _load_unloaded_params(model, dtype, load_in_8bit) else: error_msgs += _load_state_dict_into_model(model_to_load, state_dict, start_prefix) @@ -2604,6 +2603,14 @@ def _find_mismatched_keys( del state_dict gc.collect() + if low_cpu_mem_usage: + # At this point all the modules should have a device set, except for the modules that are + # in `_keys_to_ignore_on_load_missing` and `_keys_to_ignore_on_save` - let's randomly initialize them! + if hasattr(model_to_load, "_keys_to_ignore_on_load_missing") and hasattr( + model_to_load, "_keys_to_ignore_on_save" + ): + model_to_load = _load_unloaded_params(model_to_load, dtype, load_in_8bit) + if offload_index is not None and len(offload_index) > 0: if model != model_to_load: # We need to add the prefix of the base model