Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ dependencies = [
"psutil>=7.0.0",
"pydantic>=2.11.9",
"rich>=14.1.0",
"segy>=0.5.0",
"segy>=0.5.1.post1",
"tqdm>=4.67.1",
"universal-pathlib>=0.2.6",
"xarray>=2025.9.0",
"xarray>=2025.9.1",
"zarr>=3.1.3",
]

Expand Down
2 changes: 1 addition & 1 deletion src/mdio/api/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from upath import UPath
from xarray import Dataset as xr_Dataset
from xarray import open_zarr as xr_open_zarr
from xarray.backends.api import to_zarr as xr_to_zarr
from xarray.backends.writers import to_zarr as xr_to_zarr

from mdio.constants import ZarrFormat
from mdio.core.zarr_io import zarr_warnings_suppress_unstable_structs_v3
Expand Down
76 changes: 17 additions & 59 deletions src/mdio/segy/_disaster_recovery_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,73 +4,31 @@

from typing import TYPE_CHECKING

from segy.schema import Endianness
from segy.transforms import ByteSwapTransform
from segy.transforms import IbmFloatTransform

if TYPE_CHECKING:
from numpy.typing import NDArray
from segy import SegyFile
from segy.transforms import Transform
from segy.transforms import TransformPipeline


def _reverse_single_transform(data: NDArray, transform: Transform, endianness: Endianness) -> NDArray:
"""Reverse a single transform operation."""
if isinstance(transform, ByteSwapTransform):
# Reverse the endianness conversion
if endianness == Endianness.LITTLE:
return data
class SegyFileTraceDataWrapper:
def __init__(self, segy_file: SegyFile, indices: int | list[int] | NDArray | slice):
self.segy_file = segy_file
self.indices = indices

reverse_transform = ByteSwapTransform(Endianness.BIG)
return reverse_transform.apply(data)
self.idx = self.segy_file.trace.normalize_and_validate_query(self.indices)
self.traces = self.segy_file.trace.fetch(self.idx, raw=True)

# TODO(BrianMichell): #0000 Do we actually need to worry about IBM/IEEE transforms here?
if isinstance(transform, IbmFloatTransform):
# Reverse IBM float conversion
reverse_direction = "to_ibm" if transform.direction == "to_ieee" else "to_ieee"
reverse_transform = IbmFloatTransform(reverse_direction, transform.keys)
return reverse_transform.apply(data)
self.raw_view = self.traces.view(self.segy_file.spec.trace.dtype)
self.decoded_traces = self.segy_file.accessors.trace_decode_pipeline.apply(self.raw_view.copy())

# For unknown transforms, return data unchanged
return data
@property
def raw_header(self) -> NDArray:
return self.raw_view.header.view("|V240")

@property
def header(self) -> NDArray:
return self.decoded_traces.header

def get_header_raw_and_transformed(
segy_file: SegyFile, indices: int | list[int] | NDArray | slice, do_reverse_transforms: bool = True
) -> tuple[NDArray | None, NDArray, NDArray]:
"""Get both raw and transformed header data.

Args:
segy_file: The SegyFile instance
indices: Which headers to retrieve
do_reverse_transforms: Whether to apply the reverse transform to get raw data

Returns:
Tuple of (raw_headers, transformed_headers, traces)
"""
traces = segy_file.trace[indices]
transformed_headers = traces.header

# Reverse transforms to get raw data
if do_reverse_transforms:
raw_headers = _reverse_transforms(
transformed_headers, segy_file.header.transform_pipeline, segy_file.spec.endianness
)
else:
raw_headers = None

return raw_headers, transformed_headers, traces


def _reverse_transforms(
transformed_data: NDArray, transform_pipeline: TransformPipeline, endianness: Endianness
) -> NDArray:
"""Reverse the transform pipeline to get raw data."""
raw_data = transformed_data.copy() if hasattr(transformed_data, "copy") else transformed_data

# Apply transforms in reverse order
for transform in reversed(transform_pipeline.transforms):
raw_data = _reverse_single_transform(raw_data, transform, endianness)

