Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions python/tvm/relay/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import os
from sys import setrecursionlimit
from ..api import register_func
from . import call_graph
from . import base
from . import ty
from . import expr
Expand Down Expand Up @@ -141,3 +142,6 @@

# Feature
Feature = feature.Feature

# CallGraph
CallGraph = call_graph.CallGraph
144 changes: 144 additions & 0 deletions python/tvm/relay/call_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name, unused-import
"""Call graph used in Relay."""

from tvm.ir import IRModule
from .base import Object
from .expr import GlobalVar
from . import _analysis


class CallGraph(Object):
"""Class to represent a call graph."""

def __init__(self, module):
"""Construct a call graph.

Parameters
----------
module : tvm.ir.IRModule
The IR module used to create a call graph

Returns
-------
call_graph: CallGraph
A constructed call graph.
"""
self.__init_handle_by_constructor__(_analysis.CallGraph, module)

@property
def module(self):
"""Return the contained Relay IR module.

Parameters
----------
None

Returns
-------
ret : tvm.ir.IRModule
The contained IRModule
"""
return _analysis.GetModule(self)

def ref_count(self, var):
"""Return the number of references to the global var

Parameters
----------
var : Union[String, tvm.relay.GlobalVar]

Returns
-------
ret : int
The number reference to the global var
"""
var = self._get_global_var(var)
return _analysis.GetRefCountGlobalVar(self, var)

def global_call_count(self, var):
"""Return the number of global function calls from a given global var.

Parameters
----------
var : Union[String, tvm.relay.GlobalVar]

Returns
-------
ret : int
The number of global function calls from the given var.
"""
var = self._get_global_var(var)
return _analysis.GetGlobalVarCallCount(self, var)

def is_recursive(self, var):
"""Return if the function corresponding to a var is a recursive
function.

Parameters
----------
var : Union[String, tvm.relay.GlobalVar]

Returns
-------
ret : Boolean
If the function corresponding to var is recurisve.
"""
var = self._get_global_var(var)
return _analysis.IsRecursive(self, var)

def _get_global_var(self, var):
"""Return the global var using a given name or GlobalVar.

Parameters
----------
var : Union[String, tvm.relay.GlobalVar]

Returns
-------
ret : tvm.relay.GlobalVar
The global var.
"""
if isinstance(var, str):
mod = self.module
var = mod.get_global_var(var)

if isinstance(var, GlobalVar):
return var
else:
raise TypeError("var should be either a string or GlobalVar")

def print_var(self, var):
"""Print a call graph of a global function by name or by variable.

Parameters
----------
var: Union[String, tvm.relay.GlobalVar]
The name or global variable.

Returns
-------
ret : String
The call graph represented in string.
"""
var = self._get_global_var(var)
return _analysis.PrintCallGraphGlobalVar(self, var)

def __str__(self):
"""Print the call graph in the topological order."""
return _analysis.PrintCallGraph(self)
Loading