Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 34 additions & 4 deletions python/mxnet/initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,41 @@ def _init_weight(self, _, arr):


class Xavier(Initializer):
"""Initialize the weight with Xavier initialization scheme."""
"""Initialize the weight with Xavier or similar initialization scheme.

Parameters
----------
rnd_type: str, optional
Use ```gaussian``` or ```uniform``` to init

factor_type: str, optional
Use ```avg```, ```in```, or ```out``` to init

magnitude: float, optional
scale of random number range
"""
def __init__(self, rnd_type="uniform", factor_type="avg", magnitude=3):
self.rnd_type = rnd_type
self.factor_type = factor_type
self.magnitude = magnitude


def _init_weight(self, _, arr):
shape = arr.shape
fan_in, fan_out = np.prod(shape[1:]), shape[0]
scale = np.sqrt(3. / (fan_in + fan_out))
random.uniform(-scale, scale, out=arr)

factor = 1
if self.factor_type == "avg":
factor = (fan_in + fan_out) / 2
elif self.factor_type == "in":
factor = fan_in
elif self.factor_type == "out":
factor = fan_out
else:
raise ValueError("Incorrect factor type")
scale = np.sqrt(self.magnitude / factor)
if self.rnd_type == "uniform":
random.uniform(-scale, scale, out=arr)
elif self.rnd_type == "gaussian":
random.normal(0, scale, out=arr)
else:
raise ValueError("Unknown random type")