Module blechpy.datastructures.dataset
Expand source code
import pandas as pd
import datetime as dt
import pickle
import os
import shutil
import sys
import multiprocessing
import subprocess
from tqdm import tqdm
from copy import deepcopy
from blechpy.utils import print_tools as pt, write_tools as wt, userIO
from blechpy.utils.decorators import Logger
from blechpy.analysis import palatability_analysis as pal_analysis
from blechpy.analysis import spike_sorting as ss, spike_analysis, circus_interface as circ
from blechpy.analysis import blech_clustering as clust
from blechpy.plotting import palatability_plot as pal_plt, data_plot as datplt
from blechpy import dio
from blechpy.datastructures.objects import data_object
from blechpy.utils import spike_sorting_GUI as ssg
class dataset(data_object):
    '''Stores information related to an intan recording directory, allows
    executing basic processing and analysis scripts, and stores parameters data
    for those analyses
    Parameters
    ----------
    file_dir : str (optional)
        absolute path to a recording directory, if left empty a filechooser
        will popup
    '''
    PROCESSING_STEPS = ['initialize parameters',
                        'extract_data', 'create_trial_list',
                        'mark_dead_channels',
                        'common_average_reference', 'spike_detection',
                        'spike_clustering', 'cleanup_clustering',
                        'sort_units', 'make_unit_plots',
                        'units_similarity', 'make_unit_arrays',
                        'make_psth_arrays', 'plot_psths',
                        'palatability_calculate', 'palatability_plot',
                        'overlay_psth']
    def __init__(self, file_dir=None, data_name=None, shell=False):
        '''Initialize dataset object from file_dir, grabs basename from name of
        directory and initializes basic analysis parameters
        Parameters
        ----------
        file_dir : str (optional), file directory for intan recording data
        Throws
        ------
        ValueError
            if file_dir is not provided and no directory is chosen
            when prompted
        NotADirectoryError : if file_dir does not exist
        '''
        super().__init__('dataset', file_dir, data_name=data_name, shell=shell)
        h5_file = dio.h5io.get_h5_filename(self.root_dir)
        if h5_file is None:
            h5_file = os.path.join(self.root_dir, '%s.h5' % self.data_name)
        self.h5_file = h5_file
        self.dataset_creation_date = dt.datetime.today()
        # Outline standard processing pipeline and status check
        self.processing_steps = dataset.PROCESSING_STEPS.copy()
        self.process_status = dict.fromkeys(self.processing_steps, False)
    def _change_root(self, new_root=None):
        old_root = self.root_dir
        new_root = super()._change_root(new_root)
        self.h5_file = self.h5_file.replace(old_root, new_root)
        return new_root
    @Logger('Initializing Parameters')
    def initParams(self, data_quality='clean', emg_port=None,
                   emg_channels=None, car_keyword=None,
                   car_group_areas=None,
                   shell=False, dig_in_names=None,
                   dig_out_names=None, accept_params=False):
        '''
        Initalizes basic default analysis parameters and allows customization
        of parameters
        Parameters (all optional)
        -------------------------
        data_quality : {'clean', 'noisy'}
            keyword defining which default set of parameters to use to detect
            headstage disconnection during clustering
            default is 'clean'. Best practice is to run blech_clust as 'clean'
            and re-run as 'noisy' if too many early cutoffs occurr
        emg_port : str
            Port ('A', 'B', 'C') of EMG, if there was an EMG. None (default)
            will query user. False indicates no EMG port and not to query user
        emg_channels : list of int
            channel or channels of EMGs on port specified
            default is None
        car_keyword : str
            Specifes default common average reference groups
            defaults are found in CAR_defaults.json
            Currently 'bilateral32' is only keyword available
            If left as None (default) user will be queries to select common
            average reference groups
        shell : bool
            False (default) for GUI. True for command-line interface
        dig_in_names : list of str
            Names of digital inputs. Must match number of digital inputs used
            in recording.
            None (default) queries user to name each dig_in
        dig_out_names : list of str
            Names of digital outputs. Must match number of digital outputs in
            recording.
            None (default) queries user to name each dig_out
        accept_params : bool
            True automatically accepts default parameters where possible,
            decreasing user queries
            False (default) will query user to confirm or edit parameters for
            clustering, spike array and psth creation and palatability/identity
            calculations
        '''
        # Get parameters from info.rhd
        file_dir = self.root_dir
        rec_info = dio.rawIO.read_rec_info(file_dir)
        ports = rec_info.pop('ports')
        channels = rec_info.pop('channels')
        sampling_rate = rec_info['amplifier_sampling_rate']
        self.rec_info = rec_info
        self.sampling_rate = sampling_rate
        # Get default parameters from files
        clustering_params = dio.params.load_params('clustering_params', file_dir,
                                                   default_keyword=data_quality)
        spike_array_params = dio.params.load_params('spike_array_params', file_dir)
        psth_params = dio.params.load_params('psth_params', file_dir)
        pal_id_params = dio.params.load_params('pal_id_params', file_dir)
        spike_array_params['sampling_rate'] = sampling_rate
        clustering_params['file_dir'] = file_dir
        clustering_params['sampling_rate'] = sampling_rate
        # Setup digital input mapping
        if rec_info.get('dig_in'):
            self._setup_digital_mapping('in', dig_in_names, shell)
            dim = self.dig_in_mapping.copy()
            spike_array_params['laser_channels'] = dim.channel[dim['laser']].to_list()
            spike_array_params['dig_ins_to_use'] = dim.channel[dim['spike_array']].to_list()
        else:
            self.dig_in_mapping = None
        if rec_info.get('dig_out'):
            self._setup_digital_mapping('out', dig_out_names, shell)
            dom = self.dig_out_mapping.copy()
        else:
            self.dig_out_mapping = None
        # Setup electrode and emg mapping
        self._setup_channel_mapping(ports, channels, emg_port,
                                    emg_channels, shell=shell)
        # Set CAR groups
        self._set_CAR_groups(group_keyword=car_keyword, group_areas=car_group_areas, shell=shell)
        # Confirm parameters
        self.spike_array_params = spike_array_params
        if not accept_params:
            conf = userIO.confirm_parameter_dict
            clustering_params = conf(clustering_params,
                                     'Clustering Parameters', shell=shell)
            self.edit_spike_array_params(shell=shell)
            psth_params = conf(psth_params,
                               'PSTH Parameters', shell=shell)
            pal_id_params = conf(pal_id_params,
                                 'Palatability/Identity Parameters\n'
                                 'Valid unit_type is Single, Multi or All',
                                 shell=shell)
        # Store parameters
        self.clustering_params = clustering_params
        self.pal_id_params = pal_id_params
        self.psth_params = psth_params
        self._write_all_params_to_json()
        self.process_status['initialize parameters'] = True
        self.save()
    def _set_CAR_groups(self, group_keyword=None, shell=False, group_areas=None):
        '''Sets that electrode groups for common average referencing and
        defines which brain region electrodes eneded up in
        Parameters
        ----------
        group_keyword : str or int
            Keyword corresponding to a preset electrode grouping in CAR_params.json
            Or integer indicating number of CAR groups
        shell : bool
            True for command-line interface, False (default) for GUI
        '''
        if not hasattr(self, 'electrode_mapping'):
            raise ValueError('Set electrode mapping before setting CAR groups')
        em = self.electrode_mapping.copy()
        car_param_file = os.path.join(self.root_dir, 'analysis_params',
                                      'CAR_params.json')
        if os.path.isfile(car_param_file):
            tmp = dio.params.load_params('CAR_params', self.root_dir)
            if tmp is not None:
                group_electrodes = tmp
            else:
                raise ValueError('CAR_params file exists in recording dir, but is empty')
        else:
            if group_keyword is None:
                group_keyword = userIO.get_user_input(
                    'Input keyword for CAR parameters or number of CAR groups',
                    shell=shell)
                if group_keyword is None:
                    ValueError('Must provide a keyword or number of groups')
            if group_keyword.isnumeric():
                num_groups = int(group_keyword)
                group_electrodes = dio.params.select_CAR_groups(num_groups, em,
                                                                shell=shell)
            else:
                group_electrodes = dio.params.load_params('CAR_params',
                                                          self.root_dir,
                                                          default_keyword=group_keyword)
        num_groups = len(group_electrodes)
        if group_areas is not None and len(group_areas) == num_groups:
            for i, x in enumerate(zip(group_electrodes, group_areas)):
                em.loc[x[0], 'area'] = x[1]
                em.loc[x[0], 'CAR_group'] = i
        else:
            group_names = ['Group %i' % i for i in range(num_groups)]
            area_dict = dict.fromkeys(group_names, '')
            area_dict = userIO.fill_dict(area_dict, 'Set Areas for CAR groups',
                                         shell=shell)
            for k, v in area_dict.items():
                i = int(k.replace('Group', ''))
                em.loc[group_electrodes[i], 'area'] = v
                em.loc[group_electrodes[i], 'CAR_group'] = i
        self.CAR_electrodes = group_electrodes
        self.electrode_mapping = em.copy()
    @Logger('Re-labelling CAR group areas')
    def set_electrode_areas(self, areas):
        '''sets the electrode area for each CAR group.
        Parameters
        ----------
        areas : list of str
            number of elements must match number of CAR groups
        Throws
        ------
        ValueError
        '''
        em = self.electrode_mapping.copy()
        if len(em['CAR_group'].unique()) != len(areas):
            raise ValueError('Number of items in areas must match number of CAR groups')
        em['areas'] = em['CAR_group'].apply(lambda x: areas[int(x)])
        self.electrode_mapping = em.copy()
        dio.h5io.write_electrode_map_to_h5(self.h5_file, self.electrode_mapping)
        self.save()
    def _setup_digital_mapping(self, dig_type, dig_in_names=None, shell=False):
        '''sets up dig_in_mapping dataframe  and queries user to fill in columns
        Parameters
        ----------
        dig_in_names : list of str (optional)
        shell : bool (optional)
            True for command-line interface
            False (default) for GUI
        '''
        rec_info = self.rec_info
        df = pd.DataFrame()
        df['channel'] = rec_info.get('dig_%s' % dig_type)
        n_dig_in = len(df)
        # Names
        if dig_in_names:
            df['name'] = dig_in_names
        else:
            df['name'] = ''
        # Parameters to query
        if dig_type == 'in':
            df['palatability_rank'] = 0
            df['laser'] = False
            df['spike_array'] = True
        df['exclude'] = False
        # Re-format for query
        idx = df.index
        df.index = ['dig_%s_%i' % (dig_type, x) for x in df.channel]
        dig_str = dig_type + 'put'
        # Query for user input
        prompt = ('Digital %s Parameters\nSet palatability ranks from 1 to %i'
                  '\nor blank to exclude from pal_id analysis') % (dig_str, len(df))
        tmp = userIO.fill_dict(df.to_dict(), prompt=prompt, shell=shell)
        # Reformat for storage
        df2 = pd.DataFrame.from_dict(tmp)
        df2 = df2.sort_values(by=['channel'])
        df2.index = idx
        if dig_type == 'in':
            df2['palatability_rank'] = df2['palatability_rank'].fillna(-1).astype('int')
        if dig_type == 'in':
            self.dig_in_mapping = df2.copy()
        else:
            self.dig_out_mapping = df2.copy()
        if os.path.isfile(self.h5_file):
            dio.h5io.write_digital_map_to_h5(self.h5_file, self.dig_in_mapping, dig_type)
    def _setup_channel_mapping(self, ports, channels, emg_port, emg_channels, shell=False):
        '''Creates electrode_mapping and emg_mapping DataFrames with columns:
        - Electrode
        - Port
        - Channel
        Parameters
        ----------
        ports : list of str, item corresponing to each channel
        channels : list of int, channels on each port
        emg_port : str
        emg_channels : list of int
        '''
        if emg_port is None:
            q = userIO.ask_user('Do you have an EMG?', shell=shell)
            if q==1:
                emg_port = userIO.select_from_list('Select EMG Port:',
                                                   ports, 'EMG Port',
                                                   shell=shell)
                emg_channels = userIO.select_from_list(
                    'Select EMG Channels:',
                    [y for x, y in
                     zip(ports, channels)
                     if x == emg_port],
                    title='EMG Channels',
                    multi_select=True, shell=shell)
        el_map, em_map = dio.params.flatten_channels(ports, channels,
                                                     emg_port, emg_channels)
        self.electrode_mapping = el_map
        self.emg_mapping = em_map
        if os.path.isfile(self.h5_file):
            dio.h5io.write_electrode_map_to_h5(self.h5_file, self.electrode_mapping)
    def edit_spike_array_params(self, shell=False):
        '''Edit spike array parameters and adjust dig_in_mapping accordingly
        Parameters
        ----------
        shell : bool, whether to use CLI or GUI
        '''
        if not hasattr(self, 'dig_in_mapping'):
            self.spike_array_params = None
            return
        sa = deepcopy(self.spike_array_params)
        tmp = userIO.fill_dict(sa, 'Spike Array Parameters\n(Times in ms)',
                               shell=shell)
        if tmp is None:
            return
        dim = self.dig_in_mapping
        dim['spike_array'] = False
        if tmp['dig_ins_to_use'] != ['']:
            tmp['dig_ins_to_use'] = [int(x) for x in tmp['dig_ins_to_use']]
            dim.loc[[x in tmp['dig_ins_to_use'] for x in dim.channel],
                    'spike_array'] = True
        dim['laser_channels'] = False
        if tmp['laser_channels'] != ['']:
            tmp['laser_channels'] = [int(x) for x in tmp['laser_channels']]
            dim.loc[[x in tmp['laser_channels'] for x in dim.channel],
                    'laser'] = True
        self.spike_array_params = tmp.copy()
        wt.write_params_to_json('spike_array_params',
                                self.root_dir, tmp)
    def edit_clustering_params(self, shell=False):
        '''Allows user interface for editing clustering parameters
        Parameters
        ----------
        shell : bool (optional)
            True if you want command-line interface, False for GUI (default)
        '''
        tmp = userIO.fill_dict(self.clustering_params,
                               'Clustering Parameters\n(Times in ms)',
                               shell=shell)
        if tmp:
            self.clustering_params = tmp
            wt.write_params_to_json('clustering_params', self.root_dir, tmp)
    def edit_psth_params(self, shell=False):
        '''Allows user interface for editing psth parameters
        Parameters
        ----------
        shell : bool (optional)
            True if you want command-line interface, False for GUI (default)
        '''
        tmp = userIO.fill_dict(self.psth_params,
                               'PSTH Parameters\n(Times in ms)',
                               shell=shell)
        if tmp:
            self.psth_params = tmp
            wt.write_params_to_json('psth_params', self.root_dir, tmp)
    def edit_pal_id_params(self, shell=False):
        '''Allows user interface for editing palatability/identity parameters
        Parameters
        ----------
        shell : bool (optional)
            True if you want command-line interface, False for GUI (default)
        '''
        tmp = userIO.fill_dict(self.pal_id_params,
                               'Palatability/Identity Parameters\n(Times in ms)',
                               shell=shell)
        if tmp:
            self.pal_id_params = tmp
            wt.write_params_to_json('pal_id_params', self.root_dir, tmp)
    def __str__(self):
        '''Put all information about dataset in string format
        Returns
        -------
        str : representation of dataset object
        '''
        out1 = super().__str__()
        out = [out1]
        out.append('\nObject creation date: '
                   + self.dataset_creation_date.strftime('%m/%d/%y'))
        if hasattr(self, 'raw_h5_file'):
            out.append('Deleted Raw h5 file: '+self.raw_h5_file)
        out.append('h5 File: '+self.h5_file)
        out.append('')
        out.append('--------------------')
        out.append('Processing Status')
        out.append('--------------------')
        out.append(pt.print_dict(self.process_status))
        out.append('')
        if not hasattr(self, 'rec_info'):
            return '\n'.join(out)
        info = self.rec_info
        out.append('--------------------')
        out.append('Recording Info')
        out.append('--------------------')
        out.append(pt.print_dict(self.rec_info))
        out.append('')
        out.append('--------------------')
        out.append('Electrodes')
        out.append('--------------------')
        out.append(pt.print_dataframe(self.electrode_mapping))
        out.append('')
        if hasattr(self, 'CAR_electrodes'):
            out.append('--------------------')
            out.append('CAR Groups')
            out.append('--------------------')
            headers = ['Group %i' % x for x in range(len(self.CAR_electrodes))]
            out.append(pt.print_list_table(self.CAR_electrodes, headers))
            out.append('')
        if not self.emg_mapping.empty:
            out.append('--------------------')
            out.append('EMG')
            out.append('--------------------')
            out.append(pt.print_dataframe(self.emg_mapping))
            out.append('')
        if info.get('dig_in'):
            out.append('--------------------')
            out.append('Digital Input')
            out.append('--------------------')
            out.append(pt.print_dataframe(self.dig_in_mapping))
            out.append('')
        if info.get('dig_out'):
            out.append('--------------------')
            out.append('Digital Output')
            out.append('--------------------')
            out.append(pt.print_dataframe(self.dig_out_mapping))
            out.append('')
        out.append('--------------------')
        out.append('Clustering Parameters')
        out.append('--------------------')
        out.append(pt.print_dict(self.clustering_params))
        out.append('')
        out.append('--------------------')
        out.append('Spike Array Parameters')
        out.append('--------------------')
        out.append(pt.print_dict(self.spike_array_params))
        out.append('')
        out.append('--------------------')
        out.append('PSTH Parameters')
        out.append('--------------------')
        out.append(pt.print_dict(self.psth_params))
        out.append('')
        out.append('--------------------')
        out.append('Palatability/Identity Parameters')
        out.append('--------------------')
        out.append(pt.print_dict(self.pal_id_params))
        out.append('')
        return '\n'.join(out)
    @Logger('Writing parameters to JSON')
    def _write_all_params_to_json(self):
        '''Writes all parameters to json files in analysis_params folder in the
        recording directory
        '''
        print('Writing all parameters to json file in analysis_params folder...')
        clustering_params = self.clustering_params
        spike_array_params = self.spike_array_params
        psth_params = self.psth_params
        pal_id_params = self.pal_id_params
        CAR_params = self.CAR_electrodes
        rec_dir = self.root_dir
        wt.write_params_to_json('clustering_params', rec_dir, clustering_params)
        wt.write_params_to_json('spike_array_params', rec_dir, spike_array_params)
        wt.write_params_to_json('psth_params', rec_dir, psth_params)
        wt.write_params_to_json('pal_id_params', rec_dir, pal_id_params)
        wt.write_params_to_json('CAR_params', rec_dir, CAR_params)
    @Logger('Extracting Data')
    def extract_data(self, filename=None, shell=False):
        '''Create hdf5 store for data and read in Intan .dat files. Also create
        subfolders for processing outputs
        Parameters
        ----------
        data_quality: {'clean', 'noisy'} (optional)
            Specifies quality of data for default clustering parameters
            associated. Should generally first process with clean (default)
            parameters and then try noisy after running blech_clust and
            checking if too many electrodes as cutoff too early
        '''
        if self.rec_info['file_type'] is None:
            raise ValueError('Unsupported recording type. Cannot extract yet.')
        if filename is None:
            filename = self.h5_file
        print('\nExtract Intan Data\n--------------------')
        # Create h5 file
        tmp = dio.h5io.create_empty_data_h5(filename, shell)
        if tmp is None:
            return
        # Create arrays for raw data in hdf5 store
        dio.h5io.create_hdf_arrays(filename, self.rec_info,
                                   self.electrode_mapping, self.emg_mapping)
        # Read in data to arrays
        dio.h5io.read_files_into_arrays(filename,
                                        self.rec_info,
                                        self.electrode_mapping,
                                        self.emg_mapping)
        # Write electrode and digital input mapping into h5 file
        # TODO: write EMG and digital output mapping into h5 file
        dio.h5io.write_electrode_map_to_h5(self.h5_file, self.electrode_mapping)
        if self.dig_in_mapping is not None:
            dio.h5io.write_digital_map_to_h5(self.h5_file, self.dig_in_mapping, 'in')
        if self.dig_out_mapping is not None:
            dio.h5io.write_digital_map_to_h5(self.h5_file, self.dig_in_mapping, 'out')
        # update status
        self.h5_file = filename
        self.process_status['extract_data'] = True
        self.save()
        print('\nData Extraction Complete\n--------------------')
    @Logger('Creating Trial List')
    def create_trial_list(self):
        '''Create lists of trials based on digital inputs and outputs and store
        to hdf5 store
        Can only be run after data extraction
        '''
        if self.rec_info.get('dig_in'):
            in_list = dio.h5io.create_trial_data_table(
                self.h5_file,
                self.dig_in_mapping,
                self.sampling_rate,
                'in')
            self.dig_in_trials = in_list
        else:
            print('No digital input data found')
        if self.rec_info.get('dig_out'):
            out_list = dio.h5io.create_trial_data_table(
                self.h5_file,
                self.dig_out_mapping,
                self.sampling_rate,
                'out')
            self.dig_out_trials = out_list
        else:
            print('No digital output data found')
        self.process_status['create_trial_list'] = True
        self.save()
    @Logger('Marking Dead Channels')
    def mark_dead_channels(self, dead_channels=None, shell=False):
        '''Plots small piece of raw traces and a metric to help identify dead
        channels. Once user marks channels as dead a new column is added to
        electrode mapping
        Parameters
        ----------
        dead_channels : list of int, optional
            if this is specified then nothing is plotted, those channels are
            simply marked as dead
        shell : bool, optional
        '''
        print('Marking dead channels\n----------')
        em = self.electrode_mapping.copy()
        if dead_channels is None:
            userIO.tell_user('Making traces figure for dead channel detection...',
                             shell=True)
            save_file = os.path.join(self.root_dir, 'Electrode_Traces.png')
            fig, ax = datplt.plot_traces_and_outliers(self.h5_file, save_file=save_file)
            if not shell:
                # Better to open figure outside of python since its a lot of
                # data on figure and matplotlib is slow
                subprocess.call(['xdg-open', save_file])
            else:
                userIO.tell_user('Saved figure of traces to %s for reference'
                                 % save_file, shell=shell)
            choice = userIO.select_from_list('Select dead channels:',
                                             em.Electrode.to_list(),
                                             'Dead Channel Selection',
                                             multi_select=True,
                                             shell=shell)
            dead_channels = list(map(int, choice))
        print('Marking eletrodes %s as dead.\n'
              'They will be excluded from common average referencing.'
              % dead_channels)
        em['dead'] = False
        em.loc[dead_channels, 'dead'] = True
        self.electrode_mapping = em
        if os.path.isfile(self.h5_file):
            dio.h5io.write_electrode_map_to_h5(self.h5_file, self.electrode_mapping)
        self.process_status['mark_dead_channels'] = True
        self.save()
        return dead_channels
    @Logger('Common Average Referencing')
    def common_average_reference(self):
        '''Define electrode groups and remove common average from  signals
        Parameters
        ----------
        num_groups : int (optional)
            number of CAR groups, if not provided
            there's a prompt
        '''
        if not hasattr(self, 'CAR_electrodes'):
            raise ValueError('CAR_electrodes not set')
        if not hasattr(self, 'electrode_mapping'):
            raise ValueError('electrode_mapping not set')
        car_electrodes = self.CAR_electrodes
        num_groups = len(car_electrodes)
        em = self.electrode_mapping.copy()
        if 'dead' in em.columns:
            dead_electrodes = em.Electrode[em.dead].to_list()
        else:
            dead_electrodes = []
        # Gather Common Average Reference Groups
        print('CAR Groups\n')
        headers = ['Group %i' % x for x in range(num_groups)]
        print(pt.print_list_table(car_electrodes, headers))
        # Reference each group
        for i, x in enumerate(car_electrodes):
            tmp = list(set(x) - set(dead_electrodes))
            dio.h5io.common_avg_reference(self.h5_file, tmp, i)
        # Compress and repack file
        dio.h5io.compress_and_repack(self.h5_file)
        self.process_status['common_average_reference'] = True
        self.save()
    @Logger('Running Spike Detection')
    def detect_spikes(self, data_quality=None, multi_process=True, n_cores=None):
        '''Run spike detection on each electrode. Prepares for clustering with
        BlechClust. Works for both single recording clustering or
        multi-recording clustering
        Parameters
        ----------
        data_quality : {'clean', 'noisy', None (default)}
            set if you want to change the data quality parameters for cutoff
            and spike detection before running clustering. These parameters are
            automatically set as "clean" during initial parameter setup
        n_cores : int (optional)
            number of cores to use for parallel processing. default is max-1.
        '''
        if data_quality:
            tmp = dio.params.load_params('clustering_params', self.root_dir,
                                         default_keyword=data_quality)
            if tmp:
                self.clustering_params = tmp
                wt.write_params_to_json('clustering_params', self.root_dir, tmp)
            else:
                raise ValueError('%s is not a valid data_quality preset. Must '
                                 'be "clean" or "noisy" or None.')
        print('\nRunning Spike Detection\n-------------------')
        print('Parameters\n%s' % pt.print_dict(self.clustering_params))
        # Create folders for saving things within recording dir
        data_dir = self.root_dir
        em = self.electrode_mapping
        if 'dead' in em.columns:
            electrodes = em.Electrode[em['dead'] == False].tolist()
        else:
            electrodes = em.Electrode.tolist()
        pbar = tqdm(total = len(electrodes))
        results = [(None, None, None)] * (max(electrodes)+1)
        def update_pbar(ans):
            if isinstance(ans, tuple) and ans[0] is not None:
                results[ans[0]] = ans
            else:
                print('Unexpected error when clustering an electrode')
            pbar.update()
        spike_detectors = [clust.SpikeDetection(data_dir, x,
                                                self.clustering_params)
                           for x in electrodes]
        if multi_process:
            if n_cores is None or n_cores > multiprocessing.cpu_count():
                n_cores = multiprocessing.cpu_count() - 1
            pool = multiprocessing.get_context('spawn').Pool(n_cores)
            for sd in spike_detectors:
                pool.apply_async(sd.run, callback=update_pbar)
            pool.close()
            pool.join()
        else:
            for sd in spike_detectors:
                res = sd.run()
                update_pbar(res)
        pbar.close()
        print('Electrode    Result    Cutoff (s)')
        cutoffs = {}
        clust_res = {}
        clustered = []
        for x, y, z in results:
            if x is None:
                continue
            clustered.append(x)
            print('  {:<13}{:<10}{}'.format(x, y, z))
            cutoffs[x] = z
            clust_res[x] = y
        print('1 - Sucess\n0 - No data or no spikes\n-1 - Error')
        em = self.electrode_mapping.copy()
        em['cutoff_time'] = em['Electrode'].map(cutoffs)
        em['clustering_result'] = em['Electrode'].map(clust_res)
        self.electrode_mapping = em.copy()
        self.process_status['spike_detection'] = True
        dio.h5io.write_electrode_map_to_h5(self.h5_file, em)
        self.save()
        print('Spike Detection Complete\n------------------')
        return results
    @Logger('Running Blech Clust')
    def blech_clust_run(self, data_quality=None, multi_process=True, n_cores=None, umap=False):
        '''Write clustering parameters to file and
        Run blech_process on each electrode using GNU parallel
        Parameters
        ----------
        data_quality : {'clean', 'noisy', None (default)}
            set if you want to change the data quality parameters for cutoff
            and spike detection before running clustering. These parameters are
            automatically set as "clean" during initial parameter setup
        accept_params : bool, False (default)
            set to True in order to skip popup confirmation of parameters when
            running
        '''
        if self.process_status['spike_detection'] == False:
            raise FileNotFoundError('Must run spike detection before clustering.')
        if data_quality:
            tmp = dio.params.load_params('clustering_params', self.root_dir,
                                         default_keyword=data_quality)
            if tmp:
                self.clustering_params = tmp
                wt.write_params_to_json('clustering_params', self.root_dir, tmp)
            else:
                raise ValueError('%s is not a valid data_quality preset. Must '
                                 'be "clean" or "noisy" or None.')
        print('\nRunning Blech Clust\n-------------------')
        print('Parameters\n%s' % pt.print_dict(self.clustering_params))
        # Get electrodes, throw out 'dead' electrodes
        em = self.electrode_mapping
        if 'dead' in em.columns:
            electrodes = em.Electrode[em['dead'] == False].tolist()
        else:
            electrodes = em.Electrode.tolist()
        pbar = tqdm(total = len(electrodes))
        def update_pbar(ans):
            pbar.update()
        errors = []
        def error_call(e):
            errors.append(e)
        if not umap:
            clust_objs = [clust.BlechClust(self.root_dir, x, params=self.clustering_params)
                          for x in electrodes]
        else:
            clust_objs = [clust.BlechClust(self.root_dir, x,
                                           params=self.clustering_params,
                                           data_transform=clust.UMAP_METRICS,
                                           n_pc=5)
                          for x in electrodes]
        if multi_process:
            if n_cores is None or n_cores > multiprocessing.cpu_count():
                n_cores = multiprocessing.cpu_count() - 1
            pool = multiprocessing.get_context('spawn').Pool(n_cores)
            for x in clust_objs:
                pool.apply_async(x.run, callback=update_pbar, error_callback=error_call)
            pool.close()
            pool.join()
        else:
            for x in clust_objs:
                res = x.run()
                update_pbar(res)
        pbar.close()
        self.process_status['spike_clustering'] = True
        self.process_status['cleanup_clustering'] = False
        dio.h5io.write_electrode_map_to_h5(self.h5_file, em)
        self.save()
        print('Clustering Complete\n------------------')
        if len(errors) > 0:
            print('Errors encountered:')
            print(errors)
    @Logger('Cleaning up clustering memory logs. Removing raw data and setting'
            'up hdf5 for unit sorting')
    def cleanup_clustering(self):
        '''Consolidates memory monitor files, removes raw and referenced data
        and setups up hdf5 store for sorted units data
        '''
        if self.process_status['cleanup_clustering']:
            return
        h5_file = dio.h5io.cleanup_clustering(self.root_dir)
        self.h5_file = h5_file
        self.process_status['cleanup_clustering'] = True
        self.save()
    def sort_spikes(self, electrode=None, shell=False):
        if electrode is None:
            electrode = userIO.get_user_input('Electrode #: ', shell=shell)
            if electrode is None or not electrode.isnumeric():
                return
            electrode = int(electrode)
        if not self.process_status['spike_clustering']:
            raise ValueError('Must run spike clustering first.')
        if not self.process_status['cleanup_clustering']:
            self.cleanup_clustering()
        sorter = clust.SpikeSorter(self.root_dir, electrode=electrode, shell=shell)
        if not shell:
            root, sorting_GUI = ssg.launch_sorter_GUI(sorter)
            return root, sorting_GUI
        else:
            # TODO: Make shell UI
            # TODO: Make sort by table
            print('No shell UI yet')
            return
        self.process_status['sort_units'] = True
    @Logger('Calculating Units Similarity')
    def units_similarity(self, similarity_cutoff=50, shell=False):
        if 'SSH_CONNECTION' in os.environ:
            shell= True
        metrics_dir = os.path.join(self.root_dir, 'sorted_unit_metrics')
        if not os.path.isdir(metrics_dir):
            raise ValueError('No sorted unit metrics found. Must sort units before calculating similarity')
        violation_file = os.path.join(metrics_dir,
                                      'units_similarity_violations.txt')
        violations, sim = ss.calc_units_similarity(self.h5_file,
                                                   self.sampling_rate,
                                                   similarity_cutoff,
                                                   violation_file)
        if len(violations) == 0:
            userIO.tell_user('No similarity violations found!', shell=shell)
            self.process_status['units_similarity'] = True
            return violations, sim
        out_str = ['Units Similarity Violations Found:']
        out_str.append('Unit_1    Unit_2    Similarity')
        for x,y in violations:
            u1 = dio.h5io.parse_unit_number(x)
            u2 = dio.h5io.parse_unit_number(y)
            out_str.append('   {:<10}{:<10}{}\n'.format(x, y, sim[u1][u2]))
        out_str.append('Delete units with dataset.delete_unit(N)')
        out_str = '\n'.join(out_str)
        userIO.tell_user(out_str, shell=shell)
        self.process_status['units_similarity'] = True
        self.save()
        return violations, sim
    @Logger('Deleting Unit')
    def delete_unit(self, unit_num, confirm=False, shell=False):
        if isinstance(unit_num, str):
            unit_num = dio.h5io.parse_unit_number(unit_num)
        if unit_num is None:
            print('No unit deleted')
            return
        if not confirm:
            q = userIO.ask_user('Are you sure you want to delete unit%03i?' % unit_num,
                                choices = ['No','Yes'], shell=shell)
        else:
            q = 1
        if q == 0:
            print('No unit deleted')
            return
        else:
            tmp = dio.h5io.delete_unit(self.root_dir, unit_num)
            if tmp is False:
                userIO.tell_user('Unit %i not found in dataset. No unit deleted'
                                 % unit_num, shell=shell)
            else:
                userIO.tell_user('Unit %i sucessfully deleted.' % unit_num,
                                 shell=shell)
        self.save()
    @Logger('Making Unit Arrays')
    def make_unit_arrays(self):
        '''Make spike arrays for each unit and store in hdf5 store
        '''
        params = self.spike_array_params
        print('Generating unit arrays with parameters:\n----------')
        print(pt.print_dict(params, tabs=1))
        ss.make_spike_arrays(self.h5_file, params)
        self.process_status['make_unit_arrays'] = True
        self.save()
    @Logger('Making Unit Plots')
    def make_unit_plots(self):
        '''Make waveform plots for each sorted unit
        '''
        unit_table = self.get_unit_table()
        save_dir = os.path.join(self.root_dir, 'unit_waveforms_plots')
        if os.path.isdir(save_dir):
            shutil.rmtree(save_dir)
        os.mkdir(save_dir)
        for i, row in unit_table.iterrows():
            datplt.make_unit_plots(self.root_dir, row['unit_name'], save_dir=save_dir)
        self.process_status['make_unit_plots'] = True
        self.save()
    @Logger('Making PSTH Arrays')
    def make_psth_arrays(self):
        '''Make smoothed firing rate traces for each unit/trial and store in
        hdf5 store
        '''
        params = self.psth_params
        dig_ins = self.dig_in_mapping.query('spike_array == True')
        for idx, row in dig_ins.iterrows():
            spike_analysis.make_psths_for_tastant(self.h5_file,
                                                  params['window_size'],
                                                  params['window_step'],
                                                  row['channel'])
        self.process_status['make_psth_arrays'] = True
        self.save()
    @Logger('Calculating Palatability/Identity Metrics')
    def palatability_calculate(self, shell=False):
        pal_analysis.palatability_identity_calculations(self.root_dir,
                                                        params=self.pal_id_params)
        self.process_status['palatability_calculate'] = True
        self.save()
    @Logger('Plotting Palatability/Identity Metrics')
    def palatability_plot(self, shell=False):
        pal_plt.plot_palatability_identity([self.root_dir], shell=shell)
        self.process_status['palatability_plot'] = True
        self.save()
    @Logger('Removing low-spiking units')
    def cleanup_lowSpiking_units(self, min_spikes=100):
        unit_table = self.get_unit_table()
        remove = []
        spike_count = []
        for unit in unit_table['unit_num']:
            waves, descrip, fs = dio.h5io.get_unit_waveforms(self.root_dir, unit)
            if waves.shape[0] < min_spikes:
                spike_count.append(waves.shape[0])
                remove.append(unit)
        for unit, count in zip(reversed(remove), reversed(spike_count)):
            print('Removing unit %i. Only %i spikes.' % (unit, count))
            userIO.tell_user('Removing unit %i. Only %i spikes.'
                             % (unit, count), shell=True)
            self.delete_unit(unit, confirm=True, shell=True)
        userIO.tell_user('Removed %i units for having less than %i spikes.'
                         % (len(remove), min_spikes), shell=True)
    def get_unit_table(self):
        '''Returns a pandas dataframe with sorted unit information
        Returns
        --------
        pandas.DataFrame with columns:
            unit_name, unit_num, electrode, single_unit,
            regular_spiking, fast_spiking
        '''
        unit_table = dio.h5io.get_unit_table(self.root_dir)
        return unit_table
    def circus_clust_run(self, shell=False):
        circ.prep_for_circus(self.root_dir, self.electrode_mapping,
                             self.data_name, self.sampling_rate)
        circ.start_the_show()
    def pre_process_for_clustering(self, shell=False, dead_channels=None):
        status = self.process_status
        if not status['initialize parameters']:
            self.initParams(shell=shell)
        if not status['extract_data']:
            self.extract_data(shell=True)
        if not status['create_trial_list']:
            self.create_trial_list()
        if not status['mark_dead_channels'] and dead_channels != False:
            self.mark_dead_channels(dead_channels=dead_channels, shell=shell)
        if not status['common_average_reference']:
            self.common_average_reference()
        if not status['spike_detection']:
            self.detect_spikes()
    def extract_and_circus_cluster(self, dead_channels=None, shell=True):
        print('Extracting Data...')
        self.extract_data()
        print('Marking dead channels...')
        self.mark_dead_channels(dead_channels, shell=shell)
        print('Common average referencing...')
        self.common_average_reference()
        print('Initiating circus clustering...')
        circus = circ.circus_clust(self.root_dir, self.data_name,
                                   self.sampling_rate, self.electrode_mapping)
        print('Preparing for circus...')
        circus.prep_for_circus()
        print('Starting circus clustering...')
        circus.start_the_show()
        print('Plotting cluster waveforms...')
        circus.plot_cluster_waveforms()
    def post_sorting(self):
        self.make_unit_plots()
        self.make_unit_arrays()
        self.units_similarity(shell=True)
        self.make_psth_arrays()
