Skip to content

[booster] implement environment table #3051

@FrankLeeeee

Description

@FrankLeeeee

Overview

After the initial setup of the booster module, we should proceed to the fundamental components. This issue is with respect to the EnvironmentTable class.

Wanna track the development progress? Take a look at

proposal: #3046
project kanban: API Refactoring

Goal

The EnvironmentTable is a centralized manager to control the process groups and provide utility functions to access any meta information such as rank and world size.

Some problems are left for more discussion:

  1. Q: What is an easy way to create new process group/device mesh
    A: We will create a device mesh based on the information from parallelism plugin.

  2. Q: How to manage duplicated process groups (i.e. process groups containing the same group of processes)? Should we allow the creation of duplicated process groups? If we do, how can we distinguish them when performing process group retrieval?
    A: Duplicated process groups is not allowed in our environment table.

  3. Q: Who should take charge of process group initialization, colossalai.launch or EnvironmentTable?
    A: colossalai.launch

  4. Q: How to manage process group and device mesh? As device mesh contains process group as well, is there a unified way to do this?
    A: DeviceMesh will take charge of the process group management. And we may keep a process_group_pool for a more flexible usage.

  5. How to retrieve the process group when needed? What will be the key and how can we make the key meaningful so that the developers and users can easily retrieve?

A sample definition of the EnvironmentTable is given below and it is subject to possible changes during implementation.

class EnvironmentTable:

     def __init__(self, ...):
          self.rank: int
          self.world_size: int
          self.default_process_group: torch.distributed.ProcessGroup
          self.process_group_pool: Dict
          self.device_mesh_pool: Dict

    def is_master(self, process_group=None) -> bool

    def get_process_group(self, ...) -> torch.distributed.ProcessGroup

    def get_rank(self, process_group=None) -> int

    def get_world_size(self, process_group=None) -> int  

    def get_device_mesh(self, ...) -> colossalai.device.DeviceMesh

    def create_device_mesh(self, ...) -> colossalai.device.DeviceMesh

    def create_process_group(self, ...) -> torch.distributed.ProcessGroup

Metadata

Metadata

Labels

APIrelated to API changesenhancementNew feature or request

Type

No type

Projects

Status

✅ Done

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions