diff --git a/pyproject.toml b/pyproject.toml index 90c9dca..68eef87 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,9 @@ template = "{tag}" dev_template = "{tag}" dirty_template = "{tag}" +[project.scripts] +labpdfproc = "diffpy.labpdfproc.labpdfprocapp:main" + [tool.setuptools.packages.find] where = ["src"] # list of folders that contain the packages (["."] by default) include = ["*"] # package names should match these glob patterns (["*"] by default) diff --git a/src/diffpy/labpdfproc/fast_cve.py b/src/diffpy/labpdfproc/fast_cve.py deleted file mode 100644 index dedc48c..0000000 --- a/src/diffpy/labpdfproc/fast_cve.py +++ /dev/null @@ -1,81 +0,0 @@ -import os - -import numpy as np -import pandas as pd -from scipy.interpolate import interp1d - -from diffpy.utils.scattering_objects.diffraction_objects import Diffraction_object - -TTH_GRID = np.arange(1, 180.1, 0.1) -MUD_LIST = [0.5, 1, 2, 3, 4, 5, 6] -CWD = os.path.dirname(os.path.abspath(__file__)) -INVERSE_CVE_DATA = np.loadtxt(CWD + "/data/inverse_cve.xy") -COEFFICIENT_LIST = np.array(pd.read_csv(CWD + "/data/coefficient_list.csv", header=None)) -INTERPOLATION_FUNCTIONS = [interp1d(MUD_LIST, coefficients, kind="quadratic") for coefficients in COEFFICIENT_LIST] - - -def fast_compute_cve(diffraction_data, mud, wavelength): - """ - use precomputed datasets to compute the cve for given diffraction data, mud and wavelength - - Parameters - ---------- - diffraction_data Diffraction_object - the diffraction pattern - mud float - the mu*D of the diffraction object, where D is the diameter of the circle - wavelength float - the wavelength of the diffraction object - - Returns - ------- - the diffraction object with cve curves - """ - - coefficient_a, coefficient_b, coefficient_c, coefficient_d, coefficient_e = [ - interpolation_function(mud) for interpolation_function in INTERPOLATION_FUNCTIONS - ] - inverse_cve = ( - coefficient_a * INVERSE_CVE_DATA**4 - + coefficient_b * INVERSE_CVE_DATA**3 - + coefficient_c * INVERSE_CVE_DATA**2 - + coefficient_d * INVERSE_CVE_DATA**1 - + coefficient_e - ) - cve = 1 / np.array(inverse_cve) - - orig_grid = diffraction_data.on_tth[0] - newcve = np.interp(orig_grid, TTH_GRID, cve) - abdo = Diffraction_object(wavelength=wavelength) - abdo.insert_scattering_quantity( - orig_grid, - newcve, - "tth", - metadata=diffraction_data.metadata, - name=f"absorption correction, cve, for {diffraction_data.name}", - wavelength=diffraction_data.wavelength, - scat_quantity="cve", - ) - - return abdo - - -def apply_fast_corr(diffraction_pattern, absorption_correction): - """ - Apply absorption correction to the given diffraction object modo with the correction diffraction object abdo - - Parameters - ---------- - diffraction_pattern Diffraction_object - the input diffraction object to which the cve will be applied - absorption_correction Diffraction_object - the diffraction object that contains the cve to be applied - - Returns - ------- - a corrected diffraction object with the correction applied through multiplication - - """ - - corrected_pattern = diffraction_pattern * absorption_correction - return corrected_pattern diff --git a/src/diffpy/labpdfproc/functions.py b/src/diffpy/labpdfproc/functions.py index 494be9b..6d80a4d 100644 --- a/src/diffpy/labpdfproc/functions.py +++ b/src/diffpy/labpdfproc/functions.py @@ -1,12 +1,23 @@ import math +from pathlib import Path import numpy as np +import pandas as pd +from scipy.interpolate import interp1d from diffpy.utils.scattering_objects.diffraction_objects import Diffraction_object RADIUS_MM = 1 N_POINTS_ON_DIAMETER = 300 -TTH_GRID = np.arange(1, 141, 1) +TTH_GRID = np.arange(1, 180.1, 0.1) +CVE_METHODS = ["brute_force", "polynomial_interpolation"] + +# pre-computed datasets for polynomial interpolation (fast calculation) +MUD_LIST = [0.5, 1, 2, 3, 4, 5, 6] +CWD = Path(__file__).parent.resolve() +MULS = np.loadtxt(CWD / "data" / "inverse_cve.xy") +COEFFICIENT_LIST = np.array(pd.read_csv(CWD / "data" / "coefficient_list.csv", header=None)) +INTERPOLATION_FUNCTIONS = [interp1d(MUD_LIST, coefficients, kind="quadratic") for coefficients in COEFFICIENT_LIST] class Gridded_circle: @@ -27,16 +38,6 @@ def _get_grid_points(self): self.grid = {(x, y) for x in xs for y in ys if x**2 + y**2 <= self.radius**2} self.total_points_in_grid = len(self.grid) - # def get_coordinate_index(self, coordinate): # I think we probably dont need this function? - # count = 0 - # for i, target in enumerate(self.grid): - # if coordinate == target: - # return i - # else: - # count += 1 - # if count >= len(self.grid): - # raise IndexError(f"WARNING: no coordinate {coordinate} found in coordinates list") - def set_distances_at_angle(self, angle): """ given an angle, set the distances from the grid points to the entry and exit coordinates @@ -172,28 +173,10 @@ def get_path_length(self, grid_point, angle): return total_distance, primary_distance, secondary_distance -def compute_cve(diffraction_data, mud, wavelength): +def _cve_brute_force(diffraction_data, mud): """ - compute the cve for given diffraction data, mud and wavelength - - Parameters - ---------- - diffraction_data Diffraction_object - the diffraction pattern - mud float - the mu*D of the diffraction object, where D is the diameter of the circle - wavelength float - the wavelength of the diffraction object - - Returns - ------- - the diffraction object with cve curves - - it is computed as follows: - We first resample data and absorption correction to a more reasonable grid, - then calculate corresponding cve for the given mud in the resample grid - (since the same mu*D yields the same cve, we can assume that D/2=1, so mu=mud/2), - and finally interpolate cve to the original grid in diffraction_data. + compute cve for the given mud on a global grid using the brute-force method + assume mu=mud/2, given that the same mu*D yields the same cve and D/2=1 """ mu_sample_invmm = mud / 2 @@ -208,9 +191,87 @@ def compute_cve(diffraction_data, mud, wavelength): muls = np.array(muls) / abs_correction.total_points_in_grid cve = 1 / muls + abdo = Diffraction_object(wavelength=diffraction_data.wavelength) + abdo.insert_scattering_quantity( + TTH_GRID, + cve, + "tth", + metadata=diffraction_data.metadata, + name=f"absorption correction, cve, for {diffraction_data.name}", + wavelength=diffraction_data.wavelength, + scat_quantity="cve", + ) + return abdo + + +def _cve_polynomial_interpolation(diffraction_data, mud): + """ + compute cve using polynomial interpolation method, raise an error if mu*D is out of the range (0.5 to 6) + """ + + if mud > 6 or mud < 0.5: + raise ValueError( + f"mu*D is out of the acceptable range (0.5 to 6) for polynomial interpolation. " + f"Please rerun with a value within this range or specifying another method from {* CVE_METHODS, }." + ) + coeff_a, coeff_b, coeff_c, coeff_d, coeff_e = [ + interpolation_function(mud) for interpolation_function in INTERPOLATION_FUNCTIONS + ] + muls = np.array(coeff_a * MULS**4 + coeff_b * MULS**3 + coeff_c * MULS**2 + coeff_d * MULS + coeff_e) + cve = 1 / muls + + abdo = Diffraction_object(wavelength=diffraction_data.wavelength) + abdo.insert_scattering_quantity( + TTH_GRID, + cve, + "tth", + metadata=diffraction_data.metadata, + name=f"absorption correction, cve, for {diffraction_data.name}", + wavelength=diffraction_data.wavelength, + scat_quantity="cve", + ) + return abdo + + +def _cve_method(method): + """ + retrieve the cve computation function for the given method + """ + methods = { + "brute_force": _cve_brute_force, + "polynomial_interpolation": _cve_polynomial_interpolation, + } + if method not in CVE_METHODS: + raise ValueError(f"Unknown method: {method}. Allowed methods are {*CVE_METHODS, }.") + return methods[method] + + +def compute_cve(diffraction_data, mud, method="polynomial_interpolation"): + """ + compute and interpolate the cve for the given diffraction data and mud using the selected method + + Parameters + ---------- + diffraction_data Diffraction_object + the diffraction pattern + mud float + the mu*D of the diffraction object, where D is the diameter of the circle + method str + the method used to calculate cve + + Returns + ------- + the diffraction object with cve curves + + """ + + cve_function = _cve_method(method) + abdo_on_global_tth = cve_function(diffraction_data, mud) + global_tth = abdo_on_global_tth.on_tth[0] + cve_on_global_tth = abdo_on_global_tth.on_tth[1] orig_grid = diffraction_data.on_tth[0] - newcve = np.interp(orig_grid, TTH_GRID, cve) - abdo = Diffraction_object(wavelength=wavelength) + newcve = np.interp(orig_grid, global_tth, cve_on_global_tth) + abdo = Diffraction_object(wavelength=diffraction_data.wavelength) abdo.insert_scattering_quantity( orig_grid, newcve, @@ -220,7 +281,6 @@ def compute_cve(diffraction_data, mud, wavelength): wavelength=diffraction_data.wavelength, scat_quantity="cve", ) - return abdo diff --git a/src/diffpy/labpdfproc/labpdfprocapp.py b/src/diffpy/labpdfproc/labpdfprocapp.py index aedbd0a..9325f20 100644 --- a/src/diffpy/labpdfproc/labpdfprocapp.py +++ b/src/diffpy/labpdfproc/labpdfprocapp.py @@ -1,7 +1,7 @@ import sys from argparse import ArgumentParser -from diffpy.labpdfproc.functions import apply_corr, compute_cve +from diffpy.labpdfproc.functions import CVE_METHODS, apply_corr, compute_cve from diffpy.labpdfproc.tools import known_sources, load_metadata, preprocessing_args from diffpy.utils.parsers.loaddata import loadData from diffpy.utils.scattering_objects.diffraction_objects import XQUANTITIES, Diffraction_object @@ -45,7 +45,7 @@ def get_args(override_cli_inputs=None): "-o", "--output-directory", help="The name of the output directory. If not specified " - "then corrected files will be written to the current directory." + "then corrected files will be written to the current directory. " "If the specified directory doesn't exist it will be created.", default=None, ) @@ -64,7 +64,6 @@ def get_args(override_cli_inputs=None): action="store_true", help="The absorption correction will be output to a file if this " "flag is set. Default is that it is not output.", - default="tth", ) p.add_argument( "-f", @@ -72,6 +71,13 @@ def get_args(override_cli_inputs=None): action="store_true", help="Outputs will not overwrite existing file unless --force is specified.", ) + p.add_argument( + "-m", + "--method", + help=f"The method for computing absorption correction. Allowed methods: {*CVE_METHODS, }. " + f"Default method is polynomial interpolation if not specified. ", + default="polynomial_interpolation", + ) p.add_argument( "-u", "--user-metadata", @@ -109,8 +115,8 @@ def main(): for filepath in args.input_paths: outfilestem = filepath.stem + "_corrected" corrfilestem = filepath.stem + "_cve" - outfile = args.output_directory / (outfilestem + ".chi") - corrfile = args.output_directory / (corrfilestem + ".chi") + outfile = args.output_directory / (outfilestem + ".xy") + corrfile = args.output_directory / (corrfilestem + ".xy") if outfile.exists() and not args.force_overwrite: sys.exit( @@ -134,7 +140,7 @@ def main(): metadata=load_metadata(args, filepath), ) - absorption_correction = compute_cve(input_pattern, args.mud, args.wavelength) + absorption_correction = compute_cve(input_pattern, args.mud, args.method) corrected_data = apply_corr(input_pattern, absorption_correction) corrected_data.name = f"Absorption corrected input_data: {input_pattern.name}" corrected_data.dump(f"{outfile}", xtype="tth") diff --git a/src/diffpy/labpdfproc/tests/test_fast_cve.py b/src/diffpy/labpdfproc/tests/test_fast_cve.py deleted file mode 100644 index 374abe4..0000000 --- a/src/diffpy/labpdfproc/tests/test_fast_cve.py +++ /dev/null @@ -1,48 +0,0 @@ -import numpy as np - -from diffpy.labpdfproc.fast_cve import apply_fast_corr, fast_compute_cve -from diffpy.utils.scattering_objects.diffraction_objects import Diffraction_object - - -def _instantiate_test_do(xarray, yarray, name="test", scat_quantity="x-ray"): - test_do = Diffraction_object(wavelength=1.54) - test_do.insert_scattering_quantity( - xarray, - yarray, - "tth", - scat_quantity=scat_quantity, - name=name, - metadata={"thing1": 1, "thing2": "thing2"}, - ) - return test_do - - -def test_fast_compute_cve(mocker): - xarray, yarray = np.array([90, 90.1, 90.2]), np.array([2, 2, 2]) - expected_cve = np.array([0.5, 0.5, 0.5]) - mocker.patch("numpy.interp", return_value=expected_cve) - input_pattern = _instantiate_test_do(xarray, yarray) - actual_abdo = fast_compute_cve(input_pattern, mud=1, wavelength=1.54) - expected_abdo = _instantiate_test_do( - xarray, - expected_cve, - name="absorption correction, cve, for test", - scat_quantity="cve", - ) - assert actual_abdo == expected_abdo - - -def test_apply_fast_corr(mocker): - xarray, yarray = np.array([90, 90.1, 90.2]), np.array([2, 2, 2]) - expected_cve = np.array([0.5, 0.5, 0.5]) - mocker.patch("numpy.interp", return_value=expected_cve) - input_pattern = _instantiate_test_do(xarray, yarray) - absorption_correction = _instantiate_test_do( - xarray, - expected_cve, - name="absorption correction, cve, for test", - scat_quantity="cve", - ) - actual_corr = apply_fast_corr(input_pattern, absorption_correction) - expected_corr = _instantiate_test_do(xarray, np.array([1, 1, 1])) - assert actual_corr == expected_corr diff --git a/src/diffpy/labpdfproc/tests/test_functions.py b/src/diffpy/labpdfproc/tests/test_functions.py index 75ef019..417a34f 100644 --- a/src/diffpy/labpdfproc/tests/test_functions.py +++ b/src/diffpy/labpdfproc/tests/test_functions.py @@ -75,7 +75,7 @@ def test_compute_cve(mocker): mocker.patch("diffpy.labpdfproc.functions.TTH_GRID", xarray) mocker.patch("numpy.interp", return_value=expected_cve) input_pattern = _instantiate_test_do(xarray, yarray) - actual_abdo = compute_cve(input_pattern, mud=1, wavelength=1.54) + actual_abdo = compute_cve(input_pattern, mud=1) expected_abdo = _instantiate_test_do( xarray, expected_cve, diff --git a/src/diffpy/labpdfproc/tests/test_tools.py b/src/diffpy/labpdfproc/tests/test_tools.py index 3b4847f..e17ed42 100644 --- a/src/diffpy/labpdfproc/tests/test_tools.py +++ b/src/diffpy/labpdfproc/tests/test_tools.py @@ -312,6 +312,7 @@ def test_load_metadata(mocker, user_filesystem): "wavelength": 0.71, "output_directory": str(Path.cwd().resolve()), "xtype": "tth", + "method": "polynomial_interpolation", "key": "value", "username": "cli_username", "email": "cli@email.com", diff --git a/src/diffpy/labpdfproc/tools.py b/src/diffpy/labpdfproc/tools.py index 5ada419..c56aa30 100644 --- a/src/diffpy/labpdfproc/tools.py +++ b/src/diffpy/labpdfproc/tools.py @@ -17,15 +17,15 @@ def set_output_directory(args): args argparse.Namespace the arguments from the parser - Returns - ------- - pathlib.PosixPath that contains the full path of the output directory - it is determined as follows: If user provides an output directory, use it. Otherwise, we set it to the current directory if nothing is provided. We then create the directory if it does not exist. + Returns + ------- + pathlib.PosixPath that contains the full path of the output directory + """ output_dir = Path(args.output_directory).resolve() if args.output_directory else Path.cwd().resolve() output_dir.mkdir(parents=True, exist_ok=True) @@ -110,13 +110,13 @@ def set_wavelength(args): args argparse.Namespace the arguments from the parser + we raise an ValueError if the input wavelength is non-positive + or if the input anode_type is not one of the known sources + Returns ------- args argparse.Namespace - we raise an ValueError if the input wavelength is non-positive - or if the input anode_type is not one of the known sources - """ if args.wavelength is not None and args.wavelength <= 0: raise ValueError(