diff --git a/src/mdio/schemas/v1/templates/seismic_3d_prestack_coca.py b/src/mdio/schemas/v1/templates/seismic_3d_prestack_coca.py new file mode 100644 index 000000000..0903fc770 --- /dev/null +++ b/src/mdio/schemas/v1/templates/seismic_3d_prestack_coca.py @@ -0,0 +1,76 @@ +"""Seismic3DPreStackCocaTemplate MDIO v1 dataset templates.""" + +from mdio.schemas.dtype import ScalarType +from mdio.schemas.metadata import UserAttributes +from mdio.schemas.v1.templates.abstract_dataset_template import AbstractDatasetTemplate +from mdio.schemas.v1.units import AllUnits + + +class Seismic3DPreStackCocaTemplate(AbstractDatasetTemplate): + """Seismic Shot pre-stack 3D time or depth Dataset template.""" + + def __init__(self, domain: str): + super().__init__(domain=domain) + + self._coord_dim_names = ["inline", "crossline", "offset", "azimuth"] + self._dim_names = [*self._coord_dim_names, self._trace_domain] + self._coord_names = ["cdp_x", "cdp_y"] + self._var_chunk_shape = [8, 8, 32, 1, 1024] + + @property + def _name(self) -> str: + return f"PreStackCocaGathers3D{self._trace_domain.capitalize()}" + + def _load_dataset_attributes(self) -> UserAttributes: + return UserAttributes( + attributes={ + "surveyDimensionality": "3D", + "ensembleType": "cdp_coca", + "processingStage": "pre-stack", + } + ) + + def _add_coordinates(self) -> None: + # Add dimension coordinates + self._builder.add_coordinate( + "inline", + dimensions=["inline"], + data_type=ScalarType.INT32, + ) + self._builder.add_coordinate( + "crossline", + dimensions=["crossline"], + data_type=ScalarType.INT32, + ) + self._builder.add_coordinate( + "offset", + dimensions=["offset"], + data_type=ScalarType.INT32, + metadata_info=[self._horizontal_coord_unit], + ) + angle_unit = AllUnits(units_v1={"angle": "deg"}) + self._builder.add_coordinate( + "azimuth", + dimensions=["azimuth"], + data_type=ScalarType.FLOAT32, + metadata_info=[angle_unit], + ) + self._builder.add_coordinate( + self.trace_domain, + dimensions=[self.trace_domain], + data_type=ScalarType.INT32, + ) + + # Add non-dimension coordinates + self._builder.add_coordinate( + "cdp_x", + dimensions=["inline", "crossline"], + data_type=ScalarType.FLOAT64, + metadata_info=[self._horizontal_coord_unit], + ) + self._builder.add_coordinate( + "cdp_y", + dimensions=["inline", "crossline"], + data_type=ScalarType.FLOAT64, + metadata_info=[self._horizontal_coord_unit], + ) diff --git a/src/mdio/schemas/v1/templates/template_registry.py b/src/mdio/schemas/v1/templates/template_registry.py index 82b0bb96f..93c572caf 100644 --- a/src/mdio/schemas/v1/templates/template_registry.py +++ b/src/mdio/schemas/v1/templates/template_registry.py @@ -9,6 +9,7 @@ from mdio.schemas.v1.templates.seismic_2d_prestack_shot import Seismic2DPreStackShotTemplate from mdio.schemas.v1.templates.seismic_3d_poststack import Seismic3DPostStackTemplate from mdio.schemas.v1.templates.seismic_3d_prestack_cdp import Seismic3DPreStackCDPTemplate +from mdio.schemas.v1.templates.seismic_3d_prestack_coca import Seismic3DPreStackCocaTemplate from mdio.schemas.v1.templates.seismic_3d_prestack_shot import Seismic3DPreStackShotTemplate @@ -66,25 +67,25 @@ def register(self, instance: AbstractDatasetTemplate) -> str: def _register_default_templates(self) -> None: """Register default templates if needed. - This method can be overridden by subclasses to register default templates. + Subclasses can override this method to register default templates. """ + # Post-Stack Data self.register(Seismic2DPostStackTemplate("time")) self.register(Seismic2DPostStackTemplate("depth")) - - self.register(Seismic2DPreStackCDPTemplate("time")) - self.register(Seismic2DPreStackCDPTemplate("depth")) - - self.register(Seismic2DPreStackShotTemplate("time")) - self.register(Seismic2DPreStackShotTemplate("depth")) - self.register(Seismic3DPostStackTemplate("time")) self.register(Seismic3DPostStackTemplate("depth")) + # CDP/CMP Ordered Data + self.register(Seismic2DPreStackCDPTemplate("time")) + self.register(Seismic2DPreStackCDPTemplate("depth")) self.register(Seismic3DPreStackCDPTemplate("time")) self.register(Seismic3DPreStackCDPTemplate("depth")) + self.register(Seismic3DPreStackCocaTemplate("time")) + self.register(Seismic3DPreStackCocaTemplate("depth")) + # Field (shot) data + self.register(Seismic2DPreStackShotTemplate("time")) self.register(Seismic3DPreStackShotTemplate("time")) - self.register(Seismic3DPreStackShotTemplate("depth")) def get(self, template_name: str) -> AbstractDatasetTemplate: """Get a template from the registry by its name. diff --git a/tests/unit/v1/helpers.py b/tests/unit/v1/helpers.py index a13927974..baf8af427 100644 --- a/tests/unit/v1/helpers.py +++ b/tests/unit/v1/helpers.py @@ -56,7 +56,7 @@ def validate_variable( name: str, dims: list[tuple[str, int]], coords: list[str], - dtype: ScalarType, + dtype: ScalarType | StructuredType, ) -> Variable: """Validate existence and the structure of the created variable.""" if isinstance(container, MDIODatasetBuilder): diff --git a/tests/unit/v1/templates/test_seismic_3d_prestack_coca.py b/tests/unit/v1/templates/test_seismic_3d_prestack_coca.py new file mode 100644 index 000000000..089c63730 --- /dev/null +++ b/tests/unit/v1/templates/test_seismic_3d_prestack_coca.py @@ -0,0 +1,168 @@ +"""Unit tests for Seismic3DPreStackCocaTemplate.""" + +from tests.unit.v1.helpers import validate_variable + +from mdio.schemas.chunk_grid import RegularChunkGrid +from mdio.schemas.compressors import Blosc +from mdio.schemas.dtype import ScalarType +from mdio.schemas.dtype import StructuredType +from mdio.schemas.v1.dataset import Dataset +from mdio.schemas.v1.templates.seismic_3d_prestack_coca import Seismic3DPreStackCocaTemplate +from mdio.schemas.v1.units import AllUnits +from mdio.schemas.v1.units import AngleUnitEnum +from mdio.schemas.v1.units import LengthUnitEnum +from mdio.schemas.v1.units import LengthUnitModel +from mdio.schemas.v1.units import TimeUnitEnum +from mdio.schemas.v1.units import TimeUnitModel + +_UNIT_METER = AllUnits(units_v1=LengthUnitModel(length=LengthUnitEnum.METER)) +_UNIT_SECOND = AllUnits(units_v1=TimeUnitModel(time=TimeUnitEnum.SECOND)) + + +def _validate_coordinates_headers_trace_mask(dataset: Dataset, headers: StructuredType, domain: str) -> None: + """Validate the coordinate, headers, trace_mask variables in the dataset.""" + # Verify variables + # 5 dim coords + 2 non-dim coords + 1 data + 1 trace mask + 1 headers = 10 variables + assert len(dataset.variables) == 10 + + # Verify trace headers + validate_variable( + dataset, + name="headers", + dims=[("inline", 256), ("crossline", 256), ("offset", 100), ("azimuth", 6)], + coords=["cdp_x", "cdp_y"], + dtype=headers, + ) + + validate_variable( + dataset, + name="trace_mask", + dims=[("inline", 256), ("crossline", 256), ("offset", 100), ("azimuth", 6)], + coords=["cdp_x", "cdp_y"], + dtype=ScalarType.BOOL, + ) + + # Verify dimension coordinate variables + inline = validate_variable( + dataset, + name="inline", + dims=[("inline", 256)], + coords=["inline"], + dtype=ScalarType.INT32, + ) + assert inline.metadata is None + + crossline = validate_variable( + dataset, + name="crossline", + dims=[("crossline", 256)], + coords=["crossline"], + dtype=ScalarType.INT32, + ) + assert crossline.metadata is None + + offset = validate_variable( + dataset, + name="offset", + dims=[("offset", 100)], + coords=["offset"], + dtype=ScalarType.INT32, + ) + assert offset.metadata.units_v1.length == LengthUnitEnum.METER + + azimuth = validate_variable( + dataset, + name="azimuth", + dims=[("azimuth", 6)], + coords=["azimuth"], + dtype=ScalarType.FLOAT32, + ) + assert azimuth.metadata.units_v1.angle == AngleUnitEnum.DEGREES + + domain = validate_variable( + dataset, + name=domain, + dims=[(domain, 2048)], + coords=[domain], + dtype=ScalarType.INT32, + ) + assert domain.metadata is None + + # Verify non-dimension coordinate variables + cdp_x = validate_variable( + dataset, + name="cdp_x", + dims=[("inline", 256), ("crossline", 256)], + coords=["cdp_x"], + dtype=ScalarType.FLOAT64, + ) + assert cdp_x.metadata.units_v1.length == LengthUnitEnum.METER + + cdp_y = validate_variable( + dataset, + name="cdp_y", + dims=[("inline", 256), ("crossline", 256)], + coords=["cdp_y"], + dtype=ScalarType.FLOAT64, + ) + assert cdp_y.metadata.units_v1.length == LengthUnitEnum.METER + + +class TestSeismic3DPreStackCocaTemplate: + """Unit tests for Seismic3DPreStackCocaTemplate.""" + + def test_configuration_time(self) -> None: + """Unit tests for Seismic3DPreStackCocaTemplate in time domain.""" + t = Seismic3DPreStackCocaTemplate(domain="time") + + # Template attributes + assert t._coord_dim_names == ["inline", "crossline", "offset", "azimuth"] + assert t._dim_names == ["inline", "crossline", "offset", "azimuth", "time"] + assert t._coord_names == ["cdp_x", "cdp_y"] + assert t._var_chunk_shape == [8, 8, 32, 1, 1024] + + # Variables instantiated when build_dataset() is called + assert t._builder is None + assert t._dim_sizes == [] + assert t._horizontal_coord_unit is None + + # Verify dataset attributes + attrs = t._load_dataset_attributes() + assert attrs.attributes == { + "surveyDimensionality": "3D", + "ensembleType": "cdp_coca", + "processingStage": "pre-stack", + } + assert t.trace_variable_name == "amplitude" + + def test_build_dataset_time(self, structured_headers: StructuredType) -> None: + """Unit tests for Seismic3DPreStackShotTemplate build in time domain.""" + t = Seismic3DPreStackCocaTemplate(domain="time") + + dataset = t.build_dataset( + "Permian Basin 3D CDP Coca Gathers", + sizes=[256, 256, 100, 6, 2048], + horizontal_coord_unit=_UNIT_METER, + headers=structured_headers, + ) + + assert dataset.metadata.name == "Permian Basin 3D CDP Coca Gathers" + assert dataset.metadata.attributes["surveyDimensionality"] == "3D" + assert dataset.metadata.attributes["ensembleType"] == "cdp_coca" + assert dataset.metadata.attributes["processingStage"] == "pre-stack" + + _validate_coordinates_headers_trace_mask(dataset, structured_headers, "time") + + # Verify seismic variable (prestack shot depth data) + seismic = validate_variable( + dataset, + name="amplitude", + dims=[("inline", 256), ("crossline", 256), ("offset", 100), ("azimuth", 6), ("time", 2048)], + coords=["cdp_x", "cdp_y"], + dtype=ScalarType.FLOAT32, + ) + assert isinstance(seismic.compressor, Blosc) + assert seismic.compressor.algorithm == "zstd" + assert isinstance(seismic.metadata.chunk_grid, RegularChunkGrid) + assert seismic.metadata.chunk_grid.configuration.chunk_shape == [8, 8, 32, 1, 1024] + assert seismic.metadata.stats_v1 is None diff --git a/tests/unit/v1/templates/test_seismic_templates.py b/tests/unit/v1/templates/test_seismic_templates.py index 1257d0303..a57882a52 100644 --- a/tests/unit/v1/templates/test_seismic_templates.py +++ b/tests/unit/v1/templates/test_seismic_templates.py @@ -10,6 +10,7 @@ from mdio.schemas.v1.templates.seismic_3d_poststack import Seismic3DPostStackTemplate from mdio.schemas.v1.templates.seismic_3d_prestack_cdp import Seismic3DPreStackCDPTemplate from mdio.schemas.v1.templates.seismic_3d_prestack_shot import Seismic3DPreStackShotTemplate +from mdio.schemas.v1.templates.template_registry import TemplateRegistry class TestSeismicTemplates: @@ -48,44 +49,24 @@ def _name(self) -> str: def test_get_name_time(self) -> None: """Test get_name with domain.""" - time_template = Seismic2DPostStackTemplate("time") - dpth_template = Seismic2DPostStackTemplate("depth") + assert Seismic2DPostStackTemplate("time").name == "PostStack2DTime" + assert Seismic2DPostStackTemplate("depth").name == "PostStack2DDepth" - assert time_template.name == "PostStack2DTime" - assert dpth_template.name == "PostStack2DDepth" + assert Seismic3DPostStackTemplate("time").name == "PostStack3DTime" + assert Seismic3DPostStackTemplate("depth").name == "PostStack3DDepth" - time_template = Seismic3DPostStackTemplate("time") - dpth_template = Seismic3DPostStackTemplate("depth") + assert Seismic3DPreStackCDPTemplate("time").name == "PreStackCdpGathers3DTime" + assert Seismic3DPreStackCDPTemplate("depth").name == "PreStackCdpGathers3DDepth" - assert time_template.name == "PostStack3DTime" - assert dpth_template.name == "PostStack3DDepth" - - time_template = Seismic3DPreStackCDPTemplate("time") - dpth_template = Seismic3DPreStackCDPTemplate("depth") - - assert time_template.name == "PreStackCdpGathers3DTime" - assert dpth_template.name == "PreStackCdpGathers3DDepth" - - time_template = Seismic3DPreStackShotTemplate("time") - dpth_template = Seismic3DPreStackShotTemplate("depth") - - assert time_template.name == "PreStackShotGathers3DTime" - assert dpth_template.name == "PreStackShotGathers3DDepth" + assert Seismic3DPreStackShotTemplate("time").name == "PreStackShotGathers3DTime" def test_all_templates_inherit_from_abstract(self) -> None: """Test that all concrete templates inherit from AbstractDatasetTemplate.""" - templates = [ - Seismic2DPostStackTemplate("time"), - Seismic3DPostStackTemplate("time"), - Seismic3DPreStackCDPTemplate("time"), - Seismic3DPreStackShotTemplate("time"), - Seismic2DPostStackTemplate("depth"), - Seismic3DPostStackTemplate("depth"), - Seismic3DPreStackCDPTemplate("depth"), - Seismic3DPreStackShotTemplate("depth"), - ] - - for template in templates: + registry = TemplateRegistry() + template_names = registry.list_all_templates() + + for template_name in template_names: + template = registry.get(template_name) assert isinstance(template, AbstractDatasetTemplate) # That each template has the required properties and methods assert hasattr(template, "name") @@ -95,5 +76,4 @@ def test_all_templates_inherit_from_abstract(self) -> None: assert hasattr(template, "coordinate_names") assert hasattr(template, "build_dataset") - names = [template.name for template in templates] - assert len(names) == len(set(names)), f"Duplicate template names found: {names}" + assert len(template_names) == len(set(template_names)), f"Duplicate template names found: {template_names}" diff --git a/tests/unit/v1/templates/test_template_registry.py b/tests/unit/v1/templates/test_template_registry.py index 7ae3e0af9..015df7e64 100644 --- a/tests/unit/v1/templates/test_template_registry.py +++ b/tests/unit/v1/templates/test_template_registry.py @@ -13,6 +13,21 @@ from mdio.schemas.v1.templates.template_registry import list_templates from mdio.schemas.v1.templates.template_registry import register_template +EXPECTED_DEFAULT_TEMPLATE_NAMES = [ + "PostStack2DTime", + "PostStack2DDepth", + "PostStack3DTime", + "PostStack3DDepth", + "PreStackCdpGathers2DTime", + "PreStackCdpGathers2DDepth", + "PreStackCdpGathers3DTime", + "PreStackCdpGathers3DDepth", + "PreStackCocaGathers3DTime", + "PreStackCocaGathers3DDepth", + "PreStackShotGathers2DTime", + "PreStackShotGathers3DTime", +] + class MockDatasetTemplate(AbstractDatasetTemplate): """Mock template for testing.""" @@ -37,21 +52,10 @@ def create_dataset(self) -> str: return f"Mock dataset created by {self.template_name}" -def _assert_default_templates(templates: list[str]) -> None: - assert len(templates) == 12 - assert "PostStack2DTime" in templates - assert "PostStack3DTime" in templates - assert "PreStackCdpGathers2DTime" in templates - assert "PreStackShotGathers2DTime" in templates - assert "PreStackCdpGathers3DTime" in templates - assert "PreStackShotGathers3DTime" in templates - - assert "PostStack2DDepth" in templates - assert "PostStack3DDepth" in templates - assert "PreStackCdpGathers3DTime" in templates - assert "PreStackShotGathers3DTime" in templates - assert "PreStackCdpGathers3DDepth" in templates - assert "PreStackShotGathers3DDepth" in templates +def _assert_default_templates(template_names: list[str]) -> None: + assert len(template_names) == len(EXPECTED_DEFAULT_TEMPLATE_NAMES) + for name in EXPECTED_DEFAULT_TEMPLATE_NAMES: + assert name in template_names class TestTemplateRegistrySingleton: @@ -208,18 +212,8 @@ def test_clear_templates(self) -> None: assert not registry.is_registered("Template1") assert not registry.is_registered("Template2") # default templates are also cleared - assert not registry.is_registered("PostStack2DTime") - assert not registry.is_registered("PostStack3DTime") - assert not registry.is_registered("PreStackCdpGathers2DTime") - assert not registry.is_registered("PreStackShotGathers2DTime") - assert not registry.is_registered("PreStackCdpGathers3DTime") - assert not registry.is_registered("PreStackShotGathers3DTime") - assert not registry.is_registered("PostStack2DDepth") - assert not registry.is_registered("PostStack3DDepth") - assert not registry.is_registered("PreStackCdpGathers2DDepth") - assert not registry.is_registered("PreStackShotGathers2DDepth") - assert not registry.is_registered("PreStackCdpGathers3DDepth") - assert not registry.is_registered("PreStackShotGathers3DDepth") + for template_name in EXPECTED_DEFAULT_TEMPLATE_NAMES: + assert not registry.is_registered(template_name) def test_reset_instance(self) -> None: """Test resetting the singleton instance.""" @@ -237,19 +231,9 @@ def test_reset_instance(self) -> None: assert not registry2.is_registered("test") # default templates are registered - assert len(registry2.list_all_templates()) == 12 - assert registry2.is_registered("PostStack2DTime") - assert registry2.is_registered("PostStack3DTime") - assert registry2.is_registered("PreStackCdpGathers2DTime") - assert registry2.is_registered("PreStackShotGathers2DTime") - assert registry2.is_registered("PreStackCdpGathers3DTime") - assert registry2.is_registered("PreStackShotGathers3DTime") - assert registry2.is_registered("PostStack2DDepth") - assert registry2.is_registered("PostStack3DDepth") - assert registry2.is_registered("PreStackCdpGathers2DDepth") - assert registry2.is_registered("PreStackShotGathers2DDepth") - assert registry2.is_registered("PreStackCdpGathers3DDepth") - assert registry2.is_registered("PreStackShotGathers3DDepth") + assert len(registry2.list_all_templates()) == len(EXPECTED_DEFAULT_TEMPLATE_NAMES) + for template_name in EXPECTED_DEFAULT_TEMPLATE_NAMES: + assert registry2.is_registered(template_name) class TestGlobalFunctions: