Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 55 additions & 30 deletions DeepQuant/QuantManipulation/DequantModifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# Federico Brancasi <fbrancasi@ethz.ch>

import torch.fx as fx
import torch

from DeepQuant.QuantManipulation.QuantDequantNodes import Dequant
from DeepQuant.Utils.ConsoleFormatter import ConsoleColor as cc
Expand All @@ -31,7 +32,10 @@ def unifyLinearDequants(fxModel: fx.GraphModule, debug: bool = False) -> fx.Grap
newLinArgs = []

for arg in oldArgs:
if arg.op == "call_module" and "dequant" in arg.target.lower():
# FCONTI: there is no Bias, propagate this to the newLinArgs
if arg is None:
newLinArgs.append(arg)
elif arg.op == "call_module" and "dequant" in arg.target.lower():
if "bias_dequant" in arg.target.lower():
biasDequantNode = arg
elif "weight_dequant" in arg.target.lower():
Expand All @@ -47,26 +51,48 @@ def unifyLinearDequants(fxModel: fx.GraphModule, debug: bool = False) -> fx.Grap
node.args = tuple(newLinArgs)

if biasDequantNode is None:
# FBRANCASI: This would be unusual if a linear is missing bias or missing a bias_dequant
# FCONTI: this happens if a linear layer has no bias
if debug:
print(f"Skipping {node.target}: no biasDequantNode found.")
continue

biasQuantNode = biasDequantNode.args[0]
if (
biasQuantNode.op == "call_module"
and "bias_quant" in biasQuantNode.target.lower()
):
newBqArgs = list(biasQuantNode.args)
for i, bqArg in enumerate(newBqArgs):
if bqArg.op == "call_module" and "dequant" in bqArg.target.lower():
newBqArgs[i] = bqArg.args[0]
biasQuantNode.args = tuple(newBqArgs)
print(f"Skipping bias for {node.target}: no biasDequantNode found.")
biasQuantNode = None
else:
if debug:
print(
"Warning: Did not find a typical 'bias_quant' node shape in the graph."
)
biasQuantNode = biasDequantNode.args[0]
if (
biasQuantNode.op == "call_module"
and "bias_quant" in biasQuantNode.target.lower()
):
newBqArgs = list(biasQuantNode.args)
for i, bqArg in enumerate(newBqArgs):
if bqArg.op == "call_module" and "dequant" in bqArg.target.lower():
newBqArgs[i] = bqArg.args[0]
biasQuantNode.args = tuple(newBqArgs)
else:
if debug:
print(
"Warning: Did not find a typical 'bias_quant' node shape in the graph."
)

# FCONTI: if there is a bias node, use it for scale/zeropoint/bitwidth.
# otherwise, rely on weight*input
if biasDequantNode is not None:
oldBiasDequantMod = fxModel.get_submodule(biasDequantNode.target)
dequantScale = oldBiasDequantMod.scale
dequantZeroPoint = oldBiasDequantMod.zeroPoint
dequantBitWidth = oldBiasDequantMod.bitWidth
oldDequantMod = oldBiasDequantMod
else:
oldInputDequantMod = fxModel.get_submodule(inputDequantNode.target)
oldWeightDequantMod = fxModel.get_submodule(weightDequantNode.target)
dequantScale = oldWeightDequantMod.scale * oldInputDequantMod.scale
# FCONTI: technically it should be:
# dZP = oWDM.zP * oIDM.zP - oWDM.scale * oIDM.zP * sum(weights)
# how to appropriately compute sum(weights)?
# for now we restrict ourselves to oIDM.zP = 0, so dZP = 0
if debug and oldInputDequantMod.zeroPoint != 0.0:
print(f"Warning: input Dequant node for {node.target} has non-zero zero-point (unsupported). Expect wrong results!")
dequantZeroPoint = 0.0
dequantBitWidth = 32 # FCONTI: this is simply a reasonable assumption: is there a less arbitrary one?
oldDequantMod = oldWeightDequantMod

for dnode in (inputDequantNode, weightDequantNode):
if dnode is not None:
Expand All @@ -76,19 +102,17 @@ def unifyLinearDequants(fxModel: fx.GraphModule, debug: bool = False) -> fx.Grap
delattr(fxModel, dnode.target)
graph.erase_node(dnode)

oldBiasDequantMod = fxModel.get_submodule(biasDequantNode.target)

newDequantModName = (
node.target.replace(".wrappedInnerForwardImpl", "") + "_unified_dequant"
)
# JUNGVI: Torch modules name cannot contain "."
newDequantModName = newDequantModName.replace(".", "_")

unifiedDequantMod = Dequant(
originalModule=oldBiasDequantMod.originalModule,
scale=oldBiasDequantMod.scale,
zeroPoint=oldBiasDequantMod.zeroPoint,
bitWidth=oldBiasDequantMod.bitWidth,
originalModule=oldDequantMod.originalModule,
scale=dequantScale,
zeroPoint=dequantZeroPoint,
bitWidth=dequantBitWidth,
)

fxModel.add_module(newDequantModName, unifiedDequantMod)
Expand All @@ -105,11 +129,12 @@ def unifyLinearDequants(fxModel: fx.GraphModule, debug: bool = False) -> fx.Grap
newArgs[i] = newDequantNode
usr.args = tuple(newArgs)

for usr in list(biasDequantNode.users.keys()):
biasDequantNode.users[usr] = None
if hasattr(fxModel, biasDequantNode.target):
delattr(fxModel, biasDequantNode.target)
graph.erase_node(biasDequantNode)
if biasDequantNode is not None:
for usr in list(biasDequantNode.users.keys()):
biasDequantNode.users[usr] = None
if hasattr(fxModel, biasDequantNode.target):
delattr(fxModel, biasDequantNode.target)
graph.erase_node(biasDequantNode)

if debug:
print(cc.success(f"Modification done for {node.target}"))
Expand Down