Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 5 additions & 9 deletions colossalai/builder/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from colossalai.registry import *



def build_from_config(module, config: dict):
"""Returns an object of :class:`module` constructed from `config`.

Expand Down Expand Up @@ -46,23 +45,20 @@ def build_from_registry(config, registry: Registry):
Raises:
Exception: Raises an Exception if an error occurred when building from registry.
"""
config_ = config.copy() # keep the original config untouched
assert isinstance(
registry, Registry), f'Expected type Registry but got {type(registry)}'
config_ = config.copy() # keep the original config untouched
assert isinstance(registry, Registry), f'Expected type Registry but got {type(registry)}'

mod_type = config_.pop('type')
assert registry.has(
mod_type), f'{mod_type} is not found in registry {registry.name}'
assert registry.has(mod_type), f'{mod_type} is not found in registry {registry.name}'
try:
obj = registry.get_module(mod_type)(**config_)
except Exception as e:
print(
f'An error occurred when building {mod_type} from registry {registry.name}',
flush=True)
print(f'An error occurred when building {mod_type} from registry {registry.name}', flush=True)
raise e

return obj


def build_gradient_handler(config, model, optimizer):
"""Returns a gradient handler object of :class:`BaseGradientHandler` constructed from `config`,
`model` and `optimizer`.
Expand Down