Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions python/tvm/contrib/pickle_memoize.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,11 @@ class Cache(object):
----------
key: str
The file key to the function
save_at_exit: bool
Whether save the cache to file when the program exits
"""
cache_by_key = {}
def __init__(self, key):
def __init__(self, key, save_at_exit):
cache_dir = ".pkl_memoize_py{0}".format(sys.version_info[0])
if not os.path.exists(cache_dir):
os.mkdir(cache_dir)
Expand All @@ -49,6 +51,7 @@ def __init__(self, key):
else:
self.cache = {}
self.dirty = False
self.save_at_exit = save_at_exit

def save(self):
if self.dirty:
Expand All @@ -60,16 +63,19 @@ def save(self):
def _atexit():
"""Save handler."""
for value in Cache.cache_by_key.values():
value.save()
if value.save_at_exit:
value.save()


def memoize(key):
def memoize(key, save_at_exit=False):
"""Memoize the result of function and reuse multiple times.

Parameters
----------
key: str
The unique key to the file
save_at_exit: bool
Whether save the cache to file when the program exits

Returns
-------
Expand All @@ -81,9 +87,9 @@ def _register(f):
allow_types = (string_types, int, float)
fkey = key + "." + f.__name__ + ".pkl"
if fkey not in Cache.cache_by_key:
Cache.cache_by_key[fkey] = Cache(fkey)
Cache.cache_by_key[fkey] = Cache(fkey, save_at_exit)
cache = Cache.cache_by_key[fkey]
cargs = tuple(x.cell_contents for x in f.__closure__)
cargs = tuple(x.cell_contents for x in f.__closure__) if f.__closure__ else ()
cargs = (len(cargs),) + cargs

def _memoized_f(func, *args, **kwargs):
Expand Down
3 changes: 3 additions & 0 deletions topi/python/topi/nn/winograd_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from operator import mul
from functools import reduce
import numpy as np
from tvm.contrib.pickle_memoize import memoize
from ..util import const_matrix


Expand Down Expand Up @@ -131,6 +132,8 @@ def _interpolation_points(degree):

return np.array(in_pts[degree-1], dtype=np.float64)


@memoize("topi.nn.winograd_matrices", save_at_exit=False)
def winograd_transform_matrices(tile_size, kernel_size, out_dtype):
"""Compute the A, B, and G transform matrices for `tile_size` as a `tvm.Expr`.
"""
Expand Down