Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,10 @@ dmypy.json

# Pyre type checker
.pyre/

# data
*.h5
*.hdf5
notebooks/1
notebooks/2
notebooks/3
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@ repos:
hooks:
- id: docformatter
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.1.6
rev: v0.1.13
hooks:
- id: ruff
args: [--fix]
- id: ruff-format
types_or: [python, pyi, jupyter]
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.7.1
rev: v1.8.0
hooks:
- id: mypy
additional_dependencies:
Expand Down
37 changes: 37 additions & 0 deletions notebooks/prepare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Artificially modify data for the notebooks
from __future__ import annotations

from imas2xarray import Variable, to_imas, to_xarray

variables = (
Variable(
name='ion_temperature',
ids='core_profiles',
path='profiles_1d/*/ion/*/temperature',
dims=['time', 'ion', '$rho_tor_norm'],
),
't_i_ave',
)

ids = 'core_profiles'
path = '.'

for subdir, k in (
('2', 1.2),
('3', 1.4),
):
dataset = to_xarray(
f'{path}/1/data',
ids=ids,
variables=variables,
)
print(subdir, k)
dataset['t_i_ave'] *= k
dataset['ion_temperature'] *= k

to_imas(
f'{path}/{subdir}/data',
dataset=dataset,
ids=ids,
variables=variables,
)
8 changes: 8 additions & 0 deletions notebooks/prepare.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Small script to initialize a few datasets for the notebooks using the `hdf5_testdata`:
# https://github.com/duqtools/hdf5_testdata/tree/main

git clone ~/python/hdf5_testdata 1
git clone ~/python/hdf5_testdata 2
git clone ~/python/hdf5_testdata 3

python prepare.py
217 changes: 100 additions & 117 deletions notebooks/xarray-2D.ipynb

Large diffs are not rendered by default.

249 changes: 116 additions & 133 deletions notebooks/xarray-ions.ipynb

Large diffs are not rendered by default.

891 changes: 458 additions & 433 deletions notebooks/xarray.ipynb

Large diffs are not rendered by default.

34 changes: 21 additions & 13 deletions src/imas2xarray/_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from contextlib import contextmanager
from pathlib import Path
from typing import TYPE_CHECKING, Collection
from typing import TYPE_CHECKING, Iterable