def port_in_dataset(rec_dir=None, shell=False):
    '''Import an existing dataset into this framework
    '''
    if rec_dir is None:
        rec_dir = userIO.get_filedirs('Select recording directory', shell=shell)
        if rec_dir is None:
            return None
    dat = dataset(rec_dir, shell=shell)
    # Check files that will be overwritten: log_file, save_file
    if os.path.isfile(dat.save_file):
        prompt = '%s already exists. Continuing will overwrite this. Continue?' % dat.save_file
        q = userIO.ask_user(prompt, shell=shell)
        if q == 0:
            print('Aborted')
            return None
    # if os.path.isfile(dat.h5_file):
    #     prompt = '%s already exists. Continuinlg will overwrite this. Continue?' % dat.h5_file
    #     q = userIO.ask_user(prompt, shell=shell)
    #     if q == 0:
    #         print('Aborted')
    #         return None
    if os.path.isfile(dat.log_file):
        prompt = '%s already exists. Continuing will append to this. Continue?' % dat.log_file
        q = userIO.ask_user(prompt, shell=shell)
        if q == 0:
            print('Aborted')
            return None
    with open(dat.log_file, 'a') as f:
        print('\n==========\nPorting dataset into blechpy format\n==========\n', file=f)
        print(dat, file=f)
    status = dat.process_status
    dat.initParams(shell=shell)
    user_status = status.copy()
    user_status = userIO.fill_dict(user_status,
                                   'Which processes have already been '
                                   'done to the data?', shell=shell)
    status.update(user_status)
    # if h5 exists data must have been extracted
    if not os.path.isfile(dat.h5_file) or status['extract_data'] == False:
        dat.save()
        return dat
    # write eletrode map and digital input & output maps to hf5
    dio.h5io.write_electrode_map_to_h5(dat.h5_file, dat.electrode_mapping)
    if dat.rec_info.get('dig_in') is not None:
        dio.h5io.write_digital_map_to_h5(dat.h5_file, dat.dig_in_mapping, 'in')
    if dat.rec_info.get('dig_out') is not None:
        dio.h5io.write_digital_map_to_h5(dat.h5_file, dat.dig_out_mapping, 'out')
    node_list = dio.h5io.get_node_list(dat.h5_file)
    if (status['create_trial_list'] == False) and ('digital_in' in node_list):
        dat.create_trial_list()
    dat.save()
    return dat
def validate_data_integrity(rec_dir, verbose=False):
    '''incomplete
    '''
    # TODO: Finish this
    print('Raw Data Validation\n' + '-'*19)
    test_names = ['file_type', 'recording_info', 'files', 'dropped_packets', 'data_length']
    number_names = ['sample_rate', 'dropped_packets', 'missing_files', 'recording_length']
    tests = dict.fromkeys(test_names, 'NOT TESTED')
    numbers = dict.fromkeys(number_names, -1)
    file_type = dio.rawIO.get_recording_filetype(rec_dir)
    if file_type is None:
        file_type_check = 'UNSUPPORTED'
    else:
        tests['file_type'] = 'PASS'
    # Check info.rhd integrity
    info_file = os.path.join(rec_dir, 'info.rhd')
    try:
        rec_info = dio.rawIO.read_rec_info(rec_dir, shell=True)
        with open(info_file, 'rb') as f:
            info = dio.load_intan_rhd_format.read_header(f)
        tests['recording_info'] = 'PASS'
    except FileNotFoundError:
        test['recording_info'] = 'MISSING'
    except Exception as e:
        info_size = os.path.getsize(os.path.join(rec_dir, 'info.rhd'))
        if info_size == 0:
            tests['recording_info'] = 'EMPTY'
        else:
            tests['recording_info'] = 'FAIL'
        print(pt.print_dict(tests, tabs=1))
        return tests, numbers
    counts = {x : info(x) for x in info.keys() if 'num' in x}
    numbers.update(counts)
    fs = info['sample_rate']
    # Check all files needed are present
    files_expected = ['time.dat']
    if file_type == 'one file per signal type':
        files_expected.append('amplifier.dat')
        if rec_info.get('dig_in') is not None:
            files_expected.append('digitalin.dat')
        if rec_info.get('dig_out') is not None:
            files_expected.append('digitalout.dat')
        if info['num_auxilary_input_channels'] > 0:
            files_expected.append('auxiliary.dat')
    elif file_type == 'one file per channel':
        for x in info['amplifier_channels']:
            files_expected.append('amp-' + x['native_channel_name'] + '.dat')
        for x in info['board_dig_in_channels']:
            files_expected.append('board-%s.dat' % x['native_channel_name'])
        for x in info['board_dig_out_channels']:
            files_expected.append('board-%s.dat' % x['native_channel_name'])
        for x in info['aux_input_channels']:
            files_expected.append('aux-%s.dat' % x['native_channel_name'])
    missing_files = []
    file_list = os.listdir(rec_dir)
    for x in file_expected:
        if x not in file_list:
            missing_file.append(x)
    if len(missing_files) == 0:
        tests['files'] = 'PASS'
    else:
        tests['files'] = 'MISSING'
        numbers['missing_files'] = missing_files
    # Check time data for dropped packets
    time = dio.rawIO.read_time_dat(rec_dir, sampling_rate=1)  # get raw timestamps
    numbers['n_samples'] = len(time)
    numbers['recording_length'] = float(time[-1])/fs
    expected_time = np.arange(time[0], time[-1]+1, 1)
    missing_timestamps = np.setdiff1d(expected_time, time)
    missing_times = np.array([float(x)/fs for x in missing_timestamps])
    if len(missing_timestamps) == 0:
        tests['dropped_packets'] = 'PASS'
    else:
        tests['dropped_packets'] = '%i' % len(missing_timestamps)
        numbers['dropped_packets'] = missing_times
    # Check recording length of each trace
    tests['data_traces'] = 'FAIL'
    if file_type == 'one file per signal type':
        try:
            data = dio.rawIO.read_amplifier_dat(rec_dir)
            if data is None:
                tests['data_traces'] = 'UNREADABLE'
            elif data.shape[0] == numbers['n_samples']:
                tests['data_traces'] = 'PASS'
            else:
                tests['data_traces'] = 'CUTOFF'
                numbers['data_trace_length (s)'] = data.shape[0]/fs
        except:
            tests['data_traces'] = 'UNREADABLE'
    elif file_type == 'one file per channel':
        chan_info = pd.DataFrame(columns=['port', 'channel', 'n_samples'])
        lengths = []
        min_samples = numbers['n_samples']
        max_samples = number['n_samples']
        for x in info['amplifier_channels']:
            fn = os.path.join(rec_dir, 'amp-%s.dat' % x['native_channel_name'])
            if os.path.basename(fn) in missing_files:
                continue
            data = dio.rawIO.read_one_channel_file(fn)
            lengths.append((x['native_channel_name'], data.shape[0]))
            if data.shape[0] < min_samples:
                min_samples = data.shape[0]
            if data.shape[0] > max_samples:
                max_samples = data.shape[0]
        if min_samples == max_samples:
            tests['data_traces'] = 'PASS'
        else:
            test['data_traces'] = 'CUTOFF'
        numbers['max_recording_length (s)'] = max_samples/fs
        numbers['min_recording_length (s)'] = min_samples/fsFunctions
