-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathMultiplicativeFilter.lua
More file actions
65 lines (56 loc) · 2.12 KB
/
MultiplicativeFilter.lua
File metadata and controls
65 lines (56 loc) · 2.12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
local MultiplicativeFilter, parent = torch.class('nn.MultiplicativeFilter', 'nn.Module')
function MultiplicativeFilter:__init(input_size, forbid_randomize)
parent.__init(self)
--self.gradInput:resize(input_size)
--self.output:resize(input_size)
-- the filter cannot be named 'weight' or 'bias', or the neural network modules will attempt to train it
self.bias_filter = torch.Tensor(1, input_size) -- Cmul only depends upon the number of arguments being the same, so no minibatches is the same as one minibatch
self.forbid_randomize = forbid_randomize
self.filter_active = torch.Tensor(1):fill(1)
self:randomize()
end
function MultiplicativeFilter:randomize()
local perturbation_type = 'dropout' -- 'continuous_uniform'
if not(forbid_randomize) then
if perturbation_type == 'continuous_uniform' then
self.bias_filter:copy(torch.rand(self.bias_filter:size(2))):mul(0.1):add(0.95)
elseif perturbation_type == 'dropout' then
self.bias_filter:copy(torch.rand(self.bias_filter:size(2))) -- the multiplicative bias should be mean-1
self.bias_filter:add(-0.1):sign():add(1):mul(0.5)
end
--print(self.bias_filter)
end
end
function MultiplicativeFilter:activate()
self.filter_active[1] = 1
end
function MultiplicativeFilter:inactivate()
self.filter_active[1] = 0
end
function MultiplicativeFilter:updateOutput(input)
self.output:resizeAs(input)
if self.filter_active[1] == 1 then
--print(self.bias_filter)
local expanded_bias = self.bias_filter
if input:dim() == 2 then
expanded_bias = torch.expandAs(self.bias_filter, input)
end
self.output:cmul(input, expanded_bias)
else
self.output:copy(input)
end
return self.output
end
function MultiplicativeFilter:updateGradInput(input, gradOutput)
self.gradInput:resizeAs(gradOutput)
if self.filter_active[1] == 1 then
local expanded_bias = self.bias_filter
if input:dim() == 2 then
expanded_bias = torch.expandAs(self.bias_filter, input)
end
self.gradInput:cmul(gradOutput, expanded_bias)
else
self.gradInput:copy(gradOutput)
end
return self.gradInput
end