Source code for hypertools.datageometry

import copy
import pickle
import warnings

import pandas as pd

from .tools.normalize import normalize as normalizer
from .tools.reduce import reduce as reducer
from .tools.align import align as aligner
from .tools.format_data import format_data
from ._shared.helpers import convert_text, get_dtype
from .config import __version__


[docs] class DataGeometry(object): """ Hypertools data object class A DataGeometry object contains the data, figure handles and transform functions used to create a plot. Note: this class should not be called directly, but is used by the `hyp.plot` function to create a plot object. Parameters ---------- fig : matplotlib.Figure The matplotlib figure handle for the plot ax : matplotlib.Axes The matplotlib axes handle for the plot line_ani : matplotlib.animation.FuncAnimation The matplotlib animation handle (if the plot is an animation) data : list A list of numpy arrays representing the raw data xform_data : list A list of numpy arrays representing the transformed data reduce : dict A dictionary containing the reduction model and parameters align : dict A dictionary containing align model and parameters normalize : str A string representing the kind of normalization kwargs : dict A dictionary containing all kwargs passed to the plot function version : str The version of the software used to create the class instance """
[docs] def __init__(self, fig=None, ax=None, line_ani=None, data=None, xform_data=None, reduce=None, align=None, normalize=None, semantic=None, vectorizer=None, corpus=None, kwargs=None, version=__version__, dtype=None): # matplotlib figure handle self.fig = fig # matplotlib axis handle self.ax = ax # matplotlib line_ani handle (if its an animation) self.line_ani = line_ani # convert to numpy array if text if isinstance(data, list): data = list(map(convert_text, data)) self.data = data self.dtype = get_dtype(data) # the transformed data self.xform_data = xform_data # dictionary of model and model_params self.reduce = reduce # 'hyper', 'SRM' or None self.align = align # 'within', 'across', 'row' or False self.normalize = normalize # text params self.semantic = semantic self.vectorizer = vectorizer self.corpus = corpus # dictionary of kwargs self.kwargs = kwargs # hypertools version self.version = version
def get_data(self): """Return a copy of the data""" return copy.copy(self.data) def get_formatted_data(self): """Return a formatted copy of the data""" return format_data(self.data) # a function to transform new data def transform(self, data=None): """ Return transformed data, or transform new data using the same model parameters Parameters ---------- data : numpy array, pandas dataframe or list of arrays/dfs The data to transform. If no data is passed, the xform_data from the DataGeometry object will be returned. Returns ---------- xformed_data : list of numpy arrays The transformed data """ # if no new data passed, if data is None: return self.xform_data else: formatted = format_data( data, semantic=self.semantic, vectorizer=self.vectorizer, corpus=self.corpus, ppca=True) norm = normalizer(formatted, normalize=self.normalize) reduction = reducer( norm, reduce=self.reduce, ndims=self.reduce['params']['n_components']) return aligner(reduction, align=self.align) # a function to plot the data def plot(self, data=None, **kwargs): """ Plot the data Parameters ---------- data : numpy array, pandas dataframe or list of arrays/dfs The data to plot. If no data is passed, the xform_data from the DataGeometry object will be returned. kwargs : keyword arguments Any keyword arguments supported by `hypertools.plot` are also supported by this method Returns ---------- geo : hypertools.DataGeometry A new data geometry object """ # import plot here to avoid circular imports from .plot.plot import plot as plotter if data is None: d = copy.copy(self.data) transform = copy.copy(self.xform_data) if any([k in kwargs for k in ['reduce', 'align', 'normalize', 'semantic', 'vectorizer', 'corpus']]): d = copy.copy(self.data) transform = None else: d = data transform = None # get kwargs and update with new kwargs new_kwargs = copy.copy(self.kwargs) update_kwargs = dict(transform=transform, reduce=self.reduce, align=self.align, normalize=self.normalize, semantic=self.semantic, vectorizer=self.vectorizer, corpus=self.corpus) new_kwargs.update(update_kwargs) for key in kwargs: new_kwargs.update({key : kwargs[key]}) return plotter(d, **new_kwargs) def save(self, fname, compression=None): """ Save method for the data geometry object The data will be saved as a 'geo' file, which is a dictionary containing the elements of a data geometry object saved in the hd5 format using `deepdish`. Parameters ---------- fname : str A name for the file. If the file extension (.geo) is not specified, it will be appended. """ if compression is not None: warnings.warn("Hypertools has switched from deepdish to pickle " "for saving DataGeomtry objects. 'compression' " "argument has no effect and will be removed in a " "future version", FutureWarning) # automatically add extension if not present if not fname.endswith('.geo'): fname += '.geo' # can't save/restore matplotlib objects across sessions curr_fig = self.fig curr_ax = self.ax curr_line_ani = self.line_ani curr_data = self.data # convert pandas DataFrames to dicts of # {column_name: list(column_values)} to fix I/O compatibility # issues across certain pandas versions. Expected self.data # format is restored by hypertools.load if isinstance(curr_data, pd.DataFrame): data_out_fmt = curr_data.to_dict('list') else: data_out_fmt = curr_data try: self.fig = self.ax = self.line_ani = None self.data = data_out_fmt # save with open(fname, 'wb') as f: pickle.dump(self, f) finally: # make sure we don't mutate attribute values whether or not # save was successful self.fig = curr_fig self.ax = curr_ax self.line_ani = curr_line_ani self.data = curr_data