- def port_in_dataset(rec_dir=None, shell=False)
- 
Import an existing dataset into this framework Expand source codedef port_in_dataset(rec_dir=None, shell=False): '''Import an existing dataset into this framework ''' if rec_dir is None: rec_dir = userIO.get_filedirs('Select recording directory', shell=shell) if rec_dir is None: return None dat = dataset(rec_dir, shell=shell) # Check files that will be overwritten: log_file, save_file if os.path.isfile(dat.save_file): prompt = '%s already exists. Continuing will overwrite this. Continue?' % dat.save_file q = userIO.ask_user(prompt, shell=shell) if q == 0: print('Aborted') return None # if os.path.isfile(dat.h5_file): # prompt = '%s already exists. Continuinlg will overwrite this. Continue?' % dat.h5_file # q = userIO.ask_user(prompt, shell=shell) # if q == 0: # print('Aborted') # return None if os.path.isfile(dat.log_file): prompt = '%s already exists. Continuing will append to this. Continue?' % dat.log_file q = userIO.ask_user(prompt, shell=shell) if q == 0: print('Aborted') return None with open(dat.log_file, 'a') as f: print('\n==========\nPorting dataset into blechpy format\n==========\n', file=f) print(dat, file=f) status = dat.process_status dat.initParams(shell=shell) user_status = status.copy() user_status = userIO.fill_dict(user_status, 'Which processes have already been ' 'done to the data?', shell=shell) status.update(user_status) # if h5 exists data must have been extracted if not os.path.isfile(dat.h5_file) or status['extract_data'] == False: dat.save() return dat # write eletrode map and digital input & output maps to hf5 dio.h5io.write_electrode_map_to_h5(dat.h5_file, dat.electrode_mapping) if dat.rec_info.get('dig_in') is not None: dio.h5io.write_digital_map_to_h5(dat.h5_file, dat.dig_in_mapping, 'in') if dat.rec_info.get('dig_out') is not None: dio.h5io.write_digital_map_to_h5(dat.h5_file, dat.dig_out_mapping, 'out') node_list = dio.h5io.get_node_list(dat.h5_file) if (status['create_trial_list'] == False) and ('digital_in' in node_list): dat.create_trial_list() dat.save() return dat
- def validate_data_integrity(rec_dir, verbose=False)
- 
incomplete Expand source codedef validate_data_integrity(rec_dir, verbose=False): '''incomplete ''' # TODO: Finish this print('Raw Data Validation\n' + '-'*19) test_names = ['file_type', 'recording_info', 'files', 'dropped_packets', 'data_length'] number_names = ['sample_rate', 'dropped_packets', 'missing_files', 'recording_length'] tests = dict.fromkeys(test_names, 'NOT TESTED') numbers = dict.fromkeys(number_names, -1) file_type = dio.rawIO.get_recording_filetype(rec_dir) if file_type is None: file_type_check = 'UNSUPPORTED' else: tests['file_type'] = 'PASS' # Check info.rhd integrity info_file = os.path.join(rec_dir, 'info.rhd') try: rec_info = dio.rawIO.read_rec_info(rec_dir, shell=True) with open(info_file, 'rb') as f: info = dio.load_intan_rhd_format.read_header(f) tests['recording_info'] = 'PASS' except FileNotFoundError: test['recording_info'] = 'MISSING' except Exception as e: info_size = os.path.getsize(os.path.join(rec_dir, 'info.rhd')) if info_size == 0: tests['recording_info'] = 'EMPTY' else: tests['recording_info'] = 'FAIL' print(pt.print_dict(tests, tabs=1)) return tests, numbers counts = {x : info(x) for x in info.keys() if 'num' in x} numbers.update(counts) fs = info['sample_rate'] # Check all files needed are present files_expected = ['time.dat'] if file_type == 'one file per signal type': files_expected.append('amplifier.dat') if rec_info.get('dig_in') is not None: files_expected.append('digitalin.dat') if rec_info.get('dig_out') is not None: files_expected.append('digitalout.dat') if info['num_auxilary_input_channels'] > 0: files_expected.append('auxiliary.dat') elif file_type == 'one file per channel': for x in info['amplifier_channels']: files_expected.append('amp-' + x['native_channel_name'] + '.dat') for x in info['board_dig_in_channels']: files_expected.append('board-%s.dat' % x['native_channel_name']) for x in info['board_dig_out_channels']: files_expected.append('board-%s.dat' % x['native_channel_name']) for x in info['aux_input_channels']: files_expected.append('aux-%s.dat' % x['native_channel_name']) missing_files = [] file_list = os.listdir(rec_dir) for x in file_expected: if x not in file_list: missing_file.append(x) if len(missing_files) == 0: tests['files'] = 'PASS' else: tests['files'] = 'MISSING' numbers['missing_files'] = missing_files # Check time data for dropped packets time = dio.rawIO.read_time_dat(rec_dir, sampling_rate=1) # get raw timestamps numbers['n_samples'] = len(time) numbers['recording_length'] = float(time[-1])/fs expected_time = np.arange(time[0], time[-1]+1, 1) missing_timestamps = np.setdiff1d(expected_time, time) missing_times = np.array([float(x)/fs for x in missing_timestamps]) if len(missing_timestamps) == 0: tests['dropped_packets'] = 'PASS' else: tests['dropped_packets'] = '%i' % len(missing_timestamps) numbers['dropped_packets'] = missing_times # Check recording length of each trace tests['data_traces'] = 'FAIL' if file_type == 'one file per signal type': try: data = dio.rawIO.read_amplifier_dat(rec_dir) if data is None: tests['data_traces'] = 'UNREADABLE' elif data.shape[0] == numbers['n_samples']: tests['data_traces'] = 'PASS' else: tests['data_traces'] = 'CUTOFF' numbers['data_trace_length (s)'] = data.shape[0]/fs except: tests['data_traces'] = 'UNREADABLE' elif file_type == 'one file per channel': chan_info = pd.DataFrame(columns=['port', 'channel', 'n_samples']) lengths = [] min_samples = numbers['n_samples'] max_samples = number['n_samples'] for x in info['amplifier_channels']: fn = os.path.join(rec_dir, 'amp-%s.dat' % x['native_channel_name']) if os.path.basename(fn) in missing_files: continue data = dio.rawIO.read_one_channel_file(fn) lengths.append((x['native_channel_name'], data.shape[0])) if data.shape[0] < min_samples: min_samples = data.shape[0] if data.shape[0] > max_samples: max_samples = data.shape[0] if min_samples == max_samples: tests['data_traces'] = 'PASS' else: test['data_traces'] = 'CUTOFF' numbers['max_recording_length (s)'] = max_samples/fs numbers['min_recording_length (s)'] = min_samples/fs
Classes
- class dataset (file_dir=None, data_name=None, shell=False)
- 
Stores information related to an intan recording directory, allows executing basic processing and analysis scripts, and stores parameters data for those analyses Parameters- file_dir:- str (optional)
- absolute path to a recording directory, if left empty a filechooser will popup
 Initialize dataset object from file_dir, grabs basename from name of directory and initializes basic analysis parameters Parameters- file_dir:- str (optional), file directory for intan recording data
 ThrowsValueError if file_dir is not provided and no directory is chosen when prompted NotADirectoryError : if file_dir does not exist Expand source codeclass dataset(data_object): '''Stores information related to an intan recording directory, allows executing basic processing and analysis scripts, and stores parameters data for those analyses Parameters ---------- file_dir : str (optional) absolute path to a recording directory, if left empty a filechooser will popup ''' PROCESSING_STEPS = ['initialize parameters', 'extract_data', 'create_trial_list', 'mark_dead_channels', 'common_average_reference', 'spike_detection', 'spike_clustering', 'cleanup_clustering', 'sort_units', 'make_unit_plots', 'units_similarity', 'make_unit_arrays', 'make_psth_arrays', 'plot_psths', 'palatability_calculate', 'palatability_plot', 'overlay_psth'] def __init__(self, file_dir=None, data_name=None, shell=False): '''Initialize dataset object from file_dir, grabs basename from name of directory and initializes basic analysis parameters Parameters ---------- file_dir : str (optional), file directory for intan recording data Throws ------ ValueError if file_dir is not provided and no directory is chosen when prompted NotADirectoryError : if file_dir does not exist ''' super().__init__('dataset', file_dir, data_name=data_name, shell=shell) h5_file = dio.h5io.get_h5_filename(self.root_dir) if h5_file is None: h5_file = os.path.join(self.root_dir, '%s.h5' % self.data_name) self.h5_file = h5_file self.dataset_creation_date = dt.datetime.today() # Outline standard processing pipeline and status check self.processing_steps = dataset.PROCESSING_STEPS.copy() self.process_status = dict.fromkeys(self.processing_steps, False) def _change_root(self, new_root=None): old_root = self.root_dir new_root = super()._change_root(new_root) self.h5_file = self.h5_file.replace(old_root, new_root) return new_root @Logger('Initializing Parameters') def initParams(self, data_quality='clean', emg_port=None, emg_channels=None, car_keyword=None, car_group_areas=None, shell=False, dig_in_names=None, dig_out_names=None, accept_params=False): ''' Initalizes basic default analysis parameters and allows customization of parameters Parameters (all optional) ------------------------- data_quality : {'clean', 'noisy'} keyword defining which default set of parameters to use to detect headstage disconnection during clustering default is 'clean'. Best practice is to run blech_clust as 'clean' and re-run as 'noisy' if too many early cutoffs occurr emg_port : str Port ('A', 'B', 'C') of EMG, if there was an EMG. None (default) will query user. False indicates no EMG port and not to query user emg_channels : list of int channel or channels of EMGs on port specified default is None car_keyword : str Specifes default common average reference groups defaults are found in CAR_defaults.json Currently 'bilateral32' is only keyword available If left as None (default) user will be queries to select common average reference groups shell : bool False (default) for GUI. True for command-line interface dig_in_names : list of str Names of digital inputs. Must match number of digital inputs used in recording. None (default) queries user to name each dig_in dig_out_names : list of str Names of digital outputs. Must match number of digital outputs in recording. None (default) queries user to name each dig_out accept_params : bool True automatically accepts default parameters where possible, decreasing user queries False (default) will query user to confirm or edit parameters for clustering, spike array and psth creation and palatability/identity calculations ''' # Get parameters from info.rhd file_dir = self.root_dir rec_info = dio.rawIO.read_rec_info(file_dir) ports = rec_info.pop('ports') channels = rec_info.pop('channels') sampling_rate = rec_info['amplifier_sampling_rate'] self.rec_info = rec_info self.sampling_rate = sampling_rate # Get default parameters from files clustering_params = dio.params.load_params('clustering_params', file_dir, default_keyword=data_quality) spike_array_params = dio.params.load_params('spike_array_params', file_dir) psth_params = dio.params.load_params('psth_params', file_dir) pal_id_params = dio.params.load_params('pal_id_params', file_dir) spike_array_params['sampling_rate'] = sampling_rate clustering_params['file_dir'] = file_dir clustering_params['sampling_rate'] = sampling_rate # Setup digital input mapping if rec_info.get('dig_in'): self._setup_digital_mapping('in', dig_in_names, shell) dim = self.dig_in_mapping.copy() spike_array_params['laser_channels'] = dim.channel[dim['laser']].to_list() spike_array_params['dig_ins_to_use'] = dim.channel[dim['spike_array']].to_list() else: self.dig_in_mapping = None if rec_info.get('dig_out'): self._setup_digital_mapping('out', dig_out_names, shell) dom = self.dig_out_mapping.copy() else: self.dig_out_mapping = None # Setup electrode and emg mapping self._setup_channel_mapping(ports, channels, emg_port, emg_channels, shell=shell) # Set CAR groups self._set_CAR_groups(group_keyword=car_keyword, group_areas=car_group_areas, shell=shell) # Confirm parameters self.spike_array_params = spike_array_params if not accept_params: conf = userIO.confirm_parameter_dict clustering_params = conf(clustering_params, 'Clustering Parameters', shell=shell) self.edit_spike_array_params(shell=shell) psth_params = conf(psth_params, 'PSTH Parameters', shell=shell) pal_id_params = conf(pal_id_params, 'Palatability/Identity Parameters\n' 'Valid unit_type is Single, Multi or All', shell=shell) # Store parameters self.clustering_params = clustering_params self.pal_id_params = pal_id_params self.psth_params = psth_params self._write_all_params_to_json() self.process_status['initialize parameters'] = True self.save() def _set_CAR_groups(self, group_keyword=None, shell=False, group_areas=None): '''Sets that electrode groups for common average referencing and defines which brain region electrodes eneded up in Parameters ---------- group_keyword : str or int Keyword corresponding to a preset electrode grouping in CAR_params.json Or integer indicating number of CAR groups shell : bool True for command-line interface, False (default) for GUI ''' if not hasattr(self, 'electrode_mapping'): raise ValueError('Set electrode mapping before setting CAR groups') em = self.electrode_mapping.copy() car_param_file = os.path.join(self.root_dir, 'analysis_params', 'CAR_params.json') if os.path.isfile(car_param_file): tmp = dio.params.load_params('CAR_params', self.root_dir) if tmp is not None: group_electrodes = tmp else: raise ValueError('CAR_params file exists in recording dir, but is empty') else: if group_keyword is None: group_keyword = userIO.get_user_input( 'Input keyword for CAR parameters or number of CAR groups', shell=shell) if group_keyword is None: ValueError('Must provide a keyword or number of groups') if group_keyword.isnumeric(): num_groups = int(group_keyword) group_electrodes = dio.params.select_CAR_groups(num_groups, em, shell=shell) else: group_electrodes = dio.params.load_params('CAR_params', self.root_dir, default_keyword=group_keyword) num_groups = len(group_electrodes) if group_areas is not None and len(group_areas) == num_groups: for i, x in enumerate(zip(group_electrodes, group_areas)): em.loc[x[0], 'area'] = x[1] em.loc[x[0], 'CAR_group'] = i else: group_names = ['Group %i' % i for i in range(num_groups)] area_dict = dict.fromkeys(group_names, '') area_dict = userIO.fill_dict(area_dict, 'Set Areas for CAR groups', shell=shell) for k, v in area_dict.items(): i = int(k.replace('Group', '')) em.loc[group_electrodes[i], 'area'] = v em.loc[group_electrodes[i], 'CAR_group'] = i self.CAR_electrodes = group_electrodes self.electrode_mapping = em.copy() @Logger('Re-labelling CAR group areas') def set_electrode_areas(self, areas): '''sets the electrode area for each CAR group. Parameters ---------- areas : list of str number of elements must match number of CAR groups Throws ------ ValueError ''' em = self.electrode_mapping.copy() if len(em['CAR_group'].unique()) != len(areas): raise ValueError('Number of items in areas must match number of CAR groups') em['areas'] = em['CAR_group'].apply(lambda x: areas[int(x)]) self.electrode_mapping = em.copy() dio.h5io.write_electrode_map_to_h5(self.h5_file, self.electrode_mapping) self.save() def _setup_digital_mapping(self, dig_type, dig_in_names=None, shell=False): '''sets up dig_in_mapping dataframe and queries user to fill in columns Parameters ---------- dig_in_names : list of str (optional) shell : bool (optional) True for command-line interface False (default) for GUI ''' rec_info = self.rec_info df = pd.DataFrame() df['channel'] = rec_info.get('dig_%s' % dig_type) n_dig_in = len(df) # Names if dig_in_names: df['name'] = dig_in_names else: df['name'] = '' # Parameters to query if dig_type == 'in': df['palatability_rank'] = 0 df['laser'] = False df['spike_array'] = True df['exclude'] = False # Re-format for query idx = df.index df.index = ['dig_%s_%i' % (dig_type, x) for x in df.channel] dig_str = dig_type + 'put' # Query for user input prompt = ('Digital %s Parameters\nSet palatability ranks from 1 to %i' '\nor blank to exclude from pal_id analysis') % (dig_str, len(df)) tmp = userIO.fill_dict(df.to_dict(), prompt=prompt, shell=shell) # Reformat for storage df2 = pd.DataFrame.from_dict(tmp) df2 = df2.sort_values(by=['channel']) df2.index = idx if dig_type == 'in': df2['palatability_rank'] = df2['palatability_rank'].fillna(-1).astype('int') if dig_type == 'in': self.dig_in_mapping = df2.copy() else: self.dig_out_mapping = df2.copy() if os.path.isfile(self.h5_file): dio.h5io.write_digital_map_to_h5(self.h5_file, self.dig_in_mapping, dig_type) def _setup_channel_mapping(self, ports, channels, emg_port, emg_channels, shell=False): '''Creates electrode_mapping and emg_mapping DataFrames with columns: - Electrode - Port - Channel Parameters ---------- ports : list of str, item corresponing to each channel channels : list of int, channels on each port emg_port : str emg_channels : list of int ''' if emg_port is None: q = userIO.ask_user('Do you have an EMG?', shell=shell) if q==1: emg_port = userIO.select_from_list('Select EMG Port:', ports, 'EMG Port', shell=shell) emg_channels = userIO.select_from_list( 'Select EMG Channels:', [y for x, y in zip(ports, channels) if x == emg_port], title='EMG Channels', multi_select=True, shell=shell) el_map, em_map = dio.params.flatten_channels(ports, channels, emg_port, emg_channels) self.electrode_mapping = el_map self.emg_mapping = em_map if os.path.isfile(self.h5_file): dio.h5io.write_electrode_map_to_h5(self.h5_file, self.electrode_mapping) def edit_spike_array_params(self, shell=False): '''Edit spike array parameters and adjust dig_in_mapping accordingly Parameters ---------- shell : bool, whether to use CLI or GUI ''' if not hasattr(self, 'dig_in_mapping'): self.spike_array_params = None return sa = deepcopy(self.spike_array_params) tmp = userIO.fill_dict(sa, 'Spike Array Parameters\n(Times in ms)', shell=shell) if tmp is None: return dim = self.dig_in_mapping dim['spike_array'] = False if tmp['dig_ins_to_use'] != ['']: tmp['dig_ins_to_use'] = [int(x) for x in tmp['dig_ins_to_use']] dim.loc[[x in tmp['dig_ins_to_use'] for x in dim.channel], 'spike_array'] = True dim['laser_channels'] = False if tmp['laser_channels'] != ['']: tmp['laser_channels'] = [int(x) for x in tmp['laser_channels']] dim.loc[[x in tmp['laser_channels'] for x in dim.channel], 'laser'] = True self.spike_array_params = tmp.copy() wt.write_params_to_json('spike_array_params', self.root_dir, tmp) def edit_clustering_params(self, shell=False): '''Allows user interface for editing clustering parameters Parameters ---------- shell : bool (optional) True if you want command-line interface, False for GUI (default) ''' tmp = userIO.fill_dict(self.clustering_params, 'Clustering Parameters\n(Times in ms)', shell=shell) if tmp: self.clustering_params = tmp wt.write_params_to_json('clustering_params', self.root_dir, tmp) def edit_psth_params(self, shell=False): '''Allows user interface for editing psth parameters Parameters ---------- shell : bool (optional) True if you want command-line interface, False for GUI (default) ''' tmp = userIO.fill_dict(self.psth_params, 'PSTH Parameters\n(Times in ms)', shell=shell) if tmp: self.psth_params = tmp wt.write_params_to_json('psth_params', self.root_dir, tmp) def edit_pal_id_params(self, shell=False): '''Allows user interface for editing palatability/identity parameters Parameters ---------- shell : bool (optional) True if you want command-line interface, False for GUI (default) ''' tmp = userIO.fill_dict(self.pal_id_params, 'Palatability/Identity Parameters\n(Times in ms)', shell=shell) if tmp: self.pal_id_params = tmp wt.write_params_to_json('pal_id_params', self.root_dir, tmp) def __str__(self): '''Put all information about dataset in string format Returns ------- str : representation of dataset object ''' out1 = super().__str__() out = [out1] out.append('\nObject creation date: ' + self.dataset_creation_date.strftime('%m/%d/%y')) if hasattr(self, 'raw_h5_file'): out.append('Deleted Raw h5 file: '+self.raw_h5_file) out.append('h5 File: '+self.h5_file) out.append('') out.append('--------------------') out.append('Processing Status') out.append('--------------------') out.append(pt.print_dict(self.process_status)) out.append('') if not hasattr(self, 'rec_info'): return '\n'.join(out) info = self.rec_info out.append('--------------------') out.append('Recording Info') out.append('--------------------') out.append(pt.print_dict(self.rec_info)) out.append('') out.append('--------------------') out.append('Electrodes') out.append('--------------------') out.append(pt.print_dataframe(self.electrode_mapping)) out.append('') if hasattr(self, 'CAR_electrodes'): out.append('--------------------') out.append('CAR Groups') out.append('--------------------') headers = ['Group %i' % x for x in range(len(self.CAR_electrodes))] out.append(pt.print_list_table(self.CAR_electrodes, headers)) out.append('') if not self.emg_mapping.empty: out.append('--------------------') out.append('EMG') out.append('--------------------') out.append(pt.print_dataframe(self.emg_mapping)) out.append('') if info.get('dig_in'): out.append('--------------------') out.append('Digital Input') out.append('--------------------') out.append(pt.print_dataframe(self.dig_in_mapping)) out.append('') if info.get('dig_out'): out.append('--------------------') out.append('Digital Output') out.append('--------------------') out.append(pt.print_dataframe(self.dig_out_mapping)) out.append('') out.append('--------------------') out.append('Clustering Parameters') out.append('--------------------') out.append(pt.print_dict(self.clustering_params)) out.append('') out.append('--------------------') out.append('Spike Array Parameters') out.append('--------------------') out.append(pt.print_dict(self.spike_array_params)) out.append('') out.append('--------------------') out.append('PSTH Parameters') out.append('--------------------') out.append(pt.print_dict(self.psth_params)) out.append('') out.append('--------------------') out.append('Palatability/Identity Parameters') out.append('--------------------') out.append(pt.print_dict(self.pal_id_params)) out.append('') return '\n'.join(out) @Logger('Writing parameters to JSON') def _write_all_params_to_json(self): '''Writes all parameters to json files in analysis_params folder in the recording directory ''' print('Writing all parameters to json file in analysis_params folder...') clustering_params = self.clustering_params spike_array_params = self.spike_array_params psth_params = self.psth_params pal_id_params = self.pal_id_params CAR_params = self.CAR_electrodes rec_dir = self.root_dir wt.write_params_to_json('clustering_params', rec_dir, clustering_params) wt.write_params_to_json('spike_array_params', rec_dir, spike_array_params) wt.write_params_to_json('psth_params', rec_dir, psth_params) wt.write_params_to_json('pal_id_params', rec_dir, pal_id_params) wt.write_params_to_json('CAR_params', rec_dir, CAR_params) @Logger('Extracting Data') def extract_data(self, filename=None, shell=False): '''Create hdf5 store for data and read in Intan .dat files. Also create subfolders for processing outputs Parameters ---------- data_quality: {'clean', 'noisy'} (optional) Specifies quality of data for default clustering parameters associated. Should generally first process with clean (default) parameters and then try noisy after running blech_clust and checking if too many electrodes as cutoff too early ''' if self.rec_info['file_type'] is None: raise ValueError('Unsupported recording type. Cannot extract yet.') if filename is None: filename = self.h5_file print('\nExtract Intan Data\n--------------------') # Create h5 file tmp = dio.h5io.create_empty_data_h5(filename, shell) if tmp is None: return # Create arrays for raw data in hdf5 store dio.h5io.create_hdf_arrays(filename, self.rec_info, self.electrode_mapping, self.emg_mapping) # Read in data to arrays dio.h5io.read_files_into_arrays(filename, self.rec_info, self.electrode_mapping, self.emg_mapping) # Write electrode and digital input mapping into h5 file # TODO: write EMG and digital output mapping into h5 file dio.h5io.write_electrode_map_to_h5(self.h5_file, self.electrode_mapping) if self.dig_in_mapping is not None: dio.h5io.write_digital_map_to_h5(self.h5_file, self.dig_in_mapping, 'in') if self.dig_out_mapping is not None: dio.h5io.write_digital_map_to_h5(self.h5_file, self.dig_in_mapping, 'out') # update status self.h5_file = filename self.process_status['extract_data'] = True self.save() print('\nData Extraction Complete\n--------------------') @Logger('Creating Trial List') def create_trial_list(self): '''Create lists of trials based on digital inputs and outputs and store to hdf5 store Can only be run after data extraction ''' if self.rec_info.get('dig_in'): in_list = dio.h5io.create_trial_data_table( self.h5_file, self.dig_in_mapping, self.sampling_rate, 'in') self.dig_in_trials = in_list else: print('No digital input data found') if self.rec_info.get('dig_out'): out_list = dio.h5io.create_trial_data_table( self.h5_file, self.dig_out_mapping, self.sampling_rate, 'out') self.dig_out_trials = out_list else: print('No digital output data found') self.process_status['create_trial_list'] = True self.save() @Logger('Marking Dead Channels') def mark_dead_channels(self, dead_channels=None, shell=False): '''Plots small piece of raw traces and a metric to help identify dead channels. Once user marks channels as dead a new column is added to electrode mapping Parameters ---------- dead_channels : list of int, optional if this is specified then nothing is plotted, those channels are simply marked as dead shell : bool, optional ''' print('Marking dead channels\n----------') em = self.electrode_mapping.copy() if dead_channels is None: userIO.tell_user('Making traces figure for dead channel detection...', shell=True) save_file = os.path.join(self.root_dir, 'Electrode_Traces.png') fig, ax = datplt.plot_traces_and_outliers(self.h5_file, save_file=save_file) if not shell: # Better to open figure outside of python since its a lot of # data on figure and matplotlib is slow subprocess.call(['xdg-open', save_file]) else: userIO.tell_user('Saved figure of traces to %s for reference' % save_file, shell=shell) choice = userIO.select_from_list('Select dead channels:', em.Electrode.to_list(), 'Dead Channel Selection', multi_select=True, shell=shell) dead_channels = list(map(int, choice)) print('Marking eletrodes %s as dead.\n' 'They will be excluded from common average referencing.' % dead_channels) em['dead'] = False em.loc[dead_channels, 'dead'] = True self.electrode_mapping = em if os.path.isfile(self.h5_file): dio.h5io.write_electrode_map_to_h5(self.h5_file, self.electrode_mapping) self.process_status['mark_dead_channels'] = True self.save() return dead_channels @Logger('Common Average Referencing') def common_average_reference(self): '''Define electrode groups and remove common average from signals Parameters ---------- num_groups : int (optional) number of CAR groups, if not provided there's a prompt ''' if not hasattr(self, 'CAR_electrodes'): raise ValueError('CAR_electrodes not set') if not hasattr(self, 'electrode_mapping'): raise ValueError('electrode_mapping not set') car_electrodes = self.CAR_electrodes num_groups = len(car_electrodes) em = self.electrode_mapping.copy() if 'dead' in em.columns: dead_electrodes = em.Electrode[em.dead].to_list() else: dead_electrodes = [] # Gather Common Average Reference Groups print('CAR Groups\n') headers = ['Group %i' % x for x in range(num_groups)] print(pt.print_list_table(car_electrodes, headers)) # Reference each group for i, x in enumerate(car_electrodes): tmp = list(set(x) - set(dead_electrodes)) dio.h5io.common_avg_reference(self.h5_file, tmp, i) # Compress and repack file dio.h5io.compress_and_repack(self.h5_file) self.process_status['common_average_reference'] = True self.save() @Logger('Running Spike Detection') def detect_spikes(self, data_quality=None, multi_process=True, n_cores=None): '''Run spike detection on each electrode. Prepares for clustering with BlechClust. Works for both single recording clustering or multi-recording clustering Parameters ---------- data_quality : {'clean', 'noisy', None (default)} set if you want to change the data quality parameters for cutoff and spike detection before running clustering. These parameters are automatically set as "clean" during initial parameter setup n_cores : int (optional) number of cores to use for parallel processing. default is max-1. ''' if data_quality: tmp = dio.params.load_params('clustering_params', self.root_dir, default_keyword=data_quality) if tmp: self.clustering_params = tmp wt.write_params_to_json('clustering_params', self.root_dir, tmp) else: raise ValueError('%s is not a valid data_quality preset. Must ' 'be "clean" or "noisy" or None.') print('\nRunning Spike Detection\n-------------------') print('Parameters\n%s' % pt.print_dict(self.clustering_params)) # Create folders for saving things within recording dir data_dir = self.root_dir em = self.electrode_mapping if 'dead' in em.columns: electrodes = em.Electrode[em['dead'] == False].tolist() else: electrodes = em.Electrode.tolist() pbar = tqdm(total = len(electrodes)) results = [(None, None, None)] * (max(electrodes)+1) def update_pbar(ans): if isinstance(ans, tuple) and ans[0] is not None: results[ans[0]] = ans else: print('Unexpected error when clustering an electrode') pbar.update() spike_detectors = [clust.SpikeDetection(data_dir, x, self.clustering_params) for x in electrodes] if multi_process: if n_cores is None or n_cores > multiprocessing.cpu_count(): n_cores = multiprocessing.cpu_count() - 1 pool = multiprocessing.get_context('spawn').Pool(n_cores) for sd in spike_detectors: pool.apply_async(sd.run, callback=update_pbar) pool.close() pool.join() else: for sd in spike_detectors: res = sd.run() update_pbar(res) pbar.close() print('Electrode Result Cutoff (s)') cutoffs = {} clust_res = {} clustered = [] for x, y, z in results: if x is None: continue clustered.append(x) print(' {:<13}{:<10}{}'.format(x, y, z)) cutoffs[x] = z clust_res[x] = y print('1 - Sucess\n0 - No data or no spikes\n-1 - Error') em = self.electrode_mapping.copy() em['cutoff_time'] = em['Electrode'].map(cutoffs) em['clustering_result'] = em['Electrode'].map(clust_res) self.electrode_mapping = em.copy() self.process_status['spike_detection'] = True dio.h5io.write_electrode_map_to_h5(self.h5_file, em) self.save() print('Spike Detection Complete\n------------------') return results @Logger('Running Blech Clust') def blech_clust_run(self, data_quality=None, multi_process=True, n_cores=None, umap=False): '''Write clustering parameters to file and Run blech_process on each electrode using GNU parallel Parameters ---------- data_quality : {'clean', 'noisy', None (default)} set if you want to change the data quality parameters for cutoff and spike detection before running clustering. These parameters are automatically set as "clean" during initial parameter setup accept_params : bool, False (default) set to True in order to skip popup confirmation of parameters when running ''' if self.process_status['spike_detection'] == False: raise FileNotFoundError('Must run spike detection before clustering.') if data_quality: tmp = dio.params.load_params('clustering_params', self.root_dir, default_keyword=data_quality) if tmp: self.clustering_params = tmp wt.write_params_to_json('clustering_params', self.root_dir, tmp) else: raise ValueError('%s is not a valid data_quality preset. Must ' 'be "clean" or "noisy" or None.') print('\nRunning Blech Clust\n-------------------') print('Parameters\n%s' % pt.print_dict(self.clustering_params)) # Get electrodes, throw out 'dead' electrodes em = self.electrode_mapping if 'dead' in em.columns: electrodes = em.Electrode[em['dead'] == False].tolist() else: electrodes = em.Electrode.tolist() pbar = tqdm(total = len(electrodes)) def update_pbar(ans): pbar.update() errors = [] def error_call(e): errors.append(e) if not umap: clust_objs = [clust.BlechClust(self.root_dir, x, params=self.clustering_params) for x in electrodes] else: clust_objs = [clust.BlechClust(self.root_dir, x, params=self.clustering_params, data_transform=clust.UMAP_METRICS, n_pc=5) for x in electrodes] if multi_process: if n_cores is None or n_cores > multiprocessing.cpu_count(): n_cores = multiprocessing.cpu_count() - 1 pool = multiprocessing.get_context('spawn').Pool(n_cores) for x in clust_objs: pool.apply_async(x.run, callback=update_pbar, error_callback=error_call) pool.close() pool.join() else: for x in clust_objs: res = x.run() update_pbar(res) pbar.close() self.process_status['spike_clustering'] = True self.process_status['cleanup_clustering'] = False dio.h5io.write_electrode_map_to_h5(self.h5_file, em) self.save() print('Clustering Complete\n------------------') if len(errors) > 0: print('Errors encountered:') print(errors) @Logger('Cleaning up clustering memory logs. Removing raw data and setting' 'up hdf5 for unit sorting') def cleanup_clustering(self): '''Consolidates memory monitor files, removes raw and referenced data and setups up hdf5 store for sorted units data ''' if self.process_status['cleanup_clustering']: return h5_file = dio.h5io.cleanup_clustering(self.root_dir) self.h5_file = h5_file self.process_status['cleanup_clustering'] = True self.save() def sort_spikes(self, electrode=None, shell=False): if electrode is None: electrode = userIO.get_user_input('Electrode #: ', shell=shell) if electrode is None or not electrode.isnumeric(): return electrode = int(electrode) if not self.process_status['spike_clustering']: raise ValueError('Must run spike clustering first.') if not self.process_status['cleanup_clustering']: self.cleanup_clustering() sorter = clust.SpikeSorter(self.root_dir, electrode=electrode, shell=shell) if not shell: root, sorting_GUI = ssg.launch_sorter_GUI(sorter) return root, sorting_GUI else: # TODO: Make shell UI # TODO: Make sort by table print('No shell UI yet') return self.process_status['sort_units'] = True @Logger('Calculating Units Similarity') def units_similarity(self, similarity_cutoff=50, shell=False): if 'SSH_CONNECTION' in os.environ: shell= True metrics_dir = os.path.join(self.root_dir, 'sorted_unit_metrics') if not os.path.isdir(metrics_dir): raise ValueError('No sorted unit metrics found. Must sort units before calculating similarity') violation_file = os.path.join(metrics_dir, 'units_similarity_violations.txt') violations, sim = ss.calc_units_similarity(self.h5_file, self.sampling_rate, similarity_cutoff, violation_file) if len(violations) == 0: userIO.tell_user('No similarity violations found!', shell=shell) self.process_status['units_similarity'] = True return violations, sim out_str = ['Units Similarity Violations Found:'] out_str.append('Unit_1 Unit_2 Similarity') for x,y in violations: u1 = dio.h5io.parse_unit_number(x) u2 = dio.h5io.parse_unit_number(y) out_str.append(' {:<10}{:<10}{}\n'.format(x, y, sim[u1][u2])) out_str.append('Delete units with dataset.delete_unit(N)') out_str = '\n'.join(out_str) userIO.tell_user(out_str, shell=shell) self.process_status['units_similarity'] = True self.save() return violations, sim @Logger('Deleting Unit') def delete_unit(self, unit_num, confirm=False, shell=False): if isinstance(unit_num, str): unit_num = dio.h5io.parse_unit_number(unit_num) if unit_num is None: print('No unit deleted') return if not confirm: q = userIO.ask_user('Are you sure you want to delete unit%03i?' % unit_num, choices = ['No','Yes'], shell=shell) else: q = 1 if q == 0: print('No unit deleted') return else: tmp = dio.h5io.delete_unit(self.root_dir, unit_num) if tmp is False: userIO.tell_user('Unit %i not found in dataset. No unit deleted' % unit_num, shell=shell) else: userIO.tell_user('Unit %i sucessfully deleted.' % unit_num, shell=shell) self.save() @Logger('Making Unit Arrays') def make_unit_arrays(self): '''Make spike arrays for each unit and store in hdf5 store ''' params = self.spike_array_params print('Generating unit arrays with parameters:\n----------') print(pt.print_dict(params, tabs=1)) ss.make_spike_arrays(self.h5_file, params) self.process_status['make_unit_arrays'] = True self.save() @Logger('Making Unit Plots') def make_unit_plots(self): '''Make waveform plots for each sorted unit ''' unit_table = self.get_unit_table() save_dir = os.path.join(self.root_dir, 'unit_waveforms_plots') if os.path.isdir(save_dir): shutil.rmtree(save_dir) os.mkdir(save_dir) for i, row in unit_table.iterrows(): datplt.make_unit_plots(self.root_dir, row['unit_name'], save_dir=save_dir) self.process_status['make_unit_plots'] = True self.save() @Logger('Making PSTH Arrays') def make_psth_arrays(self): '''Make smoothed firing rate traces for each unit/trial and store in hdf5 store ''' params = self.psth_params dig_ins = self.dig_in_mapping.query('spike_array == True') for idx, row in dig_ins.iterrows(): spike_analysis.make_psths_for_tastant(self.h5_file, params['window_size'], params['window_step'], row['channel']) self.process_status['make_psth_arrays'] = True self.save() @Logger('Calculating Palatability/Identity Metrics') def palatability_calculate(self, shell=False): pal_analysis.palatability_identity_calculations(self.root_dir, params=self.pal_id_params) self.process_status['palatability_calculate'] = True self.save() @Logger('Plotting Palatability/Identity Metrics') def palatability_plot(self, shell=False): pal_plt.plot_palatability_identity([self.root_dir], shell=shell) self.process_status['palatability_plot'] = True self.save() @Logger('Removing low-spiking units') def cleanup_lowSpiking_units(self, min_spikes=100): unit_table = self.get_unit_table() remove = [] spike_count = [] for unit in unit_table['unit_num']: waves, descrip, fs = dio.h5io.get_unit_waveforms(self.root_dir, unit) if waves.shape[0] < min_spikes: spike_count.append(waves.shape[0]) remove.append(unit) for unit, count in zip(reversed(remove), reversed(spike_count)): print('Removing unit %i. Only %i spikes.' % (unit, count)) userIO.tell_user('Removing unit %i. Only %i spikes.' % (unit, count), shell=True) self.delete_unit(unit, confirm=True, shell=True) userIO.tell_user('Removed %i units for having less than %i spikes.' % (len(remove), min_spikes), shell=True) def get_unit_table(self): '''Returns a pandas dataframe with sorted unit information Returns -------- pandas.DataFrame with columns: unit_name, unit_num, electrode, single_unit, regular_spiking, fast_spiking ''' unit_table = dio.h5io.get_unit_table(self.root_dir) return unit_table def circus_clust_run(self, shell=False): circ.prep_for_circus(self.root_dir, self.electrode_mapping, self.data_name, self.sampling_rate) circ.start_the_show() def pre_process_for_clustering(self, shell=False, dead_channels=None): status = self.process_status if not status['initialize parameters']: self.initParams(shell=shell) if not status['extract_data']: self.extract_data(shell=True) if not status['create_trial_list']: self.create_trial_list() if not status['mark_dead_channels'] and dead_channels != False: self.mark_dead_channels(dead_channels=dead_channels, shell=shell) if not status['common_average_reference']: self.common_average_reference() if not status['spike_detection']: self.detect_spikes() def extract_and_circus_cluster(self, dead_channels=None, shell=True): print('Extracting Data...') self.extract_data() print('Marking dead channels...') self.mark_dead_channels(dead_channels, shell=shell) print('Common average referencing...') self.common_average_reference() print('Initiating circus clustering...') circus = circ.circus_clust(self.root_dir, self.data_name, self.sampling_rate, self.electrode_mapping) print('Preparing for circus...') circus.prep_for_circus() print('Starting circus clustering...') circus.start_the_show() print('Plotting cluster waveforms...') circus.plot_cluster_waveforms() def post_sorting(self): self.make_unit_plots() self.make_unit_arrays() self.units_similarity(shell=True) self.make_psth_arrays()AncestorsClass variables- var PROCESSING_STEPS
 Methods- def blech_clust_run(self, data_quality=None, multi_process=True, n_cores=None, umap=False)
