Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

[RFC] Need to export all HybridBlocks in a Gluon model #19535

@samskalicky

Description

@samskalicky

Description

Gluon can have any hierarchy of Blocks where a top-level Block is not a HybridBlock, and lower-level blocks are HybridBlocks. These HybridBlocks can be dispersed throughout the model architecture. Users can optimize the parts of their model that support it by calling hybridize on the top level block to trigger a recursive call throughout all child blocks. But there is currently no way to export all HybridBlocks without going through and manually calling export on each one. Further, theres no way to reload those exported symbol files back without changing the model design and swapping those HybridBlocks for SymbolBlocks and than one-by-one calling imports to reload.

I would like to ask for suggestions from the community and have an active discussion on the different ways to address this problem. To kick things off, here is a proposal to start the discussion around:

def save_cached_graphs(block):
    def _save_cached_graphs(blk, index):
        if isinstance(blk, mx.gluon.nn.HybridBlock):
            blk.export(blk.name + str(index[0]))
        for child in blk._children.values():
            index[0] += 1
            _save_cached_graphs(child, index)
    #save top-level block                                                                                  
    index = [0]
    _save_cached_graphs(block, index)

def load_cached_graphs(block):
    def _load_cached_graphs(blk, index):
        if isinstance(blk, mx.gluon.nn.HybridBlock):
            sym = symbol.load(blk.name + str(index[0]) + '-symbol.json')
            blk._cached_graph = sym
        for child in blk._children.values():
            index[0] += 1
            _load_cached_graphs(child, index)
    #load top-level block                                                                                  
    index = [0]
    _load_cached_graphs(block, index)

With these two functions, we can recursively export each hybrid block and then reload the symbols. Obviously the code is not complete or even functional (_cached_graph is actual a tuple of symbols and sym.var inputs). But should serve as a point of reference.

Items on v1.x

  • Since a Block's children are stored in a dictionary, need to save/restore their unique names
  • Since parameters are mapped to their block's name, need to synchronize names after reloading model architecture to match save params

Items on master

  • Since parameters have a UUID, need to save/restore mapping of a Block's params to UUID

General Approach

  1. Recursively create a dictionary structure of blocks mimicking the model architecture
  2. Be able to uniquely identify the block in the model
  3. For HybridBlocks, save/restore the cached graph (symbol + inputs) and in/out formats
  4. Save the model architecture dictionary & parameters
  5. Restore the model architecture with unique identifiers to synchronize with parameter naming/IDs

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions