From cb25563ca19c4027a1a730ee8673e7bfff503947 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Sat, 1 Feb 2020 10:24:26 +0800 Subject: [PATCH 1/2] std should be set to 1 if its value vanishes --- source/train/DescrptSeA.py | 5 ++++- source/train/DescrptSeR.py | 6 +++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/source/train/DescrptSeA.py b/source/train/DescrptSeA.py index 527dff2cce..fb914ea378 100644 --- a/source/train/DescrptSeA.py +++ b/source/train/DescrptSeA.py @@ -289,7 +289,10 @@ def _compute_dstats_sys_smth (self, def _compute_std (self,sumv2, sumv, sumn) : - return np.sqrt(sumv2/sumn - np.multiply(sumv/sumn, sumv/sumn)) + val = np.sqrt(sumv2/sumn - np.multiply(sumv/sumn, sumv/sumn)) + if val == 0: + val = 1 + return val def _filter(self, diff --git a/source/train/DescrptSeR.py b/source/train/DescrptSeR.py index aaa91ca220..5c60916d86 100644 --- a/source/train/DescrptSeR.py +++ b/source/train/DescrptSeR.py @@ -246,7 +246,11 @@ def _compute_dstats_sys_se_r (self, def _compute_std (self,sumv2, sumv, sumn) : - return np.sqrt(sumv2/sumn - np.multiply(sumv/sumn, sumv/sumn)) + val = np.sqrt(sumv2/sumn - np.multiply(sumv/sumn, sumv/sumn)) + if val == 0: + val = 1 + return val + def _filter_r(self, inputs, From b29d6d3c96ac7bea8cd99bcce70f922006fb1ec9 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Sat, 1 Feb 2020 10:31:00 +0800 Subject: [PATCH 2/2] protect the descrpt to 1e-2 rather than set to 1 --- source/train/DescrptSeA.py | 4 ++-- source/train/DescrptSeR.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/source/train/DescrptSeA.py b/source/train/DescrptSeA.py index fb914ea378..dfe7329445 100644 --- a/source/train/DescrptSeA.py +++ b/source/train/DescrptSeA.py @@ -290,8 +290,8 @@ def _compute_dstats_sys_smth (self, def _compute_std (self,sumv2, sumv, sumn) : val = np.sqrt(sumv2/sumn - np.multiply(sumv/sumn, sumv/sumn)) - if val == 0: - val = 1 + if np.abs(val) < 1e-2: + val = 1e-2 return val diff --git a/source/train/DescrptSeR.py b/source/train/DescrptSeR.py index 5c60916d86..f7f8147a6c 100644 --- a/source/train/DescrptSeR.py +++ b/source/train/DescrptSeR.py @@ -247,8 +247,8 @@ def _compute_dstats_sys_se_r (self, def _compute_std (self,sumv2, sumv, sumn) : val = np.sqrt(sumv2/sumn - np.multiply(sumv/sumn, sumv/sumn)) - if val == 0: - val = 1 + if np.abs(val) < 1e-2: + val = 1e-2 return val