- 
Write clustering parameters to file and Run blech_process on each electrode using GNU parallel Parameters- data_quality:- {'clean', 'noisy', None (default)}
- set if you want to change the data quality parameters for cutoff and spike detection before running clustering. These parameters are automatically set as "clean" during initial parameter setup
- accept_params:- bool, False (default)
- set to True in order to skip popup confirmation of parameters when running
 Expand source code@Logger('Running Blech Clust') def blech_clust_run(self, data_quality=None, multi_process=True, n_cores=None, umap=False): '''Write clustering parameters to file and Run blech_process on each electrode using GNU parallel Parameters ---------- data_quality : {'clean', 'noisy', None (default)} set if you want to change the data quality parameters for cutoff and spike detection before running clustering. These parameters are automatically set as "clean" during initial parameter setup accept_params : bool, False (default) set to True in order to skip popup confirmation of parameters when running ''' if self.process_status['spike_detection'] == False: raise FileNotFoundError('Must run spike detection before clustering.') if data_quality: tmp = dio.params.load_params('clustering_params', self.root_dir, default_keyword=data_quality) if tmp: self.clustering_params = tmp wt.write_params_to_json('clustering_params', self.root_dir, tmp) else: raise ValueError('%s is not a valid data_quality preset. Must ' 'be "clean" or "noisy" or None.') print('\nRunning Blech Clust\n-------------------') print('Parameters\n%s' % pt.print_dict(self.clustering_params)) # Get electrodes, throw out 'dead' electrodes em = self.electrode_mapping if 'dead' in em.columns: electrodes = em.Electrode[em['dead'] == False].tolist() else: electrodes = em.Electrode.tolist() pbar = tqdm(total = len(electrodes)) def update_pbar(ans): pbar.update() errors = [] def error_call(e): errors.append(e) if not umap: clust_objs = [clust.BlechClust(self.root_dir, x, params=self.clustering_params) for x in electrodes] else: clust_objs = [clust.BlechClust(self.root_dir, x, params=self.clustering_params, data_transform=clust.UMAP_METRICS, n_pc=5) for x in electrodes] if multi_process: if n_cores is None or n_cores > multiprocessing.cpu_count(): n_cores = multiprocessing.cpu_count() - 1 pool = multiprocessing.get_context('spawn').Pool(n_cores) for x in clust_objs: pool.apply_async(x.run, callback=update_pbar, error_callback=error_call) pool.close() pool.join() else: for x in clust_objs: res = x.run() update_pbar(res) pbar.close() self.process_status['spike_clustering'] = True self.process_status['cleanup_clustering'] = False dio.h5io.write_electrode_map_to_h5(self.h5_file, em) self.save() print('Clustering Complete\n------------------') if len(errors) > 0: print('Errors encountered:') print(errors)
