diff --git a/pyproject.toml b/pyproject.toml index 4a0dd004b..7d5a7e0bf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,10 +26,10 @@ dependencies = [ "psutil>=7.0.0", "pydantic>=2.11.9", "rich>=14.1.0", - "segy>=0.5.0", + "segy>=0.5.1.post1", "tqdm>=4.67.1", "universal-pathlib>=0.2.6", - "xarray>=2025.9.0", + "xarray>=2025.9.1", "zarr>=3.1.3", ] diff --git a/src/mdio/api/io.py b/src/mdio/api/io.py index 862c66edd..2654be315 100644 --- a/src/mdio/api/io.py +++ b/src/mdio/api/io.py @@ -10,7 +10,7 @@ from upath import UPath from xarray import Dataset as xr_Dataset from xarray import open_zarr as xr_open_zarr -from xarray.backends.api import to_zarr as xr_to_zarr +from xarray.backends.writers import to_zarr as xr_to_zarr from mdio.constants import ZarrFormat from mdio.core.zarr_io import zarr_warnings_suppress_unstable_structs_v3 diff --git a/src/mdio/segy/_disaster_recovery_wrapper.py b/src/mdio/segy/_disaster_recovery_wrapper.py index ad53a5400..282dfbedd 100644 --- a/src/mdio/segy/_disaster_recovery_wrapper.py +++ b/src/mdio/segy/_disaster_recovery_wrapper.py @@ -4,73 +4,31 @@ from typing import TYPE_CHECKING -from segy.schema import Endianness -from segy.transforms import ByteSwapTransform -from segy.transforms import IbmFloatTransform if TYPE_CHECKING: from numpy.typing import NDArray from segy import SegyFile - from segy.transforms import Transform - from segy.transforms import TransformPipeline -def _reverse_single_transform(data: NDArray, transform: Transform, endianness: Endianness) -> NDArray: - """Reverse a single transform operation.""" - if isinstance(transform, ByteSwapTransform): - # Reverse the endianness conversion - if endianness == Endianness.LITTLE: - return data +class SegyFileTraceDataWrapper: + def __init__(self, segy_file: SegyFile, indices: int | list[int] | NDArray | slice): + self.segy_file = segy_file + self.indices = indices - reverse_transform = ByteSwapTransform(Endianness.BIG) - return reverse_transform.apply(data) + self.idx = self.segy_file.trace.normalize_and_validate_query(self.indices) + self.traces = self.segy_file.trace.fetch(self.idx, raw=True) - # TODO(BrianMichell): #0000 Do we actually need to worry about IBM/IEEE transforms here? - if isinstance(transform, IbmFloatTransform): - # Reverse IBM float conversion - reverse_direction = "to_ibm" if transform.direction == "to_ieee" else "to_ieee" - reverse_transform = IbmFloatTransform(reverse_direction, transform.keys) - return reverse_transform.apply(data) + self.raw_view = self.traces.view(self.segy_file.spec.trace.dtype) + self.decoded_traces = self.segy_file.accessors.trace_decode_pipeline.apply(self.raw_view.copy()) - # For unknown transforms, return data unchanged - return data + @property + def raw_header(self) -> NDArray: + return self.raw_view.header.view("|V240") + @property + def header(self) -> NDArray: + return self.decoded_traces.header -def get_header_raw_and_transformed( - segy_file: SegyFile, indices: int | list[int] | NDArray | slice, do_reverse_transforms: bool = True -) -> tuple[NDArray | None, NDArray, NDArray]: - """Get both raw and transformed header data. - - Args: - segy_file: The SegyFile instance - indices: Which headers to retrieve - do_reverse_transforms: Whether to apply the reverse transform to get raw data - - Returns: - Tuple of (raw_headers, transformed_headers, traces) - """ - traces = segy_file.trace[indices] - transformed_headers = traces.header - - # Reverse transforms to get raw data - if do_reverse_transforms: - raw_headers = _reverse_transforms( - transformed_headers, segy_file.header.transform_pipeline, segy_file.spec.endianness - ) - else: - raw_headers = None - - return raw_headers, transformed_headers, traces - - -def _reverse_transforms( - transformed_data: NDArray, transform_pipeline: TransformPipeline, endianness: Endianness -) -> NDArray: - """Reverse the transform pipeline to get raw data.""" - raw_data = transformed_data.copy() if hasattr(transformed_data, "copy") else transformed_data - - # Apply transforms in reverse order - for transform in reversed(transform_pipeline.transforms): - raw_data = _reverse_single_transform(raw_data, transform, endianness) - - return raw_data + @property + def sample(self) -> NDArray: + return self.decoded_traces.sample diff --git a/src/mdio/segy/_workers.py b/src/mdio/segy/_workers.py index 56a461912..184aa488a 100644 --- a/src/mdio/segy/_workers.py +++ b/src/mdio/segy/_workers.py @@ -12,7 +12,7 @@ from mdio.api.io import to_mdio from mdio.builder.schemas.dtype import ScalarType -from mdio.segy._disaster_recovery_wrapper import get_header_raw_and_transformed +from mdio.segy._disaster_recovery_wrapper import SegyFileTraceDataWrapper if TYPE_CHECKING: from segy.arrays import HeaderArray @@ -126,28 +126,39 @@ def trace_worker( # noqa: PLR0913 header_key = "headers" raw_header_key = "raw_headers" - # Used to disable the reverse transforms if we aren't going to write the raw headers - do_reverse_transforms = False - # Get subset of the dataset that has not yet been saved # The headers might not be present in the dataset worker_variables = [data_variable_name] if header_key in dataset.data_vars: # Keeping the `if` here to allow for more worker configurations worker_variables.append(header_key) if raw_header_key in dataset.data_vars: - do_reverse_transforms = True worker_variables.append(raw_header_key) - raw_headers, transformed_headers, traces = get_header_raw_and_transformed( - segy_file, live_trace_indexes, do_reverse_transforms=do_reverse_transforms - ) + # traces = segy_file.trace[live_trace_indexes] + # Raw headers are not intended to remain as a feature of the SEGY ingestion. + # For that reason, we have wrapped the accessors to provide an interface that can be removed + # and not require additional changes to the below code. + # NOTE: The `raw_header_key` code block should be removed in full as it will become dead code. + traces = SegyFileTraceDataWrapper(segy_file, live_trace_indexes) + ds_to_write = dataset[worker_variables] + if raw_header_key in worker_variables: + tmp_raw_headers = np.zeros_like(dataset[raw_header_key]) + tmp_raw_headers[not_null] = traces.raw_header + + ds_to_write[raw_header_key] = Variable( + ds_to_write[raw_header_key].dims, + tmp_raw_headers, + attrs=ds_to_write[raw_header_key].attrs, + encoding=ds_to_write[raw_header_key].encoding, # Not strictly necessary, but safer than not doing it. + ) + if header_key in worker_variables: # TODO(BrianMichell): Implement this better so that we can enable fill values without changing the code # https://github.com/TGSAI/mdio-python/issues/584 tmp_headers = np.zeros_like(dataset[header_key]) - tmp_headers[not_null] = transformed_headers + tmp_headers[not_null] = traces.header # Create a new Variable object to avoid copying the temporary array # The ideal solution is to use `ds_to_write[header_key][:] = tmp_headers` # but Xarray appears to be copying memory instead of doing direct assignment. @@ -159,19 +170,7 @@ def trace_worker( # noqa: PLR0913 attrs=ds_to_write[header_key].attrs, encoding=ds_to_write[header_key].encoding, # Not strictly necessary, but safer than not doing it. ) - del transformed_headers # Manage memory - if raw_header_key in worker_variables: - tmp_raw_headers = np.zeros_like(dataset[raw_header_key]) - tmp_raw_headers[not_null] = raw_headers.view("|V240") - - ds_to_write[raw_header_key] = Variable( - ds_to_write[raw_header_key].dims, - tmp_raw_headers, - attrs=ds_to_write[raw_header_key].attrs, - encoding=ds_to_write[raw_header_key].encoding, # Not strictly necessary, but safer than not doing it. - ) - del raw_headers # Manage memory data_variable = ds_to_write[data_variable_name] fill_value = _get_fill_value(ScalarType(data_variable.dtype.name)) tmp_samples = np.full_like(data_variable, fill_value=fill_value) diff --git a/src/mdio/segy/creation.py b/src/mdio/segy/creation.py index 8b10ad486..4c5e25a98 100644 --- a/src/mdio/segy/creation.py +++ b/src/mdio/segy/creation.py @@ -28,6 +28,38 @@ logger = logging.getLogger(__name__) +def _filter_raw_unspecified_fields(headers: NDArray) -> NDArray: + """Filter out __MDIO_RAW_UNSPECIFIED_Field_* fields from headers array. + + These fields are added during SEGY import to preserve raw header bytes, + but they cause dtype mismatches during export. This function removes them. + + Args: + headers: Header array that may contain raw unspecified fields. + + Returns: + Header array with raw unspecified fields removed. + """ + if headers.dtype.names is None: + return headers + + # Find field names that don't start with __MDIO_RAW_UNSPECIFIED_ + field_names = [name for name in headers.dtype.names if not name.startswith("__MDIO_RAW_UNSPECIFIED_")] + + if len(field_names) == len(headers.dtype.names): + # No raw unspecified fields found, return as-is + return headers + + # Create new structured array with only the non-raw fields + new_dtype = [(name, headers.dtype.fields[name][0]) for name in field_names] + filtered_headers = np.empty(headers.shape, dtype=new_dtype) + + for name in field_names: + filtered_headers[name] = headers[name] + + return filtered_headers + + def make_segy_factory(spec: SegySpec, binary_header: dict[str, int]) -> SegyFactory: """Generate SEG-Y factory from MDIO metadata.""" sample_interval = binary_header["sample_interval"] @@ -167,7 +199,9 @@ def serialize_to_segy_stack( # noqa: PLR0913 samples = samples[live_mask] headers = headers[live_mask] - buffer = segy_factory.create_traces(headers, samples) + # Filter out raw unspecified fields that cause dtype mismatches + filtered_headers = _filter_raw_unspecified_fields(headers) + buffer = segy_factory.create_traces(filtered_headers, samples) global_index = block_start[0] record_id_str = str(global_index) @@ -199,7 +233,9 @@ def serialize_to_segy_stack( # noqa: PLR0913 rec_samples = samples[rec_index][rec_live_mask] rec_headers = headers[rec_index][rec_live_mask] - buffer = segy_factory.create_traces(rec_headers, rec_samples) + # Filter out raw unspecified fields that cause dtype mismatches + filtered_headers = _filter_raw_unspecified_fields(rec_headers) + buffer = segy_factory.create_traces(filtered_headers, rec_samples) global_index = tuple(block_start[i] + rec_index[i] for i in range(record_ndim)) record_id_str = "/".join(map(str, global_index)) diff --git a/tests/integration/test_segy_import_export_masked.py b/tests/integration/test_segy_import_export_masked.py index 2094ca4fc..fcd0df467 100644 --- a/tests/integration/test_segy_import_export_masked.py +++ b/tests/integration/test_segy_import_export_masked.py @@ -282,11 +282,16 @@ def generate_selection_mask(selection_conf: SelectionMaskConfig, grid_conf: Grid @pytest.fixture -def export_masked_path(tmp_path_factory: pytest.TempPathFactory) -> Path: +def export_masked_path(tmp_path_factory: pytest.TempPathFactory, raw_headers_env: None) -> Path: # noqa: ARG001 """Fixture that generates temp directory for export tests.""" + # Create path suffix based on current raw headers environment variable + # raw_headers_env dependency ensures the environment variable is set before this runs + raw_headers_enabled = os.getenv("MDIO__DO_RAW_HEADERS") == "1" + path_suffix = "with_raw_headers" if raw_headers_enabled else "without_raw_headers" + if DEBUG_MODE: - return Path("TMP/export_masked") - return tmp_path_factory.getbasetemp() / "export_masked" + return Path(f"TMP/export_masked_{path_suffix}") + return tmp_path_factory.getbasetemp() / f"export_masked_{path_suffix}" @pytest.fixture @@ -300,9 +305,39 @@ def raw_headers_env(request: pytest.FixtureRequest) -> None: yield - # Cleanup after test + # Cleanup after test - both environment variable and template state os.environ.pop("MDIO__DO_RAW_HEADERS", None) + # Clean up any template modifications to ensure test isolation + registry = TemplateRegistry.get_instance() + + # Reset any templates that might have been modified with raw headers + template_names = [ + "PostStack2DTime", + "PostStack3DTime", + "PreStackCdpOffsetGathers2DTime", + "PreStackCdpOffsetGathers3DTime", + "PreStackShotGathers2DTime", + "PreStackShotGathers3DTime", + "PreStackCocaGathers3DTime", + ] + + for template_name in template_names: + try: + template = registry.get(template_name) + # Remove raw headers enhancement if present + if hasattr(template, "_mdio_raw_headers_enhanced"): + delattr(template, "_mdio_raw_headers_enhanced") + # The enhancement is applied by monkey-patching _add_variables + # We need to restore it to the original method from the class + # Since we can't easily restore the exact original, we'll get a fresh instance + template_class = type(template) + if hasattr(template_class, "_add_variables"): + template._add_variables = template_class._add_variables.__get__(template, template_class) + except KeyError: + # Template not found, skip + continue + @pytest.mark.parametrize( "test_conf", @@ -471,3 +506,69 @@ def test_export_masked( # https://github.com/TGSAI/mdio-python/issues/610 assert_array_equal(actual_sgy.trace[:].header, expected_sgy.trace[expected_trc_idx].header) assert_array_equal(actual_sgy.trace[:].sample, expected_sgy.trace[expected_trc_idx].sample) + + def test_raw_headers_byte_preservation( + self, + test_conf: MaskedExportConfig, + export_masked_path: Path, + raw_headers_env: None, # noqa: ARG002 + ) -> None: + """Test that raw headers are preserved byte-for-byte when MDIO__DO_RAW_HEADERS=1.""" + grid_conf, segy_factory_conf, _, _ = test_conf + segy_path = export_masked_path / f"{grid_conf.name}.sgy" + mdio_path = export_masked_path / f"{grid_conf.name}.mdio" + + # Open MDIO dataset + ds = open_mdio(mdio_path) + + # Check if raw_headers should exist based on environment variable + has_raw_headers = "raw_headers" in ds.data_vars + if os.getenv("MDIO__DO_RAW_HEADERS") == "1": + assert has_raw_headers, "raw_headers should be present when MDIO__DO_RAW_HEADERS=1" + else: + assert not has_raw_headers, f"raw_headers should not be present when MDIO__DO_RAW_HEADERS is not set\n {ds}" + return # Exit early if raw_headers are not expected + + # Get data (only if raw_headers exist) + raw_headers_data = ds.raw_headers.values + trace_mask = ds.trace_mask.values + + # Verify 240-byte headers + assert raw_headers_data.dtype.itemsize == 240, ( + f"Expected 240-byte headers, got {raw_headers_data.dtype.itemsize}" + ) + + # Read raw bytes directly from SEG-Y file + def read_segy_trace_header(trace_index: int) -> bytes: + """Read 240-byte trace header directly from SEG-Y file.""" + # with open(segy_path, "rb") as f: + with Path.open(segy_path, "rb") as f: + # Skip text (3200) + binary (400) headers = 3600 bytes + f.seek(3600) + # Each trace: 240 byte header + (num_samples * 4) byte samples + trace_size = 240 + (segy_factory_conf.num_samples * 4) + trace_offset = trace_index * trace_size + f.seek(trace_offset, 1) # Seek relative to current position + return f.read(240) + + # Compare all valid traces byte-by-byte + segy_trace_idx = 0 + flat_mask = trace_mask.ravel() + flat_raw_headers = raw_headers_data.ravel() # Flatten to 1D array of 240-byte header records + + for grid_idx in range(flat_mask.size): + if not flat_mask[grid_idx]: + print(f"Skipping trace {grid_idx} because it is masked") + continue + + # Get MDIO header as bytes - convert single header record to bytes + header_record = flat_raw_headers[grid_idx] + mdio_header_bytes = np.frombuffer(header_record.tobytes(), dtype=np.uint8) + + # Get SEG-Y header as raw bytes directly from file + segy_raw_header_bytes = read_segy_trace_header(segy_trace_idx) + segy_header_bytes = np.frombuffer(segy_raw_header_bytes, dtype=np.uint8) + + assert_array_equal(mdio_header_bytes, segy_header_bytes) + + segy_trace_idx += 1 diff --git a/tests/unit/test_disaster_recovery_wrapper.py b/tests/unit/test_disaster_recovery_wrapper.py index 9ed5e045d..4edee6752 100644 --- a/tests/unit/test_disaster_recovery_wrapper.py +++ b/tests/unit/test_disaster_recovery_wrapper.py @@ -9,7 +9,6 @@ import tempfile from pathlib import Path -from typing import TYPE_CHECKING import numpy as np import pytest @@ -20,10 +19,7 @@ from segy.schema import SegySpec from segy.standards import get_segy_standard -from mdio.segy._disaster_recovery_wrapper import get_header_raw_and_transformed - -if TYPE_CHECKING: - from numpy.typing import NDArray +from mdio.segy._disaster_recovery_wrapper import SegyFileTraceDataWrapper SAMPLES_PER_TRACE = 1501 @@ -118,27 +114,8 @@ def create_test_segy_file( # noqa: PLR0913 return spec - def extract_header_bytes_from_file( - self, segy_path: Path, trace_index: int, byte_start: int, byte_length: int - ) -> NDArray: - """Extract specific bytes from a trace header in the SEGY file.""" - with segy_path.open("rb") as f: - # Skip text header (3200 bytes) + binary header (400 bytes) - header_offset = 3600 - - # Each trace: 240 byte header + samples - trace_size = 240 + SAMPLES_PER_TRACE * 4 # samples * 4 bytes each - trace_offset = header_offset + trace_index * trace_size - - f.seek(trace_offset + byte_start - 1) # SEGY is 1-based - header_bytes = f.read(byte_length) - - return np.frombuffer(header_bytes, dtype=np.uint8) - - def test_header_validation_configurations( - self, temp_dir: Path, basic_segy_spec: SegySpec, segy_config: dict - ) -> None: - """Test header validation with different SEGY configurations.""" + def test_wrapper_basic_functionality(self, temp_dir: Path, basic_segy_spec: SegySpec, segy_config: dict) -> None: + """Test basic functionality of SegyFileTraceDataWrapper.""" config_name = segy_config["name"] endianness = segy_config["endianness"] data_format = segy_config["data_format"] @@ -161,75 +138,33 @@ def test_header_validation_configurations( # Load the SEGY file segy_file = SegyFile(segy_path, spec=spec) - # Test a few traces - test_indices = [0, 3, 7] - - for trace_idx in test_indices: - # Get raw and transformed headers - raw_headers, transformed_headers, traces = get_header_raw_and_transformed( - segy_file=segy_file, indices=trace_idx, do_reverse_transforms=True - ) - - # Extract bytes from disk for inline (bytes 189-192) and crossline (bytes 193-196) - inline_bytes_disk = self.extract_header_bytes_from_file(segy_path, trace_idx, 189, 4) - crossline_bytes_disk = self.extract_header_bytes_from_file(segy_path, trace_idx, 193, 4) - - # Convert raw headers to bytes for comparison - if raw_headers is not None: - # Extract from raw headers - # Note: We need to extract bytes directly from the structured array to preserve endianness - # Getting a scalar and calling .tobytes() loses endianness information - if raw_headers.ndim == 0: - # Single trace case - raw_data_bytes = raw_headers.tobytes() - inline_offset = raw_headers.dtype.fields["inline"][1] - crossline_offset = raw_headers.dtype.fields["crossline"][1] - inline_size = raw_headers.dtype.fields["inline"][0].itemsize - crossline_size = raw_headers.dtype.fields["crossline"][0].itemsize - - raw_inline_bytes = np.frombuffer( - raw_data_bytes[inline_offset : inline_offset + inline_size], dtype=np.uint8 - ) - raw_crossline_bytes = np.frombuffer( - raw_data_bytes[crossline_offset : crossline_offset + crossline_size], dtype=np.uint8 - ) - else: - # Multiple traces case - this test uses single trace index, so extract that trace - raw_data_bytes = raw_headers[0:1].tobytes() # Extract first trace - inline_offset = raw_headers.dtype.fields["inline"][1] - crossline_offset = raw_headers.dtype.fields["crossline"][1] - inline_size = raw_headers.dtype.fields["inline"][0].itemsize - crossline_size = raw_headers.dtype.fields["crossline"][0].itemsize - - raw_inline_bytes = np.frombuffer( - raw_data_bytes[inline_offset : inline_offset + inline_size], dtype=np.uint8 - ) - raw_crossline_bytes = np.frombuffer( - raw_data_bytes[crossline_offset : crossline_offset + crossline_size], dtype=np.uint8 - ) - - print(f"Transformed headers: {transformed_headers.tobytes()}") - print(f"Raw headers: {raw_headers.tobytes()}") - print(f"Inline bytes disk: {inline_bytes_disk.tobytes()}") - print(f"Crossline bytes disk: {crossline_bytes_disk.tobytes()}") - - # Compare bytes - assert np.array_equal(raw_inline_bytes, inline_bytes_disk), ( - f"Inline bytes mismatch for trace {trace_idx} in {config_name}" - ) - assert np.array_equal(raw_crossline_bytes, crossline_bytes_disk), ( - f"Crossline bytes mismatch for trace {trace_idx} in {config_name}" - ) - - def test_header_validation_no_transforms( - self, temp_dir: Path, basic_segy_spec: SegySpec, segy_config: dict - ) -> None: - """Test header validation when transforms are disabled.""" + # Test single trace + trace_idx = 3 + wrapper = SegyFileTraceDataWrapper(segy_file, trace_idx) + + # Test that properties are accessible + assert wrapper.header is not None + assert wrapper.raw_header is not None + assert wrapper.sample is not None + + # Test header properties + transformed_header = wrapper.header + raw_header = wrapper.raw_header + + # Raw header should be bytes (240 bytes per trace header) + assert raw_header.dtype == np.dtype("|V240") + + # Transformed header should have the expected fields + assert "inline" in transformed_header.dtype.names + assert "crossline" in transformed_header.dtype.names + + def test_wrapper_with_multiple_traces(self, temp_dir: Path, basic_segy_spec: SegySpec, segy_config: dict) -> None: + """Test wrapper with multiple traces.""" config_name = segy_config["name"] endianness = segy_config["endianness"] data_format = segy_config["data_format"] - segy_path = temp_dir / f"test_no_transforms_{config_name}.segy" + segy_path = temp_dir / f"test_multiple_{config_name}.segy" # Create test SEGY file num_traces = 5 @@ -247,31 +182,26 @@ def test_header_validation_no_transforms( # Load the SEGY file segy_file = SegyFile(segy_path, spec=spec) - # Get headers without reverse transforms - raw_headers, transformed_headers, traces = get_header_raw_and_transformed( - segy_file=segy_file, - indices=slice(None), # All traces - do_reverse_transforms=False, - ) + # Test with list of indices + trace_indices = [0, 2, 4] + wrapper = SegyFileTraceDataWrapper(segy_file, trace_indices) - # When transforms are disabled, raw_headers should be None - assert raw_headers is None + # Test that properties work with multiple traces + assert wrapper.header is not None + assert wrapper.raw_header is not None + assert wrapper.sample is not None - # Transformed headers should still be available - assert transformed_headers is not None - assert transformed_headers.size == num_traces + # Check that we got the expected number of traces + assert wrapper.header.size == len(trace_indices) + assert wrapper.raw_header.size == len(trace_indices) - def test_multiple_traces_validation(self, temp_dir: Path, basic_segy_spec: SegySpec, segy_config: dict) -> None: - """Test validation with multiple traces at once.""" + def test_wrapper_with_slice_indices(self, temp_dir: Path, basic_segy_spec: SegySpec, segy_config: dict) -> None: + """Test wrapper with slice indices.""" config_name = segy_config["name"] endianness = segy_config["endianness"] data_format = segy_config["data_format"] - print(f"Config name: {config_name}") - print(f"Endianness: {endianness}") - print(f"Data format: {data_format}") - - segy_path = temp_dir / f"test_multiple_traces_{config_name}.segy" + segy_path = temp_dir / f"test_slice_{config_name}.segy" # Create test SEGY file with more traces num_traces = 25 # 5x5 grid @@ -291,70 +221,18 @@ def test_multiple_traces_validation(self, temp_dir: Path, basic_segy_spec: SegyS # Load the SEGY file segy_file = SegyFile(segy_path, spec=spec) - # Get all traces - raw_headers, transformed_headers, traces = get_header_raw_and_transformed( - segy_file=segy_file, - indices=slice(None), # All traces - do_reverse_transforms=True, - ) + # Test with slice + wrapper = SegyFileTraceDataWrapper(segy_file, slice(5, 15)) - first = True - - # Validate each trace - for trace_idx in range(num_traces): - # Extract bytes from disk - inline_bytes_disk = self.extract_header_bytes_from_file(segy_path, trace_idx, 189, 4) - crossline_bytes_disk = self.extract_header_bytes_from_file(segy_path, trace_idx, 193, 4) - - if first: - print(raw_headers.dtype) - print(raw_headers.shape) - first = False - - # Extract from raw headers - # Note: We need to extract bytes directly from the structured array to preserve endianness - # Getting a scalar and calling .tobytes() loses endianness information - if raw_headers.ndim == 0: - # Single trace case - raw_data_bytes = raw_headers.tobytes() - inline_offset = raw_headers.dtype.fields["inline"][1] - crossline_offset = raw_headers.dtype.fields["crossline"][1] - inline_size = raw_headers.dtype.fields["inline"][0].itemsize - crossline_size = raw_headers.dtype.fields["crossline"][0].itemsize - - raw_inline_bytes = np.frombuffer( - raw_data_bytes[inline_offset : inline_offset + inline_size], dtype=np.uint8 - ) - raw_crossline_bytes = np.frombuffer( - raw_data_bytes[crossline_offset : crossline_offset + crossline_size], dtype=np.uint8 - ) - else: - # Multiple traces case - raw_data_bytes = raw_headers[trace_idx : trace_idx + 1].tobytes() - inline_offset = raw_headers.dtype.fields["inline"][1] - crossline_offset = raw_headers.dtype.fields["crossline"][1] - inline_size = raw_headers.dtype.fields["inline"][0].itemsize - crossline_size = raw_headers.dtype.fields["crossline"][0].itemsize - - raw_inline_bytes = np.frombuffer( - raw_data_bytes[inline_offset : inline_offset + inline_size], dtype=np.uint8 - ) - raw_crossline_bytes = np.frombuffer( - raw_data_bytes[crossline_offset : crossline_offset + crossline_size], dtype=np.uint8 - ) - - print(f"Raw inline bytes: {raw_inline_bytes.tobytes()}") - print(f"Inline bytes disk: {inline_bytes_disk.tobytes()}") - print(f"Raw crossline bytes: {raw_crossline_bytes.tobytes()}") - print(f"Crossline bytes disk: {crossline_bytes_disk.tobytes()}") - - # Compare - assert np.array_equal(raw_inline_bytes, inline_bytes_disk), ( - f"Inline bytes mismatch for trace {trace_idx} in {config_name}" - ) - assert np.array_equal(raw_crossline_bytes, crossline_bytes_disk), ( - f"Crossline bytes mismatch for trace {trace_idx} in {config_name}" - ) + # Test that properties work with slice + assert wrapper.header is not None + assert wrapper.raw_header is not None + assert wrapper.sample is not None + + # Check that we got the expected number of traces (10 traces from slice(5, 15)) + expected_count = 10 + assert wrapper.header.size == expected_count + assert wrapper.raw_header.size == expected_count @pytest.mark.parametrize( "trace_indices", @@ -367,7 +245,7 @@ def test_multiple_traces_validation(self, temp_dir: Path, basic_segy_spec: SegyS def test_different_index_types( self, temp_dir: Path, basic_segy_spec: SegySpec, segy_config: dict, trace_indices: int | list[int] | slice ) -> None: - """Test with different types of trace indices.""" + """Test wrapper with different types of trace indices.""" config_name = segy_config["name"] endianness = segy_config["endianness"] data_format = segy_config["data_format"] @@ -390,15 +268,13 @@ def test_different_index_types( # Load the SEGY file segy_file = SegyFile(segy_path, spec=spec) - # Get headers with different index types - raw_headers, transformed_headers, traces = get_header_raw_and_transformed( - segy_file=segy_file, indices=trace_indices, do_reverse_transforms=True - ) + # Create wrapper with different index types + wrapper = SegyFileTraceDataWrapper(segy_file, trace_indices) # Basic validation that we got results - assert raw_headers is not None - assert transformed_headers is not None - assert traces is not None + assert wrapper.header is not None + assert wrapper.raw_header is not None + assert wrapper.sample is not None # Check that the number of results matches expectation if isinstance(trace_indices, int): @@ -410,4 +286,4 @@ def test_different_index_types( else: expected_count = 1 - assert transformed_headers.size == expected_count + assert wrapper.header.size == expected_count diff --git a/uv.lock b/uv.lock index 1de9b17fb..f73fa0d2a 100644 --- a/uv.lock +++ b/uv.lock @@ -1922,10 +1922,10 @@ requires-dist = [ { name = "pydantic", specifier = ">=2.11.9" }, { name = "rich", specifier = ">=14.1.0" }, { name = "s3fs", marker = "extra == 'cloud'", specifier = ">=2025.9.0" }, - { name = "segy", specifier = ">=0.5.0" }, + { name = "segy", specifier = ">=0.5.1.post1" }, { name = "tqdm", specifier = ">=4.67.1" }, { name = "universal-pathlib", specifier = ">=0.2.6" }, - { name = "xarray", specifier = ">=2025.9.0" }, + { name = "xarray", specifier = ">=2025.9.1" }, { name = "zarr", specifier = ">=3.1.3" }, { name = "zfpy", marker = "extra == 'lossy'", specifier = ">=1.0.1" }, ] @@ -3198,7 +3198,7 @@ wheels = [ [[package]] name = "segy" -version = "0.5.0.post1" +version = "0.5.1.post1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "fsspec" }, @@ -3210,9 +3210,9 @@ dependencies = [ { name = "rapidfuzz" }, { name = "typer" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/65/c2/aae81f9f9ae43c28c2d6b543719e6f1805d50d9565f3616af9ce29e3fbc0/segy-0.5.0.post1.tar.gz", hash = "sha256:b8c140fb10cfd4807bc6aab46a6f09d98b82c4995e045f568be3bbf6c044aba6", size = 43037, upload-time = "2025-09-15T13:33:42.348Z" } +sdist = { url = "https://files.pythonhosted.org/packages/0a/c5/c71d4f52eb1587bdeb8401445ac65b08603fb6f77ada46933dec5fbbd6f8/segy-0.5.1.post1.tar.gz", hash = "sha256:655d1b26aa7a698084d190c8b5c7d12802cfbc9627067614606b1d69c5f0f4ae", size = 43354, upload-time = "2025-09-30T20:35:19.879Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/27/f0/b67a8a89dbb331d55e9b37c779c270a48ff09ca83a0055a65a84f33dc100/segy-0.5.0.post1-py3-none-any.whl", hash = "sha256:158661da578147fa5cfbcf335047a2459f86aa5522e1acc4249bb8252d26be55", size = 55408, upload-time = "2025-09-15T13:33:40.571Z" }, + { url = "https://files.pythonhosted.org/packages/71/ff/ee1b5c982ddfb7185fac41b85ce7a8bd2d5604d6129183a63c2a851109d3/segy-0.5.1.post1-py3-none-any.whl", hash = "sha256:6f36a0795c459d77a3d715d7e5b1444be4cb8368720f89111d452be93d1cf7f1", size = 55757, upload-time = "2025-09-30T20:35:18.665Z" }, ] [[package]] @@ -3611,7 +3611,7 @@ wheels = [ [[package]] name = "typer" -version = "0.16.1" +version = "0.19.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "click" }, @@ -3619,9 +3619,9 @@ dependencies = [ { name = "shellingham" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/43/78/d90f616bf5f88f8710ad067c1f8705bf7618059836ca084e5bb2a0855d75/typer-0.16.1.tar.gz", hash = "sha256:d358c65a464a7a90f338e3bb7ff0c74ac081449e53884b12ba658cbd72990614", size = 102836, upload-time = "2025-08-18T19:18:22.898Z" } +sdist = { url = "https://files.pythonhosted.org/packages/21/ca/950278884e2ca20547ff3eb109478c6baf6b8cf219318e6bc4f666fad8e8/typer-0.19.2.tar.gz", hash = "sha256:9ad824308ded0ad06cc716434705f691d4ee0bfd0fb081839d2e426860e7fdca", size = 104755, upload-time = "2025-09-23T09:47:48.256Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/2d/76/06dbe78f39b2203d2a47d5facc5df5102d0561e2807396471b5f7c5a30a1/typer-0.16.1-py3-none-any.whl", hash = "sha256:90ee01cb02d9b8395ae21ee3368421faf21fa138cb2a541ed369c08cec5237c9", size = 46397, upload-time = "2025-08-18T19:18:21.663Z" }, + { url = "https://files.pythonhosted.org/packages/00/22/35617eee79080a5d071d0f14ad698d325ee6b3bf824fc0467c03b30e7fa8/typer-0.19.2-py3-none-any.whl", hash = "sha256:755e7e19670ffad8283db353267cb81ef252f595aa6834a0d1ca9312d9326cb9", size = 46748, upload-time = "2025-09-23T09:47:46.777Z" }, ] [[package]] @@ -3876,16 +3876,16 @@ wheels = [ [[package]] name = "xarray" -version = "2025.9.0" +version = "2025.9.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "numpy" }, { name = "packaging" }, { name = "pandas" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/4e/0b/bbb76e05c8e2099baf90e259c29cafe6a525524b1d1da8bfbc39577c043e/xarray-2025.9.0.tar.gz", hash = "sha256:7dd6816fe0062c49c5e9370dd483843bc13e5ed80a47a9ff10baff2b51e070fb", size = 3040318, upload-time = "2025-09-04T04:20:26.296Z" } +sdist = { url = "https://files.pythonhosted.org/packages/d0/5d/e139112a463336c636d4455494f3227b7f47a2e06ca7571e6b88158ffc06/xarray-2025.9.1.tar.gz", hash = "sha256:f34a27a52c13d1f3cceb7b27276aeec47021558363617dd7ef4f4c8b379011c0", size = 3057322, upload-time = "2025-09-30T05:28:53.084Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/8d/f0/73c24457c941b8b08f7d090853e40f4b2cdde88b5da721f3f28e98df77c9/xarray-2025.9.0-py3-none-any.whl", hash = "sha256:79f0e25fb39571f612526ee998ee5404d8725a1db3951aabffdb287388885df0", size = 1349595, upload-time = "2025-09-04T04:20:24.36Z" }, + { url = "https://files.pythonhosted.org/packages/0e/a7/6eeb32e705d510a672f74135f538ad27f87f3d600845bfd3834ea3a77c7e/xarray-2025.9.1-py3-none-any.whl", hash = "sha256:3e9708db0d7915c784ed6c227d81b398dca4957afe68d119481f8a448fc88c44", size = 1364411, upload-time = "2025-09-30T05:28:51.294Z" }, ] [[package]]