Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,41 @@ def _move_model_to_meta(model, loaded_state_dict_keys, start_prefix):
setattr(submodule, param_name, new_val)


def _load_unloaded_params(
model,
dtype,
load_in_8bit=False,
):
r"""
This function loads un-loaded params that can occur in some corner cases. For some models, certain keys (such as
`lm_head.xx.weight` are in the `_keys_to_ignore_on_load_missing` and `_keys_to_ignore_on_save`. This leads to a bug
when loading a saved model with `device_map=auto` since these weights will not be in the `state_dict`- so not
loaded at all by `set_module_tensor_to_device`. Therefore this function aims to load these missing params, by
randomly initializing them and put them on `CPU`.

Args:
`model`, (`torch.nn.Module`):
Input model.
`dtype`, (`torch.dtype`):
Torch dtype of the input model.
`load_in_8bit`, (`bool`, *optional*):
Whether the model has to be loaded in 8-bit or not.
"""
if model._keys_to_ignore_on_save is None:
return model

if load_in_8bit:
from .utils.bitsandbytes import set_module_8bit_tensor_to_device

for param_name, param in model.named_parameters():
if param.device.type == "meta" and param_name in model._keys_to_ignore_on_save:
if not load_in_8bit:
set_module_tensor_to_device(model, param_name, "cpu", torch.empty(*param.size(), dtype=dtype))
else:
set_module_8bit_tensor_to_device(model, param_name, "cpu", torch.empty(*param.size(), dtype=dtype))
return model


def _load_state_dict_into_meta_model(
model,
state_dict,
Expand Down Expand Up @@ -2560,13 +2595,22 @@ def _find_mismatched_keys(
load_in_8bit=load_in_8bit,
)
error_msgs += new_error_msgs

else:
error_msgs += _load_state_dict_into_model(model_to_load, state_dict, start_prefix)

# force memory release
del state_dict
gc.collect()

if low_cpu_mem_usage:
# At this point all the modules should have a device set, except for the modules that are
# in `_keys_to_ignore_on_load_missing` and `_keys_to_ignore_on_save` - let's randomly initialize them!
if hasattr(model_to_load, "_keys_to_ignore_on_load_missing") and hasattr(
model_to_load, "_keys_to_ignore_on_save"
):
model_to_load = _load_unloaded_params(model_to_load, dtype, load_in_8bit)

if offload_index is not None and len(offload_index) > 0:
if model != model_to_load:
# We need to add the prefix of the base model
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/lilt/modeling_lilt.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,7 @@ class LiltPreTrainedModel(PreTrainedModel):
config_class = LiltConfig
base_model_prefix = "lilt"
supports_gradient_checkpointing = True
_no_split_modules = ["LiltLayer", "LiltLayoutEmbeddings"]

# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
def _init_weights(self, module):
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/roberta/modeling_roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,7 @@ class RobertaPreTrainedModel(PreTrainedModel):
config_class = RobertaConfig
base_model_prefix = "roberta"
supports_gradient_checkpointing = True
_no_split_modules = ["RobertaLayer"]

# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
def _init_weights(self, module):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,7 @@ class XLMRobertaPreTrainedModel(PreTrainedModel):
config_class = XLMRobertaConfig
base_model_prefix = "roberta"
supports_gradient_checkpointing = True
_no_split_modules = ["XLMRobertaLayer"]

# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
def _init_weights(self, module):
Expand Down
180 changes: 179 additions & 1 deletion tests/models/lilt/test_modeling_lilt.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,19 @@
# limitations under the License.


import tempfile
import unittest

from transformers import LiltConfig, is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from transformers.testing_utils import (
is_accelerate_available,
require_accelerate,
require_torch,
require_torch_gpu,
require_torch_multi_gpu,
slow,
torch_device,
)

from ...generation.test_generation_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
Expand All @@ -35,6 +44,9 @@
)
from transformers.models.lilt.modeling_lilt import LILT_PRETRAINED_MODEL_ARCHIVE_LIST

if is_accelerate_available():
from accelerate.utils import compute_module_sizes


class LiltModelTester:
def __init__(
Expand Down Expand Up @@ -264,6 +276,172 @@ def test_model_from_pretrained(self):
model = LiltModel.from_pretrained(model_name)
self.assertIsNotNone(model)

@require_accelerate
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests could go on the super class, but not sure

@require_torch_multi_gpu
def test_model_parallelism(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

for model_class in self.all_model_classes:
if model_class._no_split_modules is None:
continue

inputs_dict_model_class = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config).eval()
model = model.to(torch_device)

ignore_lm_head = (
hasattr(model, "lm_head")
and hasattr(model, "_keys_to_ignore_on_load_missing")
and hasattr(model, "_keys_to_ignore_on_save")
)
if ignore_lm_head:
ignore_lm_head = (
ignore_lm_head
if any("lm_head" in keys for keys in model._keys_to_ignore_on_load_missing)
else not ignore_lm_head
)
inputs_dict_model_class["output_attentions"] = True

torch.manual_seed(0)
base_output = model(**inputs_dict_model_class)

model_size = compute_module_sizes(model)[""]
# We test several splits of sizes to make sure it works.
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents]
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir)

