Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion revtorch/revtorch.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import torch
import torch.nn as nn
#import torch.autograd.function as func
import sys
import random

class ReversibleBlock(nn.Module):
'''
Expand All @@ -15,11 +17,20 @@ class ReversibleBlock(nn.Module):
g_block (nn.Module): arbitrary subnetwork whos output shape is equal to its input shape
'''

def __init__(self, f_block, g_block, split_along_dim=1):
def __init__(self, f_block, g_block, split_along_dim=1, fix_random_seed = False):
super(ReversibleBlock, self).__init__()
self.f_block = f_block
self.g_block = g_block
self.split_along_dim = split_along_dim
self.fix_random_seed = fix_random_seed
self.random_seeds = {}

def set_seed(self, namespace, new = False):
if not self.fix_random_seed:
return
if new:
self.random_seeds[namespace] = random.randint(0, sys.maxsize)
torch.manual_seed(self.random_seeds[namespace])

def forward(self, x):
"""
Expand All @@ -30,7 +41,9 @@ def forward(self, x):
x1, x2 = torch.chunk(x, 2, dim=self.split_along_dim)
y1, y2 = None, None
with torch.no_grad():
self.set_seed('f', new=True)
y1 = x1 + self.f_block(x2)
self.set_seed('g', new=True)
y2 = x2 + self.g_block(y1)

return torch.cat([y1, y2], dim=self.split_along_dim)
Expand Down Expand Up @@ -64,6 +77,7 @@ def backward_pass(self, y, dy):

# Ensures that PyTorch tracks the operations in a DAG
with torch.enable_grad():
self.set_seed('g')
gy1 = self.g_block(y1)

# Use autograd framework to differentiate the calculation. The
Expand All @@ -83,6 +97,7 @@ def backward_pass(self, y, dy):

with torch.enable_grad():
x2.requires_grad = True
self.set_seed('f')
fx2 = self.f_block(x2)

# Use autograd framework to differentiate the calculation. The
Expand Down