- def circus_clust_run(self, shell=False)
- 
Expand source codedef circus_clust_run(self, shell=False): circ.prep_for_circus(self.root_dir, self.electrode_mapping, self.data_name, self.sampling_rate) circ.start_the_show()
- def cleanup_clustering(self)
- 
Consolidates memory monitor files, removes raw and referenced data and setups up hdf5 store for sorted units data Expand source code@Logger('Cleaning up clustering memory logs. Removing raw data and setting' 'up hdf5 for unit sorting') def cleanup_clustering(self): '''Consolidates memory monitor files, removes raw and referenced data and setups up hdf5 store for sorted units data ''' if self.process_status['cleanup_clustering']: return h5_file = dio.h5io.cleanup_clustering(self.root_dir) self.h5_file = h5_file self.process_status['cleanup_clustering'] = True self.save()
- def cleanup_lowSpiking_units(self, min_spikes=100)
- 
Expand source code@Logger('Removing low-spiking units') def cleanup_lowSpiking_units(self, min_spikes=100): unit_table = self.get_unit_table() remove = [] spike_count = [] for unit in unit_table['unit_num']: waves, descrip, fs = dio.h5io.get_unit_waveforms(self.root_dir, unit) if waves.shape[0] < min_spikes: spike_count.append(waves.shape[0]) remove.append(unit) for unit, count in zip(reversed(remove), reversed(spike_count)): print('Removing unit %i. Only %i spikes.' % (unit, count)) userIO.tell_user('Removing unit %i. Only %i spikes.' % (unit, count), shell=True) self.delete_unit(unit, confirm=True, shell=True) userIO.tell_user('Removed %i units for having less than %i spikes.' % (len(remove), min_spikes), shell=True)
- def common_average_reference(self)
- 
Define electrode groups and remove common average from signals Parameters- num_groups:- int (optional)
- number of CAR groups, if not provided there's a prompt
 Expand source code@Logger('Common Average Referencing') def common_average_reference(self): '''Define electrode groups and remove common average from signals Parameters ---------- num_groups : int (optional) number of CAR groups, if not provided there's a prompt ''' if not hasattr(self, 'CAR_electrodes'): raise ValueError('CAR_electrodes not set') if not hasattr(self, 'electrode_mapping'): raise ValueError('electrode_mapping not set') car_electrodes = self.CAR_electrodes num_groups = len(car_electrodes) em = self.electrode_mapping.copy() if 'dead' in em.columns: dead_electrodes = em.Electrode[em.dead].to_list() else: dead_electrodes = [] # Gather Common Average Reference Groups print('CAR Groups\n') headers = ['Group %i' % x for x in range(num_groups)] print(pt.print_list_table(car_electrodes, headers)) # Reference each group for i, x in enumerate(car_electrodes): tmp = list(set(x) - set(dead_electrodes)) dio.h5io.common_avg_reference(self.h5_file, tmp, i) # Compress and repack file dio.h5io.compress_and_repack(self.h5_file) self.process_status['common_average_reference'] = True self.save()
