Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 37 additions & 87 deletions src/xray/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,16 @@
"""
# TODO: implement backend logic directly in OrderedDict subclasses, to allow
# for directly manipulating Dataset.variables and the like?
import netCDF4 as nc4
import numpy as np
import pandas as pd
import netCDF4 as nc4

from scipy.io import netcdf
from collections import OrderedDict

import xarray
import conventions
from utils import FrozenOrderedDict, Frozen, datetimeindex2num

from utils import FrozenOrderedDict, Frozen
from conventions import is_valid_nc3_name, coerce_nc3_dtype, encode_cf_variable


class AbstractDataStore(object):
Expand All @@ -30,6 +30,11 @@ def set_variables(self, variables):
for vn, v in variables.iteritems():
self.set_variable(vn, v)

def set_necessary_dimensions(self, variable):
for d, l in zip(variable.dimensions, variable.shape):
if d not in self.ds.dimensions:
self.set_dimension(d, l)


class InMemoryDataStore(AbstractDataStore):
"""
Expand Down Expand Up @@ -59,49 +64,21 @@ def sync(self):
pass


def convert_to_cf_variable(array):
"""Converts an XArray into an XArray suitable for saving as a netCDF
variable
"""
data = array.data
attributes = array.attributes.copy()
if isinstance(data, pd.DatetimeIndex):
# DatetimeIndex objects need to be encoded into numeric arrays
(data, units, calendar) = datetimeindex2num(data)
attributes['units'] = units
attributes['calendar'] = calendar
elif data.dtype == np.dtype('O'):
# Unfortunately, pandas.Index arrays often have dtype=object even if
# they were created from an array with a sensible datatype (e.g.,
# pandas.Float64Index always has dtype=object for some reason). Because
# we allow for doing math with coordinates, these object arrays can
# propagate onward to other variables, which is why we don't only apply
# this check to XArrays with data that is a pandas.Index.
dtype = np.array(data.reshape(-1)[0]).dtype
# N.B. the "astype" call will fail if data cannot be cast to the type
# of its first element (which is probably the only sensible thing to
# do).
data = np.asarray(data).astype(dtype)
return xarray.XArray(array.dimensions, data, attributes)


def convert_scipy_variable(var):
return xarray.XArray(var.dimensions, var.data, var._attributes)


class ScipyDataStore(AbstractDataStore):
"""
Stores data using the scipy.io.netcdf package.
This store has the advantage of being able to
be initialized with a StringIO object, allow for
serialization.
"""
def __init__(self, fobj, *args, **kwdargs):
self.ds = netcdf.netcdf_file(fobj, *args, **kwdargs)
def __init__(self, filename_or_obj, mode='r', mmap=None, version=1):
self.ds = netcdf.netcdf_file(filename_or_obj, mode=mode, mmap=mmap,
version=version)

