diff --git a/ruins/apps/weather.py b/ruins/apps/weather.py index 2c0b9b9..56b395e 100644 --- a/ruins/apps/weather.py +++ b/ruins/apps/weather.py @@ -1,3 +1,4 @@ +from typing import List, Callable, Tuple import streamlit as st import xarray as xr # TODO: these references should be moved to DataManager import pandas as pd # TODO: these references should be moved to DataManager @@ -6,26 +7,13 @@ from ruins.plotting import plt_map, kde, yrplot_hm from ruins import components +from ruins.core import build_config, DataManager, Config +from ruins.core.cache import partial_memoize #### # OLD STUFF -# -# TODO: replace with DataManager -def load_alldata(): - weather = xr.load_dataset('data/weather.nc') - climate = xr.load_dataset('data/cordex_coast.nc') - - # WARNING - bug fix for now: - # 'HadGEM2-ES' model runs are problematic and will be removed for now - # The issue is with the timestamp and requires revision of the ESGF reading routines - kys = [s for s in list(climate.keys()) if 'HadGEM2-ES' not in s] #remove all entries of HadGEM2-ES (6 entries) - climate = climate[kys] - - return weather, climate - - def applySDM(wdata, data, meth='rel', cdf_threshold=0.9999999, lower_limit=0.1): '''apply structured distribution mapping to climate data and return unbiased version of dataset''' from sdm import SDM @@ -75,7 +63,15 @@ def climate_indi(ts, indi='Summer days (Tmax ≥ 25°C)'): # TODO: document + signature # TODO: extract plotting -def climate_indices(weather: xr.Dataset, climate: xr.Dataset, stati='coast', cliproj=True): +def climate_indices(dataManager: DataManager, config: Config): + # get data + weather = dataManager['weather'].read() + climate = dataManager['cordex_coast'].read() + + # get the relevant settings + stati = config.get('selected_station', 'coast') + + cindi = ['Ice days (Tmax < 0°C)', 'Frost days (Tmin < 0°C)', 'Summer days (Tmax ≥ 25°C)', 'Hot days (Tmax ≥ 30°C)','Tropic nights (Tmin ≥ 20°C)', 'Rainy days (Precip ≥ 1mm)'] ci_topic = st.selectbox('Select Index:', cindi) @@ -92,12 +88,12 @@ def climate_indices(weather: xr.Dataset, climate: xr.Dataset, stati='coast', cli w1 = weather[stati].sel(vars=vari).to_dataframe().dropna() w1.columns = ['bla', vari] - plt.figure(figsize=(10,2.5)) + fig = plt.figure(figsize=(10,2.5)) wi = climate_indi(w1, ci_topic).astype(int) wi.plot(style='.', color='steelblue', label='Coast weather') wi.rolling(10, center=True).mean().plot(color='steelblue', label='Rolling mean\n(10 years)') - if cliproj: + if config['include_climate']: c1 = climate.sel(vars=vari).to_dataframe() c1 = c1[c1.columns[c1.columns != 'vars']] c2 = applySDM(w1[vari], c1, meth=meth) @@ -127,7 +123,7 @@ def climate_indices(weather: xr.Dataset, climate: xr.Dataset, stati='coast', cli plt.legend(ncol=2) plt.ylabel('Number of days') plt.title(ci_topic) - st.pyplot() + st.pyplot(fig) if ci_topic == 'Ice days (Tmax < 0°C)': st.markdown('''Number of days in one year which persistently remain below 0°C air temperature.''') @@ -144,178 +140,246 @@ def climate_indices(weather: xr.Dataset, climate: xr.Dataset, stati='coast', cli return -def weather_explorer(w_topic: str): - weather, climate = load_alldata() - #weather = load_data('Weather') +def data_select(dataManager: DataManager, config: Config, container=st) -> None: + """Create the user interface to control the data view + """ + # get a station list + weather = dataManager['weather'].read() + station_list = list(weather.keys()) # TODO station names krummhoern, coast, inland, niedersachsen? + selected_station = container.selectbox('Select station/group (see map in sidebar for location):', station_list) - #aspects = ['Annual', 'Monthly', 'Season'] - #w_aspect = st.sidebar.selectbox('Temporal aggegate:', aspects) + # select a temporal aggregation + aggregations = config.get('temporal_aggregations', ['Annual', 'Monthly']) + temp_agg = container.selectbox('Select temporal aggregation:', aggregations) + + # include climate projections + include_climate = container.checkbox('Include climate projections (for coastal region)', value=False) + + # add settings + st.session_state.selected_station = selected_station + st.session_state.temp_agg = temp_agg + st.session_state.include_climate = include_climate + + +@partial_memoize(hash_names=['name', 'station', 'variable', 'time', '_filter']) +def _reduce_weather_data(dataManager: DataManager, name: str, variable: str, time: str, station: str = None, _filter: dict = None) -> pd.DataFrame: + # get weather data + arr: xr.Dataset = dataManager[name].read() + + if _filter is not None: + arr = arr.filter_by_attrs(**_filter) + + if station is None: + base = arr + else: + base = arr[station] - #cliproj = st.sidebar.checkbox('add climate projections',False) + # reduce to station and variable + reduced = base.sel(vars=variable).resample(time=time) + if variable == 'Tmax': + df = reduced.max(dim='time').to_dataframe() + elif variable == 'Tmin': + df = reduced.min(dim='time').to_dataframe() + else: + df = reduced.mean(dim='time').to_dataframe() + + if station is None: + return df.loc[:, df.columns != 'vars'] + else: + return df[station] + + +def warming_data_plotter(dataManager: DataManager, config: Config): + weather: xr.Dataset = dataManager['weather'].read() + climate = dataManager['cordex_coast'].read() statios = list(weather.keys()) - stat1 = st.selectbox('Select station/group (see map in sidebar for location):', statios) + stat1 = config['selected_station'] - aspects = ['Annual', 'Monthly'] # , 'Season'] - w_aspect = st.selectbox('Select temporal aggegate:', aspects) + # TODO refactor in data-aggregator and data-plotter for different time frames - cliproj = st.checkbox('add climate projections (for coastal region)',False) - if cliproj: - plt_map(stat1, 'CORDEX') - st.sidebar.markdown( - '''Map with available stations (blue dots) and selected reference station (magenta highlight). The climate model grid is given in orange with the selected references as filled dots.''', - unsafe_allow_html=True) + # ---- + # data-aggregator controls + navi_vars = ['Maximum Air Temperature', 'Mean Air Temperature', 'Minimum Air Temperature'] + navi_var = st.sidebar.radio("Select variable:", options=navi_vars) + if navi_var[:4] == 'Mini': + vari = 'Tmin' + ag = 'min' + elif navi_var[:4] == 'Maxi': + vari = 'Tmax' + ag = 'max' else: - plt_map(stat1) - st.sidebar.markdown( - '''Map with available stations (blue dots) and selected reference station (magenta highlight).''', - unsafe_allow_html=True) + vari = 'T' + ag = 'mean' + + # controls end + # ---- + + # TODO: this produces a slider but also needs some data caching + if config['temp_agg'] == 'Annual': + wdata = _reduce_weather_data(dataManager, name='weather', station=config['selected_station'], variable=vari, time='1Y') + allw = _reduce_weather_data(dataManager, name='weather', variable=vari, time='1Y') + + dataLq = float(np.floor(allw.min().quantile(0.22))) + datamin = float(np.min([dataLq, np.round(allw.min().min(), 1)])) + + if config['include_climate']: + rcps = ['rcp26', 'rcp45', 'rcp85'] + rcp = st.selectbox( + 'RCP (Mean over all projections will be shown. For more details go to section "Climate Projections"):', + rcps) + + data = _reduce_weather_data(dataManager, name='cordex_coast', variable=vari, time='1Y', _filter=dict(RCP=rcp)) + data_ub = applySDM(wdata, data, meth='abs') + + dataUq = float(np.ceil(data_ub.max().quantile(0.76))) + datamax = float(np.max([dataUq, np.round(data_ub.max().max(), 1)])) + else: + dataUq = float(np.ceil(allw.max().quantile(0.76))) + datamax = float(np.max([dataUq,np.round(allw.max().max(), 1)])) + + datarng = st.slider('Adjust data range on x-axis of plot:', min_value=datamin, max_value=datamax, value=(dataLq, dataUq), step=0.1, key='drangew') - if w_topic == 'Warming': - - navi_vars = ['Maximum Air Temperature', 'Mean Air Temperature', 'Minimum Air Temperature'] - navi_var = st.sidebar.radio("Select variable:", options=navi_vars) - if navi_var[:4] == 'Mini': - vari = 'Tmin' - afu = np.min - ag = 'min' - elif navi_var[:4] == 'Maxi': - vari = 'Tmax' - afu = np.max - ag = 'max' + # ------------------- + # start plotting plot + if config['include_climate']: + fig, ax = kde(wdata, data_ub.mean(axis=1), split_ts=3) else: - vari = 'T' - afu = np.mean - ag = 'mean' - - if w_aspect == 'Annual': - wdata = weather[stat1].sel(vars=vari).resample(time='1Y').apply(afu).to_dataframe()[stat1] - wdata = wdata[~np.isnan(wdata)] - allw = weather.sel(vars=vari).resample(time='1Y').apply(afu).to_dataframe().iloc[:, 1:] - - dataLq = float(np.floor(allw.min().quantile(0.22))) - datamin = float(np.min([dataLq, np.round(allw.min().min(), 1)])) - if cliproj: - rcps = ['rcp26', 'rcp45', 'rcp85'] - rcp = st.selectbox( - 'RCP (Mean over all projections will be shown. For more details go to section "Climate Projections"):', - rcps) - - data = climate.filter_by_attrs(RCP=rcp).sel(vars=vari).resample(time='1Y').apply(afu).to_dataframe() - data = data[data.columns[data.columns != 'vars']] - data_ub = applySDM(wdata, data, meth='abs') + fig, ax = kde(wdata, split_ts=3) - dataUq = float(np.ceil(data_ub.max().quantile(0.76))) - datamax = float(np.max([dataUq, np.round(data_ub.max().max(), 1)])) - else: - dataUq = float(np.ceil(allw.max().quantile(0.76))) - datamax = float(np.max([dataUq,np.round(allw.max().max(), 1)])) + ax.set_title(stat1 + ' Annual ' + navi_var) + ax.set_xlabel('T (°C)') + ax.set_xlim(datarng[0],datarng[1]) + st.pyplot(fig) - datarng = st.slider('Adjust data range on x-axis of plot:', min_value=datamin, max_value=datamax, value=(dataLq, dataUq), step=0.1, key='drangew') + sndstat = st.checkbox('Show second station for comparison') - if cliproj: - ax = kde(wdata, data_ub.mean(axis=1), split_ts=3) - else: - ax = kde(wdata, split_ts=3) + if sndstat: + stat2 = st.selectbox('Select second station:', [x for x in statios if x != config['selected_station']]) + wdata2 = _reduce_weather_data(dataManager, name='weather', station=stat2, variable=vari, time='1Y') - ax.set_title(stat1 + ' Annual ' + navi_var) - ax.set_xlabel('T (°C)') - ax.set_xlim(datarng[0],datarng[1]) - st.pyplot() + fig, ax2 = kde(wdata2, split_ts=3) + ax2.set_title(stat2 + ' Annual ' + navi_var) + ax2.set_xlabel('T (°C)') + ax2.set_xlim(datarng[0],datarng[1]) + st.pyplot(fig) - sndstat = st.checkbox('Show second station for comparison') + # Re-implement this as a application wide service + # expl_md = read_markdown_file('explainer/stripes.md') + # st.markdown(expl_md, unsafe_allow_html=True) - if sndstat: - stat2 = st.selectbox('Select second station:', [x for x in statios if x != stat1]) - wdata2 = weather[stat2].sel(vars=vari).resample(time='1Y').apply(afu).to_dataframe()[stat2] + elif config['temp_agg'] == 'Monthly': + wdata = _reduce_weather_data(dataManager, name='weather', station=config['selected_station'], variable=vari, time='1M') - ax2 = kde(wdata2, split_ts=3) - ax2.set_title(stat2 + ' Annual ' + navi_var) - ax2.set_xlabel('T (°C)') - ax2.set_xlim(datarng[0],datarng[1]) - st.pyplot() + ref_yr = st.slider('Reference period for anomaly calculation:', min_value=int(wdata.index.year.min()), max_value=2020,value=(max(1980, int(wdata.index.year.min())), 2000)) - # Re-implement this as a application wide service - # expl_md = read_markdown_file('explainer/stripes.md') - # st.markdown(expl_md, unsafe_allow_html=True) + if config['include_climate']: + rcps = ['rcp26', 'rcp45', 'rcp85'] + rcp = st.selectbox('RCP (Mean over all projections will be shown. For more details go to section "Climate Projections"):', rcps) - elif w_aspect == 'Monthly': - wdata = weather[stat1].sel(vars=vari).resample(time='1M').apply(afu).to_dataframe()[stat1] - wdata = wdata[~np.isnan(wdata)] - ref_yr = st.slider('Reference period for anomaly calculation:', min_value=int(wdata.index.year.min()), max_value=2020,value=(max(1980, int(wdata.index.year.min())), 2000)) + data = _reduce_weather_data(dataManager, name='cordex_coast', variable=vari, time='1M', _filter=dict(RCP=rcp)) - if cliproj: - rcps = ['rcp26', 'rcp45', 'rcp85'] - rcp = st.selectbox('RCP (Mean over all projections will be shown. For more details go to section "Climate Projections"):', rcps) + #ub = st.sidebar.checkbox('Apply SDM bias correction',True) + ub = True # simplify here and automatically apply bias correction - data = climate.filter_by_attrs(RCP=rcp).sel(vars=vari).resample(time='1M').apply(afu).to_dataframe() - data = data[data.columns[data.columns != 'vars']] + if ub: + data_ub = applySDM(wdata, data, meth='abs') + fig = yrplot_hm(pd.concat([wdata.loc[wdata.index[0]:data.index[0] - pd.Timedelta('1M')], data_ub.mean(axis=1)]),ref_yr, ag, li=2006) + else: + fig = yrplot_hm(pd.concat([wdata.loc[wdata.index[0]:data.index[0] - pd.Timedelta('1M')], data.mean(axis=1)]), ref_yr, ag, li=2006) + + plt.title(stat1 + ' ' + navi_var + ' anomaly to ' + str(ref_yr[0]) + '-' + str(ref_yr[1])) + st.pyplot(fig) - #ub = st.sidebar.checkbox('Apply SDM bias correction',True) - ub = True # simplify here and automatically apply bias correction - if ub: - data_ub = applySDM(wdata, data, meth='abs') - yrplot_hm(pd.concat([wdata.loc[wdata.index[0]:data.index[0] - pd.Timedelta('1M')], data_ub.mean(axis=1)]),ref_yr, ag, li=2006) - else: - yrplot_hm(pd.concat([wdata.loc[wdata.index[0]:data.index[0] - pd.Timedelta('1M')], data.mean(axis=1)]), ref_yr, ag, li=2006) + # TODO: break up this as well + else: + fig = yrplot_hm(wdata,ref_yr,ag) + plt.title(stat1 + ' ' + navi_var + ' anomaly to ' + str(ref_yr[0]) + '-' + str(ref_yr[1])) + st.pyplot(fig) - plt.title(stat1 + ' ' + navi_var + ' anomaly to ' + str(ref_yr[0]) + '-' + str(ref_yr[1])) - st.pyplot() + sndstat = st.checkbox('Compare to a second station?') + if sndstat: + stat2 = st.selectbox('Select second station:', [x for x in statios if x != stat1]) + data2 = _reduce_weather_data(dataManager, name='weather', station=stat2, variable=vari, time='1M') + + ref_yr2 = list(ref_yr) + if ref_yr2[1]blue dots) and selected reference station (magenta highlight). The climate model grid is given in orange with the selected references as filled dots.''', + unsafe_allow_html=True) + else: + fig = plt_map(dataManager, sel=config['selected_station']) + st.sidebar.plotly_chart(fig) + st.sidebar.markdown( + '''Map with available stations (blue dots) and selected reference station (magenta highlight).''', + unsafe_allow_html=True) + # switch the topic + topic = config['current_topic'] + if topic == 'Warming': + warming_data_plotter(dataManager, config) + + elif topic == 'Weather Indices': + climate_indices(dataManager, config) - else: - yrplot_hm(wdata,ref_yr,ag) - plt.title(stat1 + ' ' + navi_var + ' anomaly to ' + str(ref_yr[0]) + '-' + str(ref_yr[1])) - st.pyplot() - - sndstat = st.checkbox('Compare to a second station?') - - if sndstat: - stat2 = st.selectbox('Select second station:', [x for x in statios if x != stat1]) - data2 = weather[stat2].sel(vars=vari).resample(time='1M').apply(afu).to_dataframe()[stat2] - data2 = data2[~np.isnan(data2)] - - ref_yr2 = list(ref_yr) - if ref_yr2[1] str: + +def topic_selector(config: Config, container=st, config_expander=st) -> str: """ - Select a topic from a list of topics. The selected topic is returned and - additionally published to the session cache. - - Parameters - ---------- - topic_list : List[str] - List of topics to select from. - force_topic_select : bool - If False, the dropdown will not be shown if a topic was already - selected and is present in the streamlit session cache. - Default: True - container : streamlit.st.container - Container to use for the dropdown. Defaults to main streamlit. - **kwargs - These keyword arguments are only accepted to directly inject - :class:`Config ` objects. + TODO: Alex will das dokumentieren.... """ - # check if a topic is already present - if kwargs.get('no_cache', False): - current_topic = kwargs.get('current_topic', None) - else: # pragma: no cover - current_topic = st.session_state.get('topic', kwargs.get('current_topic', None)) - if current_topic is not None and not force_topic_select: - return current_topic + current_topic = config.get('current_topic') + + # get topic list + topic_list = config['topic_list'] + + # get the policy + policy = config.get_control_policy('topic_selector') + + # create the control + if current_topic is not None: + topic = container.selectbox('Select a topic', topic_list) - # otherwise print select - topic = st.selectbox( - 'Select a topic', - topic_list, - index=0 if current_topic is None else topic_list.index(current_topic) - ) - - # store topic in session cache - if current_topic != topic and not kwargs.get('no_cache', False): # pragma: no cover - st.session_state['topic'] = topic + elif policy == 'show': # pragma: no cover + topic = container.selectbox( + 'Select a topic', + topic_list, + #index=topic_list.index(config['current_topic']) + ) - return topic + elif policy == 'hide': # pragma: no cover + topic = config_expander.selectbox( + 'Select a topic', + topic_list, + #index=topic_list.index(config['current_topic']) + ) + + else: + topic = current_topic + + # set the new topic + if current_topic != topic: + st.session_state.current_topic = topic + diff --git a/ruins/core/__init__.py b/ruins/core/__init__.py index d1db6a5..ec72ca5 100644 --- a/ruins/core/__init__.py +++ b/ruins/core/__init__.py @@ -1,2 +1,3 @@ from .config import Config -from .data_manager import DataManager \ No newline at end of file +from .data_manager import DataManager +from .build import build_config \ No newline at end of file diff --git a/ruins/core/build.py b/ruins/core/build.py index 6247315..62eef1d 100644 --- a/ruins/core/build.py +++ b/ruins/core/build.py @@ -2,16 +2,27 @@ Build a :class:`Config ` and a :class:`DataManager ` from a kwargs dict. """ -from types import Union, Tuple +from typing import Union, Tuple, Dict, List +import streamlit as st from .config import Config from .data_manager import DataManager +st.experimental_singleton +def contextualized_data_manager(**kwargs) -> DataManager: + return DataManager(**kwargs) -def build_config(omit_dataManager: bool = False, **kwargs) -> Tuple[Config, Union[None, DataManager]]: + +def build_config(omit_dataManager: bool = False, url_params: Dict[str, List[str]] = {}, **kwargs) -> Tuple[Config, Union[None, DataManager]]: """ """ + # prepare the url params, if any + # url params are always a list: https://docs.streamlit.io/library/api-reference/utilities/st.experimental_get_query_params + # TODO: This should be sanitzed to avoid injection attacks! + ukwargs = {k: v[0] if len(v) == 1 else v for k, v in url_params.items()} + kwargs.update(ukwargs) + # extract the DataManager, if it was already instantiated if 'dataManager' in kwargs: dataManager = kwargs.pop('dataManager') @@ -22,8 +33,8 @@ def build_config(omit_dataManager: bool = False, **kwargs) -> Tuple[Config, Unio config = Config(**kwargs) if omit_dataManager: - return config + return config, None else: if dataManager is None: - dataManager = DataManager(**config) + dataManager = contextualized_data_manager(**config) return config, dataManager diff --git a/ruins/core/cache.py b/ruins/core/cache.py new file mode 100644 index 0000000..50e14d0 --- /dev/null +++ b/ruins/core/cache.py @@ -0,0 +1,33 @@ +from typing import Callable, List +from functools import wraps +import hashlib + + +LOCAL = dict() + +def _hashargs(fname, argnames): + h = f'{fname};' + ','.join(argnames) + digest = hashlib.sha256(h.encode()).hexdigest() + + return digest + +def partial_memoize(hash_names: List[str], store: str = 'local'): + def func_decorator(f: Callable): + @wraps(f) + def wrapper(*args, **kwargs): + argnames = [str(a) for a in args if a in hash_names] + argnames.extend([str(v) for k, v in kwargs.items() if k in hash_names]) + + # get the parameter hash + h = _hashargs(f.__name__, argnames) + + # check if result exists + if h in LOCAL: + return LOCAL.get(h) + else: + # process + result = f(*args, **kwargs) + LOCAL[h] = result + return result + return wrapper + return func_decorator diff --git a/ruins/core/config.py b/ruins/core/config.py index d9d412c..00d6157 100644 --- a/ruins/core/config.py +++ b/ruins/core/config.py @@ -1,8 +1,16 @@ +from streamlit import session_state +import streamlit as st + import os from os.path import join as pjoin import json from collections.abc import Mapping + +# check if streamlit is running +if not st._is_running_with_streamlit: + session_state = dict() + class Config(Mapping): """ Streamlit app Config object. @@ -29,6 +37,7 @@ def __init__(self, path: str = None, **kwargs) -> None: # path self.basepath = os.path.abspath(pjoin(os.path.dirname(__file__), '..', '..')) self.datapath = pjoin(self.basepath, 'data') + self.hot_load = kwargs.get('hot_load', False) # mime readers self.default_sources = { @@ -45,8 +54,14 @@ def __init__(self, path: str = None, **kwargs) -> None: } self.sources_args.update(kwargs.get('include_args', {})) + # app management + self.layout = 'centered' + + # app content + self.topic_list = ['Warming', 'Weather Indices', 'Drought/Flood', 'Agriculture', 'Extreme Events', 'Wind Energy'] + # store the keys - self._keys = ['debug', 'basepath', 'datapath', 'default_sources', 'sources_args'] + self._keys = ['debug', 'basepath', 'datapath', 'hot_load', 'default_sources', 'sources_args', 'layout', 'topic_list'] # check if a path was provided conf_args = self.from_json(path) if path else {} @@ -71,8 +86,35 @@ def _update(self, new_settings: dict) -> None: if k not in self._keys: self._keys.append(k) + def get_control_policy(self, control_name: str) -> str: + """ + Get the control policy for the given control name. + + allowed policies are: + - show: always show the control on the main container + - hide: hide the control on the main container, but move to the expander + - ignore: don't show anything + + """ + if self.has_key(f'{control_name}_policy'): + return self.get(f'{control_name}_policy') + elif self.has_key('controls_policy'): + return self.get('controls_policy') + else: + # TODO: discuss with conrad to change this + return 'show' + + def get(self, key: str, default = None): - return getattr(self, key, default) + if hasattr(self, key): + return getattr(self, key) + elif key in session_state: + return session_state[key] + else: + return default + + def has_key(self, key) -> bool: + return hasattr(self, key) or hasattr(session_state, key) or key in session_state def __len__(self) -> int: return len(self._keys) @@ -82,4 +124,14 @@ def __iter__(self): yield k def __getitem__(self, key: str): - return getattr(self, key) + if hasattr(self, key): + return getattr(self, key) + elif key in session_state: + return session_state[key] + else: + raise KeyError(f"Key {key} not found") + + def __setitem__(self, key: str, value): + setattr(self, key, value) + if key not in self._keys: + self._keys.append(key) diff --git a/ruins/core/data_manager.py b/ruins/core/data_manager.py index 5f95e3f..2c686d4 100644 --- a/ruins/core/data_manager.py +++ b/ruins/core/data_manager.py @@ -67,10 +67,15 @@ class FileSource(DataSource, abc.ABC): Abstract base class for file sources. This provides the common interface for every data source that is based on a file. """ - def __init__(self, path: str, cache: bool = True, **kwargs): + def __init__(self, path: str, cache: bool = True, hot_load = False, **kwargs): super().__init__(**kwargs) self.path = path self.cache = cache + + # check if the dataset should be pre-loaded + if hot_load: + self.cache = True + self.data = self._load_source() @abc.abstractmethod def _load_source(self): @@ -94,8 +99,11 @@ class HDF5Source(FileSource): """ HDF5 file sources. This class is used to load HDF5 files. """ - def _load_source(self): + def _load_source(self) -> xr.Dataset: return xr.open_dataset(self.path) + + def read(self) -> xr.Dataset: + return super(HDF5Source, self).read() class CSVSource(FileSource): @@ -139,7 +147,7 @@ class DataManager(Mapping): The include_mimes can be overwritten by passing filenames directly. """ - def __init__(self, datapath: str = None, cache: bool = True, debug: bool = False, **kwargs) -> None: + def __init__(self, datapath: str = None, cache: bool = True, hot_load = False, debug: bool = False, **kwargs) -> None: """ You can pass in a Config as kwargs. """ @@ -148,9 +156,9 @@ def __init__(self, datapath: str = None, cache: bool = True, debug: bool = False from ruins.core import Config self.from_config(**Config(**kwargs)) else: - self.from_config(datapath=datapath, cache=cache, debug=debug, **kwargs) + self.from_config(datapath=datapath, cache=cache, hot_load=hot_load, debug=debug, **kwargs) - def from_config(self, datapath: str = None, cache: bool = True, debug: bool = False, **kwargs) -> None: + def from_config(self, datapath: str = None, cache: bool = True, hot_load: bool = False, debug: bool = False, **kwargs) -> None: """ Initialize the DataManager from a :class:`Config ` object. """ @@ -158,6 +166,7 @@ def from_config(self, datapath: str = None, cache: bool = True, debug: bool = Fa self._config = kwargs self._datapath = datapath self.cache = cache + self.hot_load = hot_load self.debug = debug # file settings @@ -225,7 +234,7 @@ def add_source(self, path: str, not_exists: str = 'raise') -> None: # add the source # args = self._config.get(basename, {}) - args.update({'path': path, 'cache': self.cache}) + args.update({'path': path, 'cache': self.cache, 'hot_load': self.hot_load}) self._data_sources[basename] = BaseClass(**args) else: if not_exists == 'raise': @@ -255,7 +264,7 @@ def __iter__(self): for name in self._data_sources.keys(): yield name - def __getitem__(self, key: str): + def __getitem__(self, key: str) -> DataSource: """Return the requested datasource""" return self._data_sources[key] diff --git a/ruins/plotting/kde.py b/ruins/plotting/kde.py index 18ad7fb..ff96635 100644 --- a/ruins/plotting/kde.py +++ b/ruins/plotting/kde.py @@ -122,4 +122,4 @@ def kde(data, cmdata='none', split_ts=1, cplot=True, eq_period=True): ax.set_ylabel('Occurrence (KDE)') - return ax + return fig, ax diff --git a/ruins/plotting/maps.py b/ruins/plotting/maps.py index cdbcd1a..9a18df3 100644 --- a/ruins/plotting/maps.py +++ b/ruins/plotting/maps.py @@ -1,23 +1,29 @@ import streamlit as st import plotly.graph_objs as go import plotly.express as px +import pandas as pd import numpy as np +from ruins.core import DataManager +from ruins.core.cache import partial_memoize + + +@partial_memoize(hash_names=['sel', 'cm']) +def plt_map(dataManager: DataManager, sel='all', cm='none') -> go.Figure: + # cordex_grid = xr.open_dataset('data/CORDEXgrid.nc') + # cimp_grid = xr.open_dataset('data/CMIP5grid.nc') + # stats = pd.read_csv('data/stats.csv', index_col=0) + cordex_grid = dataManager['CORDEXgrid'].read() + cimp_grid = dataManager['CMIP5grid'].read() + stats = dataManager['stats'].read() -def plt_map(sel='all',cm='none'): - # TODO remove this part - import xarray as xr - import pandas as pd - dummy = xr.open_dataset('data/CORDEXgrid.nc') - dummy5 = xr.open_dataset('data/CMIP5grid.nc') - stats = pd.read_csv('data/stats.csv', index_col=0) stats['ms'] = 15. stats['color'] = 'gray' mapbox_access_token = 'pk.eyJ1IjoiY29qYWNrIiwiYSI6IkRTNjV1T2MifQ.EWzL4Qk-VvQoaeJBfE6VSA' px.set_mapbox_access_token(mapbox_access_token) - nodexy = pd.DataFrame([dummy.lon.values.ravel(), dummy.lat.values.ravel()]).T + nodexy = pd.DataFrame([cordex_grid.lon.values.ravel(), cordex_grid.lat.values.ravel()]).T nodexy.columns = ['lon', 'lat'] nodexy['hov'] = 'CORDEX grid' @@ -127,8 +133,8 @@ def add_cmpx(cm): [13, 13], [14, 13], [15, 13], [16, 13], [17, 13], [18, 13], [19, 14], [20, 14], [21, 13]] for cc in maskcordex_coast: fig.add_trace(go.Scattermapbox( - lat=[dummy.lat.values[tuple(cc)]], - lon=[dummy.lon.values[tuple(cc)]], + lat=[cordex_grid.lat.values[tuple(cc)]], + lon=[cordex_grid.lon.values[tuple(cc)]], mode='markers', marker=go.scattermapbox.Marker( size=8, @@ -143,11 +149,11 @@ def add_cmpx(cm): #fig = px.scatter_mapbox(nodexy, lat='lat', lon='lon', center={'lat': 53.0, 'lon': 8.3}, zoom=5, opacity=0.1, hover_data=['hov']) fig = px.scatter_mapbox(stats, lat='lat', lon='lon', center={'lat': 53.0, 'lon': 8.6}, zoom=5, size='ms', opacity=0.8, color='color', hover_data=['Station name', 'lat', 'lon'], size_max=10) if cm != 'none': - fig = lin_grid(fig, dummy) - fig = lin_grid(fig, dummy5, '#2c7fb8') + fig = lin_grid(fig, cordex_grid) + fig = lin_grid(fig, cimp_grid, '#2c7fb8') fig = add_cmpx(cm) fig = add_stats(sel) fig.update_layout(showlegend=False,width=300, height=350,margin=dict(l=10, r=10, b=10, t=10)) # ,center={'lat':54.0,'lon':8.3}, zoom=7) - st.sidebar.plotly_chart(fig) - return + # st.sidebar.plotly_chart(fig) + return fig diff --git a/ruins/plotting/weather_data.py b/ruins/plotting/weather_data.py index 5a52116..77cfea0 100644 --- a/ruins/plotting/weather_data.py +++ b/ruins/plotting/weather_data.py @@ -1,6 +1,7 @@ import numpy as np import pandas as pd import matplotlib.pyplot as plt +import seaborn as sns def yrplot_hm(sr, ref=[1980, 2000], ag='sum', qa=0.95, cbar_title='Temperature anomaly (K)', cmx='coolwarm', cmxeq=True, li=False): @@ -40,7 +41,7 @@ def yrplot_hm(sr, ref=[1980, 2000], ag='sum', qa=0.95, cbar_title='Temperature a if ag == 'sum': dummy.iloc[:, 13] = dummy.iloc[:, 13] / 12 - plt.figure(figsize=(8,len(dummy)/15.)) + fig = plt.figure(figsize=(8,len(dummy)/15.)) ax = sns.heatmap(dummy, cmap=cmx, vmin=vxL, vmax=vxU, cbar_kws={'label': cbar_title}) if ref == None: @@ -65,7 +66,8 @@ def yrplot_hm(sr, ref=[1980, 2000], ag='sum', qa=0.95, cbar_title='Temperature a ax.set_ylabel('Year') ax.set_xlabel('Month ') - return + + return fig def monthlyx(dy, dyx=1, ylab='T (°C)', clab1='Monthly Mean in Year', clab2='Monthly Max in Year', pls='cividis_r'): diff --git a/ruins/tests/test_config.py b/ruins/tests/test_config.py index 548047b..75deb4d 100644 --- a/ruins/tests/test_config.py +++ b/ruins/tests/test_config.py @@ -3,6 +3,7 @@ """ import os import json +import pytest from ruins import core @@ -62,3 +63,11 @@ def test_config_as_dict(): # check default get behavior assert c.get('doesNotExist') is None assert c.get('doesNotExists', 'foobar') == 'foobar' + + +def test_config_key_error(): + """An unknown key should throw a key error""" + c = core.Config() + + with pytest.raises(KeyError): + c['doesNotExist'] diff --git a/ruins/tests/test_topic_selector.py b/ruins/tests/test_topic_selector.py deleted file mode 100644 index 0a091c8..0000000 --- a/ruins/tests/test_topic_selector.py +++ /dev/null @@ -1,23 +0,0 @@ -from ruins import components - - -def test_default(): - """Test default behavior""" - assert components.topic_selector(['a', 'b', 'c'], no_cache=True) == 'a' - - -def test_current_preset(): - """Test with current topic set""" - topic = components.topic_selector(['a', 'b', 'c'], current_topic='b', no_cache=True) - - assert topic == 'b' - - -def test_no_force_render(): - """Test with rendering disabled""" - # default - assert components.topic_selector(['a', 'b', 'c'], force_topic_select=False, no_cache=True) == 'a' - - # with current topic set - assert components.topic_selector(['a', 'b', 'c'], current_topic='b', force_topic_select=False, no_cache=True) == 'b' - \ No newline at end of file diff --git a/ruins/tests/test_weather.py b/ruins/tests/test_weather.py index 64fe7c7..968c003 100644 --- a/ruins/tests/test_weather.py +++ b/ruins/tests/test_weather.py @@ -1,27 +1,20 @@ from ruins.apps import weather from ruins.tests.util import get_test_config -from ruins.core import DataManager +from ruins.core import DataManager, Config - -# TODO use the config and inject the dedub config here - -# TODO only run this test when the Value Error is solved -#def test_run_app(): -# """Make sure the appp runs without failing""" -# weather.main_app() +# create the test config +config = get_test_config() def test_climate_indices(): """Test only climate indices """ - conf = get_test_config() - dm = DataManager(**conf) - - w = dm['weather'].read() - c = dm['cordex_coast'].read() + # add the include_climate config + config['include_climate'] = True + dm = DataManager(**config) # run - weather.climate_indices(w, c) + weather.climate_indices(dataManager=dm, config=config) def test_climate_indi():