From 55b01c74949f7feaecbea5b090d8641ae48405cc Mon Sep 17 00:00:00 2001 From: Bing Xu Date: Tue, 17 Nov 2015 15:05:49 -0700 Subject: [PATCH] Update Xavier --- python/mxnet/initializer.py | 38 +++++++++++++++++++++++++++++++++---- 1 file changed, 34 insertions(+), 4 deletions(-) diff --git a/python/mxnet/initializer.py b/python/mxnet/initializer.py index de38867b3981..46e179ea7b96 100644 --- a/python/mxnet/initializer.py +++ b/python/mxnet/initializer.py @@ -91,11 +91,41 @@ def _init_weight(self, _, arr): class Xavier(Initializer): - """Initialize the weight with Xavier initialization scheme.""" + """Initialize the weight with Xavier or similar initialization scheme. + + Parameters + ---------- + rnd_type: str, optional + Use ```gaussian``` or ```uniform``` to init + + factor_type: str, optional + Use ```avg```, ```in```, or ```out``` to init + + magnitude: float, optional + scale of random number range + """ + def __init__(self, rnd_type="uniform", factor_type="avg", magnitude=3): + self.rnd_type = rnd_type + self.factor_type = factor_type + self.magnitude = magnitude + def _init_weight(self, _, arr): shape = arr.shape fan_in, fan_out = np.prod(shape[1:]), shape[0] - scale = np.sqrt(3. / (fan_in + fan_out)) - random.uniform(-scale, scale, out=arr) - + factor = 1 + if self.factor_type == "avg": + factor = (fan_in + fan_out) / 2 + elif self.factor_type == "in": + factor = fan_in + elif self.factor_type == "out": + factor = fan_out + else: + raise ValueError("Incorrect factor type") + scale = np.sqrt(self.magnitude / factor) + if self.rnd_type == "uniform": + random.uniform(-scale, scale, out=arr) + elif self.rnd_type == "gaussian": + random.normal(0, scale, out=arr) + else: + raise ValueError("Unknown random type")