- def create_trial_list(self)
- 
Create lists of trials based on digital inputs and outputs and store to hdf5 store Can only be run after data extraction Expand source code@Logger('Creating Trial List') def create_trial_list(self): '''Create lists of trials based on digital inputs and outputs and store to hdf5 store Can only be run after data extraction ''' if self.rec_info.get('dig_in'): in_list = dio.h5io.create_trial_data_table( self.h5_file, self.dig_in_mapping, self.sampling_rate, 'in') self.dig_in_trials = in_list else: print('No digital input data found') if self.rec_info.get('dig_out'): out_list = dio.h5io.create_trial_data_table( self.h5_file, self.dig_out_mapping, self.sampling_rate, 'out') self.dig_out_trials = out_list else: print('No digital output data found') self.process_status['create_trial_list'] = True self.save()
- def delete_unit(self, unit_num, confirm=False, shell=False)
- 
Expand source code@Logger('Deleting Unit') def delete_unit(self, unit_num, confirm=False, shell=False): if isinstance(unit_num, str): unit_num = dio.h5io.parse_unit_number(unit_num) if unit_num is None: print('No unit deleted') return if not confirm: q = userIO.ask_user('Are you sure you want to delete unit%03i?' % unit_num, choices = ['No','Yes'], shell=shell) else: q = 1 if q == 0: print('No unit deleted') return else: tmp = dio.h5io.delete_unit(self.root_dir, unit_num) if tmp is False: userIO.tell_user('Unit %i not found in dataset. No unit deleted' % unit_num, shell=shell) else: userIO.tell_user('Unit %i sucessfully deleted.' % unit_num, shell=shell) self.save()
- def detect_spikes(self, data_quality=None, multi_process=True, n_cores=None)
- 
Run spike detection on each electrode. Prepares for clustering with BlechClust. Works for both single recording clustering or multi-recording clustering Parameters- data_quality:- {'clean', 'noisy', None (default)}
- set if you want to change the data quality parameters for cutoff and spike detection before running clustering. These parameters are automatically set as "clean" during initial parameter setup
- n_cores:- int (optional)
- number of cores to use for parallel processing. default is max-1.
 Expand source code@Logger('Running Spike Detection') def detect_spikes(self, data_quality=None, multi_process=True, n_cores=None): '''Run spike detection on each electrode. Prepares for clustering with BlechClust. Works for both single recording clustering or multi-recording clustering Parameters ---------- data_quality : {'clean', 'noisy', None (default)} set if you want to change the data quality parameters for cutoff and spike detection before running clustering. These parameters are automatically set as "clean" during initial parameter setup n_cores : int (optional) number of cores to use for parallel processing. default is max-1. ''' if data_quality: tmp = dio.params.load_params('clustering_params', self.root_dir, default_keyword=data_quality) if tmp: self.clustering_params = tmp wt.write_params_to_json('clustering_params', self.root_dir, tmp) else: raise ValueError('%s is not a valid data_quality preset. Must ' 'be "clean" or "noisy" or None.') print('\nRunning Spike Detection\n-------------------') print('Parameters\n%s' % pt.print_dict(self.clustering_params)) # Create folders for saving things within recording dir data_dir = self.root_dir em = self.electrode_mapping if 'dead' in em.columns: electrodes = em.Electrode[em['dead'] == False].tolist() else: electrodes = em.Electrode.tolist() pbar = tqdm(total = len(electrodes)) results = [(None, None, None)] * (max(electrodes)+1) def update_pbar(ans): if isinstance(ans, tuple) and ans[0] is not None: results[ans[0]] = ans else: print('Unexpected error when clustering an electrode') pbar.update() spike_detectors = [clust.SpikeDetection(data_dir, x, self.clustering_params) for x in electrodes] if multi_process: if n_cores is None or n_cores > multiprocessing.cpu_count(): n_cores = multiprocessing.cpu_count() - 1 pool = multiprocessing.get_context('spawn').Pool(n_cores) for sd in spike_detectors: pool.apply_async(sd.run, callback=update_pbar) pool.close() pool.join() else: for sd in spike_detectors: res = sd.run() update_pbar(res) pbar.close() print('Electrode Result Cutoff (s)') cutoffs = {} clust_res = {} clustered = [] for x, y, z in results: if x is None: continue clustered.append(x) print(' {:<13}{:<10}{}'.format(x, y, z)) cutoffs[x] = z clust_res[x] = y print('1 - Sucess\n0 - No data or no spikes\n-1 - Error') em = self.electrode_mapping.copy() em['cutoff_time'] = em['Electrode'].map(cutoffs) em['clustering_result'] = em['Electrode'].map(clust_res) self.electrode_mapping = em.copy() self.process_status['spike_detection'] = True dio.h5io.write_electrode_map_to_h5(self.h5_file, em) self.save() print('Spike Detection Complete\n------------------') return results
- def edit_clustering_params(self, shell=False)
- 
Allows user interface for editing clustering parameters Parameters- shell:- bool (optional)
- True if you want command-line interface, False for GUI (default)
 Expand source codedef edit_clustering_params(self, shell=False): '''Allows user interface for editing clustering parameters Parameters ---------- shell : bool (optional) True if you want command-line interface, False for GUI (default) ''' tmp = userIO.fill_dict(self.clustering_params, 'Clustering Parameters\n(Times in ms)', shell=shell) if tmp: self.clustering_params = tmp wt.write_params_to_json('clustering_params', self.root_dir, tmp)
- def edit_pal_id_params(self, shell=False)
- 
Allows user interface for editing palatability/identity parameters Parameters- shell:- bool (optional)
- True if you want command-line interface, False for GUI (default)
 Expand source codedef edit_pal_id_params(self, shell=False): '''Allows user interface for editing palatability/identity parameters Parameters ---------- shell : bool (optional) True if you want command-line interface, False for GUI (default) ''' tmp = userIO.fill_dict(self.pal_id_params, 'Palatability/Identity Parameters\n(Times in ms)', shell=shell) if tmp: self.pal_id_params = tmp wt.write_params_to_json('pal_id_params', self.root_dir, tmp)
- def edit_psth_params(self, shell=False)
- 
Allows user interface for editing psth parameters Parameters- shell:- bool (optional)
- True if you want command-line interface, False for GUI (default)
 Expand source codedef edit_psth_params(self, shell=False): '''Allows user interface for editing psth parameters Parameters ---------- shell : bool (optional) True if you want command-line interface, False for GUI (default) ''' tmp = userIO.fill_dict(self.psth_params, 'PSTH Parameters\n(Times in ms)', shell=shell) if tmp: self.psth_params = tmp wt.write_params_to_json('psth_params', self.root_dir, tmp)
- def edit_spike_array_params(self, shell=False)
- 
Edit spike array parameters and adjust dig_in_mapping accordingly Parameters- shell:- bool, whether to use CLIor- GUI
 Expand source codedef edit_spike_array_params(self, shell=False): '''Edit spike array parameters and adjust dig_in_mapping accordingly Parameters ---------- shell : bool, whether to use CLI or GUI ''' if not hasattr(self, 'dig_in_mapping'): self.spike_array_params = None return sa = deepcopy(self.spike_array_params) tmp = userIO.fill_dict(sa, 'Spike Array Parameters\n(Times in ms)', shell=shell) if tmp is None: return dim = self.dig_in_mapping dim['spike_array'] = False if tmp['dig_ins_to_use'] != ['']: tmp['dig_ins_to_use'] = [int(x) for x in tmp['dig_ins_to_use']] dim.loc[[x in tmp['dig_ins_to_use'] for x in dim.channel], 'spike_array'] = True dim['laser_channels'] = False if tmp['laser_channels'] != ['']: tmp['laser_channels'] = [int(x) for x in tmp['laser_channels']] dim.loc[[x in tmp['laser_channels'] for x in dim.channel], 'laser'] = True self.spike_array_params = tmp.copy() wt.write_params_to_json('spike_array_params', self.root_dir, tmp)
- def extract_and_circus_cluster(self, dead_channels=None, shell=True)
- 
Expand source codedef extract_and_circus_cluster(self, dead_channels=None, shell=True): print('Extracting Data...') self.extract_data() print('Marking dead channels...') self.mark_dead_channels(dead_channels, shell=shell) print('Common average referencing...') self.common_average_reference() print('Initiating circus clustering...') circus = circ.circus_clust(self.root_dir, self.data_name, self.sampling_rate, self.electrode_mapping) print('Preparing for circus...') circus.prep_for_circus() print('Starting circus clustering...') circus.start_the_show() print('Plotting cluster waveforms...') circus.plot_cluster_waveforms()
- def extract_data(self, filename=None, shell=False)
- 
Create hdf5 store for data and read in Intan .dat files. Also create subfolders for processing outputs Parameters- data_quality:- {'clean', 'noisy'} (optional)
- Specifies quality of data for default clustering parameters associated. Should generally first process with clean (default) parameters and then try noisy after running blech_clust and checking if too many electrodes as cutoff too early
 Expand source code@Logger('Extracting Data') def extract_data(self, filename=None, shell=False): '''Create hdf5 store for data and read in Intan .dat files. Also create subfolders for processing outputs Parameters ---------- data_quality: {'clean', 'noisy'} (optional) Specifies quality of data for default clustering parameters associated. Should generally first process with clean (default) parameters and then try noisy after running blech_clust and checking if too many electrodes as cutoff too early ''' if self.rec_info['file_type'] is None: raise ValueError('Unsupported recording type. Cannot extract yet.') if filename is None: filename = self.h5_file print('\nExtract Intan Data\n--------------------') # Create h5 file tmp = dio.h5io.create_empty_data_h5(filename, shell) if tmp is None: return # Create arrays for raw data in hdf5 store dio.h5io.create_hdf_arrays(filename, self.rec_info, self.electrode_mapping, self.emg_mapping) # Read in data to arrays dio.h5io.read_files_into_arrays(filename, self.rec_info, self.electrode_mapping, self.emg_mapping) # Write electrode and digital input mapping into h5 file # TODO: write EMG and digital output mapping into h5 file dio.h5io.write_electrode_map_to_h5(self.h5_file, self.electrode_mapping) if self.dig_in_mapping is not None: dio.h5io.write_digital_map_to_h5(self.h5_file, self.dig_in_mapping, 'in') if self.dig_out_mapping is not None: dio.h5io.write_digital_map_to_h5(self.h5_file, self.dig_in_mapping, 'out') # update status self.h5_file = filename self.process_status['extract_data'] = True self.save() print('\nData Extraction Complete\n--------------------')
