-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Description
Motivation
The graph runtime is now able to support heterogeneous execution through annotation with various device ids. It is important to have a compiler pass to enable annotation from the frontend so that users have the flexibility to annotate the operators with "the best" device. This RFC proposes to add annotation in Relay as a standalone pass. There prototype implementation is here.
Action items
Some design items are listed as following:
- Each operator is attached with a
fallbackattribute to indicate if it will fallback. - Each
CallNodeis attached with adevice_idattribute to indicate which device it should be annotated to (by default it is 0, meaning no annotation is required) - Annotation could be performed through various ways. Currently, users can optionally provide a
Dict[op_name, device]map tobuild, or enable fallback by addingset_fallbackto an operator. More sophisticated annotation schemes (i.e. the ones with cost functions by taking device communication and data transferring overhead into account) could be explored in the future. - Copy ops are needed to copy data across different devices, and these ops can be treated specially during compilation. For example, we don't need to provide them with
fcomputeandfschedule. These ops could be omitted during lowering as well since the real data copy will be performed during runtime.
Proposed APIs
- The
buildAPI is like the following:
def build(func, target=None, target_host=None, params=None, op_name_device=None, fallback_device=None): .
where heterogeneous compilation is enabled when target is a dict of device to target.
- The annotation API is as the following:
def annotate_ops(expr, op_name_dev_map, fallback_device):
During annotation, the device_id of a CallNode is set to fallback_device if its operator is registered with fallback or it is not explicitly specified where it should be allocated to in the map.
- The return of
PlanAPI ingraph_plan_memoryneeds to be changed slightly. Now in addition to returning a list ofstorage_id, the correspondingdevice_idalso has to be returned.