-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathSafeEntropy.lua
More file actions
29 lines (26 loc) · 1003 Bytes
/
SafeEntropy.lua
File metadata and controls
29 lines (26 loc) · 1003 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
local SafeEntropy, parent = torch.class('nn.SafeEntropy','nn.Module')
-- taking the log of 0 results in infinity; instead compute -x*log(x + offset), where x >= 0.
function SafeEntropy:__init(offset)
parent.__init(self)
self.offset = offset or 1e-8
self.offsetLog = torch.Tensor()
self.divisor = torch.Tensor()
self.ratio = torch.Tensor()
end
function SafeEntropy:updateOutput(input)
self.offsetLog:resizeAs(input):copy(input)
self.offsetLog:add(self.offset):log()
self.output:cmul(self.offsetLog, input)
self.output:mul(-1)
return self.output
end
function SafeEntropy:updateGradInput(input, gradOutput)
--self.gradInput:resizeAs(input):copy(input)
--self.gradInput:add(self.offset):log()
self.divisor:resizeAs(input):copy(input):add(self.offset)
self.ratio:resizeAs(input):cdiv(input, self.divisor)
self.gradInput:resizeAs(input):add(self.offsetLog, self.ratio)
self.gradInput:cmul(gradOutput)
self.gradInput:mul(-1)
return self.gradInput
end