diff --git a/.readthedocs.yaml b/.readthedocs.yaml
index faed04697..909a218f4 100644
--- a/.readthedocs.yaml
+++ b/.readthedocs.yaml
@@ -7,3 +7,7 @@ build:
sphinx:
configuration: docs/conf.py
+
+python:
+ install:
+ - requirements: "requirements.txt"
diff --git a/docs/_build/html/_images/SN848233.png b/docs/_build/html/_images/SN848233.png
deleted file mode 100644
index 4205aa2ea..000000000
Binary files a/docs/_build/html/_images/SN848233.png and /dev/null differ
diff --git a/docs/_build/html/_images/active_learning_loop.png b/docs/_build/html/_images/active_learning_loop.png
deleted file mode 100644
index 9dcd01312..000000000
Binary files a/docs/_build/html/_images/active_learning_loop.png and /dev/null differ
diff --git a/docs/_build/html/_images/canonical.png b/docs/_build/html/_images/canonical.png
deleted file mode 100644
index 549341c43..000000000
Binary files a/docs/_build/html/_images/canonical.png and /dev/null differ
diff --git a/docs/_build/html/_images/diag.png b/docs/_build/html/_images/diag.png
deleted file mode 100644
index a1190a280..000000000
Binary files a/docs/_build/html/_images/diag.png and /dev/null differ
diff --git a/docs/_build/html/_images/time_domain.png b/docs/_build/html/_images/time_domain.png
deleted file mode 100644
index 038975318..000000000
Binary files a/docs/_build/html/_images/time_domain.png and /dev/null differ
diff --git a/docs/conf.py b/docs/conf.py
index 267314568..4a93c4486 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -30,9 +30,11 @@
master_doc = 'index'
extensions = ['sphinx.ext.autodoc',
+ 'sphinx_rtd_theme',
'sphinx.ext.autosummary',
'sphinx.ext.mathjax',
- 'sphinx.ext.napoleon']
+ 'sphinx.ext.napoleon',
+ 'sphinx_rtd_theme']
#'sphinx_automodapi.smart_resolver',
#'sphinx_automodapi.automodapi']
diff --git a/docs/index.rst b/docs/index.rst
index ee4b2f193..28a479a24 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -16,7 +16,7 @@ The code has been modify for the task of enabling photometric supernova cosmolog
Getting started
===============
-This code was developed for ``Python3`` and was not tested in Windows.
+This code was developed for ``Python3`` and was not tested in Windows.
We recommend that you work within a `virtual environment `_.
@@ -32,7 +32,7 @@ Navigate to a ``working_directory`` where you will store the new virtual environ
>>> python3.10 -m venv resspect
-.. hint:: Make sure you deactivate any ``conda`` environment you might have running before moving forward.
+.. hint:: Make sure you deactivate any ``conda`` environment you might have running before moving forward.
Once the environment is set up you can activate it:
@@ -40,7 +40,7 @@ Once the environment is set up you can activate it:
>>> source /bin/activate
-You should see a ``(resspect)`` flag in the extreme left of terminal command line.
+You should see a ``(resspect)`` flag in the extreme left of terminal command line.
Next, clone this repository in another chosen location:
@@ -118,10 +118,11 @@ Details of the tools available to evaluate different steps on feature extraction
Alternatively, you can also perform the full light curve fit for the entire sample from the command line.
-If you are only interested in testing your installation you should work with the SNPCC data:
+If you are only interested in testing your installation you should work with the SNPCC data:
.. code-block:: bash
- >>> fit_dataset.py -s SNPCC -dd -o
+
+ >>> fit_dataset.py -s SNPCC -dd -o
Once the data has been processed you can apply the full Active Learning loop according to your needs.
A detail description on how to use this tool is provided in the :ref:`Learning Loop page `.
@@ -151,7 +152,7 @@ Acknowledgements
This work is part of the Recommendation System for Spectroscopic Followup (RESSPECT) project, governed by an inter-collaboration agreement signed between the `Cosmostatistics Initiative (COIN) `_ and the `LSST Dark Energy Science Collaboration (DESC) `_.
-The `COsmostatistics INitiative (COIN) `_ is an international network of researchers whose goal is to foster interdisciplinarity inspired by Astronomy.
+The `COsmostatistics INitiative (COIN) `_ is an international network of researchers whose goal is to foster interdisciplinarity inspired by Astronomy.
COIN received financial support from `CNRS `_ for the development of this project, as part of its MOMENTUM programme over the 2018-2020 period, under the project *Active Learning for Large Scale Sky Surveys*.
diff --git a/docs/plotting.rst b/docs/plotting.rst
index b723a37ab..a26eed992 100644
--- a/docs/plotting.rst
+++ b/docs/plotting.rst
@@ -11,7 +11,7 @@ evolution of the metrics:
- Purity: fraction of correct Ia classifications;
- Figure of merit: efficiency x purity with a penalty factor of 3 for false positives (contamination).
-The class `Canvas `_ enables you do to it using:
+The class `Canvas()` enables you do to it using:
.. code-block:: python
:linenos:
diff --git a/docs/pre_processing.rst b/docs/pre_processing.rst
index 8ddf77e5e..5538d6a54 100644
--- a/docs/pre_processing.rst
+++ b/docs/pre_processing.rst
@@ -11,9 +11,9 @@ as input to the learning algorithm.
Before starting any analysis, you need to choose a feature extraction method, all light curves will then be handdled by this method. In the examples below we used the Bazin feature extraction method (`Bazin et al., 2009 `_ ).
-Load 1 light curve:
+Load 1 light curve:
-------------------
-
+
For SNPCC:
^^^^^^^^^^
@@ -34,8 +34,8 @@ You can load this data using:
>>> lc = BazinFeatureExtractor() # create light curve instance
>>> lc.load_snpcc_lc(path_to_lc) # read data
-
-This allows you to visually inspect the content of the light curve:
+
+This allows you to visually inspect the content of the light curve:
.. code-block:: python
:linenos:
@@ -51,9 +51,9 @@ This allows you to visually inspect the content of the light curve:
Fit 1 light curve:
------------
+------------------
-In order to feature extraction in one specific filter, you can do:
+In order to feature extraction in one specific filter, you can do:
.. code-block:: python
:linenos:
@@ -94,11 +94,12 @@ This can be done in flux as well as in magnitude:
>>> lc.plot_fit(save=False, show=True, unit='mag')
+
.. figure:: images/SN729076_mag.png
- :align: center
- :height: 480 px
- :width: 640 px
- :alt: Bazing fit to light curve. This is an example from SNPCC data.
+ :align: center
+ :height: 480 px
+ :width: 640 px
+ :alt: Bazing fit to light curve. This is an example from SNPCC data.
Example of light from SNPCC data.
@@ -112,20 +113,20 @@ Before deploying large batches for pre-processing, you might want to visualize
>>> # define max MJD for this light curve
>>> max_mjd = max(lc.photometry['mjd']) - min(lc.photometry['mjd'])
-
- >>> lc.plot_fit(save=False, show=True, extrapolate=True,
+
+ >>> lc.plot_fit(save=False, show=True, extrapolate=True,
time_flux_pred=[max_mjd+3, max_mjd+5, max_mjd+10])
.. figure:: images/SN729076_flux_extrap.png
- :align: center
- :height: 480 px
- :width: 640 px
- :alt: Bazing fit to light curve. This is an example from SNPCC data.
+ :align: center
+ :height: 480 px
+ :width: 640 px
+ :alt: Bazing fit to light curve. This is an example from SNPCC data.
Example of extrapolated light from SNPCC data.
-
-
+
+
For PLAsTiCC:
^^^^^^^^^^^^^
@@ -138,7 +139,7 @@ Reading only 1 light curve from PLAsTiCC requires an object identifier. This can
>>> path_to_metadata = '~/plasticc_train_metadata.csv'
>>> path_to_lightcurves = '~/plasticc_train_lightcurves.csv.gz'
-
+
# read metadata for the entire sample
>>> metadata = pd.read_csv(path_to_metadata)
@@ -151,7 +152,7 @@ Reading only 1 light curve from PLAsTiCC requires an object identifier. This can
'libid_cadence', 'tflux_u', 'tflux_g', 'tflux_r', 'tflux_i', 'tflux_z',
'tflux_y'],
dtype='object')
-
+
# choose 1 object
>>> snid = metadata['object_id'].values[0]
@@ -179,7 +180,7 @@ For SNPCC:
>>> feature_extractor = 'bazin'
>>> fit_snpcc(path_to_data_dir=path_to_data_dir, features_file=features_file)
-
+
For PLAsTiCC:
^^^^^^^^^^^^^
@@ -189,14 +190,14 @@ For PLAsTiCC:
>>> from resspect import fit_plasticc
- >>> path_photo_file = '~/plasticc_train_lightcurves.csv'
+ >>> path_photo_file = '~/plasticc_train_lightcurves.csv'
>>> path_header_file = '~/plasticc_train_metadata.csv.gz'
- >>> output_file = 'results/PLAsTiCC_Bazin_train.dat'
- >>> feature_extractor = 'bazin'
+ >>> output_file = 'results/PLAsTiCC_Bazin_train.dat'
+ >>> feature_extractor = 'bazin'
>>> sample = 'train'
- >>> fit_plasticc(path_photo_file=path_photo_file,
+ >>> fit_plasticc(path_photo_file=path_photo_file,
path_header_file=path_header_file,
output_file=output_file,
feature_extractor=feature_extractor,
@@ -207,12 +208,11 @@ The same result can be achieved using the command line:
.. code-block:: bash
:linenos:
-
+
# for SNPCC
>>> fit_dataset -s SNPCC -dd -o
# for PLAsTiCC
- >>> fit_dataset -s -p
- -hd -sp -o
+ >>> fit_dataset -s -p
+ -hd -sp -o
-
\ No newline at end of file
diff --git a/docs/reference.rst b/docs/reference.rst
deleted file mode 100644
index 7d346b0b0..000000000
--- a/docs/reference.rst
+++ /dev/null
@@ -1,193 +0,0 @@
-***************
-Reference / API
-***************
-
-.. currentmodule:: resspect
-
-Pre-processing
-==============
-
-Light curve analysis
--------------------------
-
-*Performing feature extraction for 1 light curve*
-
-.. autosummary::
- :toctree: api
-
- LightCurve
- LightCurve.fit_bazin
- LightCurve.fit_bazin_all
- LightCurve.check_queryable
- LightCurve.conv_flux_mag
- LightCurve.evaluate_bazin
- LightCurve.load_plasticc_lc
- LightCurve.load_resspect_lc
- LightCurve.load_snpcc_lc
- LightCurve.plot_bazin_fit
-
-*Fitting an entire data set*
-
-.. autosummary::
- :toctree: api
-
- fit_snpcc_bazin
-
-*Basic light curve analysis tools*
-
-.. autosummary::
-
- bazin
- errfunc
- fit_scipy
- read_fits
-
-Canonical sample
-================
-
-*The Canonical object for holding the entire sample.*
-
-.. autosummary::
- :toctree: api
-
- Canonical
- Canonical.snpcc_get_canonical_info
- Canonical.snpcc_identify_samples
- Canonical.find_neighbors
-
-*Functions to populate the Canonical object*
-
-.. autosummary::
- :toctree: api
-
- build_snpcc_canonical
- plot_snpcc_train_canonical
-
-Build time domain data base
-===========================
-
-.. autosummary::
- :toctree: api
-
- SNPCCPhotometry
- SNPCCPhotometry.get_lim_mjds
- SNPCCPhotometry.create_daily_file
- SNPCCPhotometry.build_one_epoch
-
-.. autosummary::
- :toctree: api
-
- ExpTimeCalc
- ExpTimeCalc.findexptime
- ExpTimeCalc.findmag
- ExpTimeCalc.FWHM
- ExpTimeCalc.SNR
-
-DataBase
-========
-
-*Object upon which the learning process is performed*
-
-.. autosummary::
- :toctree: api
-
- DataBase
- DataBase.build_orig_samples
- DataBase.build_random_training
- DataBase.build_samples
- DataBase.classify
- Database.classify_bootstrap
- DataBase.evaluate_classification
- DataBAse.identify_keywords
- DataBase.load_bazin_features
- DataBase.load_features
- DataBase.load_photometry_features
- DataBase.load_plasticc_mjd
- DataBase.make_query
- Dataase.output_photo_Ia
- DataBase.save_metrics
- DataBase.save_queried_sample
- DataBase.update_samples
-
-
-Classifiers
-===========
-
-.. autosummary::
- :toctree: api
-
- random_forest
-
-
-Query strategies
-================
-
-.. autosummary::
- :toctree: api
-
- random_sampling
- uncertainty_sampling
-
-Metrics
-=======
-
-*Individual metrics*
-
-.. autosummary::
- :toctree: api
-
- accuracy
- efficiency
- purity
- fom
-
-
-*Metrics agregated by category or use*
-
-.. autosummary::
- :toctree: api
-
- get_snpcc_metric
-
-
-Active Learning loop
-====================
-
-*Full light curve*
-
-.. autosummary::
- :toctree: api
-
- learn_loop
-
-*Time domain*
-
-.. autosummary::
- :toctree: api
-
- get_original_training
- time_domain_loop
-
-
-Plotting
-========
-
-.. autosummary::
- :toctree: api
-
- Canvas
- Canvas.load_metrics
- Canvas.set_plot_dimensions
- Canvas.plot_metrics
-
-Scripts
-=======
-
-.. autosummary::
-
- build_canonical
- build_time_domain_SNPCC
- fit_dataset
- make_metrics_plots
- run_loop
- run_time_domain
diff --git a/pyproject.toml b/pyproject.toml
index 5469adacd..9a70a9e8d 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -23,6 +23,7 @@ dependencies = [
"seaborn>=0.12.2",
"xgboost>=1.7.3",
"iminuit>=1.20.0",
+ "sphinx-rtd-theme>=1.3.0"
]
[project.urls]
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 000000000..483a4e960
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1 @@
+sphinx_rtd_theme
diff --git a/resspect/__init__.py b/resspect/__init__.py
index 07f60f016..18fadf94c 100644
--- a/resspect/__init__.py
+++ b/resspect/__init__.py
@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from .bazin import *
from .build_plasticc_canonical import *
from .build_plasticc_metadata import *
from .build_snpcc_canonical import *
@@ -42,7 +41,6 @@
from .time_domain_loop import *
from .batch_functions import *
from .query_budget_strategies import *
-from .bump import *
import importlib.metadata
@@ -51,12 +49,10 @@
__all__ = ['accuracy',
'assign_cosmo',
- 'bazin',
'build_canonical',
'build_plasticc_canonical',
'build_plasticc_metadata',
'build_snpcc_canonical',
- 'bump',
'calculate_SNR',
'Canonical',
'CanonicalPLAsTiCC',
@@ -92,7 +88,7 @@
'protected_exponent',
'protected_sig',
'purity',
- 'random_forest',
+ 'random_forest',
'random_sampling',
'read_features_fullLC_samples',
'read_fits',
diff --git a/resspect/database.py b/resspect/database.py
index 0dd025a49..664f29a7d 100644
--- a/resspect/database.py
+++ b/resspect/database.py
@@ -780,7 +780,7 @@ def build_random_training(self, initial_training: int, nclass=2, screen=False,
def build_samples(self, initial_training='original', nclass=2,
screen=False, Ia_frac=0.5,
queryable=False, save_samples=False, sep_files=False,
- survey='DES', output_fname=' '):
+ survey='DES', output_fname=None):
"""Separate train, test and validation samples.
Populate properties: train_features, train_header, test_features,
@@ -841,7 +841,7 @@ def build_samples(self, initial_training='original', nclass=2,
print(' From which queryable: ',
self.queryable_ids.shape[0], '\n')
- if save_samples:
+ if isinstance(initial_training, int) and output_fname is not None:
full_header = self.metadata_names + self.features_names
wsample = open(output_fname, 'w')
diff --git a/resspect/feature_extractors/bazin.py b/resspect/feature_extractors/bazin.py
index d1216396b..ef5b0ccf9 100644
--- a/resspect/feature_extractors/bazin.py
+++ b/resspect/feature_extractors/bazin.py
@@ -5,9 +5,9 @@
import matplotlib.pylab as plt
import numpy as np
-from resspect import bazin
-from resspect.bazin import fit_scipy
from resspect.feature_extractors.light_curve import LightCurve
+from resspect.utils.bazin_utils import bazin
+from resspect.utils.bazin_utils import fit_scipy
__all__ = ['BazinFeatureExtractor']
diff --git a/resspect/feature_extractors/bump.py b/resspect/feature_extractors/bump.py
index ebb4dfdeb..d71587180 100644
--- a/resspect/feature_extractors/bump.py
+++ b/resspect/feature_extractors/bump.py
@@ -5,9 +5,9 @@
import matplotlib.pylab as plt
import numpy as np
-from resspect import bump
-from resspect.bump import fit_bump
from resspect.feature_extractors.light_curve import LightCurve
+from resspect.utils.bump_utils import bump
+from resspect.utils.bump_utils import fit_bump
class BumpFeatureExtractor(LightCurve):
diff --git a/resspect/learn_loop.py b/resspect/learn_loop.py
index 3c984faab..229fd117a 100644
--- a/resspect/learn_loop.py
+++ b/resspect/learn_loop.py
@@ -32,7 +32,7 @@ def load_features(database_class: DataBase,
survey: str, features_method: str, number_of_classes: int,
training_method: str, is_queryable: bool,
separate_files: bool = False,
- save_samples: bool = False) -> DataBase:
+ initial_training_samples_file: str = None) -> DataBase:
"""
Load features according to feature extraction method
@@ -65,9 +65,9 @@ def load_features(database_class: DataBase,
separate_files: bool (optional)
If True, consider train and test samples separately read
from independent files. Default is False.
- save_samples: bool (optional)
- If True, save training and test samples to file.
- Default is False.
+ initial_training_samples_file
+ File name to save initial training samples.
+ File will be saved if "training"!="original"
"""
if isinstance(path_to_features, str):
database_class.load_features(
@@ -87,7 +87,7 @@ def load_features(database_class: DataBase,
database_class.build_samples(
initial_training=training_method, nclass=number_of_classes,
queryable=is_queryable, sep_files=separate_files,
- save_samples=save_samples)
+ output_fname=initial_training_samples_file)
return database_class
@@ -328,7 +328,8 @@ def learn_loop(nloops: int, strategy: str, path_to_features: str,
bool = False, sep_files=False, pred_dir: str = None,
queryable: bool = False, metric_label: str = 'snpcc',
save_alt_class: bool = False, SNANA_types: bool = False,
- metadata_fname: str = None, bar: bool = True, **kwargs):
+ metadata_fname: str = None, bar: bool = True,
+ initial_training_samples_file: str = None, **kwargs):
"""
Perform the active learning loop. All results are saved to file.
@@ -408,7 +409,10 @@ def learn_loop(nloops: int, strategy: str, path_to_features: str,
ensuring that at least half are SN Ia
Default is 'original'.
bar: bool (optional)
- If True, display progress bar.
+ If True, display progress bar.
+ initial_training_samples_file
+ File name to save initial training samples.
+ File will be saved if "training"!="original"
kwargs: extra parameters
All keywords required by the classifier function.
"""
@@ -421,7 +425,7 @@ def learn_loop(nloops: int, strategy: str, path_to_features: str,
logging.info('Loading features')
database_class = load_features(database_class, path_to_features, survey,
features_method, nclass, training, queryable,
- sep_files)
+ sep_files, initial_training_samples_file)
logging.info('Running active learning loop')
diff --git a/resspect/bazin.py b/resspect/utils/bazin_utils.py
similarity index 80%
rename from resspect/bazin.py
rename to resspect/utils/bazin_utils.py
index 2905d0863..6bcb7c4e1 100644
--- a/resspect/bazin.py
+++ b/resspect/utils/bazin_utils.py
@@ -52,6 +52,7 @@ def bazin(time, a, b, t0, tfall, trise):
X = np.exp(-(time - t0) / tfall) / (1 + np.exp(-(time - t0) / trise))
return a * X + b
+
def bazinr(time, a, b, t0, tfall, r):
"""
A wrapper function for bazin() which replaces trise by r = tfall/trise.
@@ -78,14 +79,15 @@ def bazinr(time, a, b, t0, tfall, r):
"""
- trise = tfall/r
+ trise = tfall / r
res = bazin(time, a, b, t0, tfall, trise)
-
+
if max(res) < 10e10:
return res
else:
return np.array([item if item < 10e10 else 10e10 for item in res])
+
def errfunc(params, time, flux, fluxerr):
"""
Absolute difference between theoretical and measured flux.
@@ -133,51 +135,50 @@ def fit_scipy(time, flux, fluxerr):
flux = np.asarray(flux)
imax = flux.argmax()
flux_max = flux[imax]
-
+
# Parameter bounds
a_bounds = [1.e-3, 10e10]
b_bounds = [-10e10, 10e10]
- t0_bounds = [-0.5*time.max(), 1.5*time.max()]
+ t0_bounds = [-0.5 * time.max(), 1.5 * time.max()]
tfall_bounds = [1.e-3, 10e10]
r_bounds = [1, 10e10]
# Parameter guess
- a_guess = 2*flux_max
+ a_guess = 2 * flux_max
b_guess = 0
t0_guess = time[imax]
-
- tfall_guess = time[imax-2:imax+2].std()/2
+
+ tfall_guess = time[imax - 2:imax + 2].std() / 2
if np.isnan(tfall_guess):
- tfall_guess = time[imax-1:imax+1].std()/2
+ tfall_guess = time[imax - 1:imax + 1].std() / 2
if np.isnan(tfall_guess):
- tfall_guess=50
- if tfall_guess<1:
- tfall_guess=50
+ tfall_guess = 50
+ if tfall_guess < 1:
+ tfall_guess = 50
r_guess = 2
# Clip guesses to stay in bound
- a_guess = np.clip(a=a_guess,a_min=a_bounds[0],a_max=a_bounds[1])
- b_guess = np.clip(a=b_guess,a_min=b_bounds[0],a_max=b_bounds[1])
- t0_guess = np.clip(a=t0_guess,a_min=t0_bounds[0],a_max=t0_bounds[1])
- tfall_guess = np.clip(a=tfall_guess,a_min=tfall_bounds[0],a_max=tfall_bounds[1])
- r_guess = np.clip(a=r_guess,a_min=r_bounds[0],a_max=r_bounds[1])
-
-
- guess = [a_guess,b_guess,t0_guess,tfall_guess,r_guess]
+ a_guess = np.clip(a=a_guess, a_min=a_bounds[0], a_max=a_bounds[1])
+ b_guess = np.clip(a=b_guess, a_min=b_bounds[0], a_max=b_bounds[1])
+ t0_guess = np.clip(a=t0_guess, a_min=t0_bounds[0], a_max=t0_bounds[1])
+ tfall_guess = np.clip(a=tfall_guess, a_min=tfall_bounds[0], a_max=tfall_bounds[1])
+ r_guess = np.clip(a=r_guess, a_min=r_bounds[0], a_max=r_bounds[1])
+ guess = [a_guess, b_guess, t0_guess, tfall_guess, r_guess]
bounds = [[a_bounds[0], b_bounds[0], t0_bounds[0], tfall_bounds[0], r_bounds[0]],
[a_bounds[1], b_bounds[1], t0_bounds[1], tfall_bounds[1], r_bounds[1]]]
-
- result = least_squares(errfunc, guess, args=(time, flux, fluxerr), method='trf', loss='linear',bounds=bounds)
-
- a_fit,b_fit,t0_fit,tfall_fit,r_fit = result.x
- trise_fit = tfall_fit/r_fit
- final_result = np.array([a_fit,b_fit,t0_fit,tfall_fit,trise_fit])
-
+
+ result = least_squares(errfunc, guess, args=(time, flux, fluxerr), method='trf', loss='linear', bounds=bounds)
+
+ a_fit, b_fit, t0_fit, tfall_fit, r_fit = result.x
+ trise_fit = tfall_fit / r_fit
+ final_result = np.array([a_fit, b_fit, t0_fit, tfall_fit, trise_fit])
+
return final_result
+
def main():
return None
diff --git a/resspect/bump.py b/resspect/utils/bump_utils.py
similarity index 88%
rename from resspect/bump.py
rename to resspect/utils/bump_utils.py
index f308067f7..e8e98200f 100644
--- a/resspect/bump.py
+++ b/resspect/utils/bump_utils.py
@@ -1,6 +1,6 @@
"""
# Author: Etienne Russeil and Emille E. O. Ishida
-#
+#
# created on 2 July 2022
#
# Licensed GNU General Public License v3.0;
@@ -22,6 +22,7 @@
__all__ = ['protected_exponent', 'protected_sig', 'bump', 'fit_bump']
+
def protected_exponent(x):
"""
Exponential function : cannot exceed e**10
@@ -34,38 +35,36 @@ def protected_sig(x):
"""
Sigmoid function using the protected exponential function
"""
- return 1/(1+protected_exponent(-x))
+ return 1 / (1 + protected_exponent(-x))
def bump(x, p1, p2, p3):
""" Parametric function, fit transient behavior
Need to fit normalised light curves (divided by maximum flux)
-
+
Parameters
----------
- x : np.array
+ x : np.array
Array of mjd translated to 0
p1,p2,p3 : floats
Parameters of the function
-
+
Returns
-------
np.array
Fitted flux array
"""
-
+
# The function is by construction meant to fit light curve centered on 40
x = x + 40
-
- return protected_sig(p1*x + p2 - protected_exponent(p3*x))
+ return protected_sig(p1 * x + p2 - protected_exponent(p3 * x))
def fit_bump(time, flux, fluxerr):
-
"""
Find best-fit parameters using iminuit least squares.
-
+
Parameters
----------
time : array_like
@@ -76,26 +75,26 @@ def fit_bump(time, flux, fluxerr):
error in response variable (flux)
Returns
-------
- output : np.ndarray of floats
+ output : np.ndarray of floats
Array is [p1, p2, p3, time_shift, max_flux]
-
+
p1, p2, p3 are best fit parameter values
time_shift is time at maximum flux
max_flux is the maximum flux
-
+
"""
-
+
# Center the maxflux at 0
- time_shift = -time[np.argmax(flux)]
+ time_shift = -time[np.argmax(flux)]
time = time + time_shift
-
+
# The function is by construction meant to fit light curve with flux normalised
max_flux = np.max(flux)
flux = flux / max_flux
fluxerr = fluxerr / max_flux
-
+
# Initial guess of the fit
- parameters_dict = {'p1':0.225, 'p2':-2.5, 'p3':0.038}
+ parameters_dict = {'p1': 0.225, 'p2': -2.5, 'p3': 0.038}
least_squares = LeastSquares(time, flux, fluxerr, bump)
fit = Minuit(least_squares, **parameters_dict)
@@ -106,10 +105,10 @@ def fit_bump(time, flux, fluxerr):
parameters = []
for fit_values in range(len(fit.values)):
parameters.append(fit.values[fit_values])
-
+
parameters.append(time_shift)
parameters.append(max_flux)
-
+
return parameters
diff --git a/tests/test_bazin.py b/tests/test_bazin.py
index 0d7d70f95..62c867f2e 100644
--- a/tests/test_bazin.py
+++ b/tests/test_bazin.py
@@ -14,7 +14,7 @@ def test_bazin():
Test the Bazin function evaluation.
"""
- from resspect import bazin
+ from resspect.utils.bazin_utils import bazin
time = 3
a = 1
@@ -33,7 +33,8 @@ def test_errfunc():
Test the error between calculates and observed error.
"""
- from resspect import bazin, errfunc
+ from resspect.utils.bazin_utils import bazin
+ from resspect.utils.bazin_utils import errfunc
# input for bazin
time = np.arange(0, 50, 3.5)
@@ -65,7 +66,7 @@ def test_fit_scipy(test_data_path):
"""
Test the scipy fit to Bazin parametrization.
"""
- from resspect import fit_scipy
+ from resspect.utils.bazin_utils import fit_scipy
fname = test_data_path / 'lc_mjd_flux.csv'
data = read_csv(fname)
diff --git a/tests/test_bump.py b/tests/test_bump.py
index c6dda80f7..e94eaf665 100644
--- a/tests/test_bump.py
+++ b/tests/test_bump.py
@@ -16,7 +16,8 @@ def test_bump():
Test the Bump function evaluation.
"""
-
+ from resspect.utils.bump_utils import bump
+
time = np.array([0])
p1 = 0.225
p2 = -2.5
@@ -31,6 +32,9 @@ def test_fit_bump(test_data_path):
"""
Test fit to Bump parametrization.
"""
+
+ from resspect.utils.bump_utils import fit_bump
+
fname = test_data_path / 'lc_mjd_flux.csv'
data = read_csv(fname)