for max_size in max_gpu_sizes:
max_memory = {0: max_size, 1: model_size * 2, "cpu": model_size * 2}

new_model = model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)

# Making sure part of the model will actually end up offloaded
self.assertSetEqual(set(new_model.hf_device_map.values()), {0, 1})

self.check_device_map_is_respected(new_model, new_model.hf_device_map)

torch.manual_seed(0)
new_output = new_model(**inputs_dict_model_class)

# Since the `lm_head.decoder.weight` is in `_keys_to_ignore_on_save` we don't check
# the final logits but only the intermediate attentions for all layers
if ignore_lm_head:
for i in range(len(base_output["attentions"])):
self.assertTrue(torch.allclose(base_output["attentions"][i], new_output["attentions"][i]))
else:
self.assertTrue(torch.allclose(base_output[0], new_output[0]))

@require_accelerate
@require_torch_gpu
def test_disk_offload(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

for model_class in self.all_model_classes:
if model_class._no_split_modules is None:
continue

inputs_dict_model_class = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config).eval()
model = model.to(torch_device)
torch.manual_seed(0)

ignore_lm_head = (
hasattr(model, "lm_head")
and hasattr(model, "_keys_to_ignore_on_load_missing")
and hasattr(model, "_keys_to_ignore_on_save")
)
if ignore_lm_head:
ignore_lm_head = (
ignore_lm_head
if any("lm_head" in keys for keys in model._keys_to_ignore_on_load_missing)
else not ignore_lm_head
)
inputs_dict_model_class["output_attentions"] = True

base_output = model(**inputs_dict_model_class)

model_size = compute_module_sizes(model)[""]
max_size = int(self.model_split_percents[0] * model_size)
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir)

max_memory = {0: max_size, "cpu": max_size}
with self.assertRaises(ValueError):
# This errors out cause it's missing an offload folder
new_model = model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)

new_model = model_class.from_pretrained(
tmp_dir, device_map="auto", max_memory=max_memory, offload_folder=tmp_dir
)

self.check_device_map_is_respected(new_model, new_model.hf_device_map)
torch.manual_seed(0)
new_output = new_model(**inputs_dict_model_class)

# Since the `lm_head.decoder.weight` is in `_keys_to_ignore_on_save` we don't check
# the final logits but only the intermediate attentions for all layers
if ignore_lm_head:
for i in range(len(base_output["attentions"])):
self.assertTrue(torch.allclose(base_output["attentions"][i], new_output["attentions"][i]))
else:
self.assertTrue(torch.allclose(base_output[0], new_output[0]))

@require_accelerate
@require_torch_gpu
def test_cpu_offload(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

for model_class in self.all_model_classes:
if model_class._no_split_modules is None:
continue

inputs_dict_model_class = self._prepare_for_class(inputs_dict, model_class)

model = model_class(config).eval()
model = model.to(torch_device)

ignore_lm_head = (
hasattr(model, "lm_head")
and hasattr(model, "_keys_to_ignore_on_load_missing")
and hasattr(model, "_keys_to_ignore_on_save")
)
if ignore_lm_head:
ignore_lm_head = (
ignore_lm_head
if any("lm_head" in keys for keys in model._keys_to_ignore_on_load_missing)
else not ignore_lm_head
)
inputs_dict_model_class["output_attentions"] = True

torch.manual_seed(0)
base_output = model(**inputs_dict_model_class)

model_size = compute_module_sizes(model)[""]
# We test several splits of sizes to make sure it works.
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents]
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir)

for max_size in max_gpu_sizes:
max_memory = {0: max_size, "cpu": model_size * 2}
new_model = model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
# Making sure part of the model will actually end up offloaded
self.assertSetEqual(set(new_model.hf_device_map.values()), {0, "cpu"})

self.check_device_map_is_respected(new_model, new_model.hf_device_map)

torch.manual_seed(0)
new_output = new_model(**inputs_dict_model_class)

# Since the `lm_head.decoder.weight` is in `_keys_to_ignore_on_save` we don't check
# the final logits but only the intermediate attentions for all layers
if ignore_lm_head:
for i in range(len(base_output["attentions"])):
self.assertTrue(torch.allclose(base_output["attentions"][i], new_output["attentions"][i]))
else:
self.assertTrue(torch.allclose(base_output[0], new_output[0]))


@require_torch
@slow
Expand Down
Loading