Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -138,4 +138,4 @@ You can now create a pull request on the GitHub webpage of your repository. The

Do write clearly the description of your pull request and [link the pull request to your target issue](https://docs.github.com/en/issues/tracking-your-work-with-issues/linking-a-pull-request-to-an-issue). This will automatically close the issue when the pull request is approved.

In case of code conflict, you should rebase your branch and resolve the conflicts manually.
In case of code conflict, you should rebase your branch and resolve the conflicts manually.
10 changes: 6 additions & 4 deletions colossalai/amp/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-

from .amp_type import AMP_TYPE
from colossalai.context import Config
import torch.nn as nn
from torch.optim import Optimizer
from torch.nn.modules.loss import _Loss
from .torch_amp import convert_to_torch_amp
from torch.optim import Optimizer

from colossalai.context import Config

from .amp_type import AMP_TYPE
from .apex_amp import convert_to_apex_amp
from .naive_amp import convert_to_naive_amp
from .torch_amp import convert_to_torch_amp

__all__ = ['convert_to_amp', 'convert_to_naive_amp', 'convert_to_apex_amp', 'convert_to_torch_amp', 'AMP_TYPE']

Expand Down
11 changes: 6 additions & 5 deletions colossalai/auto_parallel/offload/amp_optimizer.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,25 @@
from typing import Dict, Tuple
from enum import Enum
from typing import Dict, Tuple

import torch
from torch.optim import Optimizer

from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.utils import get_current_device

from .base_offload_module import BaseOffloadModule
from .region_manager import RegionManager
from .region import Region
from .region_manager import RegionManager


class OptimState(Enum):
SCALED = 0
UNSCALED = 1

class AMPOptimizer(ColossalaiOptimizer):

class AMPOptimizer(ColossalaiOptimizer):
"""
A wrapper for Optimizer.
Code reference: https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/optimizer/zero_optimizer.py
Expand Down Expand Up @@ -174,4 +175,4 @@ def __init__optimizer(self):

# Leverage state_dict() and load_state_dict() to
# recast preexisting per-param state tensors
self.optim.load_state_dict(self.optim.state_dict())
self.optim.load_state_dict(self.optim.state_dict())
10 changes: 4 additions & 6 deletions colossalai/auto_parallel/offload/base_offload_module.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from typing import Optional, Set
from functools import partial
from typing import Optional, Set

import torch
import torch.nn as nn

from colossalai.nn.parallel.data_parallel import _cast_float
from colossalai.gemini.tensor_utils import free_storage
from colossalai.nn.parallel.data_parallel import _cast_float

from .region_manager import RegionManager
from .util import GlobalRuntimeInfo
Expand All @@ -20,10 +21,7 @@ class BaseOffloadModule:
is_sync (bool): synchronous mode or not.
"""

def __init__(self,
model: nn.Module,
region_manager: RegionManager,
is_sync=True):
def __init__(self, model: nn.Module, region_manager: RegionManager, is_sync=True):

self.model = model
self.region_manager = region_manager
Expand Down
19 changes: 11 additions & 8 deletions colossalai/auto_parallel/offload/mem_optimize.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Dict

import torch
import torch.fx
from torch.fx import GraphModule
Expand All @@ -7,10 +8,11 @@
from colossalai.fx import ColoTracer, is_compatible_with_meta
from colossalai.fx.passes.meta_info_prop import MetaInfoProp

from .region_manager import RegionManager
from .runtime import runtime_syn_offload_apply_pass, runtime_asyn_offload_apply_pass
from .base_offload_module import BaseOffloadModule
from .util import compute_max_param_mem, compute_total_param_mem, compute_act_peak_mem, GlobalRuntimeInfo
from .region_manager import RegionManager
from .runtime import runtime_asyn_offload_apply_pass, runtime_syn_offload_apply_pass
from .util import GlobalRuntimeInfo, compute_act_peak_mem, compute_max_param_mem, compute_total_param_mem


def memory_optimize(model: torch.nn.Module,
inps: Dict[str, torch.Tensor],
Expand All @@ -31,11 +33,12 @@ def memory_optimize(model: torch.nn.Module,
region_manager._build_regions()
GlobalRuntimeInfo.region_list = region_manager.region_list

act_peak_mem = compute_act_peak_mem(region_manager.region_list) / 1024 ** 2
max_param_mem = compute_max_param_mem(region_manager.region_list) / 1024 ** 2
total_param_mem = compute_total_param_mem(region_manager.region_list) / 1024 ** 2
act_peak_mem = compute_act_peak_mem(region_manager.region_list) / 1024**2
max_param_mem = compute_max_param_mem(region_manager.region_list) / 1024**2
total_param_mem = compute_total_param_mem(region_manager.region_list) / 1024**2
print(
f"act_peak_mem={act_peak_mem:.3f} MB | max_param_mem={max_param_mem:.3f} MB | total_param_mem={total_param_mem:.3f}")
f"act_peak_mem={act_peak_mem:.3f} MB | max_param_mem={max_param_mem:.3f} MB | total_param_mem={total_param_mem:.3f}"
)

if solver_name == 'syn':
gm = runtime_syn_offload_apply_pass(gm, region_manager.region_list)
Expand All @@ -45,5 +48,5 @@ def memory_optimize(model: torch.nn.Module,
raise TypeError(f"Unknown solver name {solver_name}!")

gm.recompile()
optimized_model = BaseOffloadModule(gm, region_manager, solver_name=='syn')
optimized_model = BaseOffloadModule(gm, region_manager, solver_name == 'syn')
return optimized_model
13 changes: 7 additions & 6 deletions colossalai/auto_parallel/offload/region.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from typing import List, Dict, Tuple
from typing import Dict, List, Tuple

import torch
from torch.fx import Node

from colossalai.gemini.tensor_utils import alloc_storage, free_storage


class Region:
"""
Region: A container owning a piece of contiguous nodes in the DNN computing graph.
Expand Down Expand Up @@ -52,15 +55,13 @@ def init_param_data(self, pre_alloc_tensor: torch.Tensor = None):
Map the parameters in the region to a contiguous memory space.
"""

self.fp16_data = torch.zeros(
self.param_num, dtype=torch.half, device='cuda')
self.fp16_data = torch.zeros(self.param_num, dtype=torch.half, device='cuda')
offset = 0
for param in self.fp16_params:
param.data = param.data.cuda()
p_num = param.data.numel()
self.fp16_data[offset:offset + p_num].copy_(param.data.flatten())
param.data = self.fp16_data[offset:offset +
p_num].view(param.data.shape)
param.data = self.fp16_data[offset:offset + p_num].view(param.data.shape)
self.param_to_range[param] = (offset, offset + p_num)
offset += p_num

Expand Down Expand Up @@ -141,4 +142,4 @@ def split(self, cut_node_idx: int, cut_param_idx: int):
def __update_params_ptr(self) -> None:
for param in self.fp16_params:
begin, end = self.param_to_range[param]
param.data = self.fp16_data[begin:end].view(param.data.shape)
param.data = self.fp16_data[begin:end].view(param.data.shape)
73 changes: 28 additions & 45 deletions colossalai/auto_parallel/offload/region_manager.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from typing import List, Any, Dict, Tuple
from typing import Any, Dict, List, Tuple

import torch
from torch.fx import Graph, Node

from .region import Region
from .solver import SolverFactory
from .training_simulator import TrainingSimulator
from .region import Region
from .util import NodeInfo


Expand All @@ -19,11 +20,7 @@ class RegionManager:
cnode (List[str], optional): Common node List, should be the subset of input.
"""

def __init__(self,
graph: Graph,
solver_name: str = 'asyn',
memory_budget: float = -1.0,
cnode: List[str] = None):
def __init__(self, graph: Graph, solver_name: str = 'asyn', memory_budget: float = -1.0, cnode: List[str] = None):

self.graph = graph
assert graph.owning_module is not None, 'The given graph is not associated with a owning_module'
Expand Down Expand Up @@ -65,8 +62,7 @@ def _pre_process(self):
init_region_list = self._linearize_graph()

if len(self.shared_region_pairs) > 1:
raise NotImplementedError(
'The current version only considers at most one pair of parameter sharing.')
raise NotImplementedError('The current version only considers at most one pair of parameter sharing.')

elif len(self.shared_region_pairs) == 1:
shared_regs = self.shared_region_pairs[0]
Expand Down Expand Up @@ -122,21 +118,17 @@ def _early_region_placement(self, ts: TrainingSimulator):
it may not find a suitable region placement strategy for the given execution flow.
"""

reg_flow = torch.cat(
[ts.fwd_reg_flow, ts.bwd_reg_flow], dim=0)
mem_block_num = torch.max(
torch.sum(reg_flow[:, self.rid_in_pool], dim=1))
coexist_matrix = torch.logical_or(
ts.fwd_reg_flow, ts.bwd_reg_flow)
reg_flow = torch.cat([ts.fwd_reg_flow, ts.bwd_reg_flow], dim=0)
mem_block_num = torch.max(torch.sum(reg_flow[:, self.rid_in_pool], dim=1))
coexist_matrix = torch.logical_or(ts.fwd_reg_flow, ts.bwd_reg_flow)

block_to_regs = {}
for block_idx in range(mem_block_num):
block_to_regs[block_idx] = []
for reg in self.region_list:
if reg.r_id in self.rid_in_pool:
cur_reg_appears = coexist_matrix[:, reg.r_id]
cur_reg_coexists = torch.sum(
coexist_matrix[cur_reg_appears], dim=0).bool()
cur_reg_coexists = torch.sum(coexist_matrix[cur_reg_appears], dim=0).bool()
for block_idx in range(mem_block_num):
if not any(cur_reg_coexists[block_to_regs[block_idx]]):
block_to_regs[block_idx].append(reg.r_id)
Expand All @@ -146,8 +138,10 @@ def _early_region_placement(self, ts: TrainingSimulator):
if reg.r_id not in self.reg_to_block:
raise NotImplementedError(
f'can not find a block from the memory pool to store parameters of the region')
self.memory_pool = torch.chunk(torch.zeros(int(
mem_block_num * self.mem_block_size / 2), dtype=torch.half, device='cuda'), chunks=int(mem_block_num))
self.memory_pool = torch.chunk(torch.zeros(int(mem_block_num * self.mem_block_size / 2),
dtype=torch.half,
device='cuda'),
chunks=int(mem_block_num))

def _merge_small_regions(self, orig_reg_list: List[Region]) -> List[Region]:
"""
Expand Down Expand Up @@ -181,7 +175,7 @@ def _merge_small_regions(self, orig_reg_list: List[Region]) -> List[Region]:
def _search_block_size(self,
region_list: List[Region],
search_interval_byte: int = 1024,
search_range_byte: int = 128 * 1024 ** 2) -> int:
search_range_byte: int = 128 * 1024**2) -> int:
"""
Search for a suitable memory block size.

Expand All @@ -208,8 +202,7 @@ def _get_wasted_mem(size_list: List[int], blk_size: int):
acc_wasted += blk_size - left
return acc_wasted

param_size_list = [
region.param_size for region in region_list if region.r_id == region.shared_rid]
param_size_list = [region.param_size for region in region_list if region.r_id == region.shared_rid]

start_size = max(param_size_list)
min_mem_waste = float('+inf')
Expand Down Expand Up @@ -244,8 +237,7 @@ def _init_region_data(self):
region.fp16_data = shared_region.fp16_data
region.fp32_data = shared_region.fp32_data
region.param_to_range = shared_region.param_to_range
region.temp_fp32_data = self.temp_fp32_data[:region.param_num].detach(
)
region.temp_fp32_data = self.temp_fp32_data[:region.param_num].detach()

torch.cuda.empty_cache()

Expand Down Expand Up @@ -343,10 +335,8 @@ def _maybe_param_comp_start() -> bool:
elif n.op == "call_module":
target = n.target
submod = self.root_module.get_submodule(target)
if (
len(list(submod.named_parameters(recurse=False))) != 0
or len(list(submod.named_buffers(recurse=False))) != 0
):
if (len(list(submod.named_parameters(recurse=False))) != 0
or len(list(submod.named_buffers(recurse=False))) != 0):
label = True

return label and not sum([v for _, v in param_op_deps.items()])
Expand All @@ -368,19 +358,16 @@ def _is_inplace(n: Node):
if n.op == "call_function":
inplace = n.kwargs.get("inplace", False)
elif n.op == "call_module":
inplace = getattr(n.graph.owning_module.get_submodule(
n.target), "inplace", False)
inplace = getattr(n.graph.owning_module.get_submodule(n.target), "inplace", False)
return inplace

label = False

if n.op == "call_module":
target = n.target
submod = self.root_module.get_submodule(target)
if (
len(list(submod.named_parameters(recurse=False))) != 0
or len(list(submod.named_buffers(recurse=False))) != 0
):
if (len(list(submod.named_parameters(recurse=False))) != 0
or len(list(submod.named_buffers(recurse=False))) != 0):
label = True

elif n.op == "call_function":
Expand Down Expand Up @@ -449,18 +436,16 @@ def _exception_node_handling():

# propagate common node attr if possible
if len(n.all_input_nodes) == len([node for node in n.all_input_nodes if node.name in self.cnode
]) or _is_cop(n.target):
]) or _is_cop(n.target):
self.cnode.append(n.name)
else:
deps[n] = len(
[user for user in n.users if user.op != "output"])
deps[n] = len([user for user in n.users if user.op != "output"])

# propagate param node attr if possible
if len(n.all_input_nodes) == len([node for node in n.all_input_nodes if node.name in self.only_param_ops
]) or n.op == "get_attr":
if len(n.all_input_nodes) == len(
[node for node in n.all_input_nodes if node.name in self.only_param_ops]) or n.op == "get_attr":
self.only_param_ops.append(n.name)
param_op_deps[n] = len(
[user for user in n.users if user.op != "output"])
param_op_deps[n] = len([user for user in n.users if user.op != "output"])

# record last activation node
if _is_act(n._meta_data):
Expand All @@ -483,8 +468,7 @@ def _set_node_and_region_info(self, node_id: int, cur_n: Node, cur_reg: Region):
if p in self.param_region_map:
cur_reg.shared_rid = self.param_region_map[p].r_id
self.param_region_map[p].shared_rid = cur_reg.r_id
self.shared_region_pairs.append(
(self.param_region_map[p], cur_reg))
self.shared_region_pairs.append((self.param_region_map[p], cur_reg))
else:
self.param_region_map[p] = cur_reg

Expand All @@ -503,8 +487,7 @@ def _set_node_and_region_info(self, node_id: int, cur_n: Node, cur_reg: Region):
if attr_itr in self.param_region_map:
cur_reg.shared_rid = self.param_region_map[attr_itr].r_id
self.param_region_map[attr_itr].shared_rid = cur_reg.r_id
self.shared_region_pairs.append(
(self.param_region_map[attr_itr], cur_reg))
self.shared_region_pairs.append((self.param_region_map[attr_itr], cur_reg))
else:
self.param_region_map[attr_itr] = cur_reg

Expand Down
Loading