import requests
import pandas as pd
import deepdish as dd
import os
import pickle
import warnings
from .analyze import analyze
from ..datageometry import DataGeometry
BASE_URL = 'https://docs.google.com/uc?export=download'
homedir = os.path.expanduser('~/')
datadir = os.path.join(homedir, 'hypertools_data')
datadict = {
'weights' : '1-zzaUMHuXHSzFcGT4vNlqcV8tMY4q7jS',
'weights_avg' : '1v_IrU6n72nTOHwD3AnT2LKgtKfHINyXt',
'weights_sample' : '1CiVSP-8sjdQN_cdn3uCrBH5lNOkvgJp1',
'spiral' : '1JB4RIgNfzGaTFWRBCzi8CQ2syTE-BnWg',
'mushrooms' : '1wRXObmwLjSHPAUWC8QvUl37iY2qRObg8',
'wiki' : '1e5lCi17bLbOXuRjiGO2eqkEWVpeCuRvM',
'sotus' : '1D2dsrLAXkC3eUUaw2VV_mldzxX5ufmkm',
'nips' : '1Vva4Xcc5kUX78R0BKkLtdCWQx9GI-FG2',
'wiki_model' : '1OrN1F39GkMPjrB2bOTgNRT1pNBmsCQsN',
'nips_model' : '1orgxWJdWYzBlU3EF2u7EDsZrp3jTNNLG',
'sotus_model' : '1g2F18WLxfFosIqhiLs79G0MpiG72mWQr'
}
[docs]def load(dataset, reduce=None, ndims=None, align=None, normalize=None):
"""
Load a .geo file or example data
Parameters
----------
dataset : string
The name of the example dataset. Can be a `.geo` file, or one of a
number of example datasets listed below.
`weights` is list of 2 numpy arrays, each containing average brain
activity (fMRI) from 18 subjects listening to the same story, fit using
Hierarchical Topographic Factor Analysis (HTFA) with 100 nodes. The rows
are fMRI measurements and the columns are parameters of the model.
`weights_sample` is a sample of 3 subjects from that dataset.
`weights_avg` is the dataset split in half and averaged into two groups.
`spiral` is numpy array containing data for a 3D spiral, used to
highlight the `procrustes` function.
`mushrooms` is a numpy array comprised of features (columns) of a
collection of 8,124 mushroomm samples (rows).
`sotus` is a collection of State of the Union speeches from 1989-2018.
`wiki` is a collection of wikipedia pages used to fit wiki-model.
`wiki-model` is a sklearn Pipeline (CountVectorizer->LatentDirichletAllocation)
trained on a sample of wikipedia articles. It can be used to transform
text to topic vectors.
normalize : str or False or None
If set to 'across', the columns of the input data will be z-scored
across lists (default). That is, the z-scores will be computed with
with respect to column n across all arrays passed in the list. If set
to 'within', the columns will be z-scored within each list that is
passed. If set to 'row', each row of the input data will be z-scored.
If set to False, the input data will be returned with no z-scoring.
reduce : str or dict
Decomposition/manifold learning model to use. Models supported: PCA,
IncrementalPCA, SparsePCA, MiniBatchSparsePCA, KernelPCA, FastICA,
FactorAnalysis, TruncatedSVD, DictionaryLearning, MiniBatchDictionaryLearning,
TSNE, Isomap, SpectralEmbedding, LocallyLinearEmbedding, and MDS. Can be
passed as a string, but for finer control of the model parameters, pass
as a dictionary, e.g. reduce={'model' : 'PCA', 'params' : {'whiten' : True}}.
See scikit-learn specific model docs for details on parameters supported
for each model.
ndims : int
Number of dimensions to reduce
align : str or dict
If str, either 'hyper' or 'SRM'. If 'hyper', alignment algorithm will be
hyperalignment. If 'SRM', alignment algorithm will be shared response
model. You can also pass a dictionary for finer control, where the 'model'
key is a string that specifies the model and the params key is a dictionary
of parameter values (default : 'hyper').
Returns
----------
data : Numpy Array
Example data
"""
if dataset[-4:] == '.geo':
geo = dd.io.load(dataset)
if 'dtype' in geo:
if 'list' in geo['dtype']:
geo['data'] = list(geo['data'])
elif 'df' in geo['dtype']:
geo['data'] = pd.DataFrame(geo['data'])
geo['xform_data'] = list(geo['xform_data'])
data = DataGeometry(**geo)
elif dataset in datadict.keys():
data = _load_data(dataset, datadict[dataset])
else:
raise RuntimeError('No data loaded. Please specify a .geo file or '
'one of the following sample files: weights, '
'weights_avg, weights_sample, spiral, mushrooms, '
'wiki, nips or sotus.')
if data is not None:
if dataset in ('wiki_model', 'nips_model', 'sotus_model'):
return data
if isinstance(data, DataGeometry):
if any([reduce, ndims, align, normalize]):
from ..plot.plot import plot
if ndims:
if reduce is None:
reduce='IncrementalPCA'
d = analyze(data.get_data(), reduce=reduce, ndims=ndims, align=align, normalize=normalize)
return plot(d, show=False)
else:
return data
else:
return analyze(data, reduce=reduce, ndims=ndims, align=align, normalize=normalize)
def _load_data(dataset, fileid):
fullpath = os.path.join(homedir, 'hypertools_data', dataset)
if not os.path.exists(datadir):
os.makedirs(datadir)
if not os.path.exists(fullpath):
try:
_download(dataset, _load_stream(fileid))
data = _load_from_disk(dataset)
except:
raise ValueError('Download failed.')
else:
try:
data = _load_from_disk(dataset)
except:
try:
_download(dataset, _load_stream(fileid))
data = _load_from_disk(dataset)
except:
raise ValueError('Download failed. Try deleting cache data in'
' /Users/homedir/hypertools_data.')
return data
def _load_stream(fileid):
def _get_confirm_token(response):
for key, value in response.cookies.items():
if key.startswith('download_warning'):
return value
return None
url = BASE_URL + fileid
session = requests.Session()
response = session.get(BASE_URL, params = { 'id' : fileid }, stream = True)
token = _get_confirm_token(response)
if token:
params = { 'id' : fileid, 'confirm' : token }
response = session.get(BASE_URL, params = params, stream = True)
return response
def _download(dataset, data):
fullpath = os.path.join(homedir, 'hypertools_data', dataset)
with open(fullpath, 'wb') as f:
f.write(data.content)
def _load_from_disk(dataset):
fullpath = os.path.join(homedir, 'hypertools_data', dataset)
if dataset in ('wiki_model', 'nips_model', 'sotus_model',):
try:
with open(fullpath, 'rb') as f:
return pickle.load(f)
except ValueError as e:
print(e)
else:
with warnings.catch_warnings():
warnings.filterwarnings("ignore")
geo = dd.io.load(fullpath)
if 'dtype' in geo:
if 'list' in geo['dtype']:
geo['data'] = list(geo['data'])
elif 'df' in geo['dtype']:
geo['data'] = pd.DataFrame(geo['data'])
geo['xform_data'] = list(geo['xform_data'])
return DataGeometry(**geo)