Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog/474.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add a fast method `~ndcube.NDCube.axis_world_coords_limits` to find wcs extension in world coordinates.
86 changes: 84 additions & 2 deletions ndcube/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,56 @@ def wcs_3d_lt_ln_l():
return WCS(header=header)


@pytest.fixture
def wcs_3d_l_ra_dec():
header = {
'CTYPE1': 'WAVE ',
'CUNIT1': 'Angstrom',
'CDELT1': 0.2,
'CRPIX1': 0,
'CRVAL1': 10,

'CTYPE2': 'RA---TAN',
'CUNIT2': 'deg',
'CDELT2': 0.1,
'CRPIX2': 200,
'CRVAL2': 90.5,

'CTYPE3': 'DEC--TAN',
'CUNIT3': 'deg',
'CDELT3': 0.08,
'CRPIX3': 150,
'CRVAL3': 30.4,
}

return WCS(header=header)


@pytest.fixture
def wcs_3d_l_ra_pol():
header = {
'CTYPE1': 'WAVE ',
'CUNIT1': 'Angstrom',
'CDELT1': 0.2,
'CRPIX1': 0,
'CRVAL1': 10,

'CTYPE2': 'RA---TAN',
'CUNIT2': 'deg',
'CDELT2': 0.1,
'CRPIX2': 200,
'CRVAL2': 90.5,

'CTYPE3': 'DEC--TAN',
'CUNIT3': 'deg',
'CDELT3': 0.08,
'CRPIX3': 150,
'CRVAL3': 80.4,
}

return WCS(header=header)


