From 4ad740e55ffc6c246c2701ff4dc5971ae03b4e17 Mon Sep 17 00:00:00 2001 From: Samin Yasar Date: Thu, 27 Feb 2020 21:28:59 +0600 Subject: [PATCH] Linear was undefined added nn namspace --- net/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/net/models.py b/net/models.py index 01dbc8b..c2a74c6 100644 --- a/net/models.py +++ b/net/models.py @@ -22,7 +22,7 @@ def forward(self, x): class LeNet_5(PruningModule): def __init__(self, mask=False): super(LeNet_5, self).__init__() - linear = MaskedLinear if mask else Linear + linear = MaskedLinear if mask else nn.Linear self.conv1 = nn.Conv2d(1, 6, kernel_size=(5, 5)) self.conv2 = nn.Conv2d(6, 16, kernel_size=(5, 5)) self.conv3 = nn.Conv2d(16, 120, kernel_size=(5,5))