Source code for hypertools.tools.load

import pickle
from os.path import expanduser, expandvars
from pathlib import Path

import numpy as np
import pandas as pd
import requests

from ..datageometry import DataGeometry
from .._shared.exceptions import HypertoolsIOError
from .analyze import analyze


BASE_URL = 'https://docs.google.com/uc?export=download'
DATA_DIR = Path.home().joinpath('hypertools_data')

EXAMPLE_DATA = {
    'weights': '1ZXLao5Rxkr45KUMkv08Y1eAedTkpivsd',
    'weights_avg': '1gfI1WB7QqogdYgdclqznhUfxsrhobueO',
    'weights_sample': '1ub-xlYW1D_ASzbLcALcPJuhHUxRwHdIs',
    'spiral': '1nHAusn2VsQinJk35xvJSd7CtWPC1uOwK',
    'mushrooms': '12hmCIZp1tyUoPRHwpiAsm1GDBxiJS8ji',
    'wiki': '1NUqm3svfu2rrFH04xmLbOh0u5WyTe9mh',
    'sotus': '1J0MBhpRwdT2WChfWJ4HXYq6jU4XpyJPm',
    'nips': '1FV7xT2hVgZ1sXfMvAdP1jRsK_dWhp49I',
    'wiki_model': '1T-UAU-6KVGUBcUWqz7yG59vXnThu9T0H',
    'nips_model': '1J0MBhpRwdT2WChfWJ4HXYq6jU4XpyJPm',
    'sotus_model': '16_n9r82pwxzZh-0qdS4a6l0z3v__Q91C'
}


[docs] def load( dataset, reduce=None, ndims=None, align=None, normalize=None, *, legacy=False ): """ Load a .geo file or example data Parameters ---------- dataset : string The name of the example dataset. Can be a `.geo` file, or one of a number of example datasets listed below. `weights` is list of 2 numpy arrays, each containing average brain activity (fMRI) from 18 subjects listening to the same story, fit using Hierarchical Topographic Factor Analysis (HTFA) with 100 nodes. The rows are fMRI measurements and the columns are parameters of the model. `weights_sample` is a sample of 3 subjects from that dataset. `weights_avg` is the dataset split in half and averaged into two groups. `spiral` is numpy array containing data for a 3D spiral, used to highlight the `procrustes` function. `mushrooms` is a numpy array comprised of features (columns) of a collection of 8,124 mushroomm samples (rows). `sotus` is a collection of State of the Union speeches from 1989-2018. `wiki` is a collection of wikipedia pages used to fit wiki-model. `wiki-model` is a sklearn Pipeline (CountVectorizer->LatentDirichletAllocation) trained on a sample of wikipedia articles. It can be used to transform text to topic vectors. normalize : str or False or None If set to 'across', the columns of the input data will be z-scored across lists (default). That is, the z-scores will be computed with with respect to column n across all arrays passed in the list. If set to 'within', the columns will be z-scored within each list that is passed. If set to 'row', each row of the input data will be z-scored. If set to False, the input data will be returned with no z-scoring. reduce : str or dict Decomposition/manifold learning model to use. Models supported: PCA, IncrementalPCA, SparsePCA, MiniBatchSparsePCA, KernelPCA, FastICA, FactorAnalysis, TruncatedSVD, DictionaryLearning, MiniBatchDictionaryLearning, TSNE, Isomap, SpectralEmbedding, LocallyLinearEmbedding, and MDS. Can be passed as a string, but for finer control of the model parameters, pass as a dictionary, e.g. reduce={'model' : 'PCA', 'params' : {'whiten' : True}}. See scikit-learn specific model docs for details on parameters supported for each model. ndims : int Number of dimensions to reduce align : str or dict If str, either 'hyper' or 'SRM'. If 'hyper', alignment algorithm will be hyperalignment. If 'SRM', alignment algorithm will be shared response model. You can also pass a dictionary for finer control, where the 'model' key is a string that specifies the model and the params key is a dictionary of parameter values (default : 'hyper'). legacy : bool Pass legacy=True to load DataGeometry objects created with hypertools<0.8.0 Returns ---------- data : Numpy Array Example data """ if dataset in EXAMPLE_DATA.keys(): geo_data = _load_example_data(dataset) if dataset.endswith('_model'): # geo_data is a sklearn.pipeline.Pipeline, not a DataGeometry return geo_data else: dataset_path = Path(expanduser(expandvars(dataset))).resolve() if not dataset_path.is_file(): raise HypertoolsIOError( f"Dataset not found at {dataset_path}. Please specify a .geo " "file or one of the following sample files: 'weights', " "'weights_avg', 'weights_sample', 'spiral', 'mushrooms', " "'wiki', 'nips', 'sotus', 'wiki_model', 'nips_model', " "'sotus_model'" ) elif legacy: geo_data = _load_legacy(dataset_path) else: try: geo_data = pickle.loads(dataset_path.read_bytes()) except pickle.UnpicklingError as e: raise HypertoolsIOError( "Failed to load DataGeometry object from " f"{dataset_path}. If {dataset_path.name} was created " "with hypertools<0.8.0, pass legacy=True to load it." ) from e if isinstance(geo_data.data, dict): geo_data.data = pd.DataFrame(geo_data.data) if any({reduce, ndims, align, normalize}): from ..plot.plot import plot reduce = reduce or 'IncrementalPCA' d = analyze(geo_data.get_data(), reduce=reduce, ndims=ndims, align=align, normalize=normalize) return plot(d, show=False) return geo_data
def _load_legacy(dataset_path): try: import deepdish as dd except ImportError as e: # catches ModuleNotFoundError since it's a subclass raise HypertoolsIOError( "To load legacy-format datasets, install the 'deepdish' module" ) from e data_dict = dd.io.load(dataset_path) if isinstance(data_dict['data'], dict): data_dict['data'] = pd.DataFrame(data_dict['data']) elif isinstance(data_dict['data'], np.ndarray): data_dict['data'] = list(data_dict['data']) data_dict['xform_data'] = list(data_dict['xform_data']) return DataGeometry(**data_dict) def _load_example_data(dataset): dataset_path = DATA_DIR.joinpath(dataset) if not dataset_path.is_file(): if not DATA_DIR.is_dir(): DATA_DIR.mkdir() _download_example_data(dataset_path) try: geo_data = pickle.loads(dataset_path.read_bytes()) except Exception as e: raise HypertoolsIOError( f"Failed to load '{dataset}' data. Try deleting cached file at" f"{dataset_path} and reloading." ) from e if dataset == 'mushrooms': # format mushrooms dataset as a pandas DataFrame geo_data.data = pd.DataFrame(geo_data.data) return geo_data def _download_example_data(dataset_path): file_id = EXAMPLE_DATA[dataset_path.name] session = requests.Session() params = {'id': file_id} try: response = session.get(BASE_URL, params=params, stream=True) for key, value in response.cookies.items(): if key.startswith('download_warning'): # Google Drive requires confirmation for large files params['confirm'] = value response = session.get(BASE_URL, params=params, stream=True) break with dataset_path.open('wb') as f: # write stream in chunks to avoid loading whole file into memory for chunk in response.iter_content(chunk_size=32768): if chunk: f.write(chunk) except Exception as e: # clean up partial file in case of error while writing stream dataset_path.unlink(missing_ok=True) raise HypertoolsIOError( f"Failed to download '{dataset_path.name}' dataset" ) from e