@pytest.fixture
def wcs_2d_lt_ln():
spatial = {
Expand Down Expand Up @@ -437,6 +487,38 @@ def ndcube_3d_ln_lt_l_ec_time(wcs_3d_l_lt_ln, time_and_simple_extra_coords_2d):
return cube


@pytest.fixture
def ndcube_3d_l_ra_dec(wcs_3d_l_ra_dec, simple_extra_coords_3d):
shape = (400, 300, 10)
wcs_3d_l_ra_dec.array_shape = shape
data = data_nd(shape)
mask = data > 0
cube = NDCube(
data,
wcs_3d_l_ra_dec,
mask=mask,
uncertainty=data,
)
cube._extra_coords = simple_extra_coords_3d
return cube


@pytest.fixture
def ndcube_3d_l_ra_pol(wcs_3d_l_ra_pol, simple_extra_coords_3d):
shape = (400, 300, 10)
wcs_3d_l_ra_pol.array_shape = shape
data = data_nd(shape)
mask = data > 0
cube = NDCube(
data,
wcs_3d_l_ra_pol,
mask=mask,
uncertainty=data,
)
cube._extra_coords = simple_extra_coords_3d
return cube


@pytest.fixture
def ndcube_3d_rotated(wcs_3d_ln_lt_t_rotated, simple_extra_coords_3d):
data_rotated = np.array([[[1, 2, 3, 4, 6], [2, 4, 5, 3, 1], [0, -1, 2, 4, 2], [3, 5, 1, 2, 0]],
Expand All @@ -455,8 +537,8 @@ def ndcube_3d_rotated(wcs_3d_ln_lt_t_rotated, simple_extra_coords_3d):
@pytest.fixture
def ndcube_3d_l_ln_lt_ectime(wcs_3d_lt_ln_l):
return gen_ndcube_3d_l_ln_lt_ectime(wcs_3d_lt_ln_l,
1,
Time('2000-01-01', format='fits', scale='utc'))
1, Time('2000-01-01', format='fits', scale='utc'))



@pytest.fixture
Expand Down
171 changes: 167 additions & 4 deletions ndcube/ndcube.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import textwrap
import warnings
from copy import deepcopy
from itertools import product as iterprod
from collections import namedtuple
from collections.abc import Mapping

Expand All @@ -23,6 +24,7 @@
from ndcube.global_coords import GlobalCoords
from ndcube.mixins import NDCubeSlicingMixin
from ndcube.ndcube_sequence import NDCubeSequence
from ndcube.utils.wcs import get_dependent_world_axes
from ndcube.utils.wcs_high_level_conversion import values_to_high_level_objects
from ndcube.visualization import PlotterDescriptor
from ndcube.wcs.wrappers import CompoundLowLevelWCS
Expand Down Expand Up @@ -309,17 +311,37 @@ def array_axis_physical_types(self):
return [tuple(world_axis_physical_types[axis_correlation_matrix[:, i]])
for i in range(axis_correlation_matrix.shape[1])][::-1]

def _generate_world_coords(self, pixel_corners, wcs):
def _generate_world_coords(self, pixel_corners, wcs, get_sizes=False):
# TODO: We can improve this by not always generating all coordinates
# To make our lives easier here we generate all the coordinates for all
# pixels and then choose the ones we want to return to the user based
# on the axes argument. We could be smarter by integrating this logic
# into the main loop, this would potentially reduce the number of calls
# to pixel_to_world_values
"""
Create meshgrid of all pixels transformed to world coordinates.

Parameters
----------
pixel_corners: `bool`
If `True` then instead of returning the coordinates at the centers of the
pixels, the coordinates at the pixel corners will be returned. This increases
the size of the output by 1 in all dimensions as all corners are returned.

wcs : `~astropy.wcs.wcsapi.BaseHighLevelWCS`, `~astropy.wcs.wcsapi.BaseLowLevelWCS`
The WCS to use to convert pixel values to world coordinates.

get_sizes: `bool`, optional
Only calculate sizes of all separate coordinate arrays without creating and
transforming them.

Returns
-------
world_coords: `list`
An iterable of `Quantity` objects representing the real world coordinates in
all axes; or, for `get_sizes=True`, of sizes of all separate coordinate arrays.
"""

# Create meshgrid of all pixel coordinates.
# If user, wants pixel_corners, set pixel values to pixel pixel_corners.
# Else make pixel centers.
pixel_shape = self.data.shape[::-1]
if pixel_corners:
pixel_shape = tuple(np.array(pixel_shape) + 1)
Expand All @@ -334,6 +356,8 @@ def _generate_world_coords(self, pixel_corners, wcs):
if wcs is None:
return []

if get_sizes:
world_sizes = []
world_coords = [None] * wcs.world_n_dim
for (pixel_axes_indices, world_axes_indices) in _split_matrix(wcs.axis_correlation_matrix):
# First construct a range of pixel indices for this set of coupled dimensions
Expand All @@ -343,6 +367,10 @@ def _generate_world_coords(self, pixel_corners, wcs):
# And inject 0s for those coordinates
for idx in non_corr_axes:
sub_range.insert(idx, 0)
# If requested, only calculate and return sizes
if get_sizes:
world_sizes.append(np.prod([np.array([r]).size for r in sub_range]))
continue
# Generate a grid of broadcastable pixel indices for all pixel dimensions
grid = np.meshgrid(*sub_range, indexing='ij')
# Convert to world coordinates
Expand All @@ -358,6 +386,9 @@ def _generate_world_coords(self, pixel_corners, wcs):
tmp_world = world[idx][tuple(array_slice)].T
world_coords[idx] = tmp_world

if get_sizes:
return world_sizes

for i, (coord, unit) in enumerate(zip(world_coords, wcs.world_axis_units)):
world_coords[i] = coord << u.Unit(unit)

Expand Down Expand Up @@ -507,6 +538,132 @@ def axis_world_coords_values(self, *axes, pixel_corners=False, wcs=None):
CoordValues = namedtuple("CoordValues", identifiers)
return CoordValues(*axes_coords[::-1])

@utils.cube.sanitize_wcs
def axis_world_coords_limits(self, *axes, pixel_corners=False, wcs=None, max_size=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My initial thought is I wonder if it would be better if this starts off as private API? I am not sure of it's utility to the general users?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, at the start I was relatively indifferent whether it should be public or private, just thought some general uses might spring from it as it develops. The speed differences probably won't matter there, but the output can be much more compact than axis_world_coords.

"""
Returns (estimated) extrema of the WCS coordinate values for all axes.

Parameters
----------
axes: `int` or `str`, or multiple `int` or `str`, optional
Axis number in numpy ordering or unique substring of
`~ndcube.NDCube.world_axis_physical_types`
of axes for which real world coordinates are desired.
axes=None implies all axes will be returned.

pixel_corners: `bool`, optional
If `True` then instead of returning the limits for the centers of the pixels,
the limits for the pixel corners will be returned. This increases the resulting
size over the limits in all dimensions as all corner positions are included.

wcs: `astropy.wcs.wcsapi.BaseHighLevelWCS`, optional
The WCS object used to calculate the world coordinates.
Although technically this can be any valid WCS, it will typically be
``self.wcs``, ``self.extra_coords``, or ``self.combined_wcs`` which combines both
the WCS and extra coords.
Defaults to the ``.wcs`` property.

max_size: `int`, optional
Sets the maximum size of the pixel grid for which world coordinates will be
calculated in full to determine the extrema. If this size is exceeded, for
faster evaluation only the corners in pixel space, plus if possible, the axes
at the reference pixel values as given by ``wcs.wcs.crpix``, are considered.

Returns
-------
axes_coords: `list`
An iterable of "high level" objects containing the minima and maxima in
real world coordinates for the axes requested by user.
For example, a tuple of `~astropy.coordinates.SkyCoord` objects.
The types returned are determined by the WCS object.
These objects will have length 2 along each axis.

Example
-------
>>> NDCube.axis_world_coords_limits('lat', 'lon', max_size=100000) # doctest: +SKIP
>>> NDCube.axis_world_coords_limits(2) # doctest: +SKIP

"""
# Cannot use naxis and array_shape to construct bounding box for
# extra_coords or combined_wcs, so for now force using the full wcs for those.
if isinstance(wcs, (ExtraCoords, HighLevelWCSWrapper)) or wcs.array_shape is None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are potentially many situations where we would be passed a HighLevelWCSWrapper, not just those which you list in the comment (mainly a wrapped SlicedLowLevelWCS). A more nuanced check would be to extract the low level object first and then dispatch on the type of that instead.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure how to get an overview of what could potentially be passed there.

max_size = None
if isinstance(wcs, BaseHighLevelWCS):
wcs = wcs.low_level_wcs

if max_size is not None:
full_size = np.sum(self._generate_world_coords(pixel_corners, wcs, get_sizes=True))

# Check if we only probe pixel bounding box to speed up computation.
if max_size is None or full_size <= max_size:
axes_coords = self._generate_world_coords(pixel_corners, wcs)
else:
if pixel_corners:
lower = np.ones(wcs.naxis) * -0.5
upper = np.array(wcs.array_shape[::-1]) - 0.5
else:
lower = np.zeros(wcs.naxis)
upper = np.array(wcs.array_shape[::-1]) - 1
bbox = np.array(self._bounding_box_to_points(lower, upper, wcs)).T
# If wcs has a FITS-type Wcsprm, try to include CRPIX axes
try:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure I am following exactly what this block is doing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The one below?
It takes the world coordinate at crpix[i] in the respective dimension, and creates a grid over the other dependent axes (will probably be usually only one) through that point.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example in

coords = ndcube_3d_l_ra_pol.axis_world_coords_limits(max_size=10000)
assert u.allclose(coords[0], [1.02e-09, 1.20e-09] * u.m)
assert u.allclose(coords[1].ra, [26.1484, 333.4424] * u.deg)
assert u.allclose(coords[1].dec, [61.8572, 89.98945] * u.deg)
coords = ndcube_3d_l_ra_pol.axis_world_coords_limits(max_size=10)
assert u.allclose(coords[0], [1.02e-09, 1.20e-09] * u.m)
assert u.allclose(coords[1].ra, [26.1484, 333.4424] * u.deg)
assert u.allclose(coords[1].dec, [61.8572, 80.429] * u.deg)

which is working on wcs_3d_l_ra_pol with CRPIX2=200 at CRVAL2=90.5 and CRPIX3 at CRVAL3=80.4, it is spanning a line through all declinations at RA=90.5 deg and all Right Ascensions at Dec=80.4 deg each.
This is covering the point closest to the pole, or very nearly, while the last call, which is skipping this block as it requires more than max_size=10 points, only returns the world coordinates at the bbox corners.
It's of course arguable which is really more useful as limit values...

bbox_l = [bbox]
for ax, pix in enumerate(wcs.wcs.crpix):
sub_range = np.maximum(wcs.wcs.crpix - 1, 0).tolist()
for dwa in get_dependent_world_axes(ax, wcs.axis_correlation_matrix):
if dwa != ax:
sub_range[ax] = np.arange(lower[ax], upper[ax])
# Generate a grid of broadcastable pixel indices for dependent axes
grid = np.meshgrid(*sub_range, indexing='ij')
# Check if size of the subgrid now exceeds limit; in that case cut all
# dependent axes to edge values (perhaps should undersample range instead?)
if grid[0].size > max_size:
for i, r in enumerate(sub_range):
if np.size(r) > 2:
sub_range[i] = r[[0, -1]]
grid = np.meshgrid(*sub_range, indexing='ij')
grid = np.array(grid).squeeze()
if grid.ndim == bbox.ndim:
bbox_l.append(grid)
bbox = np.concatenate(bbox_l, axis=1)
except AttributeError:
pass
# axes_coords = [None] * wcs.world_n_dim
# Convert to world coordinates
axes_coords = list(wcs.pixel_to_world_values(*bbox))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will have to get @DanRyanIrish to check me on this, but I think there is a subset of cases where this _bounding_box_to_points method generates invalid inputs to pixel_to_world.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that's from the suggestion in #413 (comment), but the link there must have pointed to an older revision.

for i, (coord, unit) in enumerate(zip(axes_coords, wcs.world_axis_units)):
axes_coords[i] = coord << u.Unit(unit)

axes_limits = []
for ac in axes_coords:
ac = ac[np.isfinite(ac)]
axes_limits.append(u.Quantity([ac.min(), ac.max()]))

if isinstance(wcs, ExtraCoords):
wcs = wcs.wcs

axes_coords = values_to_high_level_objects(*axes_limits, low_level_wcs=wcs)

if not axes:
return tuple(axes_coords)

object_names = np.array([wao_comp[0] for wao_comp in wcs.world_axis_object_components])
unique_obj_names = utils.misc.unique_sorted(object_names)
world_axes_for_obj = [np.where(object_names == name)[0] for name in unique_obj_names]

# Create a mapping from world index in the WCS to object index in axes_coords
world_index_to_object_index = {}
for object_index, world_axes in enumerate(world_axes_for_obj):
for world_index in world_axes:
world_index_to_object_index[world_index] = object_index

world_indices = utils.wcs.calculate_world_indices_from_axes(wcs, axes)
object_indices = utils.misc.unique_sorted(
[world_index_to_object_index[world_index] for world_index in world_indices]
)

return tuple(axes_coords[i] for i in object_indices)
Comment on lines +650 to +665
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this identical to the logic at the end of axis_world_coords? if so should we move it into a shared method?

Copy link
Contributor Author

@dhomeier dhomeier Oct 29, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Much of the code is; I was just hesitant to overload axis_world_coords (like implementing the whole thing as axis_world_coords(limits=True)).


def crop(self, *points, wcs=None):
# The docstring is defined in NDCubeABC
# Calculate the array slice item corresponding to bounding box and return sliced cube.
Expand Down Expand Up @@ -556,6 +713,12 @@ def _get_crop_by_values_item(self, *points, units=None, wcs=None):

return utils.cube.get_crop_item_from_points(points, wcs, True)

def _bounding_box_to_points(self, lower_corner_values, upper_corner_values, wcs):
"""
Convert two corners of a bounding box to the points of all corners.
"""
return tuple(iterprod(*zip(lower_corner_values, upper_corner_values)))

def __str__(self):
return textwrap.dedent(f"""\
NDCube
Expand Down
Loading