Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 137 additions & 59 deletions SparseLinear.lua
Original file line number Diff line number Diff line change
@@ -1,19 +1,25 @@
local THNN = require 'nn.THNN'
local SparseLinear, parent = torch.class('nn.SparseLinear', 'nn.Module')

function SparseLinear:__init(inputSize, outputSize)
local NO_LAST_INPUT = 0
local ONE_LAST_INPUT = 1
local ACC_MULTIPLE_TIMES = 2

function SparseLinear:__init(inputSize, outputSize, doGradInput)
parent.__init(self)

self.weightDecay = 0
self.doGradInput = doGradInput or false
self.weight = torch.Tensor(outputSize, inputSize):zero()
self.bias = torch.Tensor(outputSize):zero()
self.gradWeight = torch.Tensor(outputSize, inputSize):zero()
self.gradBias = torch.Tensor(outputSize):zero()
self.lastInput = nil

if torch.getnumthreads() > 1 and outputSize >= 128 then
self.shardBuffer = torch.Tensor(outputSize, torch.getnumthreads())
end
assert(type(self.doGradInput) == type(true))

self.lastInput = nil
self.sparseUpdate = NO_LAST_INPUT
self.formatted_input = nil

-- state
self.gradInput:resize(inputSize)
Expand All @@ -33,78 +39,148 @@ function SparseLinear:reset(stdv)
end

function SparseLinear:reshapeInput(input)
if input:dim() == 2 then
return input:view(1, input:size(1), input:size(2)), false
if type(input) == 'table' then
return input, true, false
else
return input, true
if input:dim() == 2 then
return {input}, false, false
else
return input, true, true
end
end
end

function SparseLinear:updateOutput(input)
self.cudaBuffer = self.cudaBuffer or input.new()
local input, batchMode = self:reshapeInput(input)

input.THNN.SparseLinear_updateOutput(
input:cdata(),
self.output:cdata(),
self.weight:cdata(),
self.bias:cdata(),
self.cudaBuffer:cdata(),
THNN.optionalTensor(self.shardBuffer)
)

-- fix output size for batchSize = 1
if not batchMode then
self.output:set(self.output:view(self.output:size(2)))
end
local input, batchMode, legacyMode = self:reshapeInput(input)
self.legacyMode = legacyMode

return self.output
end
if legacyMode then
input.THNN.SparseLinear_legacyUpdateOutput(
input:cdata(),
self.output:cdata(),
self.weight:cdata(),
self.bias:cdata()
)
else
local nbatches = #input
if nbatches == 0 then
self.output:copy(self.bias)
return self.output
end

function SparseLinear:accGradParameters(input, gradOutput, scale)
local input, batchMode = self:reshapeInput(input)
local size = 0
local marker = 1
self.formatted_input = self.formatted_input or input[1].new()

for i,v in ipairs(input) do size = size + input[i]:size(1) end
self.formatted_input:resize(size, 3)
for i,v in ipairs(input) do
local buf = self.formatted_input:narrow(1, marker, input[i]:size(1))
buf:narrow(2,2,2):copy(input[i])
buf:select(2,1):fill(i)
marker = marker + input[i]:size(1)
end

self.lastInput = self.lastInput or input.new()
self.lastInput:resizeAs(input):copy(input)
if not batchMode then
gradOutput = gradOutput:view(1, gradOutput:size(1))
self.output:resize(nbatches, self.weight:size(1))
input[1].THNN.SparseLinear_updateOutput(
self.formatted_input:cdata(),
self.output:cdata(),
self.weight:cdata(),
self.bias:cdata()
)

-- fix output size for batchSize = 1
if not batchMode then
self.output = self.output[1]
end
end

input.THNN.SparseLinear_accGradParameters(
input:cdata(),
gradOutput:cdata(),
self.gradWeight:cdata(),
self.gradBias:cdata(),
self.weight:cdata(),
self.bias:cdata(),
self.weightDecay or 0,
scale or 1
)
return self.output
end

function SparseLinear:updateGradInput(input, gradOutput)
if self.gradInput then
local input, batchMode = self:reshapeInput(input)
if not batchMode then
gradOutput = gradOutput:view(1, gradOutput:size(1))
function SparseLinear:accGradParameters(input, gradOutput, scale)
local input, batchMode, legacyMode = self:reshapeInput(input)
self.legacyMode = legacyMode

