This repository was archived by the owner on Aug 5, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathThreshold.lua
More file actions
65 lines (59 loc) · 1.86 KB
/
Threshold.lua
File metadata and controls
65 lines (59 loc) · 1.86 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
local Threshold, parent = torch.class('mklnn.Threshold','nn.Module')
local wrapper = mklnn.wrapper
local getType = mklnn.getType
function Threshold:__init(th,v,ip)
parent.__init(self)
self.threshold = th or 1e-6
self.val = v or 0
if (th and type(th) ~= 'number') or (v and type(v) ~= 'number') then
error('nn.Threshold(threshold, value)')
end
-- default for inplace is false
self.inplace = ip or false
if (ip and type(ip) ~= 'boolean') then
error('in-place flag must be boolean')
end
self:validateParameters()
end
function Threshold:updateOutput(input)
if self.dnnPrimitives then
self.mkldnnInitOk = 1
else
self.mkldnnInitOk = 0
end
self.dnnPrimitives = self.dnnPrimitives or torch.LongTensor(11)
self.gradInput = self.gradInput:mkl()
self.output = self.output:mkl()
self:validateParameters()
wrapper(getType(input),'Threshold_updateOutput',
input:cdata(),
self.output:cdata(),
self.threshold,
self.val,
self.inplace,
self.dnnPrimitives:cdata(),
self.mkldnnInitOk
)
return self.output
end
function Threshold:updateGradInput(input, gradOutput)
self:validateParameters()
wrapper(getType(input),'Threshold_updateGradInput',
input:cdata(),
gradOutput:cdata(),
self.gradInput:cdata(),
self.threshold,
self.inplace,
self.dnnPrimitives:cdata(),self.mkldnnInitOk
)
return self.gradInput
end
function Threshold:validateParameters()
self.inplace = self.inplace or false -- backwards compatibility pre inplace
if self.inplace then
if self.val > self.threshold then
error('in-place processing requires value (' .. self.val ..
') not exceed threshold (' .. self.threshold .. ')')
end
end
end