Module blechpy.datastructures.experiment
Expand source code
import os
import shutil
from tqdm import tqdm
import multiprocessing
import numpy as np
import pandas as pd
from itertools import combinations
from blechpy import dio
from blechpy.datastructures.objects import data_object, load_dataset
from blechpy.utils import userIO, print_tools as pt, write_tools as wt, spike_sorting_GUI as ssg
from blechpy.analysis import held_unit_analysis as hua, blech_clustering as bclust
from blechpy.plotting import data_plot as dplt
from blechpy.utils.decorators import Logger
class experiment(data_object):
def __init__(self, exp_dir=None, exp_name=None, shell=False, order_dict=None):
'''Setup for analysis across recording sessions
Parameters
----------
exp_dir : str (optional)
path to directory containing all recording directories
if None (default) is passed then a popup to choose file
will come up
shell : bool (optional)
True to use command-line interface for user input
False (default) for GUI
'''
if 'SSH_CONNECTION' in os.environ:
shell = True
super().__init__('experiment', root_dir=exp_dir, data_name=exp_name, shell=shell)
fd = [os.path.join(exp_dir, x) for x in os.listdir(exp_dir)]
file_dirs = [x for x in fd if (os.path.isdir(x) and
dio.h5io.get_h5_filename(x) is not None)]
if len(file_dirs) == 0:
q = userIO.ask_user('No recording directories with h5 files found '
'in experiment directory\nContinue creating'
'empty experiment?', shell=shell)
if q == 0:
return
self.recording_dirs = file_dirs
self._order_dirs(shell=shell, order_dict=order_dict)
rec_names = [os.path.basename(x) for x in self.recording_dirs]
el_map = None
rec_labels = {}
for rd in self.recording_dirs:
dat = load_dataset(rd)
if dat is None:
raise FileNotFoundError('No dataset.p object found for in %s' % rd)
elif el_map is None:
el_map = dat.electrode_mapping.copy()
rec_labels[dat.data_name] = rd
self.rec_labels = rec_labels
self.electrode_mapping = el_map
self._setup_taste_map()
save_dir = os.path.join(self.root_dir, '%s_analysis' % self.data_name)
if not os.path.isdir(save_dir):
os.mkdir(save_dir)
self.analysis_dir = save_dir
self.save()
def _change_root(self, new_root=None):
old_root = self.root_dir
new_root = super()._change_root(new_root)
self.recording_dirs = [x.replace(old_root, new_root)
for x in self.recording_dirs]
self.rec_labels = {k: v.replace(old_root, new_root)
for k,v in self.rec_labels.items()}
return new_root
def __str__(self):
out = [super().__str__()]
out.append('Analysis Directory: %s' % self.analysis_dir)
out.append('Recording Directories :')
out.append(pt.print_dict(self.rec_labels, tabs=1))
out.append('\nTaste Mapping :')
out.append(pt.print_dict(self.taste_map, tabs=1))
out.append('\nElectrode Mapping\n----------------')
out.append(pt.print_dataframe(self.electrode_mapping))
if hasattr(self, 'held_units'):
out.append('\nHeld Units :')
out.append(pt.print_dataframe(self.held_units.drop(columns=['J3'])))
return '\n'.join(out)
def _order_dirs(self, shell=False, order_dict=None):
'''set order of redcording directories
'''
if 'SSH_CONNECTION' in os.environ:
shell = True
if self.recording_dirs == []:
return
if order_dict is None:
self.recording_dirs = [x[:-1] if x.endswith(os.sep) else x
for x in self.recording_dirs]
top_dirs = {os.path.basename(x): os.path.dirname(x)
for x in self.recording_dirs}
file_dirs = list(top_dirs.keys())
order_dict = dict.fromkeys(file_dirs, 0)
tmp = userIO.dictIO(order_dict, shell=shell)
order_dict = userIO.fill_dict(order_dict,
('Set order of recordings (1-%i)\n'
'Leave blank to delete directory'
' from list') % len(file_dirs),
shell)
if order_dict is None:
return
file_dirs = [k for k, v in order_dict.items()
if v is not None and v != 0]
file_dirs = sorted(file_dirs, key=order_dict.get)
file_dirs = [os.path.join(top_dirs.get(x), x) for x in file_dirs]
else:
file_dirs = sorted(self.recording_dirs, key=order_dict.get)
self.recording_dirs = file_dirs
def _setup_taste_map(self):
rec_dirs = self.recording_dirs
rec_labels = self.rec_labels
tastants = []
for rd in rec_dirs:
dat = load_dataset(rd)
tmp = dat.dig_in_mapping
tastants.extend(tmp['name'].to_list())
tastants = np.unique(tastants)
taste_map = {}
for rl, rd in rec_labels.items():
dat = load_dataset(rd)
din = dat.dig_in_mapping
for t in tastants:
if taste_map.get(t) is None:
taste_map[t] = {}
tmp = din['channel'][din['name'] == t]
if not tmp.empty:
taste_map[t][rl] = tmp.values[0]
self.taste_map = taste_map
def add_recording(self, new_dir=None, shell=None):
'''Add recording directory to experiment
Parameters
----------
new_dir : str (optional)
full path to new directory to add to recording dirs
shell : bool (optional)
True for command-line interface for user input
False (default) for GUI
If not passed then the preference set upon object creation is used
'''
if 'SSH_CONNECTION' in os.environ:
shell = True
elif shell is None:
shell = False
if new_dir is None:
new_dir = userIO.get_filedirs('Select recoring directory',
root=self.root_dir, shell=shell)
if not os.path.isdir(new_dir):
raise NotADirectoryError('%s must be a valid directory' % new_dir)
if not any([x.endswith('.h5') for x in os.listdir(new_dir)]):
raise FileNotFoundError('No .h5 file found in %s' % new_dir)
if not any([x.endswith('dataset.p') for x in os.listdir(new_dir)]):
raise FileNotFoundError('*_dataset.p file not found in %s' % new_dir)
if new_dir.endswith('/'):
new_dir = new_dir[:-1]
label = userIO.get_user_input('Enter label for recording %s' %
os.path.basename(new_dir), shell=shell)
self.recording_dirs.append(new_dir)
self.rec_labels[label] = new_dir
self._order_dirs(shell=shell)
self._setup_taste_map()
print('Added recording: %s')
self.save()
def remove_recording(self, rec_dir=None, shell=None):
'''Remove recording directory from experiment
Parameters
----------
rec_dir : str (optional)
full path or label of the directory to remove
shell : bool (optional)
True for command-line interface for user input. Default for SSH
False (default) for GUI
If not passed then the preference set upon object creation is used
'''
if 'SSH_CONNECTION' in os.environ:
shell = True
elif shell is None:
shell = False
if rec_dir is None:
rec_dir = userIO.select_from_list('Choose recording to remove\n'
'Leave blank to cancel',
list(self.rec_labels.keys(())),
shell=shell)
if rec_dir is None:
return
if os.path.isabs(rec_dir):
if rec_dir.endswith('/'):
rec_dir = rec_dir[:-1]
idx = list(self.rec_labels.values()).index(rec_dir) # throws ValueError
key = list(self.rec_labels.keys())[idx]
else:
key = rec_dir
rec_dir = self.rec_labels.get(key)
if rec_dir is None:
raise ValueError('%s is not in recording dirs' % key)
self.rec_labels.pop(key)
self.recording_dirs.pop(rec_dir)
self._setup_taste_map()
print('Removed recording: %s' % rec_dir)
self.save()
@Logger('Detecting held units')
def detect_held_units(self, percent_criterion=95, raw_waves=False, shell=False):
'''Determine which units are held across recording sessions
Grabs single units from each recording and compares consecutive
recordings to determine if units were held
Parameters
----------
percent_criterion : float
percentile (0-100) of intra_J3 below which to accept units as held
5.0 (default) for 95th percentile
lower number is stricter criteria
shell : bool (optional)
True for command-line interface for user input
False (default) for GUI
'''
if 'SSH_CONNECTION' in os.environ:
shell = True
save_dir = os.path.join(self.analysis_dir, 'held_unit_detection')
if os.path.isdir(save_dir):
shutil.rmtree(save_dir)
os.mkdir(save_dir)
rec_dirs = self.recording_dirs
rec_labels = self.rec_labels
rec_names = list(rec_labels.keys())
print('Detecting held units for :')
print('\t' + '\n\t'.join(rec_names))
print('Saving output to : %s' % save_dir)
held_df, intra_J3, inter_J3 = hua.find_held_units(rec_dirs,
percent_criterion,
raw_waves=raw_waves)
rl_dict = {os.path.basename(v) : k for k, v in self.rec_labels.items()}
held_df = held_df.rename(columns=rl_dict)
em = self.electrode_mapping
held_df = held_df.apply(lambda x: self._assign_area(x), axis=1)
self.held_units = held_df
self.J3_values = {'intra_J3': intra_J3,
'inter_J3': inter_J3}
# Write dataframe of held units to text file
df_file = os.path.join(save_dir, 'held_units_table.txt')
json_file = os.path.join(save_dir, 'held_units.json')
held_df.to_json(json_file, orient='records', lines=True)
# Print table of held unti to text file, separate tables by pairs of recordings
rec_pairs = [(rec_names[i], rec_names[i+1]) for i in range(len(rec_names) - 1)]
with open(df_file, 'w') as f:
for rec1, rec2 in rec_pairs:
tmp_df = held_df.copy()
exc = [x for x in rec_names if x not in [rec1, rec2]]
tmp_df = tmp_df.drop(columns=['J3', *exc]).dropna()
print('Units held from %s to %s\n----------' % (rec1, rec2), file=f)
print(pt.print_dataframe(tmp_df, tabs=1), file=f)
print('', file=f)
np.save(os.path.join(save_dir, 'intra_J3'), np.array(intra_J3))
np.save(os.path.join(save_dir, 'inter_J3'), np.array(inter_J3))
# For each held unit, plot waveforms side by side
dplt.plot_held_units(rec_dirs, held_df, save_dir, rec_names=rec_names)
# Plot intra and inter J3
dplt.plot_J3s(intra_J3, inter_J3, save_dir, percent_criterion)
self.save()
def _assign_area(self, row):
data_dir = self.root_dir
em = self.electrode_mapping
rec = None
unit = None
for k, v in self.rec_labels.items():
if not pd.isna(row[k]):
rec = v
unit = row[k]
break
if rec is None:
row['area'] = float('nan')
row['electrode'] = float('nan')
return row
unit_num = dio.h5io.parse_unit_number(unit)
descrip = dio.h5io.get_unit_descriptor(rec, unit_num)
electrode = descrip['electrode_number']
area = em.query('Electrode == @electrode')['area'].values[0]
row['electrode'] = electrode
row['area'] = area
return row
@Logger('Running Spike Clustering')
def cluster_spikes(self, data_quality=None, multi_process=True,
n_cores=None, custom_params=None, umap=False):
'''Write clustering parameters to file and
Run blech_process on each electrode using GNU parallel
Parameters
----------
data_quality : {'clean', 'noisy', None (default)}
set if you want to change the data quality parameters for cutoff
and spike detection before running clustering. These parameters are
automatically set as "clean" during initial parameter setup
accept_params : bool, False (default)
set to True in order to skip popup confirmation of parameters when
running
'''
clustering_params = None
if custom_params:
clustering_params = custom_params
elif data_quality:
tmp = dio.params.load_params('clustering_params', self.root_dir,
default_keyword=data_quality)
if tmp:
clustering_params = tmp
else:
raise ValueError('%s is not a valid data_quality preset. Must '
'be "clean" or "noisy" or None.')
# Get electrodes, throw out 'dead' electrodes
em = self.electrode_mapping
if 'dead' in em.columns:
electrodes = em.Electrode[em['dead'] == False].tolist()
else:
electrodes = em.Electrode.tolist()
# Setup progress bar
pbar = tqdm(total = len(electrodes))
def update_pbar(ans):
pbar.update()
# get clustering params
rec_dirs = list(self.rec_labels.values())
if clustering_params is None:
dat = load_dataset(rec_dirs[0])
clustering_params = dat.clustering_params.copy()
print('\nRunning Blech Clust\n-------------------')
print('Parameters\n%s' % pt.print_dict(clustering_params))
# Write clustering params to recording directories & check for spike detection
spike_detect = True
for rd in rec_dirs:
dat = load_dataset(rd)
if dat.process_status['spike_detection'] == False:
raise FileNotFoundError('Spike detection has not been run on %s' % rd)
dat.clustering_params = clustering_params
wt.write_params_to_json('clustering_params', rd, clustering_params)
# dat.save()
# Run clustering
if not umap:
clust_objs = [bclust.BlechClust(rec_dirs, x, params=clustering_params)
for x in electrodes]
else:
clust_objs = [bclust.BlechClust(rec_dirs, x,
params=clustering_params,
data_transform=bclust.UMAP_METRICS,
n_pc=5)
for x in electrodes]
if multi_process:
if n_cores is None or n_cores > multiprocessing.cpu_count():
n_cores = multiprocessing.cpu_count() - 1
pool = multiprocessing.get_context('spawn').Pool(n_cores)
for x in clust_objs:
pool.apply_async(x.run, callback=update_pbar)
pool.close()
pool.join()
else:
for x in clust_objs:
res = x.run()
update_pbar(res)
pbar.close()
for rd in rec_dirs:
dat = load_dataset(rd)
dat.process_status['spike_clustering'] = True
dat.process_status['cleanup_clustering'] = False
# dat.save()
# self.save()
print('Clustering Complete\n------------------')
def sort_spikes(self, electrode=None, shell=False):
if electrode is None:
electrode = userIO.get_user_input('Electrode #: ', shell=shell)
if electrode is None or not electrode.isnumeric():
return
electrode = int(electrode)
rec_dirs = list(self.rec_labels.values())
for rd in rec_dirs:
dat = load_dataset(rd)
if not dat.process_status['cleanup_clustering']:
dat.cleanup_clustering()
dat.process_status['sort_units'] = True
sorter = bclust.SpikeSorter(rec_dirs, electrode, shell=shell)
if not shell:
root, sorter_GUI = ssg.launch_sorter_GUI(sorter)
return root, sorter_GUI
else:
print('No shell UI yet')
return
Classes
class experiment (exp_dir=None, exp_name=None, shell=False, order_dict=None)
-
Setup for analysis across recording sessions
Parameters
exp_dir
:str (optional)
- path to directory containing all recording directories if None (default) is passed then a popup to choose file will come up
shell
:bool (optional)
- True to use command-line interface for user input False (default) for GUI
Expand source code
class experiment(data_object): def __init__(self, exp_dir=None, exp_name=None, shell=False, order_dict=None): '''Setup for analysis across recording sessions Parameters ---------- exp_dir : str (optional) path to directory containing all recording directories if None (default) is passed then a popup to choose file will come up shell : bool (optional) True to use command-line interface for user input False (default) for GUI ''' if 'SSH_CONNECTION' in os.environ: shell = True super().__init__('experiment', root_dir=exp_dir, data_name=exp_name, shell=shell) fd = [os.path.join(exp_dir, x) for x in os.listdir(exp_dir)] file_dirs = [x for x in fd if (os.path.isdir(x) and dio.h5io.get_h5_filename(x) is not None)] if len(file_dirs) == 0: q = userIO.ask_user('No recording directories with h5 files found ' 'in experiment directory\nContinue creating' 'empty experiment?', shell=shell) if q == 0: return self.recording_dirs = file_dirs self._order_dirs(shell=shell, order_dict=order_dict) rec_names = [os.path.basename(x) for x in self.recording_dirs] el_map = None rec_labels = {} for rd in self.recording_dirs: dat = load_dataset(rd) if dat is None: raise FileNotFoundError('No dataset.p object found for in %s' % rd) elif el_map is None: el_map = dat.electrode_mapping.copy() rec_labels[dat.data_name] = rd self.rec_labels = rec_labels self.electrode_mapping = el_map self._setup_taste_map() save_dir = os.path.join(self.root_dir, '%s_analysis' % self.data_name) if not os.path.isdir(save_dir): os.mkdir(save_dir) self.analysis_dir = save_dir self.save() def _change_root(self, new_root=None): old_root = self.root_dir new_root = super()._change_root(new_root) self.recording_dirs = [x.replace(old_root, new_root) for x in self.recording_dirs] self.rec_labels = {k: v.replace(old_root, new_root) for k,v in self.rec_labels.items()} return new_root def __str__(self): out = [super().__str__()] out.append('Analysis Directory: %s' % self.analysis_dir) out.append('Recording Directories :') out.append(pt.print_dict(self.rec_labels, tabs=1)) out.append('\nTaste Mapping :') out.append(pt.print_dict(self.taste_map, tabs=1)) out.append('\nElectrode Mapping\n----------------') out.append(pt.print_dataframe(self.electrode_mapping)) if hasattr(self, 'held_units'): out.append('\nHeld Units :') out.append(pt.print_dataframe(self.held_units.drop(columns=['J3']))) return '\n'.join(out) def _order_dirs(self, shell=False, order_dict=None): '''set order of redcording directories ''' if 'SSH_CONNECTION' in os.environ: shell = True if self.recording_dirs == []: return if order_dict is None: self.recording_dirs = [x[:-1] if x.endswith(os.sep) else x for x in self.recording_dirs] top_dirs = {os.path.basename(x): os.path.dirname(x) for x in self.recording_dirs} file_dirs = list(top_dirs.keys()) order_dict = dict.fromkeys(file_dirs, 0) tmp = userIO.dictIO(order_dict, shell=shell) order_dict = userIO.fill_dict(order_dict, ('Set order of recordings (1-%i)\n' 'Leave blank to delete directory' ' from list') % len(file_dirs), shell) if order_dict is None: return file_dirs = [k for k, v in order_dict.items() if v is not None and v != 0] file_dirs = sorted(file_dirs, key=order_dict.get) file_dirs = [os.path.join(top_dirs.get(x), x) for x in file_dirs] else: file_dirs = sorted(self.recording_dirs, key=order_dict.get) self.recording_dirs = file_dirs def _setup_taste_map(self): rec_dirs = self.recording_dirs rec_labels = self.rec_labels tastants = [] for rd in rec_dirs: dat = load_dataset(rd) tmp = dat.dig_in_mapping tastants.extend(tmp['name'].to_list()) tastants = np.unique(tastants) taste_map = {} for rl, rd in rec_labels.items(): dat = load_dataset(rd) din = dat.dig_in_mapping for t in tastants: if taste_map.get(t) is None: taste_map[t] = {} tmp = din['channel'][din['name'] == t] if not tmp.empty: taste_map[t][rl] = tmp.values[0] self.taste_map = taste_map def add_recording(self, new_dir=None, shell=None): '''Add recording directory to experiment Parameters ---------- new_dir : str (optional) full path to new directory to add to recording dirs shell : bool (optional) True for command-line interface for user input False (default) for GUI If not passed then the preference set upon object creation is used ''' if 'SSH_CONNECTION' in os.environ: shell = True elif shell is None: shell = False if new_dir is None: new_dir = userIO.get_filedirs('Select recoring directory', root=self.root_dir, shell=shell) if not os.path.isdir(new_dir): raise NotADirectoryError('%s must be a valid directory' % new_dir) if not any([x.endswith('.h5') for x in os.listdir(new_dir)]): raise FileNotFoundError('No .h5 file found in %s' % new_dir) if not any([x.endswith('dataset.p') for x in os.listdir(new_dir)]): raise FileNotFoundError('*_dataset.p file not found in %s' % new_dir) if new_dir.endswith('/'): new_dir = new_dir[:-1] label = userIO.get_user_input('Enter label for recording %s' % os.path.basename(new_dir), shell=shell) self.recording_dirs.append(new_dir) self.rec_labels[label] = new_dir self._order_dirs(shell=shell) self._setup_taste_map() print('Added recording: %s') self.save() def remove_recording(self, rec_dir=None, shell=None): '''Remove recording directory from experiment Parameters ---------- rec_dir : str (optional) full path or label of the directory to remove shell : bool (optional) True for command-line interface for user input. Default for SSH False (default) for GUI If not passed then the preference set upon object creation is used ''' if 'SSH_CONNECTION' in os.environ: shell = True elif shell is None: shell = False if rec_dir is None: rec_dir = userIO.select_from_list('Choose recording to remove\n' 'Leave blank to cancel', list(self.rec_labels.keys(())), shell=shell) if rec_dir is None: return if os.path.isabs(rec_dir): if rec_dir.endswith('/'): rec_dir = rec_dir[:-1] idx = list(self.rec_labels.values()).index(rec_dir) # throws ValueError key = list(self.rec_labels.keys())[idx] else: key = rec_dir rec_dir = self.rec_labels.get(key) if rec_dir is None: raise ValueError('%s is not in recording dirs' % key) self.rec_labels.pop(key) self.recording_dirs.pop(rec_dir) self._setup_taste_map() print('Removed recording: %s' % rec_dir) self.save() @Logger('Detecting held units') def detect_held_units(self, percent_criterion=95, raw_waves=False, shell=False): '''Determine which units are held across recording sessions Grabs single units from each recording and compares consecutive recordings to determine if units were held Parameters ---------- percent_criterion : float percentile (0-100) of intra_J3 below which to accept units as held 5.0 (default) for 95th percentile lower number is stricter criteria shell : bool (optional) True for command-line interface for user input False (default) for GUI ''' if 'SSH_CONNECTION' in os.environ: shell = True save_dir = os.path.join(self.analysis_dir, 'held_unit_detection') if os.path.isdir(save_dir): shutil.rmtree(save_dir) os.mkdir(save_dir) rec_dirs = self.recording_dirs rec_labels = self.rec_labels rec_names = list(rec_labels.keys()) print('Detecting held units for :') print('\t' + '\n\t'.join(rec_names)) print('Saving output to : %s' % save_dir) held_df, intra_J3, inter_J3 = hua.find_held_units(rec_dirs, percent_criterion, raw_waves=raw_waves) rl_dict = {os.path.basename(v) : k for k, v in self.rec_labels.items()} held_df = held_df.rename(columns=rl_dict) em = self.electrode_mapping held_df = held_df.apply(lambda x: self._assign_area(x), axis=1) self.held_units = held_df self.J3_values = {'intra_J3': intra_J3, 'inter_J3': inter_J3} # Write dataframe of held units to text file df_file = os.path.join(save_dir, 'held_units_table.txt') json_file = os.path.join(save_dir, 'held_units.json') held_df.to_json(json_file, orient='records', lines=True) # Print table of held unti to text file, separate tables by pairs of recordings rec_pairs = [(rec_names[i], rec_names[i+1]) for i in range(len(rec_names) - 1)] with open(df_file, 'w') as f: for rec1, rec2 in rec_pairs: tmp_df = held_df.copy() exc = [x for x in rec_names if x not in [rec1, rec2]] tmp_df = tmp_df.drop(columns=['J3', *exc]).dropna() print('Units held from %s to %s\n----------' % (rec1, rec2), file=f) print(pt.print_dataframe(tmp_df, tabs=1), file=f) print('', file=f) np.save(os.path.join(save_dir, 'intra_J3'), np.array(intra_J3)) np.save(os.path.join(save_dir, 'inter_J3'), np.array(inter_J3)) # For each held unit, plot waveforms side by side dplt.plot_held_units(rec_dirs, held_df, save_dir, rec_names=rec_names) # Plot intra and inter J3 dplt.plot_J3s(intra_J3, inter_J3, save_dir, percent_criterion) self.save() def _assign_area(self, row): data_dir = self.root_dir em = self.electrode_mapping rec = None unit = None for k, v in self.rec_labels.items(): if not pd.isna(row[k]): rec = v unit = row[k] break if rec is None: row['area'] = float('nan') row['electrode'] = float('nan') return row unit_num = dio.h5io.parse_unit_number(unit) descrip = dio.h5io.get_unit_descriptor(rec, unit_num) electrode = descrip['electrode_number'] area = em.query('Electrode == @electrode')['area'].values[0] row['electrode'] = electrode row['area'] = area return row @Logger('Running Spike Clustering') def cluster_spikes(self, data_quality=None, multi_process=True, n_cores=None, custom_params=None, umap=False): '''Write clustering parameters to file and Run blech_process on each electrode using GNU parallel Parameters ---------- data_quality : {'clean', 'noisy', None (default)} set if you want to change the data quality parameters for cutoff and spike detection before running clustering. These parameters are automatically set as "clean" during initial parameter setup accept_params : bool, False (default) set to True in order to skip popup confirmation of parameters when running ''' clustering_params = None if custom_params: clustering_params = custom_params elif data_quality: tmp = dio.params.load_params('clustering_params', self.root_dir, default_keyword=data_quality) if tmp: clustering_params = tmp else: raise ValueError('%s is not a valid data_quality preset. Must ' 'be "clean" or "noisy" or None.') # Get electrodes, throw out 'dead' electrodes em = self.electrode_mapping if 'dead' in em.columns: electrodes = em.Electrode[em['dead'] == False].tolist() else: electrodes = em.Electrode.tolist() # Setup progress bar pbar = tqdm(total = len(electrodes)) def update_pbar(ans): pbar.update() # get clustering params rec_dirs = list(self.rec_labels.values()) if clustering_params is None: dat = load_dataset(rec_dirs[0]) clustering_params = dat.clustering_params.copy() print('\nRunning Blech Clust\n-------------------') print('Parameters\n%s' % pt.print_dict(clustering_params)) # Write clustering params to recording directories & check for spike detection spike_detect = True for rd in rec_dirs: dat = load_dataset(rd) if dat.process_status['spike_detection'] == False: raise FileNotFoundError('Spike detection has not been run on %s' % rd) dat.clustering_params = clustering_params wt.write_params_to_json('clustering_params', rd, clustering_params) # dat.save() # Run clustering if not umap: clust_objs = [bclust.BlechClust(rec_dirs, x, params=clustering_params) for x in electrodes] else: clust_objs = [bclust.BlechClust(rec_dirs, x, params=clustering_params, data_transform=bclust.UMAP_METRICS, n_pc=5) for x in electrodes] if multi_process: if n_cores is None or n_cores > multiprocessing.cpu_count(): n_cores = multiprocessing.cpu_count() - 1 pool = multiprocessing.get_context('spawn').Pool(n_cores) for x in clust_objs: pool.apply_async(x.run, callback=update_pbar) pool.close() pool.join() else: for x in clust_objs: res = x.run() update_pbar(res) pbar.close() for rd in rec_dirs: dat = load_dataset(rd) dat.process_status['spike_clustering'] = True dat.process_status['cleanup_clustering'] = False # dat.save() # self.save() print('Clustering Complete\n------------------') def sort_spikes(self, electrode=None, shell=False): if electrode is None: electrode = userIO.get_user_input('Electrode #: ', shell=shell) if electrode is None or not electrode.isnumeric(): return electrode = int(electrode) rec_dirs = list(self.rec_labels.values()) for rd in rec_dirs: dat = load_dataset(rd) if not dat.process_status['cleanup_clustering']: dat.cleanup_clustering() dat.process_status['sort_units'] = True sorter = bclust.SpikeSorter(rec_dirs, electrode, shell=shell) if not shell: root, sorter_GUI = ssg.launch_sorter_GUI(sorter) return root, sorter_GUI else: print('No shell UI yet') return
Ancestors
Methods
def add_recording(self, new_dir=None, shell=None)
-
Add recording directory to experiment
Parameters
new_dir
:str (optional)
- full path to new directory to add to recording dirs
shell
:bool (optional)
- True for command-line interface for user input False (default) for GUI If not passed then the preference set upon object creation is used
Expand source code
def add_recording(self, new_dir=None, shell=None): '''Add recording directory to experiment Parameters ---------- new_dir : str (optional) full path to new directory to add to recording dirs shell : bool (optional) True for command-line interface for user input False (default) for GUI If not passed then the preference set upon object creation is used ''' if 'SSH_CONNECTION' in os.environ: shell = True elif shell is None: shell = False if new_dir is None: new_dir = userIO.get_filedirs('Select recoring directory', root=self.root_dir, shell=shell) if not os.path.isdir(new_dir): raise NotADirectoryError('%s must be a valid directory' % new_dir) if not any([x.endswith('.h5') for x in os.listdir(new_dir)]): raise FileNotFoundError('No .h5 file found in %s' % new_dir) if not any([x.endswith('dataset.p') for x in os.listdir(new_dir)]): raise FileNotFoundError('*_dataset.p file not found in %s' % new_dir) if new_dir.endswith('/'): new_dir = new_dir[:-1] label = userIO.get_user_input('Enter label for recording %s' % os.path.basename(new_dir), shell=shell) self.recording_dirs.append(new_dir) self.rec_labels[label] = new_dir self._order_dirs(shell=shell) self._setup_taste_map() print('Added recording: %s') self.save()
def cluster_spikes(self, data_quality=None, multi_process=True, n_cores=None, custom_params=None, umap=False)
-
Write clustering parameters to file and Run blech_process on each electrode using GNU parallel
Parameters
data_quality
:{'clean', 'noisy', None (default)}
- set if you want to change the data quality parameters for cutoff and spike detection before running clustering. These parameters are automatically set as "clean" during initial parameter setup
accept_params
:bool, False (default)
- set to True in order to skip popup confirmation of parameters when running
Expand source code
@Logger('Running Spike Clustering') def cluster_spikes(self, data_quality=None, multi_process=True, n_cores=None, custom_params=None, umap=False): '''Write clustering parameters to file and Run blech_process on each electrode using GNU parallel Parameters ---------- data_quality : {'clean', 'noisy', None (default)} set if you want to change the data quality parameters for cutoff and spike detection before running clustering. These parameters are automatically set as "clean" during initial parameter setup accept_params : bool, False (default) set to True in order to skip popup confirmation of parameters when running ''' clustering_params = None if custom_params: clustering_params = custom_params elif data_quality: tmp = dio.params.load_params('clustering_params', self.root_dir, default_keyword=data_quality) if tmp: clustering_params = tmp else: raise ValueError('%s is not a valid data_quality preset. Must ' 'be "clean" or "noisy" or None.') # Get electrodes, throw out 'dead' electrodes em = self.electrode_mapping if 'dead' in em.columns: electrodes = em.Electrode[em['dead'] == False].tolist() else: electrodes = em.Electrode.tolist() # Setup progress bar pbar = tqdm(total = len(electrodes)) def update_pbar(ans): pbar.update() # get clustering params rec_dirs = list(self.rec_labels.values()) if clustering_params is None: dat = load_dataset(rec_dirs[0]) clustering_params = dat.clustering_params.copy() print('\nRunning Blech Clust\n-------------------') print('Parameters\n%s' % pt.print_dict(clustering_params)) # Write clustering params to recording directories & check for spike detection spike_detect = True for rd in rec_dirs: dat = load_dataset(rd) if dat.process_status['spike_detection'] == False: raise FileNotFoundError('Spike detection has not been run on %s' % rd) dat.clustering_params = clustering_params wt.write_params_to_json('clustering_params', rd, clustering_params) # dat.save() # Run clustering if not umap: clust_objs = [bclust.BlechClust(rec_dirs, x, params=clustering_params) for x in electrodes] else: clust_objs = [bclust.BlechClust(rec_dirs, x, params=clustering_params, data_transform=bclust.UMAP_METRICS, n_pc=5) for x in electrodes] if multi_process: if n_cores is None or n_cores > multiprocessing.cpu_count(): n_cores = multiprocessing.cpu_count() - 1 pool = multiprocessing.get_context('spawn').Pool(n_cores) for x in clust_objs: pool.apply_async(x.run, callback=update_pbar) pool.close() pool.join() else: for x in clust_objs: res = x.run() update_pbar(res) pbar.close() for rd in rec_dirs: dat = load_dataset(rd) dat.process_status['spike_clustering'] = True dat.process_status['cleanup_clustering'] = False # dat.save() # self.save() print('Clustering Complete\n------------------')
def detect_held_units(self, percent_criterion=95, raw_waves=False, shell=False)
-
Determine which units are held across recording sessions Grabs single units from each recording and compares consecutive recordings to determine if units were held
Parameters
percent_criterion
:float
- percentile (0-100) of intra_J3 below which to accept units as held 5.0 (default) for 95th percentile lower number is stricter criteria
shell
:bool (optional)
- True for command-line interface for user input False (default) for GUI
Expand source code
@Logger('Detecting held units') def detect_held_units(self, percent_criterion=95, raw_waves=False, shell=False): '''Determine which units are held across recording sessions Grabs single units from each recording and compares consecutive recordings to determine if units were held Parameters ---------- percent_criterion : float percentile (0-100) of intra_J3 below which to accept units as held 5.0 (default) for 95th percentile lower number is stricter criteria shell : bool (optional) True for command-line interface for user input False (default) for GUI ''' if 'SSH_CONNECTION' in os.environ: shell = True save_dir = os.path.join(self.analysis_dir, 'held_unit_detection') if os.path.isdir(save_dir): shutil.rmtree(save_dir) os.mkdir(save_dir) rec_dirs = self.recording_dirs rec_labels = self.rec_labels rec_names = list(rec_labels.keys()) print('Detecting held units for :') print('\t' + '\n\t'.join(rec_names)) print('Saving output to : %s' % save_dir) held_df, intra_J3, inter_J3 = hua.find_held_units(rec_dirs, percent_criterion, raw_waves=raw_waves) rl_dict = {os.path.basename(v) : k for k, v in self.rec_labels.items()} held_df = held_df.rename(columns=rl_dict) em = self.electrode_mapping held_df = held_df.apply(lambda x: self._assign_area(x), axis=1) self.held_units = held_df self.J3_values = {'intra_J3': intra_J3, 'inter_J3': inter_J3} # Write dataframe of held units to text file df_file = os.path.join(save_dir, 'held_units_table.txt') json_file = os.path.join(save_dir, 'held_units.json') held_df.to_json(json_file, orient='records', lines=True) # Print table of held unti to text file, separate tables by pairs of recordings rec_pairs = [(rec_names[i], rec_names[i+1]) for i in range(len(rec_names) - 1)] with open(df_file, 'w') as f: for rec1, rec2 in rec_pairs: tmp_df = held_df.copy() exc = [x for x in rec_names if x not in [rec1, rec2]] tmp_df = tmp_df.drop(columns=['J3', *exc]).dropna() print('Units held from %s to %s\n----------' % (rec1, rec2), file=f) print(pt.print_dataframe(tmp_df, tabs=1), file=f) print('', file=f) np.save(os.path.join(save_dir, 'intra_J3'), np.array(intra_J3)) np.save(os.path.join(save_dir, 'inter_J3'), np.array(inter_J3)) # For each held unit, plot waveforms side by side dplt.plot_held_units(rec_dirs, held_df, save_dir, rec_names=rec_names) # Plot intra and inter J3 dplt.plot_J3s(intra_J3, inter_J3, save_dir, percent_criterion) self.save()
def remove_recording(self, rec_dir=None, shell=None)
-
Remove recording directory from experiment
Parameters
rec_dir
:str (optional)
- full path or label of the directory to remove
shell
:bool (optional)
- True for command-line interface for user input. Default for SSH False (default) for GUI If not passed then the preference set upon object creation is used
Expand source code
def remove_recording(self, rec_dir=None, shell=None): '''Remove recording directory from experiment Parameters ---------- rec_dir : str (optional) full path or label of the directory to remove shell : bool (optional) True for command-line interface for user input. Default for SSH False (default) for GUI If not passed then the preference set upon object creation is used ''' if 'SSH_CONNECTION' in os.environ: shell = True elif shell is None: shell = False if rec_dir is None: rec_dir = userIO.select_from_list('Choose recording to remove\n' 'Leave blank to cancel', list(self.rec_labels.keys(())), shell=shell) if rec_dir is None: return if os.path.isabs(rec_dir): if rec_dir.endswith('/'): rec_dir = rec_dir[:-1] idx = list(self.rec_labels.values()).index(rec_dir) # throws ValueError key = list(self.rec_labels.keys())[idx] else: key = rec_dir rec_dir = self.rec_labels.get(key) if rec_dir is None: raise ValueError('%s is not in recording dirs' % key) self.rec_labels.pop(key) self.recording_dirs.pop(rec_dir) self._setup_taste_map() print('Removed recording: %s' % rec_dir) self.save()
def sort_spikes(self, electrode=None, shell=False)
-
Expand source code
def sort_spikes(self, electrode=None, shell=False): if electrode is None: electrode = userIO.get_user_input('Electrode #: ', shell=shell) if electrode is None or not electrode.isnumeric(): return electrode = int(electrode) rec_dirs = list(self.rec_labels.values()) for rd in rec_dirs: dat = load_dataset(rd) if not dat.process_status['cleanup_clustering']: dat.cleanup_clustering() dat.process_status['sort_units'] = True sorter = bclust.SpikeSorter(rec_dirs, electrode, shell=shell) if not shell: root, sorter_GUI = ssg.launch_sorter_GUI(sorter) return root, sorter_GUI else: print('No shell UI yet') return