if legacyMode then
self.lastInput = self.lastInput or input.new()
if self.sparseUpdate == NO_LAST_INPUT then
self.lastInput:resizeAs(input):copy(input)
self.sparseUpdate = ONE_LAST_INPUT
elseif self.sparseUpdate == ONE_LAST_INPUT then
self.sparseUpdate = ACC_MULTIPLE_TIMES
end
input.THNN.SparseLinear_updateGradInput(

input.THNN.SparseLinear_legacyAccGradParameters(
input:cdata(),
gradOutput:cdata(),
self.gradInput:cdata(),
self.weight:cdata()
self.gradWeight:cdata(),
self.gradBias:cdata(),
self.weight:cdata(),
self.bias:cdata(),
self.weightDecay or 0,
scale or 1
)
-- fix gradInput size for batchSize = 1
else
if not batchMode then
self.gradInput:set(self.gradInput:view(self.gradInput:size(2), self.gradInput:size(3)))
gradOutput:resize(1, gradOutput:size(1))
end

return self.gradInput
input[1].THNN.SparseLinear_accGradParameters(
self.formatted_input:cdata(),
gradOutput:cdata(),
self.gradWeight:cdata(),
self.gradBias:cdata(),
self.weight:cdata(),
self.bias:cdata(),
self.weightDecay or 0,
scale or 1
)
end
end

function SparseLinear:updateGradInput(input, gradOutput)
if self.legacyMode then
if type(self.gradInput) ~= type(gradOutput) then self.gradInput = gradOutput.new() end
self.gradInput:resizeAs(input)
else
self.gradInput = {}
end
if self.doGradInput then
-- GradInput should be dense anyway
local gi
local batchMode = true
if gradOutput:dim() == 1 then
gi = self.weight:t()*gradOutput
batchMode = false
elseif gradOutput:dim() == 2 then
gi = gradOutput*self.weight
end
local ini = self.weight:size(2)

if self.legacyMode then
local batches = self.gradInput:size(1)
self.gradInput:resize(batches, ini, 2)
self.gradInput:select(3,1):copy(torch.repeatTensor(torch.range(1, ini), batches, 1))
self.gradInput:select(3,2):copy(gi)
else
indicies = torch.range(1, ini)
if not batchMode then gi:resize(1, ini) end
for i = 1,gi:size(1) do
self.gradInput[i] = gradOutput.new(ini, 2)
self.gradInput[i]:select(2, 2):copy(gi[i])
self.gradInput[i]:select(2, 1):range(1, ini)
end
end
end
return self.gradInput
end

-- These functions do sparse updates / zeros. However, if we accumulated
-- gradients multiple times, we can't depend on the last input to do sparse
-- updates.
function SparseLinear:updateParameters(learningRate)
if self.lastInput then
if self.lastInput and self.legacyMode and self.sparseUpdate == ONE_LAST_INPUT then
self.lastInput.THNN.SparseLinear_updateParameters(
self.weight:cdata(),
self.bias:cdata(),
Expand All @@ -116,22 +192,24 @@ function SparseLinear:updateParameters(learningRate)
else
parent.updateParameters(self, learningRate)
end
self.sparseUpdate = 0
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.sparseUpdate = NO_LAST_INPUT

end

function SparseLinear:zeroGradParameters()
if self.lastInput then
if self.lastInput and self.legacyMode and self.sparseUpdate == ONE_LAST_INPUT then
self.lastInput.THNN.SparseLinear_zeroGradParameters(
self.gradWeight:cdata(),
self.gradBias:cdata(),
self.lastInput:cdata()
self.gradWeight:cdata(),
self.gradBias:cdata(),
self.lastInput:cdata()
)
else
parent.zeroGradParameters(self)
end
self.sparseUpdate = 0
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.sparseUpdate = NO_LAST_INPUT

end

function SparseLinear:clearState()
if self.lastInput then self.lastInput:set() end
if self.cudaBuffer then self.cudaBuffer:set() end
input.THNN.SparseLinear_cudaClearState()
return parent.clearState(self)
end
Loading