diff --git a/src/diffpy/labpdfproc/fast_cve.py b/src/diffpy/labpdfproc/fast_cve.py deleted file mode 100644 index dedc48c..0000000 --- a/src/diffpy/labpdfproc/fast_cve.py +++ /dev/null @@ -1,81 +0,0 @@ -import os - -import numpy as np -import pandas as pd -from scipy.interpolate import interp1d - -from diffpy.utils.scattering_objects.diffraction_objects import Diffraction_object - -TTH_GRID = np.arange(1, 180.1, 0.1) -MUD_LIST = [0.5, 1, 2, 3, 4, 5, 6] -CWD = os.path.dirname(os.path.abspath(__file__)) -INVERSE_CVE_DATA = np.loadtxt(CWD + "/data/inverse_cve.xy") -COEFFICIENT_LIST = np.array(pd.read_csv(CWD + "/data/coefficient_list.csv", header=None)) -INTERPOLATION_FUNCTIONS = [interp1d(MUD_LIST, coefficients, kind="quadratic") for coefficients in COEFFICIENT_LIST] - - -def fast_compute_cve(diffraction_data, mud, wavelength): - """ - use precomputed datasets to compute the cve for given diffraction data, mud and wavelength - - Parameters - ---------- - diffraction_data Diffraction_object - the diffraction pattern - mud float - the mu*D of the diffraction object, where D is the diameter of the circle - wavelength float - the wavelength of the diffraction object - - Returns - ------- - the diffraction object with cve curves - """ - - coefficient_a, coefficient_b, coefficient_c, coefficient_d, coefficient_e = [ - interpolation_function(mud) for interpolation_function in INTERPOLATION_FUNCTIONS - ] - inverse_cve = ( - coefficient_a * INVERSE_CVE_DATA**4 - + coefficient_b * INVERSE_CVE_DATA**3 - + coefficient_c * INVERSE_CVE_DATA**2 - + coefficient_d * INVERSE_CVE_DATA**1 - + coefficient_e - ) - cve = 1 / np.array(inverse_cve) - - orig_grid = diffraction_data.on_tth[0] - newcve = np.interp(orig_grid, TTH_GRID, cve) - abdo = Diffraction_object(wavelength=wavelength) - abdo.insert_scattering_quantity( - orig_grid, - newcve, - "tth", - metadata=diffraction_data.metadata, - name=f"absorption correction, cve, for {diffraction_data.name}", - wavelength=diffraction_data.wavelength, - scat_quantity="cve", - ) - - return abdo - - -def apply_fast_corr(diffraction_pattern, absorption_correction): - """ - Apply absorption correction to the given diffraction object modo with the correction diffraction object abdo - - Parameters - ---------- - diffraction_pattern Diffraction_object - the input diffraction object to which the cve will be applied - absorption_correction Diffraction_object - the diffraction object that contains the cve to be applied - - Returns - ------- - a corrected diffraction object with the correction applied through multiplication - - """ - - corrected_pattern = diffraction_pattern * absorption_correction - return corrected_pattern diff --git a/src/diffpy/labpdfproc/functions.py b/src/diffpy/labpdfproc/functions.py index 03b7061..d423468 100644 --- a/src/diffpy/labpdfproc/functions.py +++ b/src/diffpy/labpdfproc/functions.py @@ -1,12 +1,23 @@ import math +from pathlib import Path import numpy as np +import pandas as pd +from scipy.interpolate import interp1d from diffpy.utils.scattering_objects.diffraction_objects import Diffraction_object RADIUS_MM = 1 N_POINTS_ON_DIAMETER = 300 -TTH_GRID = np.arange(1, 141, 1) +TTH_GRID = np.arange(1, 180.1, 0.1) +CVE_METHODS = ["brute_force", "polynomial_interpolation"] + +# pre-computed datasets for polynomial interpolation (fast calculation) +MUD_LIST = [0.5, 1, 2, 3, 4, 5, 6] +CWD = Path(__file__).parent.resolve() +MULS = np.loadtxt(CWD / "data" / "inverse_cve.xy") +COEFFICIENT_LIST = np.array(pd.read_csv(CWD / "data" / "coefficient_list.csv", header=None)) +INTERPOLATION_FUNCTIONS = [interp1d(MUD_LIST, coefficients, kind="quadratic") for coefficients in COEFFICIENT_LIST] class Gridded_circle: @@ -162,28 +173,10 @@ def get_path_length(self, grid_point, angle): return total_distance, primary_distance, secondary_distance -def compute_cve(diffraction_data, mud, wavelength): +def _cve_brute_force(diffraction_data, mud): """ - compute the cve for given diffraction data, mud and wavelength - - Parameters - ---------- - diffraction_data Diffraction_object - the diffraction pattern - mud float - the mu*D of the diffraction object, where D is the diameter of the circle - wavelength float - the wavelength of the diffraction object - - Returns - ------- - the diffraction object with cve curves - - it is computed as follows: - We first resample data and absorption correction to a more reasonable grid, - then calculate corresponding cve for the given mud in the resample grid - (since the same mu*D yields the same cve, we can assume that D/2=1, so mu=mud/2), - and finally interpolate cve to the original grid in diffraction_data. + compute cve for the given mud on a global grid using the brute-force method + assume mu=mud/2, given that the same mu*D yields the same cve and D/2=1 """ mu_sample_invmm = mud / 2 @@ -198,10 +191,86 @@ def compute_cve(diffraction_data, mud, wavelength): muls = np.array(muls) / abs_correction.total_points_in_grid cve = 1 / muls + cve_do = Diffraction_object(wavelength=diffraction_data.wavelength) + cve_do.insert_scattering_quantity( + TTH_GRID, + cve, + "tth", + metadata=diffraction_data.metadata, + name=f"absorption correction, cve, for {diffraction_data.name}", + wavelength=diffraction_data.wavelength, + scat_quantity="cve", + ) + return cve_do + + +def _cve_polynomial_interpolation(diffraction_data, mud): + """ + compute cve using polynomial interpolation method, raise an error if mu*D is out of the range (0.5 to 6) + """ + + if mud > 6 or mud < 0.5: + raise ValueError( + f"mu*D is out of the acceptable range (0.5 to 6) for polynomial interpolation. " + f"Please rerun with a value within this range or specifying another method from {* CVE_METHODS, }." + ) + coeff_a, coeff_b, coeff_c, coeff_d, coeff_e = [ + interpolation_function(mud) for interpolation_function in INTERPOLATION_FUNCTIONS + ] + muls = np.array(coeff_a * MULS**4 + coeff_b * MULS**3 + coeff_c * MULS**2 + coeff_d * MULS + coeff_e) + cve = 1 / muls + + cve_do = Diffraction_object(wavelength=diffraction_data.wavelength) + cve_do.insert_scattering_quantity( + TTH_GRID, + cve, + "tth", + metadata=diffraction_data.metadata, + name=f"absorption correction, cve, for {diffraction_data.name}", + wavelength=diffraction_data.wavelength, + scat_quantity="cve", + ) + return cve_do + + +def _cve_method(method): + """ + retrieve the cve computation function for the given method + """ + methods = { + "brute_force": _cve_brute_force, + "polynomial_interpolation": _cve_polynomial_interpolation, + } + if method not in CVE_METHODS: + raise ValueError(f"Unknown method: {method}. Allowed methods are {*CVE_METHODS, }.") + return methods[method] + + +def compute_cve(diffraction_data, mud, method="polynomial_interpolation"): + f""" + compute and interpolate the cve for the given diffraction data and mud using the selected method + Parameters + ---------- + diffraction_data Diffraction_object + the diffraction pattern + mud float + the mu*D of the diffraction object, where D is the diameter of the circle + method str + the method used to calculate cve, must be one of {* CVE_METHODS, } + + Returns + ------- + the diffraction object with cve curves + """ + + cve_function = _cve_method(method) + abdo_on_global_tth = cve_function(diffraction_data, mud) + global_tth = abdo_on_global_tth.on_tth[0] + cve_on_global_tth = abdo_on_global_tth.on_tth[1] orig_grid = diffraction_data.on_tth[0] - newcve = np.interp(orig_grid, TTH_GRID, cve) - abdo = Diffraction_object(wavelength=wavelength) - abdo.insert_scattering_quantity( + newcve = np.interp(orig_grid, global_tth, cve_on_global_tth) + cve_do = Diffraction_object(wavelength=diffraction_data.wavelength) + cve_do.insert_scattering_quantity( orig_grid, newcve, "tth", @@ -211,7 +280,7 @@ def compute_cve(diffraction_data, mud, wavelength): scat_quantity="cve", ) - return abdo + return cve_do def apply_corr(diffraction_pattern, absorption_correction): diff --git a/src/diffpy/labpdfproc/labpdfprocapp.py b/src/diffpy/labpdfproc/labpdfprocapp.py index 986a95d..a9f5d58 100644 --- a/src/diffpy/labpdfproc/labpdfprocapp.py +++ b/src/diffpy/labpdfproc/labpdfprocapp.py @@ -1,7 +1,7 @@ import sys from argparse import ArgumentParser -from diffpy.labpdfproc.functions import apply_corr, compute_cve +from diffpy.labpdfproc.functions import CVE_METHODS, apply_corr, compute_cve from diffpy.labpdfproc.tools import known_sources, load_metadata, preprocessing_args from diffpy.utils.parsers.loaddata import loadData from diffpy.utils.scattering_objects.diffraction_objects import XQUANTITIES, Diffraction_object @@ -13,57 +13,69 @@ def get_args(override_cli_inputs=None): p.add_argument( "input", nargs="+", - help="The filename(s) or folder(s) of the datafile(s) to load. " - "Required.\nSupply a space-separated list of files or directories. " - "Long lists can be supplied, one per line, in a file with name " - "file_list.txt. If one or more directory is provided, all valid " - "data-files in that directory will be processed. Examples of valid " - "inputs are 'file.xy', 'data/file.xy', 'file.xy, data/file.xy', " - "'.' (load everything in the current directory), 'data' (load " - "everything in the folder ./data), 'data/file_list.txt' (load " - "the list of files contained in the text-file called " - "file_list.txt that can be found in the folder ./data), " - "'./*.chi', 'data/*.chi' (load all files with extension .chi in the " - "folder ./data).", + help=( + "The filename(s) or folder(s) of the datafile(s) to load. " + "Required.\nSupply a space-separated list of files or directories. " + "Long lists can be supplied, one per line, in a file with name " + "file_list.txt. If one or more directory is provided, all valid " + "data-files in that directory will be processed. Examples of valid " + "inputs are 'file.xy', 'data/file.xy', 'file.xy, data/file.xy', " + "'.' (load everything in the current directory), 'data' (load " + "everything in the folder ./data), 'data/file_list.txt' (load " + "the list of files contained in the text-file called " + "file_list.txt that can be found in the folder ./data), " + "'./*.chi', 'data/*.chi' (load all files with extension .chi in the " + "folder ./data)." + ), ) p.add_argument( "-a", "--anode-type", - help=f"The type of the x-ray source. Allowed values are " - f"{*[known_sources], }. Either specify a known x-ray source or specify wavelength.", + help=( + f"The type of the x-ray source. Allowed values are " + f"{*[known_sources], }. Either specify a known x-ray source or specify wavelength." + ), default="Mo", ) p.add_argument( "-w", "--wavelength", - help="X-ray source wavelength in angstroms. Not needed if the anode-type " - "is specified. This wavelength will override the anode wavelength if both are specified.", + help=( + "X-ray source wavelength in angstroms. Not needed if the anode-type " + "is specified. This wavelength will override the anode wavelength if both are specified." + ), default=None, type=float, ) p.add_argument( "-o", "--output-directory", - help="The name of the output directory. If not specified " - "then corrected files will be written to the current directory. " - "If the specified directory doesn't exist it will be created.", + help=( + "The name of the output directory. If not specified " + "then corrected files will be written to the current directory. " + "If the specified directory doesn't exist it will be created." + ), default=None, ) p.add_argument( "-x", "--xtype", - help=f"The quantity on the independent variable axis. Allowed " - f"values: {*XQUANTITIES, }. If not specified then two-theta " - f"is assumed for the independent variable. Only implemented for " - f"tth currently.", + help=( + f"The quantity on the independent variable axis. Allowed " + f"values: {*XQUANTITIES, }. If not specified then two-theta " + f"is assumed for the independent variable. Only implemented for " + f"tth currently." + ), default="tth", ) p.add_argument( "-c", "--output-correction", action="store_true", - help="The absorption correction will be output to a file if this " - "flag is set. Default is that it is not output.", + help=( + "The absorption correction will be output to a file if this " + "flag is set. Default is that it is not output." + ), ) p.add_argument( "-f", @@ -71,30 +83,45 @@ def get_args(override_cli_inputs=None): action="store_true", help="Outputs will not overwrite existing file unless --force is specified.", ) + p.add_argument( + "-m", + "--method", + help=( + f"The method for computing absorption correction. Allowed methods: {*CVE_METHODS, }. " + f"Default method is polynomial interpolation if not specified. " + ), + default="polynomial_interpolation", + ) p.add_argument( "-u", "--user-metadata", metavar="KEY=VALUE", nargs="+", - help="Specify key-value pairs to be loaded into metadata using the format key=value. " - "Separate pairs with whitespace, and ensure no whitespaces before or after the = sign. " - "Avoid using = in keys. If multiple = signs are present, only the first separates the key and value. " - "If a key or value contains whitespace, enclose it in quotes. " - "For example, facility='NSLS II', 'facility=NSLS II', beamline=28ID-2, " - "'beamline'='28ID-2', 'favorite color'=blue, are all valid key=value items. ", + help=( + "Specify key-value pairs to be loaded into metadata using the format key=value. " + "Separate pairs with whitespace, and ensure no whitespaces before or after the = sign. " + "Avoid using = in keys. If multiple = signs are present, only the first separates the key and value. " + "If a key or value contains whitespace, enclose it in quotes. " + "For example, facility='NSLS II', 'facility=NSLS II', beamline=28ID-2, " + "'beamline'='28ID-2', 'favorite color'=blue, are all valid key=value items. " + ), ) p.add_argument( "-n", "--username", - help="Username will be loaded from config files. Specify here " - "only if you want to override that behavior at runtime. ", + help=( + "Username will be loaded from config files. Specify here " + "only if you want to override that behavior at runtime. " + ), default=None, ) p.add_argument( "-e", "--email", - help="Email will be loaded from config files. Specify here " - "only if you want to override that behavior at runtime. ", + help=( + "Email will be loaded from config files. Specify here " + "only if you want to override that behavior at runtime. " + ), default=None, ) args = p.parse_args(override_cli_inputs) @@ -133,7 +160,7 @@ def main(): metadata=load_metadata(args, filepath), ) - absorption_correction = compute_cve(input_pattern, args.mud, args.wavelength) + absorption_correction = compute_cve(input_pattern, args.mud, args.method) corrected_data = apply_corr(input_pattern, absorption_correction) corrected_data.name = f"Absorption corrected input_data: {input_pattern.name}" corrected_data.dump(f"{outfile}", xtype="tth") diff --git a/src/diffpy/labpdfproc/tests/test_fast_cve.py b/src/diffpy/labpdfproc/tests/test_fast_cve.py deleted file mode 100644 index 374abe4..0000000 --- a/src/diffpy/labpdfproc/tests/test_fast_cve.py +++ /dev/null @@ -1,48 +0,0 @@ -import numpy as np - -from diffpy.labpdfproc.fast_cve import apply_fast_corr, fast_compute_cve -from diffpy.utils.scattering_objects.diffraction_objects import Diffraction_object - - -def _instantiate_test_do(xarray, yarray, name="test", scat_quantity="x-ray"): - test_do = Diffraction_object(wavelength=1.54) - test_do.insert_scattering_quantity( - xarray, - yarray, - "tth", - scat_quantity=scat_quantity, - name=name, - metadata={"thing1": 1, "thing2": "thing2"}, - ) - return test_do - - -def test_fast_compute_cve(mocker): - xarray, yarray = np.array([90, 90.1, 90.2]), np.array([2, 2, 2]) - expected_cve = np.array([0.5, 0.5, 0.5]) - mocker.patch("numpy.interp", return_value=expected_cve) - input_pattern = _instantiate_test_do(xarray, yarray) - actual_abdo = fast_compute_cve(input_pattern, mud=1, wavelength=1.54) - expected_abdo = _instantiate_test_do( - xarray, - expected_cve, - name="absorption correction, cve, for test", - scat_quantity="cve", - ) - assert actual_abdo == expected_abdo - - -def test_apply_fast_corr(mocker): - xarray, yarray = np.array([90, 90.1, 90.2]), np.array([2, 2, 2]) - expected_cve = np.array([0.5, 0.5, 0.5]) - mocker.patch("numpy.interp", return_value=expected_cve) - input_pattern = _instantiate_test_do(xarray, yarray) - absorption_correction = _instantiate_test_do( - xarray, - expected_cve, - name="absorption correction, cve, for test", - scat_quantity="cve", - ) - actual_corr = apply_fast_corr(input_pattern, absorption_correction) - expected_corr = _instantiate_test_do(xarray, np.array([1, 1, 1])) - assert actual_corr == expected_corr diff --git a/src/diffpy/labpdfproc/tests/test_functions.py b/src/diffpy/labpdfproc/tests/test_functions.py index 75ef019..ac086a5 100644 --- a/src/diffpy/labpdfproc/tests/test_functions.py +++ b/src/diffpy/labpdfproc/tests/test_functions.py @@ -1,7 +1,9 @@ +import re + import numpy as np import pytest -from diffpy.labpdfproc.functions import Gridded_circle, apply_corr, compute_cve +from diffpy.labpdfproc.functions import CVE_METHODS, Gridded_circle, apply_corr, compute_cve from diffpy.utils.scattering_objects.diffraction_objects import Diffraction_object params1 = [ @@ -75,14 +77,38 @@ def test_compute_cve(mocker): mocker.patch("diffpy.labpdfproc.functions.TTH_GRID", xarray) mocker.patch("numpy.interp", return_value=expected_cve) input_pattern = _instantiate_test_do(xarray, yarray) - actual_abdo = compute_cve(input_pattern, mud=1, wavelength=1.54) - expected_abdo = _instantiate_test_do( + actual_cve_do = compute_cve(input_pattern, mud=1) + expected_cve_do = _instantiate_test_do( xarray, expected_cve, name="absorption correction, cve, for test", scat_quantity="cve", ) - assert actual_abdo == expected_abdo + assert actual_cve_do == expected_cve_do + + +params_cve_bad = [ + ( + [7, "polynomial_interpolation"], + [ + f"mu*D is out of the acceptable range (0.5 to 6) for polynomial interpolation. " + f"Please rerun with a value within this range or specifying another method from {* CVE_METHODS, }." + ], + ), + ([1, "invalid_method"], [f"Unknown method: invalid_method. Allowed methods are {*CVE_METHODS, }."]), + ([7, "invalid_method"], [f"Unknown method: invalid_method. Allowed methods are {*CVE_METHODS, }."]), +] + + +@pytest.mark.parametrize("inputs, msg", params_cve_bad) +def test_compute_cve_bad(mocker, inputs, msg): + xarray, yarray = np.array([90, 90.1, 90.2]), np.array([2, 2, 2]) + expected_cve = np.array([0.5, 0.5, 0.5]) + mocker.patch("diffpy.labpdfproc.functions.TTH_GRID", xarray) + mocker.patch("numpy.interp", return_value=expected_cve) + input_pattern = _instantiate_test_do(xarray, yarray) + with pytest.raises(ValueError, match=re.escape(msg[0])): + compute_cve(input_pattern, mud=inputs[0], method=inputs[1]) def test_apply_corr(mocker): diff --git a/src/diffpy/labpdfproc/tests/test_tools.py b/src/diffpy/labpdfproc/tests/test_tools.py index 3b4847f..e17ed42 100644 --- a/src/diffpy/labpdfproc/tests/test_tools.py +++ b/src/diffpy/labpdfproc/tests/test_tools.py @@ -312,6 +312,7 @@ def test_load_metadata(mocker, user_filesystem): "wavelength": 0.71, "output_directory": str(Path.cwd().resolve()), "xtype": "tth", + "method": "polynomial_interpolation", "key": "value", "username": "cli_username", "email": "cli@email.com",