import h5py
import numpy as np
Expand Down Expand Up @@ -64,7 +64,7 @@ def _var_path_to_hdf5_key_and_slices(path: str) -> tuple[str, tuple[slice | int,

def _mapping_to_xarray(
data_file: h5py.File,
variables: Collection[str | IDSVariableModel],
variables: Iterable[str | IDSVariableModel],
missing_ok: bool = False,
empty_ok: bool = False,
) -> xr.Dataset:
Expand All @@ -74,7 +74,7 @@ def _mapping_to_xarray(
----------
data_file : h5py.File
Open hdf5 file
variables : Collection[str | IDSVariableModel]]
variables : Collection[(str | IDSVariableModel)]]
List of data variables
missing_ok : bool
Ignore missing variables from dataset
Expand Down Expand Up @@ -116,7 +116,7 @@ def _mapping_to_xarray(


def to_xarray(
path: str | Path, *, ids: str, variables: None | Collection[str] = None
path: str | Path, *, ids: str, variables: None | Iterable[str | IDSVariableModel] = None
) -> xr.Dataset:
"""Load IDS from given path to IMAS data into an xarray dataset.

Expand All @@ -128,7 +128,7 @@ def to_xarray(
Path to the data
ids : str
The IDS to load (i.e. 'core_profiles')
variables : None | list[str], optional
variables : None | Iterable[str | Variable], optional
List of variables to load. If None, attempt to load
all variables known to `imas2xarray`

Expand All @@ -146,7 +146,11 @@ def to_xarray(


def to_imas(
path: str | Path, dataset: xr.Dataset, *, ids: str, variables: None | Collection[str] = None
path: str | Path,
dataset: xr.Dataset,
*,
ids: str,
variables: None | Iterable[str | IDSVariableModel] = None,
):
"""Write variables in xarray dataset back to IMAS data at given path.

Expand All @@ -160,7 +164,7 @@ def to_imas(
Input dataset
ids : str
The IDS to write to (i.e. 'core_profiles')
variables : Collection[str]
variables : Iterable[str | Variable]
List of variables to write back. If None, attempt to write back
all variables known to `imas2xarray`
"""
Expand Down Expand Up @@ -196,7 +200,7 @@ def get_all_variables(
self,
*,
ids: str,
extra_variables: None | Collection[IDSVariableModel] = None,
extra_variables: None | Iterable[IDSVariableModel] = None,
squash: bool = True,
**kwargs,
) -> xr.Dataset:
Expand All @@ -209,7 +213,7 @@ def get_all_variables(
----------
ids : str
The IDS to write to (i.e. 'core_profiles')
extra_variables : Collection[IDSVariableModel]
extra_variables : Iterable[Variable]
Extra variables to load in addition to the ones known through the config
squash : bool
Squash placeholder variables
Expand All @@ -229,7 +233,7 @@ def get_all_variables(

def get_variables(
self,
variables: Collection[str | IDSVariableModel],
variables: Iterable[str | IDSVariableModel],
*,
ids: str,
squash: bool = True,
Expand All @@ -243,7 +247,7 @@ def get_variables(

Parameters
----------
variables : Collection[Union[str, IDSVariableModel]]
variables : Iterable[str | Variable]
Variable names of the data to load.
ids : str
The IDS to write to (i.e. 'core_profiles')
Expand Down Expand Up @@ -281,7 +285,11 @@ def get_variables(
return ds

def set_variables(
self, dataset: xr.Dataset, *, ids: str, variables: None | Collection[str] = None
self,
dataset: xr.Dataset,
*,
ids: str,
variables: None | Iterable[str | IDSVariableModel] = None,
):
"""Update variables in corresponding ids datafile.

Expand All @@ -292,7 +300,7 @@ def set_variables(
target dataset.
ids : str
IDS to write to.
variables : Collection[str], optional
variables : Iterable[str | Variable], optional
List of data variables to write.
"""
if not variables:
Expand Down
6 changes: 3 additions & 3 deletions src/imas2xarray/_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import sys
from collections import UserDict
from pathlib import Path, PosixPath
from typing import Any, Collection, Hashable
from typing import Any, Hashable, Iterable

from pydantic_yaml import parse_yaml_raw_as

Expand Down Expand Up @@ -80,13 +80,13 @@ def groupby_ids(self) -> dict[Hashable, list[IDSVariableModel]]:
return grouped_ids_vars

def lookup(
self, variables: Collection[(str | IDSVariableModel)], skip_missing: bool = False
self, variables: Iterable[(str | IDSVariableModel)], skip_missing: bool = False
) -> set[IDSVariableModel]:
"""Helper function to look up a bunch of variables.

Parameters
----------
variables : Collection[(str | IDSVariableModel)]
variables : Iterable[(str | IDSVariableModel)]
List of variables to load. If str, look up the variable from the `var_lookup`.
Else, ensure the variable is an `IDSVariableModel`.
skip_missing : bool
Expand Down
6 changes: 4 additions & 2 deletions src/imas2xarray/_rebase.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
logger = logging.getLogger(__name__)


def rezero_time(ds: xr.Dataset, *, start: int = 0) -> None:
def rezero_time(ds: xr.Dataset, *, start: int = 0, key: str = 'time') -> None:
"""Standardize the time within a dataset by setting the first timestep to
0.

Expand All @@ -21,10 +21,12 @@ def rezero_time(ds: xr.Dataset, *, start: int = 0) -> None:
----------
ds : xr.Dataset
Source dataset
key : str
Name of the time dimension
start : int, optional
Where to start the returned time series
"""
ds['time'] = ds['time'] - ds['time'][0] + start
ds[key] = ds[key] - ds[key][0] + start


def squash_placeholders(ds: xr.Dataset) -> xr.Dataset:
Expand Down