return raw_data
@property
def sample(self) -> NDArray:
return self.decoded_traces.sample
41 changes: 20 additions & 21 deletions src/mdio/segy/_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from mdio.api.io import to_mdio
from mdio.builder.schemas.dtype import ScalarType
from mdio.segy._disaster_recovery_wrapper import get_header_raw_and_transformed
from mdio.segy._disaster_recovery_wrapper import SegyFileTraceDataWrapper

if TYPE_CHECKING:
from segy.arrays import HeaderArray
Expand Down Expand Up @@ -126,28 +126,39 @@ def trace_worker( # noqa: PLR0913
header_key = "headers"
raw_header_key = "raw_headers"

# Used to disable the reverse transforms if we aren't going to write the raw headers
do_reverse_transforms = False

# Get subset of the dataset that has not yet been saved
# The headers might not be present in the dataset
worker_variables = [data_variable_name]
if header_key in dataset.data_vars: # Keeping the `if` here to allow for more worker configurations
worker_variables.append(header_key)
if raw_header_key in dataset.data_vars:
do_reverse_transforms = True
worker_variables.append(raw_header_key)

raw_headers, transformed_headers, traces = get_header_raw_and_transformed(
segy_file, live_trace_indexes, do_reverse_transforms=do_reverse_transforms
)
# traces = segy_file.trace[live_trace_indexes]
# Raw headers are not intended to remain as a feature of the SEGY ingestion.
# For that reason, we have wrapped the accessors to provide an interface that can be removed
# and not require additional changes to the below code.
# NOTE: The `raw_header_key` code block should be removed in full as it will become dead code.
traces = SegyFileTraceDataWrapper(segy_file, live_trace_indexes)

ds_to_write = dataset[worker_variables]

if raw_header_key in worker_variables:
tmp_raw_headers = np.zeros_like(dataset[raw_header_key])
tmp_raw_headers[not_null] = traces.raw_header

ds_to_write[raw_header_key] = Variable(
ds_to_write[raw_header_key].dims,
tmp_raw_headers,
attrs=ds_to_write[raw_header_key].attrs,
encoding=ds_to_write[raw_header_key].encoding, # Not strictly necessary, but safer than not doing it.
)

if header_key in worker_variables:
# TODO(BrianMichell): Implement this better so that we can enable fill values without changing the code
# https://github.com/TGSAI/mdio-python/issues/584
tmp_headers = np.zeros_like(dataset[header_key])
tmp_headers[not_null] = transformed_headers
tmp_headers[not_null] = traces.header
# Create a new Variable object to avoid copying the temporary array
# The ideal solution is to use `ds_to_write[header_key][:] = tmp_headers`
# but Xarray appears to be copying memory instead of doing direct assignment.
Expand All @@ -159,19 +170,7 @@ def trace_worker( # noqa: PLR0913
attrs=ds_to_write[header_key].attrs,
encoding=ds_to_write[header_key].encoding, # Not strictly necessary, but safer than not doing it.
)
del transformed_headers # Manage memory
if raw_header_key in worker_variables:
tmp_raw_headers = np.zeros_like(dataset[raw_header_key])
tmp_raw_headers[not_null] = raw_headers.view("|V240")

ds_to_write[raw_header_key] = Variable(
ds_to_write[raw_header_key].dims,
tmp_raw_headers,
attrs=ds_to_write[raw_header_key].attrs,
encoding=ds_to_write[raw_header_key].encoding, # Not strictly necessary, but safer than not doing it.
)

del raw_headers # Manage memory
data_variable = ds_to_write[data_variable_name]
fill_value = _get_fill_value(ScalarType(data_variable.dtype.name))
tmp_samples = np.full_like(data_variable, fill_value=fill_value)
Expand Down
40 changes: 38 additions & 2 deletions src/mdio/segy/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,38 @@
logger = logging.getLogger(__name__)


def _filter_raw_unspecified_fields(headers: NDArray) -> NDArray:
"""Filter out __MDIO_RAW_UNSPECIFIED_Field_* fields from headers array.

These fields are added during SEGY import to preserve raw header bytes,
but they cause dtype mismatches during export. This function removes them.

Args:
headers: Header array that may contain raw unspecified fields.

Returns:
Header array with raw unspecified fields removed.
"""
if headers.dtype.names is None:
return headers

# Find field names that don't start with __MDIO_RAW_UNSPECIFIED_
field_names = [name for name in headers.dtype.names if not name.startswith("__MDIO_RAW_UNSPECIFIED_")]

if len(field_names) == len(headers.dtype.names):
# No raw unspecified fields found, return as-is
return headers

# Create new structured array with only the non-raw fields
new_dtype = [(name, headers.dtype.fields[name][0]) for name in field_names]
filtered_headers = np.empty(headers.shape, dtype=new_dtype)

for name in field_names:
filtered_headers[name] = headers[name]

return filtered_headers


def make_segy_factory(spec: SegySpec, binary_header: dict[str, int]) -> SegyFactory:
"""Generate SEG-Y factory from MDIO metadata."""
sample_interval = binary_header["sample_interval"]
Expand Down Expand Up @@ -167,7 +199,9 @@ def serialize_to_segy_stack( # noqa: PLR0913
samples = samples[live_mask]
headers = headers[live_mask]

buffer = segy_factory.create_traces(headers, samples)
# Filter out raw unspecified fields that cause dtype mismatches
filtered_headers = _filter_raw_unspecified_fields(headers)
buffer = segy_factory.create_traces(filtered_headers, samples)

global_index = block_start[0]
record_id_str = str(global_index)
Expand Down Expand Up @@ -199,7 +233,9 @@ def serialize_to_segy_stack( # noqa: PLR0913
rec_samples = samples[rec_index][rec_live_mask]
rec_headers = headers[rec_index][rec_live_mask]

buffer = segy_factory.create_traces(rec_headers, rec_samples)
# Filter out raw unspecified fields that cause dtype mismatches
filtered_headers = _filter_raw_unspecified_fields(rec_headers)
buffer = segy_factory.create_traces(filtered_headers, rec_samples)

global_index = tuple(block_start[i] + rec_index[i] for i in range(record_ndim))
record_id_str = "/".join(map(str, global_index))
Expand Down
109 changes: 105 additions & 4 deletions tests/integration/test_segy_import_export_masked.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,11 +282,16 @@ def generate_selection_mask(selection_conf: SelectionMaskConfig, grid_conf: Grid


@pytest.fixture
def export_masked_path(tmp_path_factory: pytest.TempPathFactory) -> Path:
def export_masked_path(tmp_path_factory: pytest.TempPathFactory, raw_headers_env: None) -> Path: # noqa: ARG001
"""Fixture that generates temp directory for export tests."""
# Create path suffix based on current raw headers environment variable
# raw_headers_env dependency ensures the environment variable is set before this runs
raw_headers_enabled = os.getenv("MDIO__DO_RAW_HEADERS") == "1"
path_suffix = "with_raw_headers" if raw_headers_enabled else "without_raw_headers"

if DEBUG_MODE:
return Path("TMP/export_masked")
return tmp_path_factory.getbasetemp() / "export_masked"
return Path(f"TMP/export_masked_{path_suffix}")
return tmp_path_factory.getbasetemp() / f"export_masked_{path_suffix}"


@pytest.fixture
Expand All @@ -300,9 +305,39 @@ def raw_headers_env(request: pytest.FixtureRequest) -> None:

yield

# Cleanup after test
# Cleanup after test - both environment variable and template state
os.environ.pop("MDIO__DO_RAW_HEADERS", None)

# Clean up any template modifications to ensure test isolation
registry = TemplateRegistry.get_instance()

# Reset any templates that might have been modified with raw headers
template_names = [
"PostStack2DTime",
"PostStack3DTime",
"PreStackCdpOffsetGathers2DTime",
"PreStackCdpOffsetGathers3DTime",
"PreStackShotGathers2DTime",
"PreStackShotGathers3DTime",
"PreStackCocaGathers3DTime",
]

for template_name in template_names:
try:
template = registry.get(template_name)
# Remove raw headers enhancement if present
if hasattr(template, "_mdio_raw_headers_enhanced"):
delattr(template, "_mdio_raw_headers_enhanced")
# The enhancement is applied by monkey-patching _add_variables
# We need to restore it to the original method from the class
# Since we can't easily restore the exact original, we'll get a fresh instance
template_class = type(template)
if hasattr(template_class, "_add_variables"):
template._add_variables = template_class._add_variables.__get__(template, template_class)
except KeyError:
# Template not found, skip
continue


@pytest.mark.parametrize(
"test_conf",
Expand Down Expand Up @@ -471,3 +506,69 @@ def test_export_masked(
# https://github.com/TGSAI/mdio-python/issues/610
assert_array_equal(actual_sgy.trace[:].header, expected_sgy.trace[expected_trc_idx].header)
assert_array_equal(actual_sgy.trace[:].sample, expected_sgy.trace[expected_trc_idx].sample)

def test_raw_headers_byte_preservation(
self,
test_conf: MaskedExportConfig,
export_masked_path: Path,
raw_headers_env: None, # noqa: ARG002
) -> None:
"""Test that raw headers are preserved byte-for-byte when MDIO__DO_RAW_HEADERS=1."""
grid_conf, segy_factory_conf, _, _ = test_conf
segy_path = export_masked_path / f"{grid_conf.name}.sgy"
mdio_path = export_masked_path / f"{grid_conf.name}.mdio"

# Open MDIO dataset
ds = open_mdio(mdio_path)

# Check if raw_headers should exist based on environment variable
has_raw_headers = "raw_headers" in ds.data_vars
if os.getenv("MDIO__DO_RAW_HEADERS") == "1":
assert has_raw_headers, "raw_headers should be present when MDIO__DO_RAW_HEADERS=1"
else:
assert not has_raw_headers, f"raw_headers should not be present when MDIO__DO_RAW_HEADERS is not set\n {ds}"
return # Exit early if raw_headers are not expected

# Get data (only if raw_headers exist)
raw_headers_data = ds.raw_headers.values
trace_mask = ds.trace_mask.values

# Verify 240-byte headers
assert raw_headers_data.dtype.itemsize == 240, (
f"Expected 240-byte headers, got {raw_headers_data.dtype.itemsize}"
)

# Read raw bytes directly from SEG-Y file
def read_segy_trace_header(trace_index: int) -> bytes:
"""Read 240-byte trace header directly from SEG-Y file."""
# with open(segy_path, "rb") as f:
with Path.open(segy_path, "rb") as f:
# Skip text (3200) + binary (400) headers = 3600 bytes
f.seek(3600)
# Each trace: 240 byte header + (num_samples * 4) byte samples
trace_size = 240 + (segy_factory_conf.num_samples * 4)
trace_offset = trace_index * trace_size
f.seek(trace_offset, 1) # Seek relative to current position
return f.read(240)

# Compare all valid traces byte-by-byte
segy_trace_idx = 0
flat_mask = trace_mask.ravel()
flat_raw_headers = raw_headers_data.ravel() # Flatten to 1D array of 240-byte header records

for grid_idx in range(flat_mask.size):
if not flat_mask[grid_idx]:
print(f"Skipping trace {grid_idx} because it is masked")
continue

# Get MDIO header as bytes - convert single header record to bytes
header_record = flat_raw_headers[grid_idx]
mdio_header_bytes = np.frombuffer(header_record.tobytes(), dtype=np.uint8)

# Get SEG-Y header as raw bytes directly from file
segy_raw_header_bytes = read_segy_trace_header(segy_trace_idx)
segy_header_bytes = np.frombuffer(segy_raw_header_bytes, dtype=np.uint8)

assert_array_equal(mdio_header_bytes, segy_header_bytes)

segy_trace_idx += 1
Loading
Loading