Skip to content
26 changes: 26 additions & 0 deletions python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,6 +736,32 @@ def gradient(expr, mod=None, mode='higher_order'):
return _ffi_api.gradient(expr, mod)
raise Exception('unknown mode')

def Defunctionalization(func, mod):
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add some description.

Performs defunctionalization on func,
transforming func from a higher-order program to a first-order program.

At each call site, the function is cloned and type parameters are substituted in.
Function arguments are encoded as datatypes
and additional apply functions are used for application.

Parameters
----------
func : tvm.relay.Function
The input function, which should not be polymorphic or be higher-order.
This is because all types must be known and we can't encode function arguments
to the program itself.

mod : tvm.IRModule
The IRModule containing function and type definitions,
which is also mutated during this pass.

Returns
-------
expr : tvm.relay.Function
The output function.
"""
return _ffi_api.Defunctionalization(func, mod)

def to_cps(func, mod=None):
"""
Expand Down
Loading