- def get_unit_table(self)
- 
Returns a pandas dataframe with sorted unit information Returns- pandas.DataFrame with columns:
- unit_name, unit_num, electrode, single_unit, regular_spiking, fast_spiking
 Expand source codedef get_unit_table(self): '''Returns a pandas dataframe with sorted unit information Returns -------- pandas.DataFrame with columns: unit_name, unit_num, electrode, single_unit, regular_spiking, fast_spiking ''' unit_table = dio.h5io.get_unit_table(self.root_dir) return unit_table
- def initParams(self, data_quality='clean', emg_port=None, emg_channels=None, car_keyword=None, car_group_areas=None, shell=False, dig_in_names=None, dig_out_names=None, accept_params=False)
- 
Initalizes basic default analysis parameters and allows customization of parameters Parameters (all optional)data_quality : {'clean', 'noisy'} keyword defining which default set of parameters to use to detect headstage disconnection during clustering default is 'clean'. Best practice is to run blech_clust as 'clean' and re-run as 'noisy' if too many early cutoffs occurr emg_port : str Port ('A', 'B', 'C') of EMG, if there was an EMG. None (default) will query user. False indicates no EMG port and not to query user emg_channels : list of int channel or channels of EMGs on port specified default is None car_keyword : str Specifes default common average reference groups defaults are found in CAR_defaults.json Currently 'bilateral32' is only keyword available If left as None (default) user will be queries to select common average reference groups shell : bool False (default) for GUI. True for command-line interface dig_in_names : list of str Names of digital inputs. Must match number of digital inputs used in recording. None (default) queries user to name each dig_in dig_out_names : list of str Names of digital outputs. Must match number of digital outputs in recording. None (default) queries user to name each dig_out accept_params : bool True automatically accepts default parameters where possible, decreasing user queries False (default) will query user to confirm or edit parameters for clustering, spike array and psth creation and palatability/identity calculations Expand source code@Logger('Initializing Parameters') def initParams(self, data_quality='clean', emg_port=None, emg_channels=None, car_keyword=None, car_group_areas=None, shell=False, dig_in_names=None, dig_out_names=None, accept_params=False): ''' Initalizes basic default analysis parameters and allows customization of parameters Parameters (all optional) ------------------------- data_quality : {'clean', 'noisy'} keyword defining which default set of parameters to use to detect headstage disconnection during clustering default is 'clean'. Best practice is to run blech_clust as 'clean' and re-run as 'noisy' if too many early cutoffs occurr emg_port : str Port ('A', 'B', 'C') of EMG, if there was an EMG. None (default) will query user. False indicates no EMG port and not to query user emg_channels : list of int channel or channels of EMGs on port specified default is None car_keyword : str Specifes default common average reference groups defaults are found in CAR_defaults.json Currently 'bilateral32' is only keyword available If left as None (default) user will be queries to select common average reference groups shell : bool False (default) for GUI. True for command-line interface dig_in_names : list of str Names of digital inputs. Must match number of digital inputs used in recording. None (default) queries user to name each dig_in dig_out_names : list of str Names of digital outputs. Must match number of digital outputs in recording. None (default) queries user to name each dig_out accept_params : bool True automatically accepts default parameters where possible, decreasing user queries False (default) will query user to confirm or edit parameters for clustering, spike array and psth creation and palatability/identity calculations ''' # Get parameters from info.rhd file_dir = self.root_dir rec_info = dio.rawIO.read_rec_info(file_dir) ports = rec_info.pop('ports') channels = rec_info.pop('channels') sampling_rate = rec_info['amplifier_sampling_rate'] self.rec_info = rec_info self.sampling_rate = sampling_rate # Get default parameters from files clustering_params = dio.params.load_params('clustering_params', file_dir, default_keyword=data_quality) spike_array_params = dio.params.load_params('spike_array_params', file_dir) psth_params = dio.params.load_params('psth_params', file_dir) pal_id_params = dio.params.load_params('pal_id_params', file_dir) spike_array_params['sampling_rate'] = sampling_rate clustering_params['file_dir'] = file_dir clustering_params['sampling_rate'] = sampling_rate # Setup digital input mapping if rec_info.get('dig_in'): self._setup_digital_mapping('in', dig_in_names, shell) dim = self.dig_in_mapping.copy() spike_array_params['laser_channels'] = dim.channel[dim['laser']].to_list() spike_array_params['dig_ins_to_use'] = dim.channel[dim['spike_array']].to_list() else: self.dig_in_mapping = None if rec_info.get('dig_out'): self._setup_digital_mapping('out', dig_out_names, shell) dom = self.dig_out_mapping.copy() else: self.dig_out_mapping = None # Setup electrode and emg mapping self._setup_channel_mapping(ports, channels, emg_port, emg_channels, shell=shell) # Set CAR groups self._set_CAR_groups(group_keyword=car_keyword, group_areas=car_group_areas, shell=shell) # Confirm parameters self.spike_array_params = spike_array_params if not accept_params: conf = userIO.confirm_parameter_dict clustering_params = conf(clustering_params, 'Clustering Parameters', shell=shell) self.edit_spike_array_params(shell=shell) psth_params = conf(psth_params, 'PSTH Parameters', shell=shell) pal_id_params = conf(pal_id_params, 'Palatability/Identity Parameters\n' 'Valid unit_type is Single, Multi or All', shell=shell) # Store parameters self.clustering_params = clustering_params self.pal_id_params = pal_id_params self.psth_params = psth_params self._write_all_params_to_json() self.process_status['initialize parameters'] = True self.save()
- def make_psth_arrays(self)
- 
Make smoothed firing rate traces for each unit/trial and store in hdf5 store Expand source code@Logger('Making PSTH Arrays') def make_psth_arrays(self): '''Make smoothed firing rate traces for each unit/trial and store in hdf5 store ''' params = self.psth_params dig_ins = self.dig_in_mapping.query('spike_array == True') for idx, row in dig_ins.iterrows(): spike_analysis.make_psths_for_tastant(self.h5_file, params['window_size'], params['window_step'], row['channel']) self.process_status['make_psth_arrays'] = True self.save()
- def make_unit_arrays(self)
- 
Make spike arrays for each unit and store in hdf5 store Expand source code@Logger('Making Unit Arrays') def make_unit_arrays(self): '''Make spike arrays for each unit and store in hdf5 store ''' params = self.spike_array_params print('Generating unit arrays with parameters:\n----------') print(pt.print_dict(params, tabs=1)) ss.make_spike_arrays(self.h5_file, params) self.process_status['make_unit_arrays'] = True self.save()
- def make_unit_plots(self)
- 
Make waveform plots for each sorted unit Expand source code@Logger('Making Unit Plots') def make_unit_plots(self): '''Make waveform plots for each sorted unit ''' unit_table = self.get_unit_table() save_dir = os.path.join(self.root_dir, 'unit_waveforms_plots') if os.path.isdir(save_dir): shutil.rmtree(save_dir) os.mkdir(save_dir) for i, row in unit_table.iterrows(): datplt.make_unit_plots(self.root_dir, row['unit_name'], save_dir=save_dir) self.process_status['make_unit_plots'] = True self.save()
- def mark_dead_channels(self, dead_channels=None, shell=False)
- 
Plots small piece of raw traces and a metric to help identify dead channels. Once user marks channels as dead a new column is added to electrode mapping Parameters- dead_channels:- listof- int, optional
- if this is specified then nothing is plotted, those channels are simply marked as dead
- shell:- bool, optional
 Expand source code@Logger('Marking Dead Channels') def mark_dead_channels(self, dead_channels=None, shell=False): '''Plots small piece of raw traces and a metric to help identify dead channels. Once user marks channels as dead a new column is added to electrode mapping Parameters ---------- dead_channels : list of int, optional if this is specified then nothing is plotted, those channels are simply marked as dead shell : bool, optional ''' print('Marking dead channels\n----------') em = self.electrode_mapping.copy() if dead_channels is None: userIO.tell_user('Making traces figure for dead channel detection...', shell=True) save_file = os.path.join(self.root_dir, 'Electrode_Traces.png') fig, ax = datplt.plot_traces_and_outliers(self.h5_file, save_file=save_file) if not shell: # Better to open figure outside of python since its a lot of # data on figure and matplotlib is slow subprocess.call(['xdg-open', save_file]) else: userIO.tell_user('Saved figure of traces to %s for reference' % save_file, shell=shell) choice = userIO.select_from_list('Select dead channels:', em.Electrode.to_list(), 'Dead Channel Selection', multi_select=True, shell=shell) dead_channels = list(map(int, choice)) print('Marking eletrodes %s as dead.\n' 'They will be excluded from common average referencing.' % dead_channels) em['dead'] = False em.loc[dead_channels, 'dead'] = True self.electrode_mapping = em if os.path.isfile(self.h5_file): dio.h5io.write_electrode_map_to_h5(self.h5_file, self.electrode_mapping) self.process_status['mark_dead_channels'] = True self.save() return dead_channels
- def palatability_calculate(self, shell=False)
- 
Expand source code@Logger('Calculating Palatability/Identity Metrics') def palatability_calculate(self, shell=False): pal_analysis.palatability_identity_calculations(self.root_dir, params=self.pal_id_params) self.process_status['palatability_calculate'] = True self.save()
- def palatability_plot(self, shell=False)
- 
Expand source code@Logger('Plotting Palatability/Identity Metrics') def palatability_plot(self, shell=False): pal_plt.plot_palatability_identity([self.root_dir], shell=shell) self.process_status['palatability_plot'] = True self.save()
- def post_sorting(self)
- 
Expand source codedef post_sorting(self): self.make_unit_plots() self.make_unit_arrays() self.units_similarity(shell=True) self.make_psth_arrays()
- def pre_process_for_clustering(self, shell=False, dead_channels=None)
- 
Expand source codedef pre_process_for_clustering(self, shell=False, dead_channels=None): status = self.process_status if not status['initialize parameters']: self.initParams(shell=shell) if not status['extract_data']: self.extract_data(shell=True) if not status['create_trial_list']: self.create_trial_list() if not status['mark_dead_channels'] and dead_channels != False: self.mark_dead_channels(dead_channels=dead_channels, shell=shell) if not status['common_average_reference']: self.common_average_reference() if not status['spike_detection']: self.detect_spikes()
- def set_electrode_areas(self, areas)
- 
sets the electrode area for each CAR group. Parameters- areas:- listof- str
- number of elements must match number of CAR groups
 ThrowsValueError Expand source code@Logger('Re-labelling CAR group areas') def set_electrode_areas(self, areas): '''sets the electrode area for each CAR group. Parameters ---------- areas : list of str number of elements must match number of CAR groups Throws ------ ValueError ''' em = self.electrode_mapping.copy() if len(em['CAR_group'].unique()) != len(areas): raise ValueError('Number of items in areas must match number of CAR groups') em['areas'] = em['CAR_group'].apply(lambda x: areas[int(x)]) self.electrode_mapping = em.copy() dio.h5io.write_electrode_map_to_h5(self.h5_file, self.electrode_mapping) self.save()
- def sort_spikes(self, electrode=None, shell=False)
- 
Expand source codedef sort_spikes(self, electrode=None, shell=False): if electrode is None: electrode = userIO.get_user_input('Electrode #: ', shell=shell) if electrode is None or not electrode.isnumeric(): return electrode = int(electrode) if not self.process_status['spike_clustering']: raise ValueError('Must run spike clustering first.') if not self.process_status['cleanup_clustering']: self.cleanup_clustering() sorter = clust.SpikeSorter(self.root_dir, electrode=electrode, shell=shell) if not shell: root, sorting_GUI = ssg.launch_sorter_GUI(sorter) return root, sorting_GUI else: # TODO: Make shell UI # TODO: Make sort by table print('No shell UI yet') return self.process_status['sort_units'] = True
- def units_similarity(self, similarity_cutoff=50, shell=False)
- 
Expand source code@Logger('Calculating Units Similarity') def units_similarity(self, similarity_cutoff=50, shell=False): if 'SSH_CONNECTION' in os.environ: shell= True metrics_dir = os.path.join(self.root_dir, 'sorted_unit_metrics') if not os.path.isdir(metrics_dir): raise ValueError('No sorted unit metrics found. Must sort units before calculating similarity') violation_file = os.path.join(metrics_dir, 'units_similarity_violations.txt') violations, sim = ss.calc_units_similarity(self.h5_file, self.sampling_rate, similarity_cutoff, violation_file) if len(violations) == 0: userIO.tell_user('No similarity violations found!', shell=shell) self.process_status['units_similarity'] = True return violations, sim out_str = ['Units Similarity Violations Found:'] out_str.append('Unit_1 Unit_2 Similarity') for x,y in violations: u1 = dio.h5io.parse_unit_number(x) u2 = dio.h5io.parse_unit_number(y) out_str.append(' {:<10}{:<10}{}\n'.format(x, y, sim[u1][u2])) out_str.append('Delete units with dataset.delete_unit(N)') out_str = '\n'.join(out_str) userIO.tell_user(out_str, shell=shell) self.process_status['units_similarity'] = True self.save() return violations, sim