diff --git a/net/models.py b/net/models.py index 01dbc8b..c2a74c6 100644 --- a/net/models.py +++ b/net/models.py @@ -22,7 +22,7 @@ def forward(self, x): class LeNet_5(PruningModule): def __init__(self, mask=False): super(LeNet_5, self).__init__() - linear = MaskedLinear if mask else Linear + linear = MaskedLinear if mask else nn.Linear self.conv1 = nn.Conv2d(1, 6, kernel_size=(5, 5)) self.conv2 = nn.Conv2d(6, 16, kernel_size=(5, 5)) self.conv3 = nn.Conv2d(16, 120, kernel_size=(5,5))