Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
4ae4086
Implement basic analyses
slyubomirsky Oct 2, 2023
2ffec3a
Fix typo
slyubomirsky Oct 2, 2023
3faf914
Add tests for analyses
slyubomirsky Oct 3, 2023
3cc3c9b
Include in-place analysis
slyubomirsky Oct 4, 2023
441ead2
Return the lists instead
slyubomirsky Oct 4, 2023
71d32a1
Update python binding
slyubomirsky Oct 4, 2023
5a3d3f2
No need to assume *pure* functions capture all values ever passed to …
slyubomirsky Oct 5, 2023
632c773
Improve handling of tuples in mystery call case
slyubomirsky Oct 5, 2023
d97ab92
Corrections to inplace checking
slyubomirsky Oct 5, 2023
3f099df
Add test case for mystery value
slyubomirsky Oct 6, 2023
2018980
typo
slyubomirsky Oct 6, 2023
cfe03d1
Add inplace test case, correct minor issues
slyubomirsky Oct 6, 2023
945c206
Consider also using larger tensors to store smaller ones
slyubomirsky Oct 6, 2023
0e6027a
Check call args against any possible target sinfo, also check tensor …
slyubomirsky Oct 6, 2023
693bc82
Handle output vars and tuple get item
slyubomirsky Oct 23, 2023
bcc8513
Add legalization for in-place functions
slyubomirsky Nov 7, 2023
eb7a004
No need to update the NoAlias attribute, actually
slyubomirsky Nov 7, 2023
2e9315d
Fix TIR transformation, add tests for inline transformation
slyubomirsky Nov 8, 2023
5616789
Only find candidates from supported ops and list _all_ feasible argum…
slyubomirsky Nov 8, 2023
dd3c882
Implement basic transformation pass
slyubomirsky Nov 8, 2023
b8f7359
Use a module pass so wider changes are visible, reorganize
slyubomirsky Nov 8, 2023
d8a2f4d
Have an end-to-end test case for the in-place transformation
slyubomirsky Nov 10, 2023
87e2a41
Rebase fixes and use GetBoundValue instead of reimplementing it
slyubomirsky Nov 13, 2023
6e7d447
Let's just use 'inplace' everywhere
slyubomirsky Nov 13, 2023
03fb2a5
Reorganize code and add more documentation
slyubomirsky Nov 14, 2023
bbd1d53
Include proper bounds check
slyubomirsky Nov 15, 2023
47639e2
Trailing whitespace
slyubomirsky Nov 15, 2023
ba7fb3d
Need a trailing newline
slyubomirsky Nov 15, 2023
ed7595a
Remove unused imports
slyubomirsky Nov 15, 2023
77eab79
Add docstrings for exposed inner functions
slyubomirsky Nov 15, 2023
71e4d71
Reformat docstrings to appease the linter
slyubomirsky Nov 16, 2023
17724be
C++ stylistic changes
slyubomirsky Nov 16, 2023
a329915
Treat args as mystery values by default, do not allow overwriting
slyubomirsky Nov 21, 2023
29bcdfb
Formatting
slyubomirsky Nov 22, 2023
cf396a1
Clarify pass description
slyubomirsky Jan 14, 2024
b6c7c36
Add check to ensure that testing functions are used only in a testing…
slyubomirsky Jan 14, 2024
aad9aab
Improve size match check readability per review suggestions
slyubomirsky Jan 14, 2024
1470d24
Improve the size match check per review suggestions (use PrimExprs)
Lunderberg Jan 14, 2024
af5fd05
Treat non-dataflow vars as living past the end of the block in all cases
slyubomirsky Jan 14, 2024
23a4d0d
Clarify notion of size in comment
slyubomirsky Jan 14, 2024
c604d3b
Remove commented-out code
slyubomirsky Jan 14, 2024
3060ea1
Assume any op that returns a tuple is returning a fresh one (exceptio…
slyubomirsky Jan 14, 2024
3107ce4
Add full structural equality check in large test case
slyubomirsky Jan 14, 2024
585a748
Fix parser roundtripping bug with call_tir_inplace
slyubomirsky Jan 14, 2024
c5d2871
Refactor tests to ensure maps are nonempty
slyubomirsky Jan 14, 2024
9dc2e52
Use .empty() where it's more reasonable
slyubomirsky Jan 14, 2024
b9d01b7
linting changes
slyubomirsky Jan 16, 2024
8dd70df
Flipped the check by accident
slyubomirsky Jan 16, 2024
285c7d6
Remove debug print
slyubomirsky Jan 16, 2024
81b1fcb
Factor out data structure for representing matches and match opportun…
slyubomirsky Jan 16, 2024
cc48b25
Style fix
slyubomirsky Jan 16, 2024
a3683e7
Use the analyzer to handle dynamic cases too
slyubomirsky Jan 16, 2024
463893f
Whitespace
slyubomirsky Jan 16, 2024
559f197
Use BlockBuilder APIs more to avoid re-normalizing
slyubomirsky Jan 17, 2024
d9a7973
Check for expired vars at start of loop so that the use of continue d…
slyubomirsky Jan 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions include/tvm/relax/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,16 @@ TVM_DLL Pass ConvertToDataflow(int min_size = 2);
*/
TVM_DLL Pass DeadCodeElimination(Array<runtime::String> entry_functions);

/*!
* \brief Pass that changes calls to operators that can be done in-place
* (generally, these are elementwise operations) in dataflow blocks into in-place implementations.
* Supported operators will be replaced by calls to `call_tir_inplace` that invoke in-place
* PrimFunc implementations of those operators (which are based on the legalizations of those
* operators).
* \return The pass.
*/
TVM_DLL Pass DataflowUseInplaceCalls();

/*!
* \brief Automatic mixed precision pass. Currently the pass assumes the input module to be fp32
* only, and will automatically cast fp32 to fp16 for certain ops.
Expand Down
8 changes: 7 additions & 1 deletion python/tvm/relax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,13 @@
from .exec_builder import ExecBuilder

# Operator
from .op.base import call_tir, call_pure_packed, call_dps_packed, call_tir_with_grad
from .op.base import (
call_tir,
call_tir_inplace,
call_pure_packed,
call_dps_packed,
call_tir_with_grad,
)

# BlockBuilder
from .block_builder import BlockBuilder
Expand Down
98 changes: 97 additions & 1 deletion python/tvm/relax/testing/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,18 @@
# pylint: disable=unused-argument, invalid-name, no-else-return, abstract-method, arguments-differ
"""Relax transformation passes for testing"""

import logging
import os
from typing import Dict, List, Set, Tuple
import tvm
from tvm import ir, relax
from tvm.ir import transform
from tvm.ir.module import IRModule
from tvm.ir.transform import PassContext
from tvm.relax import PyExprMutator
from tvm.relax.expr import Call
from tvm.relax.expr import Call, DataflowBlock, Var
from tvm.relay.backend.te_compiler import select_implementation
from tvm.runtime.object import Object
from tvm.target import Target


Expand Down Expand Up @@ -128,3 +132,95 @@ def transform(self):
def ApplyEmptyCppMutator() -> tvm.ir.transform.Pass:
packed_func = tvm.get_global_func("relax.testing.transform.ApplyEmptyCppMutator")
return packed_func()


def dataflow_liveness_analysis(block: DataflowBlock) -> Dict[Var, Tuple[int, int]]:
"""
Inner function for the dataflow inplace transformation exposed for testing.
"""
if "PYTEST_CURRENT_TEST" not in os.environ:
logging.warning("The function dataflow_liveness_analysis is exposed for testing only.")

live_ranges = tvm.get_global_func("relax.testing.transform.DataflowLivenessAnalysis")(
block
) # type: ignore
ret = {}
for var, live_range in live_ranges.items():
ret[var] = tuple(live_range)
return ret # type: ignore


def dataflow_alias_analysis(
block: DataflowBlock, inputs: List[Var]
) -> Tuple[Dict[Var, Set[int]], Dict[int, List[Set[int]]]]:
"""
Inner function for the dataflow inplace transformation exposed for testing.
"""
if "PYTEST_CURRENT_TEST" not in os.environ:
logging.warning("The function dataflow_alias_analysis is exposed for testing only.")

alias_sets, tuple_map = tvm.get_global_func("relax.testing.transform.DataflowAliasAnalysis")(
block,
inputs,
) # type: ignore
res_alias_sets = {}
res_tuple_map = {}
for var, alias_set in alias_sets.items():
res_alias_sets[var] = set(alias_set)
for idx, elem_alias_sets in tuple_map.items():
res_tuple_map[idx] = [set(alias_set) for alias_set in elem_alias_sets]
return res_alias_sets, res_tuple_map # type: ignore


@tvm._ffi.register_object("relax.transform.InplaceOpportunity")
class InplaceOpportunity(Object):
"""
Represents an opportunity to make a binding in-place. Exposed only for testing;
the constructor is not exposed.

Parameters:
-----------
binding_idx: int
Index of the binding within its block

arg_idxs: List[int]
Indices of arguments that are eligible to be used as in-place targets.
"""

def __init__(self, _binding_idx, _arg_idxs):
raise NotImplementedError("Constructor for InplaceOpportunity not exposed!")


def dataflow_inplace_analysis(
block: DataflowBlock, inputs: List[Var], mod: IRModule
) -> Tuple[List[Tuple[int, Set[int]]], List[Tuple[int, Set[int]]]]:
"""
Inner function for the dataflow inplace transformation exposed for testing.
"""
if "PYTEST_CURRENT_TEST" not in os.environ:
logging.warning("The function dataflow_inplace_analysis is exposed for testing only.")
index_lists = tvm.get_global_func("relax.testing.transform.DataflowInplaceAnalysis")(
block, inputs, mod
) # type: ignore

def convert(opp_list):
return list(map(lambda opp: (int(opp.binding_idx), set(map(int, opp.arg_idxs))), opp_list))

return (convert(index_lists[0]), convert(index_lists[1])) # type: ignore


def dataflow_single_inplace_call(
mod: IRModule, call: Call, inplace_indices: List[int]
) -> Tuple[Call, IRModule]:
"""
Inner function for the dataflow inplace transformation exposed for testing.
"""
if "PYTEST_CURRENT_TEST" not in os.environ:
logging.warning("The function dataflow_single_inplace_call is exposed for testing only.")

ret = tvm.get_global_func("relax.testing.transform.SingleInplaceCall")(
mod,
call,
inplace_indices,
) # type: ignore
return (ret[0], ret[1]) # type: ignore
1 change: 1 addition & 0 deletions python/tvm/relax/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
ConvertLayout,
ConvertToDataflow,
DataflowBlockPass,
DataflowUseInplaceCalls,
DeadCodeElimination,
DecomposeOpsForInference,
DecomposeOpsForTraining,
Expand Down
16 changes: 16 additions & 0 deletions python/tvm/relax/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,22 @@ def RemovePurityChecking() -> tvm.ir.transform.Pass:
return _ffi_api.RemovePurityChecking() # type: ignore


def DataflowUseInplaceCalls() -> tvm.ir.transform.Pass:
"""
Pass that changes calls to operators that can be done in-place
(generally, these are elementwise operations) into in-place implementations.
Supported operators will be replaced by calls to `call_tir_inplace` that invoke
in-place PrimFunc implementations of those operators (which are based on the legalizations of
those operators).

Returns
-------
ret: tvm.ir.transform.Pass
The pass
"""
return _ffi_api.DataflowUseInplaceCalls()


def LambdaLift() -> tvm.ir.transform.Pass:
"""A pass that lifts local functions into global.

Expand Down
Loading