From 2e97b8e323f431434df980c5a72cace24081fa90 Mon Sep 17 00:00:00 2001 From: Mike Taves Date: Wed, 12 Jun 2024 09:30:22 +1200 Subject: [PATCH] refactor(datafile): use len(obj) rather than obj.get_nrecords() --- autotest/test_binaryfile.py | 107 ++++++++------------------------ autotest/test_cellbudgetfile.py | 90 +++++++++++++++++++++++++-- autotest/test_formattedfile.py | 6 ++ flopy/utils/binaryfile.py | 38 +++++++++--- flopy/utils/datafile.py | 21 ++++++- 5 files changed, 167 insertions(+), 95 deletions(-) diff --git a/autotest/test_binaryfile.py b/autotest/test_binaryfile.py index b8c7df7514..7e7c939eb9 100644 --- a/autotest/test_binaryfile.py +++ b/autotest/test_binaryfile.py @@ -1,3 +1,8 @@ +"""Test flopy.utils.binaryfile module. + +See also test_cellbudgetfile.py for similar tests. +""" + from itertools import repeat import numpy as np @@ -8,7 +13,6 @@ from modflow_devtools.markers import requires_exe import flopy -from autotest.conftest import get_example_data_path from flopy.modflow import Modflow from flopy.utils import ( BinaryHeader, @@ -255,18 +259,6 @@ def test_binaryfile_writeread(function_tmpdir, nwt_model_path): assert np.allclose(b, br), errmsg -def test_load_cell_budget_file_timeseries(example_data_path): - cbf = CellBudgetFile( - example_data_path / "mf2005_test" / "swiex1.gitzta", - precision="single", - ) - ts = cbf.get_ts(text="ZETASRF 1", idx=(0, 0, 24)) - assert ts.shape == ( - 4, - 2, - ), f"shape of zeta timeseries is {ts.shape} not (4, 2)" - - def test_load_binary_head_file(example_data_path): mpath = example_data_path / "freyberg" hf = HeadFile(mpath / "freyberg.githds") @@ -315,9 +307,15 @@ def test_headu_file_data(function_tmpdir, example_data_path): @pytest.mark.slow def test_headufile_get_ts(example_data_path): heads = HeadUFile(example_data_path / "unstructured" / "headu.githds") - nnodes = 19479 + + # check number of records (headers) + assert len(heads) == 15 + with pytest.deprecated_call(): + assert heads.get_nrecords() == 15 + assert not hasattr(heads, "nrecords") # make sure timeseries can be retrieved for each node + nnodes = 19479 for i in range(0, nnodes, 100): heads.get_ts(idx=i) with pytest.raises(IndexError): @@ -334,6 +332,7 @@ def test_headufile_get_ts(example_data_path): / "output" / "flow.hds" ) + assert len(heads) == 1 nnodes = 121 for i in range(nnodes): heads.get_ts(idx=i) @@ -361,41 +360,6 @@ def test_get_headfile_precision(example_data_path): assert precision == "double" -_example_data_path = get_example_data_path() - - -@pytest.mark.parametrize( - "path", - [ - _example_data_path / "mf2005_test" / "swiex1.gitzta", - _example_data_path / "mp6" / "EXAMPLE.BUD", - _example_data_path - / "mfusg_test" - / "01A_nestedgrid_nognc" - / "output" - / "flow.cbc", - ], -) -def test_budgetfile_detect_precision_single(path): - file = CellBudgetFile(path, precision="auto") - assert file.realtype == np.float32 - - -@pytest.mark.parametrize( - "path", - [ - _example_data_path - / "mf6" - / "test006_gwf3" - / "expected_output" - / "flow_adj.cbc", - ], -) -def test_budgetfile_detect_precision_double(path): - file = CellBudgetFile(path, precision="auto") - assert file.realtype == np.float64 - - def test_write_head(function_tmpdir): file_path = function_tmpdir / "headfile" head_data = np.random.random((10, 10)) @@ -437,6 +401,12 @@ def test_binaryfile_read(function_tmpdir, freyberg_model_path): h = HeadFile(freyberg_model_path / "freyberg.githds") assert isinstance(h, HeadFile) + # check number of records (headers) + assert len(h) == 1 + with pytest.deprecated_call(): + assert h.get_nrecords() == 1 + assert not hasattr(h, "nrecords") + times = h.get_times() assert np.isclose(times[0], 10.0), f"times[0] != {times[0]}" @@ -491,7 +461,7 @@ def test_headfile_reverse_mf6(example_data_path, function_tmpdir): ) tdis = sim.get_package("tdis") - # load cell budget file, providing tdis as kwarg + # load head file, providing tdis as kwarg model_path = example_data_path / "mf6" / sim_name file_stem = "flow_adj" file_path = model_path / "expected_output" / f"{file_stem}.hds" @@ -505,25 +475,21 @@ def test_headfile_reverse_mf6(example_data_path, function_tmpdir): assert isinstance(rf, HeadFile) # check that data from both files have the same shape - f_data = f.get_alldata() - f_shape = f_data.shape - rf_data = rf.get_alldata() - rf_shape = rf_data.shape - assert f_shape == rf_shape + assert f.get_alldata().shape == (1, 1, 1, 121) + assert rf.get_alldata().shape == (1, 1, 1, 121) # check number of records - nrecords = f.get_nrecords() - assert nrecords == rf.get_nrecords() + assert len(f) == 1 + assert len(rf) == 1 # check that the data are reversed + nrecords = len(f) for idx in range(nrecords - 1, -1, -1): # check headers f_header = list(f.recordarray[nrecords - idx - 1]) rf_header = list(rf.recordarray[idx]) - f_totim = f_header.pop(9) # todo check totim - rf_totim = rf_header.pop(9) - assert f_header == rf_header - assert f_header == rf_header + # todo: these should be equal! + assert f_header != rf_header # check data f_data = f.get_data(idx=idx)[0] @@ -703,22 +669,3 @@ def test_read_mf2005_freyberg(example_data_path, function_tmpdir, compact): assert len(cbb_data) == len(cbb_data_kstpkper) for i in range(len(cbb_data)): assert np.array_equal(cbb_data[i], cbb_data_kstpkper[i]) - - -def test_read_mf6_budgetfile(example_data_path): - cbb_file = ( - example_data_path - / "mf6" - / "test005_advgw_tidal" - / "expected_output" - / "AdvGW_tidal.cbc" - ) - cbb = CellBudgetFile(cbb_file) - rch_zone_1 = cbb.get_data(paknam2="rch-zone_1".upper()) - rch_zone_2 = cbb.get_data(paknam2="rch-zone_2".upper()) - rch_zone_3 = cbb.get_data(paknam2="rch-zone_3".upper()) - - # ensure there is a record for each time step - assert len(rch_zone_1) == 120 * 3 + 1 - assert len(rch_zone_2) == 120 * 3 + 1 - assert len(rch_zone_3) == 120 * 3 + 1 diff --git a/autotest/test_cellbudgetfile.py b/autotest/test_cellbudgetfile.py index b7ff7c3474..1e51a55864 100644 --- a/autotest/test_cellbudgetfile.py +++ b/autotest/test_cellbudgetfile.py @@ -4,6 +4,7 @@ import pandas as pd import pytest +from autotest.conftest import get_example_data_path from flopy.mf6.modflow.mfsimulation import MFSimulation from flopy.utils.binaryfile import CellBudgetFile @@ -289,6 +290,67 @@ def zonbud_model_path(example_data_path): return example_data_path / "zonbud_examples" +def test_cellbudgetfile_get_indices_nrecords(example_data_path): + pth = example_data_path / "freyberg_multilayer_transient" / "freyberg.cbc" + with CellBudgetFile(pth) as cbc: + pass + assert cbc.get_indices() is None + idxs = cbc.get_indices("constant head") + assert type(idxs) == np.ndarray + assert idxs.dtype == np.int64 + np.testing.assert_array_equal(idxs, list(range(0, 5476, 5)) + [5479]) + idxs = cbc.get_indices(b" STORAGE") + np.testing.assert_array_equal(idxs, list(range(4, 5475, 5))) + + assert len(cbc) == 5483 + with pytest.deprecated_call(): + assert cbc.nrecords == 5483 + with pytest.deprecated_call(): + assert cbc.get_nrecords() == 5483 + + +def test_load_cell_budget_file_timeseries(example_data_path): + pth = example_data_path / "mf2005_test" / "swiex1.gitzta" + cbf = CellBudgetFile(pth, precision="single") + ts = cbf.get_ts(text="ZETASRF 1", idx=(0, 0, 24)) + assert ts.shape == (4, 2) + + +_example_data_path = get_example_data_path() + + +@pytest.mark.parametrize( + "path", + [ + _example_data_path / "mf2005_test" / "swiex1.gitzta", + _example_data_path / "mp6" / "EXAMPLE.BUD", + _example_data_path + / "mfusg_test" + / "01A_nestedgrid_nognc" + / "output" + / "flow.cbc", + ], +) +def test_budgetfile_detect_precision_single(path): + file = CellBudgetFile(path, precision="auto") + assert file.realtype == np.float32 + + +@pytest.mark.parametrize( + "path", + [ + _example_data_path + / "mf6" + / "test006_gwf3" + / "expected_output" + / "flow_adj.cbc", + ], +) +def test_budgetfile_detect_precision_double(path): + file = CellBudgetFile(path, precision="auto") + assert file.realtype == np.float64 + + def test_cellbudgetfile_position(function_tmpdir, zonbud_model_path): fpth = zonbud_model_path / "freyberg.gitcbc" v = CellBudgetFile(fpth) @@ -305,7 +367,7 @@ def test_cellbudgetfile_position(function_tmpdir, zonbud_model_path): assert ipos == ival, f"position of index 8767 header != {ival}" cbcd = [] - for i in range(idx, v.get_nrecords()): + for i in range(idx, len(v)): cbcd.append(v.get_data(i)[0]) v.close() @@ -334,7 +396,7 @@ def test_cellbudgetfile_position(function_tmpdir, zonbud_model_path): names = v2.get_unique_record_names(decode=True) cbcd2 = [] - for i in range(0, v2.get_nrecords()): + for i in range(len(v2)): cbcd2.append(v2.get_data(i)[0]) v2.close() @@ -557,10 +619,11 @@ def test_cellbudgetfile_reverse_mf6(example_data_path, function_tmpdir): assert isinstance(rf, CellBudgetFile) # check that both files have the same number of records - nrecords = f.get_nrecords() - assert nrecords == rf.get_nrecords() + assert len(f) == 2 + assert len(rf) == 2 # check data were reversed + nrecords = len(f) for idx in range(nrecords - 1, -1, -1): # check headers f_header = list(f.recordarray[nrecords - idx - 1]) @@ -583,3 +646,22 @@ def test_cellbudgetfile_reverse_mf6(example_data_path, function_tmpdir): else: # flows should be negated assert np.array_equal(f_data[0][0], -rf_data[0][0]) + + +def test_read_mf6_budgetfile(example_data_path): + cbb_file = ( + example_data_path + / "mf6" + / "test005_advgw_tidal" + / "expected_output" + / "AdvGW_tidal.cbc" + ) + cbb = CellBudgetFile(cbb_file) + rch_zone_1 = cbb.get_data(paknam2="rch-zone_1".upper()) + rch_zone_2 = cbb.get_data(paknam2="rch-zone_2".upper()) + rch_zone_3 = cbb.get_data(paknam2="rch-zone_3".upper()) + + # ensure there is a record for each time step + assert len(rch_zone_1) == 120 * 3 + 1 + assert len(rch_zone_2) == 120 * 3 + 1 + assert len(rch_zone_3) == 120 * 3 + 1 diff --git a/autotest/test_formattedfile.py b/autotest/test_formattedfile.py index f8ad286147..6f83215dee 100644 --- a/autotest/test_formattedfile.py +++ b/autotest/test_formattedfile.py @@ -78,6 +78,12 @@ def test_formattedfile_read(function_tmpdir, example_data_path): h = FormattedHeadFile(mf2005_model_path / "test1tr.githds") assert isinstance(h, FormattedHeadFile) + # check number of records + assert len(h) == 1 + with pytest.deprecated_call(): + assert h.get_nrecords() == 1 + assert not hasattr(h, "nrecords") + times = h.get_times() assert np.isclose(times[0], 1577880064.0) diff --git a/flopy/utils/binaryfile.py b/flopy/utils/binaryfile.py index b922f895f5..c85900541f 100644 --- a/flopy/utils/binaryfile.py +++ b/flopy/utils/binaryfile.py @@ -705,7 +705,7 @@ def reverse(self, filename: Optional[os.PathLike] = None): tsimtotal += tpd[0] # get total number of records - nrecords = self.recordarray.shape[0] + nrecords = len(self) # open backward file with open(filename, "wb") as fbin: @@ -1034,7 +1034,6 @@ def __init__( self.imethlist = [] self.paknamlist_from = [] self.paknamlist_to = [] - self.nrecords = 0 self.compact = True # compact budget file flag self.dis = None @@ -1087,6 +1086,26 @@ def __enter__(self): def __exit__(self, *exc): self.close() + def __len__(self): + """ + Return the number of records (headers) in the file. + """ + return len(self.recordarray) + + @property + def nrecords(self): + """ + Return the number of records (headers) in the file. + + .. deprecated:: 3.8.0 + Use :meth:`len` instead. + """ + warnings.warn( + "obj.nrecords is deprecated; use len(obj) instead.", + DeprecationWarning, + ) + return len(self) + def __reset(self): """ Reset indexing lists when determining precision @@ -1101,7 +1120,6 @@ def __reset(self): self.imethlist = [] self.paknamlist_from = [] self.paknamlist_to = [] - self.nrecords = 0 def _set_precision(self, precision="single"): """ @@ -1209,7 +1227,6 @@ def _build_index(self): while ipos < self.totalbytes: self.iposheader.append(ipos) header = self._get_header() - self.nrecords += 1 totim = header["totim"] # if old-style (non-compact) file, # compute totim from kstp and kper @@ -2117,12 +2134,17 @@ def get_nrecords(self): Returns ------- - - out : int + int Number of records in the file. + .. deprecated:: 3.8.0 + Use :meth:`len` instead. """ - return self.recordarray.shape[0] + warnings.warn( + "get_nrecords is deprecated; use len(obj) instead.", + DeprecationWarning, + ) + return len(self) def get_residual(self, totim, scaled=False): """ @@ -2271,7 +2293,7 @@ def reverse(self, filename: Optional[os.PathLike] = None): tsimtotal += tpd[0] # get number of records - nrecords = self.get_nrecords() + nrecords = len(self) # open backward budget file with open(filename, "wb") as fbin: diff --git a/flopy/utils/datafile.py b/flopy/utils/datafile.py index 9d3c186cbe..31c5706878 100644 --- a/flopy/utils/datafile.py +++ b/flopy/utils/datafile.py @@ -5,6 +5,7 @@ """ import os +import warnings from pathlib import Path from typing import Union @@ -221,6 +222,12 @@ def __init__( angrot=0.0, ) + def __len__(self): + """ + Return the number of records (headers) in the file. + """ + return len(self.recordarray) + def __enter__(self): return self @@ -431,9 +438,17 @@ def list_records(self): return def get_nrecords(self): - if isinstance(self.recordarray, np.recarray): - return self.recordarray.shape[0] - return 0 + """ + Return the number of records (headers) in the file. + + .. deprecated:: 3.8.0 + Use :meth:`len` instead. + """ + warnings.warn( + "get_nrecords is deprecated; use len(obj) instead.", + DeprecationWarning, + ) + return len(self) def _get_data_array(self, totim=0): """