Source code for aidapy.aidaxr.graphical

"""AIDA module responsible for the Graphical utilities of the timeseries
Contributors: Etienne Behar
"""

import numpy as np
import matplotlib.pyplot as plt
import xarray as xr


[docs]@xr.register_dataset_accessor('graphical') @xr.register_dataarray_accessor('graphical') class AidaAccessorGraphical: """Xarray accessor responsible for the Graphical utilities """ def __init__(self, xarray_obj): self._obj = xarray_obj
[docs] def peek(self): """ Plot of the time series Parameters ---------- axes : `~matplotlib.axes.Axes` or None If provided the image will be plotted on the given axes. Otherwise the current axes will be used. **plot_args : `dict` Any additional plot arguments that should be used when plotting. """ def xr_find_time_index(xarr): name = None for index in xarr.dims: if xarr.coords[index].values.dtype == np.dtype('datetime64[ns]'): name = index if name: return name else: raise TimeIndexNotFound def _find_proper_coord(indexes, __tindex): for ind in indexes: if ind == __tindex: continue return ind def _proper_plot(_data, _tindex): indexes = [key for key in _data.dims] _var = _find_proper_coord(indexes, _tindex) names = _data.coords[_var].values values = _data.coords[_tindex].values for ik, nam in enumerate(names): plt.plot(values, _data.values[:, ik], c=np.random.rand(3, ), label=nam, markersize=14) if isinstance(self._obj, xr.Dataset): cols = self.find_columns_name() for _, var in enumerate(cols): data = self._obj[var] if len(data.values.shape) == 2: time_index = xr_find_time_index(data) _proper_plot(data, time_index) # TODO: Adapt the labels for multiple probes else: data.plot(c=np.random.rand(3, ), label=data.name, markersize=14) # TODO: Remove random colors elif isinstance(self._obj, xr.DataArray): if len(self._obj.values.shape) == 2: time_index = self.find_time_index() _proper_plot(self._obj, time_index) else: self._obj.plot(c=np.random.rand(3, ), label=self._obj.name, markersize=14) # TODO: Remove random colors else: raise ValueError #pylab.legend(loc='upper left') plt.show()
[docs] def find_columns_name(self): """ Method to automatically find the name of the columns of the xarray Returns ------- col : list the name of the columns of the xarray """ col = list() if isinstance(self._obj, xr.Dataset): for key in self._obj.data_vars: col.append(key) elif isinstance(self._obj, xr.DataArray): for dim in self._obj.dims: col.append(dim) else: raise ValueError return col
[docs] def find_time_index(self): """ Method to automatically find the time index of xarray Returns ------- name : str the name of the time index """ name = None for index in self.index_names(): if self._obj.coords[index].values.dtype == np.dtype('datetime64[ns]'): name = index if name: return name else: raise TimeIndexNotFound
[docs] def index_names(self): """ The data indexes names """ keys = list() for key in self._obj.dims: keys.append(key) return keys
[docs]class TimeIndexNotFound: """Error Class responsible for the time index"""