@property
def variables(self):
return FrozenOrderedDict((k, convert_scipy_variable(v))
return FrozenOrderedDict((k, xarray.XArray(v.dimensions, v.data,
v._attributes))
for k, v in self.ds.variables.iteritems())

@property
Expand All @@ -119,42 +96,26 @@ def set_dimension(self, name, length):
self.ds.createDimension(name, length)

def _validate_attr_key(self, key):
if not conventions.is_valid_name(key):
if not is_valid_nc3_name(key):
raise ValueError("Not a valid attribute name")

def _cast_attr_value(self, value):
# Strings get special handling because netCDF treats them as
# character arrays. Everything else gets coerced to a numpy
# vector. netCDF treats scalars as 1-element vectors. Arrays of
# non-numeric type are not allowed.
if isinstance(value, basestring):
# netcdf attributes should be unicode
value = unicode(value)
else:
try:
value = conventions.coerce_type(np.atleast_1d(np.asarray(value)))
except:
raise ValueError("Not a valid value for a netCDF attribute")
value = coerce_nc3_dtype(np.atleast_1d(value))
if value.ndim > 1:
raise ValueError("netCDF attributes must be vectors " +
"(1-dimensional)")
value = conventions.coerce_type(value)
if str(value.dtype) not in conventions.TYPEMAP:
# A plain string attribute is okay, but an array of
# string objects is not okay!
raise ValueError("Can not convert to a valid netCDF type")
raise ValueError("netCDF attributes must be 1-dimensional")
return value

def set_attribute(self, key, value):
self._validate_attr_key(key)
setattr(self.ds, key, self._cast_attr_value(value))

def set_variable(self, name, variable):
variable = convert_to_cf_variable(variable)
data = variable.data
dtype_convert = {'int64': 'int32', 'float64': 'float32'}
if str(data.dtype) in dtype_convert:
data = np.asarray(data, dtype=dtype_convert[str(data.dtype)])
variable = encode_cf_variable(variable)
data = coerce_nc3_dtype(variable.data)
self.set_necessary_dimensions(variable)
self.ds.createVariable(name, data.dtype, variable.dimensions)
scipy_var = self.ds.variables[name]
scipy_var[:] = data[:]
Expand All @@ -169,31 +130,22 @@ def sync(self):
self.ds.flush()


def convert_nc4_variable(var):
# we don't want to see scale_factor and add_offset in the attributes
# since the netCDF4 package automatically scales the data on read.
# If we kept scale_factor and add_offset around and did this:
#
# foo = ncdf4.Dataset('foo.nc')
# ncdf4.dump(foo, 'bar.nc')
# bar = ncdf4.Dataset('bar.nc')
#
# you would find that any packed variables in the original
# netcdf file would now have been scaled twice!
attr = OrderedDict((k, var.getncattr(k)) for k in var.ncattrs()
if k not in ['scale_factor', 'add_offset'])
return xarray.XArray(var.dimensions, var, attr, indexing_mode='orthogonal')


class NetCDF4DataStore(AbstractDataStore):
def __init__(self, filename, *args, **kwdargs):
# TODO: set auto_maskandscale=True so we can handle the array
# packing/unpacking ourselves (using NaN instead of masked arrays)
self.ds = nc4.Dataset(filename, *args, **kwdargs)

def __init__(self, filename, mode='r', clobber=True, diskless=False,
persist=False, format='NETCDF4'):
self.ds = nc4.Dataset(filename, mode=mode, clobber=clobber,
diskless=diskless, persist=persist,
format=format)

@property
def variables(self):
return FrozenOrderedDict((k, convert_nc4_variable(v))
def convert_variable(var):
attr = OrderedDict((k, var.getncattr(k)) for k in var.ncattrs())
var.set_auto_maskandscale(False)
return xarray.XArray(var.dimensions, var,
attr, indexing_mode='orthogonal')
return FrozenOrderedDict((k, convert_variable(v))
for k, v in self.ds.variables.iteritems())

@property
Expand All @@ -203,21 +155,18 @@ def attributes(self):

@property
def dimensions(self):
return FrozenOrderedDict((k, len(v)) for k, v in self.ds.dimensions.iteritems())
return FrozenOrderedDict((k, len(v))
for k, v in self.ds.dimensions.iteritems())

def set_dimension(self, name, length):
self.ds.createDimension(name, size=length)

def set_attribute(self, key, value):
self.ds.setncatts({key: value})

def _cast_data(self, data):
if isinstance(data, pd.DatetimeIndex):
data = datetimeindex2num(data)
return data

def set_variable(self, name, variable):
variable = convert_to_cf_variable(variable)
variable = encode_cf_variable(variable)
self.set_necessary_dimensions(variable)
# netCDF4 will automatically assign a fill value
# depending on the datatype of the variable. Here
# we let the package handle the _FillValue attribute
Expand All @@ -228,6 +177,7 @@ def set_variable(self, name, variable):
dimensions=variable.dimensions,
fill_value=fill_value)
nc4_var = self.ds.variables[name]
nc4_var.set_auto_maskandscale(False)
nc4_var[:] = variable.data[:]
nc4_var.setncatts(variable.attributes)

Expand Down
10 changes: 5 additions & 5 deletions src/xray/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,19 +35,19 @@ def __len__(self):
return len(self._data)

def __nonzero__(self):
return bool(self._data)
return bool(self.data)

def __float__(self):
return float(self._data)
return float(self.data)

def __int__(self):
return int(self._data)
return int(self.data)

def __complex__(self):
return complex(self._data)
return complex(self.data)

def __long__(self):
return long(self._data)
return long(self.data)

# adapted from pandas.NDFrame
# https://github.com/pydata/pandas/blob/master/pandas/core/generic.py#L699
Expand Down
Loading