Module blechpy.datastructures.dataset

Expand source code
import pandas as pd
import datetime as dt
import pickle
import os
import shutil
import sys
import multiprocessing
import subprocess
from tqdm import tqdm
from copy import deepcopy
from blechpy.utils import print_tools as pt, write_tools as wt, userIO
from blechpy.utils.decorators import Logger
from blechpy.analysis import palatability_analysis as pal_analysis
from blechpy.analysis import spike_sorting as ss, spike_analysis, circus_interface as circ
from blechpy.analysis import blech_clustering as clust
from blechpy.plotting import palatability_plot as pal_plt, data_plot as datplt
from blechpy import dio
from blechpy.datastructures.objects import data_object
from blechpy.utils import spike_sorting_GUI as ssg


class dataset(data_object):
    '''Stores information related to an intan recording directory, allows
    executing basic processing and analysis scripts, and stores parameters data
    for those analyses

    Parameters
    ----------
    file_dir : str (optional)
        absolute path to a recording directory, if left empty a filechooser
        will popup
    '''
    PROCESSING_STEPS = ['initialize parameters',
                        'extract_data', 'create_trial_list',
                        'mark_dead_channels',
                        'common_average_reference', 'spike_detection',
                        'spike_clustering', 'cleanup_clustering',
                        'sort_units', 'make_unit_plots',
                        'units_similarity', 'make_unit_arrays',
                        'make_psth_arrays', 'plot_psths',
                        'palatability_calculate', 'palatability_plot',
                        'overlay_psth']

    def __init__(self, file_dir=None, data_name=None, shell=False):
        '''Initialize dataset object from file_dir, grabs basename from name of
        directory and initializes basic analysis parameters

        Parameters
        ----------
        file_dir : str (optional), file directory for intan recording data

        Throws
        ------
        ValueError
            if file_dir is not provided and no directory is chosen
            when prompted
        NotADirectoryError : if file_dir does not exist
        '''
        super().__init__('dataset', file_dir, data_name=data_name, shell=shell)
        h5_file = dio.h5io.get_h5_filename(self.root_dir)
        if h5_file is None:
            h5_file = os.path.join(self.root_dir, '%s.h5' % self.data_name)

        self.h5_file = h5_file

        self.dataset_creation_date = dt.datetime.today()

        # Outline standard processing pipeline and status check
        self.processing_steps = dataset.PROCESSING_STEPS.copy()
        self.process_status = dict.fromkeys(self.processing_steps, False)

    def _change_root(self, new_root=None):
        old_root = self.root_dir
        new_root = super()._change_root(new_root)
        self.h5_file = self.h5_file.replace(old_root, new_root)
        return new_root

    @Logger('Initializing Parameters')
    def initParams(self, data_quality='clean', emg_port=None,
                   emg_channels=None, car_keyword=None,
                   car_group_areas=None,
                   shell=False, dig_in_names=None,
                   dig_out_names=None, accept_params=False):
        '''
        Initalizes basic default analysis parameters and allows customization
        of parameters

        Parameters (all optional)
        -------------------------
        data_quality : {'clean', 'noisy'}
            keyword defining which default set of parameters to use to detect
            headstage disconnection during clustering
            default is 'clean'. Best practice is to run blech_clust as 'clean'
            and re-run as 'noisy' if too many early cutoffs occurr
        emg_port : str
            Port ('A', 'B', 'C') of EMG, if there was an EMG. None (default)
            will query user. False indicates no EMG port and not to query user
        emg_channels : list of int
            channel or channels of EMGs on port specified
            default is None
        car_keyword : str
            Specifes default common average reference groups
            defaults are found in CAR_defaults.json
            Currently 'bilateral32' is only keyword available
            If left as None (default) user will be queries to select common
            average reference groups
        shell : bool
            False (default) for GUI. True for command-line interface
        dig_in_names : list of str
            Names of digital inputs. Must match number of digital inputs used
            in recording.
            None (default) queries user to name each dig_in
        dig_out_names : list of str
            Names of digital outputs. Must match number of digital outputs in
            recording.
            None (default) queries user to name each dig_out
        accept_params : bool
            True automatically accepts default parameters where possible,
            decreasing user queries
            False (default) will query user to confirm or edit parameters for
            clustering, spike array and psth creation and palatability/identity
            calculations
        '''
        # Get parameters from info.rhd
        file_dir = self.root_dir
        rec_info = dio.rawIO.read_rec_info(file_dir)
        ports = rec_info.pop('ports')
        channels = rec_info.pop('channels')
        sampling_rate = rec_info['amplifier_sampling_rate']
        self.rec_info = rec_info
        self.sampling_rate = sampling_rate

        # Get default parameters from files
        clustering_params = dio.params.load_params('clustering_params', file_dir,
                                                   default_keyword=data_quality)
        spike_array_params = dio.params.load_params('spike_array_params', file_dir)
        psth_params = dio.params.load_params('psth_params', file_dir)
        pal_id_params = dio.params.load_params('pal_id_params', file_dir)
        spike_array_params['sampling_rate'] = sampling_rate
        clustering_params['file_dir'] = file_dir
        clustering_params['sampling_rate'] = sampling_rate

        # Setup digital input mapping
        if rec_info.get('dig_in'):
            self._setup_digital_mapping('in', dig_in_names, shell)
            dim = self.dig_in_mapping.copy()
            spike_array_params['laser_channels'] = dim.channel[dim['laser']].to_list()
            spike_array_params['dig_ins_to_use'] = dim.channel[dim['spike_array']].to_list()
        else:
            self.dig_in_mapping = None

        if rec_info.get('dig_out'):
            self._setup_digital_mapping('out', dig_out_names, shell)
            dom = self.dig_out_mapping.copy()
        else:
            self.dig_out_mapping = None

        # Setup electrode and emg mapping
        self._setup_channel_mapping(ports, channels, emg_port,
                                    emg_channels, shell=shell)

        # Set CAR groups
        self._set_CAR_groups(group_keyword=car_keyword, group_areas=car_group_areas, shell=shell)

        # Confirm parameters
        self.spike_array_params = spike_array_params
        if not accept_params:
            conf = userIO.confirm_parameter_dict
            clustering_params = conf(clustering_params,
                                     'Clustering Parameters', shell=shell)
            self.edit_spike_array_params(shell=shell)
            psth_params = conf(psth_params,
                               'PSTH Parameters', shell=shell)
            pal_id_params = conf(pal_id_params,
                                 'Palatability/Identity Parameters\n'
                                 'Valid unit_type is Single, Multi or All',
                                 shell=shell)

        # Store parameters
        self.clustering_params = clustering_params
        self.pal_id_params = pal_id_params
        self.psth_params = psth_params
        self._write_all_params_to_json()
        self.process_status['initialize parameters'] = True
        self.save()

    def _set_CAR_groups(self, group_keyword=None, shell=False, group_areas=None):
        '''Sets that electrode groups for common average referencing and
        defines which brain region electrodes eneded up in

        Parameters
        ----------
        group_keyword : str or int
            Keyword corresponding to a preset electrode grouping in CAR_params.json
            Or integer indicating number of CAR groups
        shell : bool
            True for command-line interface, False (default) for GUI
        '''
        if not hasattr(self, 'electrode_mapping'):
            raise ValueError('Set electrode mapping before setting CAR groups')

        em = self.electrode_mapping.copy()

        car_param_file = os.path.join(self.root_dir, 'analysis_params',
                                      'CAR_params.json')
        if os.path.isfile(car_param_file):
            tmp = dio.params.load_params('CAR_params', self.root_dir)
            if tmp is not None:
                group_electrodes = tmp
            else:
                raise ValueError('CAR_params file exists in recording dir, but is empty')

        else:
            if group_keyword is None:
                group_keyword = userIO.get_user_input(
                    'Input keyword for CAR parameters or number of CAR groups',
                    shell=shell)

                if group_keyword is None:
                    ValueError('Must provide a keyword or number of groups')

            if group_keyword.isnumeric():
                num_groups = int(group_keyword)
                group_electrodes = dio.params.select_CAR_groups(num_groups, em,
                                                                shell=shell)
            else:
                group_electrodes = dio.params.load_params('CAR_params',
                                                          self.root_dir,
                                                          default_keyword=group_keyword)

        num_groups = len(group_electrodes)
        if group_areas is not None and len(group_areas) == num_groups:
            for i, x in enumerate(zip(group_electrodes, group_areas)):
                em.loc[x[0], 'area'] = x[1]
                em.loc[x[0], 'CAR_group'] = i

        else:
            group_names = ['Group %i' % i for i in range(num_groups)]
            area_dict = dict.fromkeys(group_names, '')
            area_dict = userIO.fill_dict(area_dict, 'Set Areas for CAR groups',
                                         shell=shell)
            for k, v in area_dict.items():
                i = int(k.replace('Group', ''))
                em.loc[group_electrodes[i], 'area'] = v
                em.loc[group_electrodes[i], 'CAR_group'] = i

        self.CAR_electrodes = group_electrodes
        self.electrode_mapping = em.copy()

    @Logger('Re-labelling CAR group areas')
    def set_electrode_areas(self, areas):
        '''sets the electrode area for each CAR group.

        Parameters
        ----------
        areas : list of str
            number of elements must match number of CAR groups

        Throws
        ------
        ValueError
        '''
        em = self.electrode_mapping.copy()
        if len(em['CAR_group'].unique()) != len(areas):
            raise ValueError('Number of items in areas must match number of CAR groups')

        em['areas'] = em['CAR_group'].apply(lambda x: areas[int(x)])
        self.electrode_mapping = em.copy()
        dio.h5io.write_electrode_map_to_h5(self.h5_file, self.electrode_mapping)
        self.save()

    def _setup_digital_mapping(self, dig_type, dig_in_names=None, shell=False):
        '''sets up dig_in_mapping dataframe  and queries user to fill in columns

        Parameters
        ----------
        dig_in_names : list of str (optional)
        shell : bool (optional)
            True for command-line interface
            False (default) for GUI
        '''
        rec_info = self.rec_info
        df = pd.DataFrame()
        df['channel'] = rec_info.get('dig_%s' % dig_type)
        n_dig_in = len(df)
        # Names
        if dig_in_names:
            df['name'] = dig_in_names
        else:
            df['name'] = ''

        # Parameters to query
        if dig_type == 'in':
            df['palatability_rank'] = 0
            df['laser'] = False
            df['spike_array'] = True

        df['exclude'] = False
        # Re-format for query
        idx = df.index
        df.index = ['dig_%s_%i' % (dig_type, x) for x in df.channel]
        dig_str = dig_type + 'put'
        # Query for user input
        prompt = ('Digital %s Parameters\nSet palatability ranks from 1 to %i'
                  '\nor blank to exclude from pal_id analysis') % (dig_str, len(df))
        tmp = userIO.fill_dict(df.to_dict(), prompt=prompt, shell=shell)
        # Reformat for storage
        df2 = pd.DataFrame.from_dict(tmp)
        df2 = df2.sort_values(by=['channel'])
        df2.index = idx
        if dig_type == 'in':
            df2['palatability_rank'] = df2['palatability_rank'].fillna(-1).astype('int')

        if dig_type == 'in':
            self.dig_in_mapping = df2.copy()
        else:
            self.dig_out_mapping = df2.copy()

        if os.path.isfile(self.h5_file):
            dio.h5io.write_digital_map_to_h5(self.h5_file, self.dig_in_mapping, dig_type)

    def _setup_channel_mapping(self, ports, channels, emg_port, emg_channels, shell=False):
        '''Creates electrode_mapping and emg_mapping DataFrames with columns:
        - Electrode
        - Port
        - Channel

        Parameters
        ----------
        ports : list of str, item corresponing to each channel
        channels : list of int, channels on each port
        emg_port : str
        emg_channels : list of int
        '''
        if emg_port is None:
            q = userIO.ask_user('Do you have an EMG?', shell=shell)
            if q==1:
                emg_port = userIO.select_from_list('Select EMG Port:',
                                                   ports, 'EMG Port',
                                                   shell=shell)
                emg_channels = userIO.select_from_list(
                    'Select EMG Channels:',
                    [y for x, y in
                     zip(ports, channels)
                     if x == emg_port],
                    title='EMG Channels',
                    multi_select=True, shell=shell)

        el_map, em_map = dio.params.flatten_channels(ports, channels,
                                                     emg_port, emg_channels)
        self.electrode_mapping = el_map
        self.emg_mapping = em_map
        if os.path.isfile(self.h5_file):
            dio.h5io.write_electrode_map_to_h5(self.h5_file, self.electrode_mapping)

    def edit_spike_array_params(self, shell=False):
        '''Edit spike array parameters and adjust dig_in_mapping accordingly

        Parameters
        ----------
        shell : bool, whether to use CLI or GUI
        '''
        if not hasattr(self, 'dig_in_mapping'):
            self.spike_array_params = None
            return

        sa = deepcopy(self.spike_array_params)
        tmp = userIO.fill_dict(sa, 'Spike Array Parameters\n(Times in ms)',
                               shell=shell)
        if tmp is None:
            return

        dim = self.dig_in_mapping
        dim['spike_array'] = False
        if tmp['dig_ins_to_use'] != ['']:
            tmp['dig_ins_to_use'] = [int(x) for x in tmp['dig_ins_to_use']]
            dim.loc[[x in tmp['dig_ins_to_use'] for x in dim.channel],
                    'spike_array'] = True

        dim['laser_channels'] = False
        if tmp['laser_channels'] != ['']:
            tmp['laser_channels'] = [int(x) for x in tmp['laser_channels']]
            dim.loc[[x in tmp['laser_channels'] for x in dim.channel],
                    'laser'] = True

        self.spike_array_params = tmp.copy()
        wt.write_params_to_json('spike_array_params',
                                self.root_dir, tmp)

    def edit_clustering_params(self, shell=False):
        '''Allows user interface for editing clustering parameters

        Parameters
        ----------
        shell : bool (optional)
            True if you want command-line interface, False for GUI (default)
        '''
        tmp = userIO.fill_dict(self.clustering_params,
                               'Clustering Parameters\n(Times in ms)',
                               shell=shell)
        if tmp:
            self.clustering_params = tmp
            wt.write_params_to_json('clustering_params', self.root_dir, tmp)

    def edit_psth_params(self, shell=False):
        '''Allows user interface for editing psth parameters

        Parameters
        ----------
        shell : bool (optional)
            True if you want command-line interface, False for GUI (default)
        '''
        tmp = userIO.fill_dict(self.psth_params,
                               'PSTH Parameters\n(Times in ms)',
                               shell=shell)
        if tmp:
            self.psth_params = tmp
            wt.write_params_to_json('psth_params', self.root_dir, tmp)

    def edit_pal_id_params(self, shell=False):
        '''Allows user interface for editing palatability/identity parameters

        Parameters
        ----------
        shell : bool (optional)
            True if you want command-line interface, False for GUI (default)
        '''
        tmp = userIO.fill_dict(self.pal_id_params,
                               'Palatability/Identity Parameters\n(Times in ms)',
                               shell=shell)
        if tmp:
            self.pal_id_params = tmp
            wt.write_params_to_json('pal_id_params', self.root_dir, tmp)

    def __str__(self):
        '''Put all information about dataset in string format

        Returns
        -------
        str : representation of dataset object
        '''
        out1 = super().__str__()
        out = [out1]
        out.append('\nObject creation date: '
                   + self.dataset_creation_date.strftime('%m/%d/%y'))

        if hasattr(self, 'raw_h5_file'):
            out.append('Deleted Raw h5 file: '+self.raw_h5_file)

        out.append('h5 File: '+self.h5_file)
        out.append('')

        out.append('--------------------')
        out.append('Processing Status')
        out.append('--------------------')
        out.append(pt.print_dict(self.process_status))
        out.append('')

        if not hasattr(self, 'rec_info'):
            return '\n'.join(out)

        info = self.rec_info

        out.append('--------------------')
        out.append('Recording Info')
        out.append('--------------------')
        out.append(pt.print_dict(self.rec_info))
        out.append('')

        out.append('--------------------')
        out.append('Electrodes')
        out.append('--------------------')
        out.append(pt.print_dataframe(self.electrode_mapping))
        out.append('')

        if hasattr(self, 'CAR_electrodes'):
            out.append('--------------------')
            out.append('CAR Groups')
            out.append('--------------------')
            headers = ['Group %i' % x for x in range(len(self.CAR_electrodes))]
            out.append(pt.print_list_table(self.CAR_electrodes, headers))
            out.append('')

        if not self.emg_mapping.empty:
            out.append('--------------------')
            out.append('EMG')
            out.append('--------------------')
            out.append(pt.print_dataframe(self.emg_mapping))
            out.append('')

        if info.get('dig_in'):
            out.append('--------------------')
            out.append('Digital Input')
            out.append('--------------------')
            out.append(pt.print_dataframe(self.dig_in_mapping))
            out.append('')

        if info.get('dig_out'):
            out.append('--------------------')
            out.append('Digital Output')
            out.append('--------------------')
            out.append(pt.print_dataframe(self.dig_out_mapping))
            out.append('')

        out.append('--------------------')
        out.append('Clustering Parameters')
        out.append('--------------------')
        out.append(pt.print_dict(self.clustering_params))
        out.append('')

        out.append('--------------------')
        out.append('Spike Array Parameters')
        out.append('--------------------')
        out.append(pt.print_dict(self.spike_array_params))
        out.append('')

        out.append('--------------------')
        out.append('PSTH Parameters')
        out.append('--------------------')
        out.append(pt.print_dict(self.psth_params))
        out.append('')

        out.append('--------------------')
        out.append('Palatability/Identity Parameters')
        out.append('--------------------')
        out.append(pt.print_dict(self.pal_id_params))
        out.append('')

        return '\n'.join(out)

    @Logger('Writing parameters to JSON')
    def _write_all_params_to_json(self):
        '''Writes all parameters to json files in analysis_params folder in the
        recording directory
        '''
        print('Writing all parameters to json file in analysis_params folder...')
        clustering_params = self.clustering_params
        spike_array_params = self.spike_array_params
        psth_params = self.psth_params
        pal_id_params = self.pal_id_params
        CAR_params = self.CAR_electrodes

        rec_dir = self.root_dir
        wt.write_params_to_json('clustering_params', rec_dir, clustering_params)
        wt.write_params_to_json('spike_array_params', rec_dir, spike_array_params)
        wt.write_params_to_json('psth_params', rec_dir, psth_params)
        wt.write_params_to_json('pal_id_params', rec_dir, pal_id_params)
        wt.write_params_to_json('CAR_params', rec_dir, CAR_params)

    @Logger('Extracting Data')
    def extract_data(self, filename=None, shell=False):
        '''Create hdf5 store for data and read in Intan .dat files. Also create
        subfolders for processing outputs

        Parameters
        ----------
        data_quality: {'clean', 'noisy'} (optional)
            Specifies quality of data for default clustering parameters
            associated. Should generally first process with clean (default)
            parameters and then try noisy after running blech_clust and
            checking if too many electrodes as cutoff too early
        '''
        if self.rec_info['file_type'] is None:
            raise ValueError('Unsupported recording type. Cannot extract yet.')

        if filename is None:
            filename = self.h5_file

        print('\nExtract Intan Data\n--------------------')
        # Create h5 file
        tmp = dio.h5io.create_empty_data_h5(filename, shell)
        if tmp is None:
            return

        # Create arrays for raw data in hdf5 store
        dio.h5io.create_hdf_arrays(filename, self.rec_info,
                                   self.electrode_mapping, self.emg_mapping)

        # Read in data to arrays
        dio.h5io.read_files_into_arrays(filename,
                                        self.rec_info,
                                        self.electrode_mapping,
                                        self.emg_mapping)

        # Write electrode and digital input mapping into h5 file
        # TODO: write EMG and digital output mapping into h5 file
        dio.h5io.write_electrode_map_to_h5(self.h5_file, self.electrode_mapping)
        if self.dig_in_mapping is not None:
            dio.h5io.write_digital_map_to_h5(self.h5_file, self.dig_in_mapping, 'in')

        if self.dig_out_mapping is not None:
            dio.h5io.write_digital_map_to_h5(self.h5_file, self.dig_in_mapping, 'out')

        # update status
        self.h5_file = filename
        self.process_status['extract_data'] = True
        self.save()
        print('\nData Extraction Complete\n--------------------')

    @Logger('Creating Trial List')
    def create_trial_list(self):
        '''Create lists of trials based on digital inputs and outputs and store
        to hdf5 store
        Can only be run after data extraction
        '''
        if self.rec_info.get('dig_in'):
            in_list = dio.h5io.create_trial_data_table(
                self.h5_file,
                self.dig_in_mapping,
                self.sampling_rate,
                'in')
            self.dig_in_trials = in_list
        else:
            print('No digital input data found')

        if self.rec_info.get('dig_out'):
            out_list = dio.h5io.create_trial_data_table(
                self.h5_file,
                self.dig_out_mapping,
                self.sampling_rate,
                'out')
            self.dig_out_trials = out_list
        else:
            print('No digital output data found')

        self.process_status['create_trial_list'] = True
        self.save()

    @Logger('Marking Dead Channels')
    def mark_dead_channels(self, dead_channels=None, shell=False):
        '''Plots small piece of raw traces and a metric to help identify dead
        channels. Once user marks channels as dead a new column is added to
        electrode mapping

        Parameters
        ----------
        dead_channels : list of int, optional
            if this is specified then nothing is plotted, those channels are
            simply marked as dead
        shell : bool, optional
        '''
        print('Marking dead channels\n----------')
        em = self.electrode_mapping.copy()
        if dead_channels is None:
            userIO.tell_user('Making traces figure for dead channel detection...',
                             shell=True)
            save_file = os.path.join(self.root_dir, 'Electrode_Traces.png')
            fig, ax = datplt.plot_traces_and_outliers(self.h5_file, save_file=save_file)
            if not shell:
                # Better to open figure outside of python since its a lot of
                # data on figure and matplotlib is slow
                subprocess.call(['xdg-open', save_file])
            else:
                userIO.tell_user('Saved figure of traces to %s for reference'
                                 % save_file, shell=shell)

            choice = userIO.select_from_list('Select dead channels:',
                                             em.Electrode.to_list(),
                                             'Dead Channel Selection',
                                             multi_select=True,
                                             shell=shell)
            dead_channels = list(map(int, choice))

        print('Marking eletrodes %s as dead.\n'
              'They will be excluded from common average referencing.'
              % dead_channels)
        em['dead'] = False
        em.loc[dead_channels, 'dead'] = True
        self.electrode_mapping = em
        if os.path.isfile(self.h5_file):
            dio.h5io.write_electrode_map_to_h5(self.h5_file, self.electrode_mapping)

        self.process_status['mark_dead_channels'] = True
        self.save()
        return dead_channels

    @Logger('Common Average Referencing')
    def common_average_reference(self):
        '''Define electrode groups and remove common average from  signals

        Parameters
        ----------
        num_groups : int (optional)
            number of CAR groups, if not provided
            there's a prompt
        '''
        if not hasattr(self, 'CAR_electrodes'):
            raise ValueError('CAR_electrodes not set')

        if not hasattr(self, 'electrode_mapping'):
            raise ValueError('electrode_mapping not set')

        car_electrodes = self.CAR_electrodes
        num_groups = len(car_electrodes)
        em = self.electrode_mapping.copy()

        if 'dead' in em.columns:
            dead_electrodes = em.Electrode[em.dead].to_list()
        else:
            dead_electrodes = []

        # Gather Common Average Reference Groups
        print('CAR Groups\n')
        headers = ['Group %i' % x for x in range(num_groups)]
        print(pt.print_list_table(car_electrodes, headers))

        # Reference each group
        for i, x in enumerate(car_electrodes):
            tmp = list(set(x) - set(dead_electrodes))
            dio.h5io.common_avg_reference(self.h5_file, tmp, i)

        # Compress and repack file
        dio.h5io.compress_and_repack(self.h5_file)

        self.process_status['common_average_reference'] = True
        self.save()

    @Logger('Running Spike Detection')
    def detect_spikes(self, data_quality=None, multi_process=True, n_cores=None):
        '''Run spike detection on each electrode. Prepares for clustering with
        BlechClust. Works for both single recording clustering or
        multi-recording clustering

        Parameters
        ----------
        data_quality : {'clean', 'noisy', None (default)}
            set if you want to change the data quality parameters for cutoff
            and spike detection before running clustering. These parameters are
            automatically set as "clean" during initial parameter setup
        n_cores : int (optional)
            number of cores to use for parallel processing. default is max-1.
        '''
        if data_quality:
            tmp = dio.params.load_params('clustering_params', self.root_dir,
                                         default_keyword=data_quality)
            if tmp:
                self.clustering_params = tmp
                wt.write_params_to_json('clustering_params', self.root_dir, tmp)
            else:
                raise ValueError('%s is not a valid data_quality preset. Must '
                                 'be "clean" or "noisy" or None.')

        print('\nRunning Spike Detection\n-------------------')
        print('Parameters\n%s' % pt.print_dict(self.clustering_params))

        # Create folders for saving things within recording dir
        data_dir = self.root_dir

        em = self.electrode_mapping
        if 'dead' in em.columns:
            electrodes = em.Electrode[em['dead'] == False].tolist()
        else:
            electrodes = em.Electrode.tolist()


        pbar = tqdm(total = len(electrodes))
        results = [(None, None, None)] * (max(electrodes)+1)
        def update_pbar(ans):
            if isinstance(ans, tuple) and ans[0] is not None:
                results[ans[0]] = ans
            else:
                print('Unexpected error when clustering an electrode')

            pbar.update()

        spike_detectors = [clust.SpikeDetection(data_dir, x,
                                                self.clustering_params)
                           for x in electrodes]
        if multi_process:
            if n_cores is None or n_cores > multiprocessing.cpu_count():
                n_cores = multiprocessing.cpu_count() - 1

            pool = multiprocessing.get_context('spawn').Pool(n_cores)
            for sd in spike_detectors:
                pool.apply_async(sd.run, callback=update_pbar)

            pool.close()
            pool.join()
        else:
            for sd in spike_detectors:
                res = sd.run()
                update_pbar(res)

        pbar.close()

        print('Electrode    Result    Cutoff (s)')
        cutoffs = {}
        clust_res = {}
        clustered = []
        for x, y, z in results:
            if x is None:
                continue

            clustered.append(x)
            print('  {:<13}{:<10}{}'.format(x, y, z))
            cutoffs[x] = z
            clust_res[x] = y

        print('1 - Sucess\n0 - No data or no spikes\n-1 - Error')

        em = self.electrode_mapping.copy()
        em['cutoff_time'] = em['Electrode'].map(cutoffs)
        em['clustering_result'] = em['Electrode'].map(clust_res)
        self.electrode_mapping = em.copy()
        self.process_status['spike_detection'] = True
        dio.h5io.write_electrode_map_to_h5(self.h5_file, em)
        self.save()
        print('Spike Detection Complete\n------------------')
        return results

    @Logger('Running Blech Clust')
    def blech_clust_run(self, data_quality=None, multi_process=True, n_cores=None, umap=False):
        '''Write clustering parameters to file and
        Run blech_process on each electrode using GNU parallel

        Parameters
        ----------
        data_quality : {'clean', 'noisy', None (default)}
            set if you want to change the data quality parameters for cutoff
            and spike detection before running clustering. These parameters are
            automatically set as "clean" during initial parameter setup
        accept_params : bool, False (default)
            set to True in order to skip popup confirmation of parameters when
            running
        '''
        if self.process_status['spike_detection'] == False:
            raise FileNotFoundError('Must run spike detection before clustering.')

        if data_quality:
            tmp = dio.params.load_params('clustering_params', self.root_dir,
                                         default_keyword=data_quality)
            if tmp:
                self.clustering_params = tmp
                wt.write_params_to_json('clustering_params', self.root_dir, tmp)
            else:
                raise ValueError('%s is not a valid data_quality preset. Must '
                                 'be "clean" or "noisy" or None.')

        print('\nRunning Blech Clust\n-------------------')
        print('Parameters\n%s' % pt.print_dict(self.clustering_params))

        # Get electrodes, throw out 'dead' electrodes
        em = self.electrode_mapping
        if 'dead' in em.columns:
            electrodes = em.Electrode[em['dead'] == False].tolist()
        else:
            electrodes = em.Electrode.tolist()


        pbar = tqdm(total = len(electrodes))
        def update_pbar(ans):
            pbar.update()

        errors = []
        def error_call(e):
            errors.append(e)

        if not umap:
            clust_objs = [clust.BlechClust(self.root_dir, x, params=self.clustering_params)
                          for x in electrodes]
        else:
            clust_objs = [clust.BlechClust(self.root_dir, x,
                                           params=self.clustering_params,
                                           data_transform=clust.UMAP_METRICS,
                                           n_pc=5)
                          for x in electrodes]

        if multi_process:
            if n_cores is None or n_cores > multiprocessing.cpu_count():
                n_cores = multiprocessing.cpu_count() - 1

            pool = multiprocessing.get_context('spawn').Pool(n_cores)
            for x in clust_objs:
                pool.apply_async(x.run, callback=update_pbar, error_callback=error_call)

            pool.close()
            pool.join()
        else:
            for x in clust_objs:
                res = x.run()
                update_pbar(res)

        pbar.close()

        self.process_status['spike_clustering'] = True
        self.process_status['cleanup_clustering'] = False
        dio.h5io.write_electrode_map_to_h5(self.h5_file, em)
        self.save()
        print('Clustering Complete\n------------------')
        if len(errors) > 0:
            print('Errors encountered:')
            print(errors)

    @Logger('Cleaning up clustering memory logs. Removing raw data and setting'
            'up hdf5 for unit sorting')
    def cleanup_clustering(self):
        '''Consolidates memory monitor files, removes raw and referenced data
        and setups up hdf5 store for sorted units data
        '''
        if self.process_status['cleanup_clustering']:
            return

        h5_file = dio.h5io.cleanup_clustering(self.root_dir)
        self.h5_file = h5_file
        self.process_status['cleanup_clustering'] = True
        self.save()

    def sort_spikes(self, electrode=None, shell=False):
        if electrode is None:
            electrode = userIO.get_user_input('Electrode #: ', shell=shell)
            if electrode is None or not electrode.isnumeric():
                return

            electrode = int(electrode)

        if not self.process_status['spike_clustering']:
            raise ValueError('Must run spike clustering first.')

        if not self.process_status['cleanup_clustering']:
            self.cleanup_clustering()

        sorter = clust.SpikeSorter(self.root_dir, electrode=electrode, shell=shell)
        if not shell:
            root, sorting_GUI = ssg.launch_sorter_GUI(sorter)
            return root, sorting_GUI
        else:
            # TODO: Make shell UI
            # TODO: Make sort by table
            print('No shell UI yet')
            return

        self.process_status['sort_units'] = True

    @Logger('Calculating Units Similarity')
    def units_similarity(self, similarity_cutoff=50, shell=False):
        if 'SSH_CONNECTION' in os.environ:
            shell= True

        metrics_dir = os.path.join(self.root_dir, 'sorted_unit_metrics')
        if not os.path.isdir(metrics_dir):
            raise ValueError('No sorted unit metrics found. Must sort units before calculating similarity')

        violation_file = os.path.join(metrics_dir,
                                      'units_similarity_violations.txt')
        violations, sim = ss.calc_units_similarity(self.h5_file,
                                                   self.sampling_rate,
                                                   similarity_cutoff,
                                                   violation_file)
        if len(violations) == 0:
            userIO.tell_user('No similarity violations found!', shell=shell)
            self.process_status['units_similarity'] = True
            return violations, sim

        out_str = ['Units Similarity Violations Found:']
        out_str.append('Unit_1    Unit_2    Similarity')
        for x,y in violations:
            u1 = dio.h5io.parse_unit_number(x)
            u2 = dio.h5io.parse_unit_number(y)
            out_str.append('   {:<10}{:<10}{}\n'.format(x, y, sim[u1][u2]))

        out_str.append('Delete units with dataset.delete_unit(N)')
        out_str = '\n'.join(out_str)
        userIO.tell_user(out_str, shell=shell)
        self.process_status['units_similarity'] = True
        self.save()
        return violations, sim

    @Logger('Deleting Unit')
    def delete_unit(self, unit_num, confirm=False, shell=False):
        if isinstance(unit_num, str):
            unit_num = dio.h5io.parse_unit_number(unit_num)

        if unit_num is None:
            print('No unit deleted')
            return

        if not confirm:
            q = userIO.ask_user('Are you sure you want to delete unit%03i?' % unit_num,
                                choices = ['No','Yes'], shell=shell)
        else:
            q = 1

        if q == 0:
            print('No unit deleted')
            return
        else:
            tmp = dio.h5io.delete_unit(self.root_dir, unit_num)
            if tmp is False:
                userIO.tell_user('Unit %i not found in dataset. No unit deleted'
                                 % unit_num, shell=shell)
            else:
                userIO.tell_user('Unit %i sucessfully deleted.' % unit_num,
                                 shell=shell)

        self.save()

    @Logger('Making Unit Arrays')
    def make_unit_arrays(self):
        '''Make spike arrays for each unit and store in hdf5 store
        '''
        params = self.spike_array_params

        print('Generating unit arrays with parameters:\n----------')
        print(pt.print_dict(params, tabs=1))
        ss.make_spike_arrays(self.h5_file, params)
        self.process_status['make_unit_arrays'] = True
        self.save()

    @Logger('Making Unit Plots')
    def make_unit_plots(self):
        '''Make waveform plots for each sorted unit
        '''
        unit_table = self.get_unit_table()
        save_dir = os.path.join(self.root_dir, 'unit_waveforms_plots')
        if os.path.isdir(save_dir):
            shutil.rmtree(save_dir)

        os.mkdir(save_dir)
        for i, row in unit_table.iterrows():
            datplt.make_unit_plots(self.root_dir, row['unit_name'], save_dir=save_dir)

        self.process_status['make_unit_plots'] = True
        self.save()

    @Logger('Making PSTH Arrays')
    def make_psth_arrays(self):
        '''Make smoothed firing rate traces for each unit/trial and store in
        hdf5 store
        '''
        params = self.psth_params
        dig_ins = self.dig_in_mapping.query('spike_array == True')
        for idx, row in dig_ins.iterrows():
            spike_analysis.make_psths_for_tastant(self.h5_file,
                                                  params['window_size'],
                                                  params['window_step'],
                                                  row['channel'])

        self.process_status['make_psth_arrays'] = True
        self.save()

    @Logger('Calculating Palatability/Identity Metrics')
    def palatability_calculate(self, shell=False):
        pal_analysis.palatability_identity_calculations(self.root_dir,
                                                        params=self.pal_id_params)
        self.process_status['palatability_calculate'] = True
        self.save()

    @Logger('Plotting Palatability/Identity Metrics')
    def palatability_plot(self, shell=False):
        pal_plt.plot_palatability_identity([self.root_dir], shell=shell)
        self.process_status['palatability_plot'] = True
        self.save()

    @Logger('Removing low-spiking units')
    def cleanup_lowSpiking_units(self, min_spikes=100):
        unit_table = self.get_unit_table()
        remove = []
        spike_count = []
        for unit in unit_table['unit_num']:
            waves, descrip, fs = dio.h5io.get_unit_waveforms(self.root_dir, unit)
            if waves.shape[0] < min_spikes:
                spike_count.append(waves.shape[0])
                remove.append(unit)

        for unit, count in zip(reversed(remove), reversed(spike_count)):
            print('Removing unit %i. Only %i spikes.' % (unit, count))
            userIO.tell_user('Removing unit %i. Only %i spikes.'
                             % (unit, count), shell=True)
            self.delete_unit(unit, confirm=True, shell=True)

        userIO.tell_user('Removed %i units for having less than %i spikes.'
                         % (len(remove), min_spikes), shell=True)

    def get_unit_table(self):
        '''Returns a pandas dataframe with sorted unit information

        Returns
        --------
        pandas.DataFrame with columns:
            unit_name, unit_num, electrode, single_unit,
            regular_spiking, fast_spiking
        '''
        unit_table = dio.h5io.get_unit_table(self.root_dir)
        return unit_table

    def circus_clust_run(self, shell=False):
        circ.prep_for_circus(self.root_dir, self.electrode_mapping,
                             self.data_name, self.sampling_rate)
        circ.start_the_show()

    def pre_process_for_clustering(self, shell=False, dead_channels=None):
        status = self.process_status
        if not status['initialize parameters']:
            self.initParams(shell=shell)

        if not status['extract_data']:
            self.extract_data(shell=True)

        if not status['create_trial_list']:
            self.create_trial_list()

        if not status['mark_dead_channels'] and dead_channels != False:
            self.mark_dead_channels(dead_channels=dead_channels, shell=shell)

        if not status['common_average_reference']:
            self.common_average_reference()

        if not status['spike_detection']:
            self.detect_spikes()

    def extract_and_circus_cluster(self, dead_channels=None, shell=True):
        print('Extracting Data...')
        self.extract_data()
        print('Marking dead channels...')
        self.mark_dead_channels(dead_channels, shell=shell)
        print('Common average referencing...')
        self.common_average_reference()
        print('Initiating circus clustering...')
        circus = circ.circus_clust(self.root_dir, self.data_name,
                                   self.sampling_rate, self.electrode_mapping)
        print('Preparing for circus...')
        circus.prep_for_circus()
        print('Starting circus clustering...')
        circus.start_the_show()
        print('Plotting cluster waveforms...')
        circus.plot_cluster_waveforms()

    def post_sorting(self):
        self.make_unit_plots()
        self.make_unit_arrays()
        self.units_similarity(shell=True)
        self.make_psth_arrays()


def port_in_dataset(rec_dir=None, shell=False):
    '''Import an existing dataset into this framework
    '''
    if rec_dir is None:
        rec_dir = userIO.get_filedirs('Select recording directory', shell=shell)
        if rec_dir is None:
            return None

    dat = dataset(rec_dir, shell=shell)
    # Check files that will be overwritten: log_file, save_file
    if os.path.isfile(dat.save_file):
        prompt = '%s already exists. Continuing will overwrite this. Continue?' % dat.save_file
        q = userIO.ask_user(prompt, shell=shell)
        if q == 0:
            print('Aborted')
            return None

    # if os.path.isfile(dat.h5_file):
    #     prompt = '%s already exists. Continuinlg will overwrite this. Continue?' % dat.h5_file
    #     q = userIO.ask_user(prompt, shell=shell)
    #     if q == 0:
    #         print('Aborted')
    #         return None

    if os.path.isfile(dat.log_file):
        prompt = '%s already exists. Continuing will append to this. Continue?' % dat.log_file
        q = userIO.ask_user(prompt, shell=shell)
        if q == 0:
            print('Aborted')
            return None

    with open(dat.log_file, 'a') as f:
        print('\n==========\nPorting dataset into blechpy format\n==========\n', file=f)
        print(dat, file=f)

    status = dat.process_status
    dat.initParams(shell=shell)


    user_status = status.copy()
    user_status = userIO.fill_dict(user_status,
                                   'Which processes have already been '
                                   'done to the data?', shell=shell)

    status.update(user_status)
    # if h5 exists data must have been extracted

    if not os.path.isfile(dat.h5_file) or status['extract_data'] == False:
        dat.save()
        return dat

    # write eletrode map and digital input & output maps to hf5
    dio.h5io.write_electrode_map_to_h5(dat.h5_file, dat.electrode_mapping)
    if dat.rec_info.get('dig_in') is not None:
        dio.h5io.write_digital_map_to_h5(dat.h5_file, dat.dig_in_mapping, 'in')

    if dat.rec_info.get('dig_out') is not None:
        dio.h5io.write_digital_map_to_h5(dat.h5_file, dat.dig_out_mapping, 'out')


    node_list = dio.h5io.get_node_list(dat.h5_file)

    if (status['create_trial_list'] == False) and ('digital_in' in node_list):
        dat.create_trial_list()

    dat.save()
    return dat


def validate_data_integrity(rec_dir, verbose=False):
    '''incomplete
    '''
    # TODO: Finish this
    print('Raw Data Validation\n' + '-'*19)
    test_names = ['file_type', 'recording_info', 'files', 'dropped_packets', 'data_length']
    number_names = ['sample_rate', 'dropped_packets', 'missing_files', 'recording_length']
    tests = dict.fromkeys(test_names, 'NOT TESTED')
    numbers = dict.fromkeys(number_names, -1)
    file_type = dio.rawIO.get_recording_filetype(rec_dir)
    if file_type is None:
        file_type_check = 'UNSUPPORTED'
    else:
        tests['file_type'] = 'PASS'

    # Check info.rhd integrity
    info_file = os.path.join(rec_dir, 'info.rhd')
    try:
        rec_info = dio.rawIO.read_rec_info(rec_dir, shell=True)
        with open(info_file, 'rb') as f:
            info = dio.load_intan_rhd_format.read_header(f)

        tests['recording_info'] = 'PASS'
    except FileNotFoundError:
        test['recording_info'] = 'MISSING'
    except Exception as e:
        info_size = os.path.getsize(os.path.join(rec_dir, 'info.rhd'))
        if info_size == 0:
            tests['recording_info'] = 'EMPTY'
        else:
            tests['recording_info'] = 'FAIL'

        print(pt.print_dict(tests, tabs=1))
        return tests, numbers

    counts = {x : info(x) for x in info.keys() if 'num' in x}
    numbers.update(counts)
    fs = info['sample_rate']
    # Check all files needed are present
    files_expected = ['time.dat']
    if file_type == 'one file per signal type':
        files_expected.append('amplifier.dat')
        if rec_info.get('dig_in') is not None:
            files_expected.append('digitalin.dat')

        if rec_info.get('dig_out') is not None:
            files_expected.append('digitalout.dat')

        if info['num_auxilary_input_channels'] > 0:
            files_expected.append('auxiliary.dat')

    elif file_type == 'one file per channel':
        for x in info['amplifier_channels']:
            files_expected.append('amp-' + x['native_channel_name'] + '.dat')

        for x in info['board_dig_in_channels']:
            files_expected.append('board-%s.dat' % x['native_channel_name'])

        for x in info['board_dig_out_channels']:
            files_expected.append('board-%s.dat' % x['native_channel_name'])

        for x in info['aux_input_channels']:
            files_expected.append('aux-%s.dat' % x['native_channel_name'])


    missing_files = []
    file_list = os.listdir(rec_dir)
    for x in file_expected:
        if x not in file_list:
            missing_file.append(x)

    if len(missing_files) == 0:
        tests['files'] = 'PASS'
    else:
        tests['files'] = 'MISSING'
        numbers['missing_files'] = missing_files

    # Check time data for dropped packets
    time = dio.rawIO.read_time_dat(rec_dir, sampling_rate=1)  # get raw timestamps
    numbers['n_samples'] = len(time)
    numbers['recording_length'] = float(time[-1])/fs
    expected_time = np.arange(time[0], time[-1]+1, 1)
    missing_timestamps = np.setdiff1d(expected_time, time)
    missing_times = np.array([float(x)/fs for x in missing_timestamps])
    if len(missing_timestamps) == 0:
        tests['dropped_packets'] = 'PASS'
    else:
        tests['dropped_packets'] = '%i' % len(missing_timestamps)
        numbers['dropped_packets'] = missing_times

    # Check recording length of each trace
    tests['data_traces'] = 'FAIL'
    if file_type == 'one file per signal type':
        try:
            data = dio.rawIO.read_amplifier_dat(rec_dir)
            if data is None:
                tests['data_traces'] = 'UNREADABLE'
            elif data.shape[0] == numbers['n_samples']:
                tests['data_traces'] = 'PASS'
            else:
                tests['data_traces'] = 'CUTOFF'
                numbers['data_trace_length (s)'] = data.shape[0]/fs

        except:
            tests['data_traces'] = 'UNREADABLE'

    elif file_type == 'one file per channel':
        chan_info = pd.DataFrame(columns=['port', 'channel', 'n_samples'])
        lengths = []
        min_samples = numbers['n_samples']
        max_samples = number['n_samples']
        for x in info['amplifier_channels']:
            fn = os.path.join(rec_dir, 'amp-%s.dat' % x['native_channel_name'])
            if os.path.basename(fn) in missing_files:
                continue

            data = dio.rawIO.read_one_channel_file(fn)
            lengths.append((x['native_channel_name'], data.shape[0]))
            if data.shape[0] < min_samples:
                min_samples = data.shape[0]

            if data.shape[0] > max_samples:
                max_samples = data.shape[0]

        if min_samples == max_samples:
            tests['data_traces'] = 'PASS'

        else:
            test['data_traces'] = 'CUTOFF'

        numbers['max_recording_length (s)'] = max_samples/fs
        numbers['min_recording_length (s)'] = min_samples/fs

Functions

def port_in_dataset(rec_dir=None, shell=False)

Import an existing dataset into this framework

Expand source code
def port_in_dataset(rec_dir=None, shell=False):
    '''Import an existing dataset into this framework
    '''
    if rec_dir is None:
        rec_dir = userIO.get_filedirs('Select recording directory', shell=shell)
        if rec_dir is None:
            return None

    dat = dataset(rec_dir, shell=shell)
    # Check files that will be overwritten: log_file, save_file
    if os.path.isfile(dat.save_file):
        prompt = '%s already exists. Continuing will overwrite this. Continue?' % dat.save_file
        q = userIO.ask_user(prompt, shell=shell)
        if q == 0:
            print('Aborted')
            return None

    # if os.path.isfile(dat.h5_file):
    #     prompt = '%s already exists. Continuinlg will overwrite this. Continue?' % dat.h5_file
    #     q = userIO.ask_user(prompt, shell=shell)
    #     if q == 0:
    #         print('Aborted')
    #         return None

    if os.path.isfile(dat.log_file):
        prompt = '%s already exists. Continuing will append to this. Continue?' % dat.log_file
        q = userIO.ask_user(prompt, shell=shell)
        if q == 0:
            print('Aborted')
            return None

    with open(dat.log_file, 'a') as f:
        print('\n==========\nPorting dataset into blechpy format\n==========\n', file=f)
        print(dat, file=f)

    status = dat.process_status
    dat.initParams(shell=shell)


    user_status = status.copy()
    user_status = userIO.fill_dict(user_status,
                                   'Which processes have already been '
                                   'done to the data?', shell=shell)

    status.update(user_status)
    # if h5 exists data must have been extracted

    if not os.path.isfile(dat.h5_file) or status['extract_data'] == False:
        dat.save()
        return dat

    # write eletrode map and digital input & output maps to hf5
    dio.h5io.write_electrode_map_to_h5(dat.h5_file, dat.electrode_mapping)
    if dat.rec_info.get('dig_in') is not None:
        dio.h5io.write_digital_map_to_h5(dat.h5_file, dat.dig_in_mapping, 'in')

    if dat.rec_info.get('dig_out') is not None:
        dio.h5io.write_digital_map_to_h5(dat.h5_file, dat.dig_out_mapping, 'out')


    node_list = dio.h5io.get_node_list(dat.h5_file)

    if (status['create_trial_list'] == False) and ('digital_in' in node_list):
        dat.create_trial_list()

    dat.save()
    return dat
def validate_data_integrity(rec_dir, verbose=False)

incomplete

Expand source code
def validate_data_integrity(rec_dir, verbose=False):
    '''incomplete
    '''
    # TODO: Finish this
    print('Raw Data Validation\n' + '-'*19)
    test_names = ['file_type', 'recording_info', 'files', 'dropped_packets', 'data_length']
    number_names = ['sample_rate', 'dropped_packets', 'missing_files', 'recording_length']
    tests = dict.fromkeys(test_names, 'NOT TESTED')
    numbers = dict.fromkeys(number_names, -1)
    file_type = dio.rawIO.get_recording_filetype(rec_dir)
    if file_type is None:
        file_type_check = 'UNSUPPORTED'
    else:
        tests['file_type'] = 'PASS'

    # Check info.rhd integrity
    info_file = os.path.join(rec_dir, 'info.rhd')
    try:
        rec_info = dio.rawIO.read_rec_info(rec_dir, shell=True)
        with open(info_file, 'rb') as f:
            info = dio.load_intan_rhd_format.read_header(f)

        tests['recording_info'] = 'PASS'
    except FileNotFoundError:
        test['recording_info'] = 'MISSING'
    except Exception as e:
        info_size = os.path.getsize(os.path.join(rec_dir, 'info.rhd'))
        if info_size == 0:
            tests['recording_info'] = 'EMPTY'
        else:
            tests['recording_info'] = 'FAIL'

        print(pt.print_dict(tests, tabs=1))
        return tests, numbers

    counts = {x : info(x) for x in info.keys() if 'num' in x}
    numbers.update(counts)
    fs = info['sample_rate']
    # Check all files needed are present
    files_expected = ['time.dat']
    if file_type == 'one file per signal type':
        files_expected.append('amplifier.dat')
        if rec_info.get('dig_in') is not None:
            files_expected.append('digitalin.dat')

        if rec_info.get('dig_out') is not None:
            files_expected.append('digitalout.dat')

        if info['num_auxilary_input_channels'] > 0:
            files_expected.append('auxiliary.dat')

    elif file_type == 'one file per channel':
        for x in info['amplifier_channels']:
            files_expected.append('amp-' + x['native_channel_name'] + '.dat')

        for x in info['board_dig_in_channels']:
            files_expected.append('board-%s.dat' % x['native_channel_name'])

        for x in info['board_dig_out_channels']:
            files_expected.append('board-%s.dat' % x['native_channel_name'])

        for x in info['aux_input_channels']:
            files_expected.append('aux-%s.dat' % x['native_channel_name'])


    missing_files = []
    file_list = os.listdir(rec_dir)
    for x in file_expected:
        if x not in file_list:
            missing_file.append(x)

    if len(missing_files) == 0:
        tests['files'] = 'PASS'
    else:
        tests['files'] = 'MISSING'
        numbers['missing_files'] = missing_files

    # Check time data for dropped packets
    time = dio.rawIO.read_time_dat(rec_dir, sampling_rate=1)  # get raw timestamps
    numbers['n_samples'] = len(time)
    numbers['recording_length'] = float(time[-1])/fs
    expected_time = np.arange(time[0], time[-1]+1, 1)
    missing_timestamps = np.setdiff1d(expected_time, time)
    missing_times = np.array([float(x)/fs for x in missing_timestamps])
    if len(missing_timestamps) == 0:
        tests['dropped_packets'] = 'PASS'
    else:
        tests['dropped_packets'] = '%i' % len(missing_timestamps)
        numbers['dropped_packets'] = missing_times

    # Check recording length of each trace
    tests['data_traces'] = 'FAIL'
    if file_type == 'one file per signal type':
        try:
            data = dio.rawIO.read_amplifier_dat(rec_dir)
            if data is None:
                tests['data_traces'] = 'UNREADABLE'
            elif data.shape[0] == numbers['n_samples']:
                tests['data_traces'] = 'PASS'
            else:
                tests['data_traces'] = 'CUTOFF'
                numbers['data_trace_length (s)'] = data.shape[0]/fs

        except:
            tests['data_traces'] = 'UNREADABLE'

    elif file_type == 'one file per channel':
        chan_info = pd.DataFrame(columns=['port', 'channel', 'n_samples'])
        lengths = []
        min_samples = numbers['n_samples']
        max_samples = number['n_samples']
        for x in info['amplifier_channels']:
            fn = os.path.join(rec_dir, 'amp-%s.dat' % x['native_channel_name'])
            if os.path.basename(fn) in missing_files:
                continue

            data = dio.rawIO.read_one_channel_file(fn)
            lengths.append((x['native_channel_name'], data.shape[0]))
            if data.shape[0] < min_samples:
                min_samples = data.shape[0]

            if data.shape[0] > max_samples:
                max_samples = data.shape[0]

        if min_samples == max_samples:
            tests['data_traces'] = 'PASS'

        else:
            test['data_traces'] = 'CUTOFF'

        numbers['max_recording_length (s)'] = max_samples/fs
        numbers['min_recording_length (s)'] = min_samples/fs

Classes

class dataset (file_dir=None, data_name=None, shell=False)

Stores information related to an intan recording directory, allows executing basic processing and analysis scripts, and stores parameters data for those analyses

Parameters

file_dir : str (optional)
absolute path to a recording directory, if left empty a filechooser will popup

Initialize dataset object from file_dir, grabs basename from name of directory and initializes basic analysis parameters

Parameters

file_dir : str (optional), file directory for intan recording data
 

Throws

ValueError if file_dir is not provided and no directory is chosen when prompted NotADirectoryError : if file_dir does not exist

Expand source code
class dataset(data_object):
    '''Stores information related to an intan recording directory, allows
    executing basic processing and analysis scripts, and stores parameters data
    for those analyses

    Parameters
    ----------
    file_dir : str (optional)
        absolute path to a recording directory, if left empty a filechooser
        will popup
    '''
    PROCESSING_STEPS = ['initialize parameters',
                        'extract_data', 'create_trial_list',
                        'mark_dead_channels',
                        'common_average_reference', 'spike_detection',
                        'spike_clustering', 'cleanup_clustering',
                        'sort_units', 'make_unit_plots',
                        'units_similarity', 'make_unit_arrays',
                        'make_psth_arrays', 'plot_psths',
                        'palatability_calculate', 'palatability_plot',
                        'overlay_psth']

    def __init__(self, file_dir=None, data_name=None, shell=False):
        '''Initialize dataset object from file_dir, grabs basename from name of
        directory and initializes basic analysis parameters

        Parameters
        ----------
        file_dir : str (optional), file directory for intan recording data

        Throws
        ------
        ValueError
            if file_dir is not provided and no directory is chosen
            when prompted
        NotADirectoryError : if file_dir does not exist
        '''
        super().__init__('dataset', file_dir, data_name=data_name, shell=shell)
        h5_file = dio.h5io.get_h5_filename(self.root_dir)
        if h5_file is None:
            h5_file = os.path.join(self.root_dir, '%s.h5' % self.data_name)

        self.h5_file = h5_file

        self.dataset_creation_date = dt.datetime.today()

        # Outline standard processing pipeline and status check
        self.processing_steps = dataset.PROCESSING_STEPS.copy()
        self.process_status = dict.fromkeys(self.processing_steps, False)

    def _change_root(self, new_root=None):
        old_root = self.root_dir
        new_root = super()._change_root(new_root)
        self.h5_file = self.h5_file.replace(old_root, new_root)
        return new_root

    @Logger('Initializing Parameters')
    def initParams(self, data_quality='clean', emg_port=None,
                   emg_channels=None, car_keyword=None,
                   car_group_areas=None,
                   shell=False, dig_in_names=None,
                   dig_out_names=None, accept_params=False):
        '''
        Initalizes basic default analysis parameters and allows customization
        of parameters

        Parameters (all optional)
        -------------------------
        data_quality : {'clean', 'noisy'}
            keyword defining which default set of parameters to use to detect
            headstage disconnection during clustering
            default is 'clean'. Best practice is to run blech_clust as 'clean'
            and re-run as 'noisy' if too many early cutoffs occurr
        emg_port : str
            Port ('A', 'B', 'C') of EMG, if there was an EMG. None (default)
            will query user. False indicates no EMG port and not to query user
        emg_channels : list of int
            channel or channels of EMGs on port specified
            default is None
        car_keyword : str
            Specifes default common average reference groups
            defaults are found in CAR_defaults.json
            Currently 'bilateral32' is only keyword available
            If left as None (default) user will be queries to select common
            average reference groups
        shell : bool
            False (default) for GUI. True for command-line interface
        dig_in_names : list of str
            Names of digital inputs. Must match number of digital inputs used
            in recording.
            None (default) queries user to name each dig_in
        dig_out_names : list of str
            Names of digital outputs. Must match number of digital outputs in
            recording.
            None (default) queries user to name each dig_out
        accept_params : bool
            True automatically accepts default parameters where possible,
            decreasing user queries
            False (default) will query user to confirm or edit parameters for
            clustering, spike array and psth creation and palatability/identity
            calculations
        '''
        # Get parameters from info.rhd
        file_dir = self.root_dir
        rec_info = dio.rawIO.read_rec_info(file_dir)
        ports = rec_info.pop('ports')
        channels = rec_info.pop('channels')
        sampling_rate = rec_info['amplifier_sampling_rate']
        self.rec_info = rec_info
        self.sampling_rate = sampling_rate

        # Get default parameters from files
        clustering_params = dio.params.load_params('clustering_params', file_dir,
                                                   default_keyword=data_quality)
        spike_array_params = dio.params.load_params('spike_array_params', file_dir)
        psth_params = dio.params.load_params('psth_params', file_dir)
        pal_id_params = dio.params.load_params('pal_id_params', file_dir)
        spike_array_params['sampling_rate'] = sampling_rate
        clustering_params['file_dir'] = file_dir
        clustering_params['sampling_rate'] = sampling_rate

        # Setup digital input mapping
        if rec_info.get('dig_in'):
            self._setup_digital_mapping('in', dig_in_names, shell)
            dim = self.dig_in_mapping.copy()
            spike_array_params['laser_channels'] = dim.channel[dim['laser']].to_list()
            spike_array_params['dig_ins_to_use'] = dim.channel[dim['spike_array']].to_list()
        else:
            self.dig_in_mapping = None

        if rec_info.get('dig_out'):
            self._setup_digital_mapping('out', dig_out_names, shell)
            dom = self.dig_out_mapping.copy()
        else:
            self.dig_out_mapping = None

        # Setup electrode and emg mapping
        self._setup_channel_mapping(ports, channels, emg_port,
                                    emg_channels, shell=shell)

        # Set CAR groups
        self._set_CAR_groups(group_keyword=car_keyword, group_areas=car_group_areas, shell=shell)

        # Confirm parameters
        self.spike_array_params = spike_array_params
        if not accept_params:
            conf = userIO.confirm_parameter_dict
            clustering_params = conf(clustering_params,
                                     'Clustering Parameters', shell=shell)
            self.edit_spike_array_params(shell=shell)
            psth_params = conf(psth_params,
                               'PSTH Parameters', shell=shell)
            pal_id_params = conf(pal_id_params,
                                 'Palatability/Identity Parameters\n'
                                 'Valid unit_type is Single, Multi or All',
                                 shell=shell)

        # Store parameters
        self.clustering_params = clustering_params
        self.pal_id_params = pal_id_params
        self.psth_params = psth_params
        self._write_all_params_to_json()
        self.process_status['initialize parameters'] = True
        self.save()

    def _set_CAR_groups(self, group_keyword=None, shell=False, group_areas=None):
        '''Sets that electrode groups for common average referencing and
        defines which brain region electrodes eneded up in

        Parameters
        ----------
        group_keyword : str or int
            Keyword corresponding to a preset electrode grouping in CAR_params.json
            Or integer indicating number of CAR groups
        shell : bool
            True for command-line interface, False (default) for GUI
        '''
        if not hasattr(self, 'electrode_mapping'):
            raise ValueError('Set electrode mapping before setting CAR groups')

        em = self.electrode_mapping.copy()

        car_param_file = os.path.join(self.root_dir, 'analysis_params',
                                      'CAR_params.json')
        if os.path.isfile(car_param_file):
            tmp = dio.params.load_params('CAR_params', self.root_dir)
            if tmp is not None:
                group_electrodes = tmp
            else:
                raise ValueError('CAR_params file exists in recording dir, but is empty')

        else:
            if group_keyword is None:
                group_keyword = userIO.get_user_input(
                    'Input keyword for CAR parameters or number of CAR groups',
                    shell=shell)

                if group_keyword is None:
                    ValueError('Must provide a keyword or number of groups')

            if group_keyword.isnumeric():
                num_groups = int(group_keyword)
                group_electrodes = dio.params.select_CAR_groups(num_groups, em,
                                                                shell=shell)
            else:
                group_electrodes = dio.params.load_params('CAR_params',
                                                          self.root_dir,
                                                          default_keyword=group_keyword)

        num_groups = len(group_electrodes)
        if group_areas is not None and len(group_areas) == num_groups:
            for i, x in enumerate(zip(group_electrodes, group_areas)):
                em.loc[x[0], 'area'] = x[1]
                em.loc[x[0], 'CAR_group'] = i

        else:
            group_names = ['Group %i' % i for i in range(num_groups)]
            area_dict = dict.fromkeys(group_names, '')
            area_dict = userIO.fill_dict(area_dict, 'Set Areas for CAR groups',
                                         shell=shell)
            for k, v in area_dict.items():
                i = int(k.replace('Group', ''))
                em.loc[group_electrodes[i], 'area'] = v
                em.loc[group_electrodes[i], 'CAR_group'] = i

        self.CAR_electrodes = group_electrodes
        self.electrode_mapping = em.copy()

    @Logger('Re-labelling CAR group areas')
    def set_electrode_areas(self, areas):
        '''sets the electrode area for each CAR group.

        Parameters
        ----------
        areas : list of str
            number of elements must match number of CAR groups

        Throws
        ------
        ValueError
        '''
        em = self.electrode_mapping.copy()
        if len(em['CAR_group'].unique()) != len(areas):
            raise ValueError('Number of items in areas must match number of CAR groups')

        em['areas'] = em['CAR_group'].apply(lambda x: areas[int(x)])
        self.electrode_mapping = em.copy()
        dio.h5io.write_electrode_map_to_h5(self.h5_file, self.electrode_mapping)
        self.save()

    def _setup_digital_mapping(self, dig_type, dig_in_names=None, shell=False):
        '''sets up dig_in_mapping dataframe  and queries user to fill in columns

        Parameters
        ----------
        dig_in_names : list of str (optional)
        shell : bool (optional)
            True for command-line interface
            False (default) for GUI
        '''
        rec_info = self.rec_info
        df = pd.DataFrame()
        df['channel'] = rec_info.get('dig_%s' % dig_type)
        n_dig_in = len(df)
        # Names
        if dig_in_names:
            df['name'] = dig_in_names
        else:
            df['name'] = ''

        # Parameters to query
        if dig_type == 'in':
            df['palatability_rank'] = 0
            df['laser'] = False
            df['spike_array'] = True

        df['exclude'] = False
        # Re-format for query
        idx = df.index
        df.index = ['dig_%s_%i' % (dig_type, x) for x in df.channel]
        dig_str = dig_type + 'put'
        # Query for user input
        prompt = ('Digital %s Parameters\nSet palatability ranks from 1 to %i'
                  '\nor blank to exclude from pal_id analysis') % (dig_str, len(df))
        tmp = userIO.fill_dict(df.to_dict(), prompt=prompt, shell=shell)
        # Reformat for storage
        df2 = pd.DataFrame.from_dict(tmp)
        df2 = df2.sort_values(by=['channel'])
        df2.index = idx
        if dig_type == 'in':
            df2['palatability_rank'] = df2['palatability_rank'].fillna(-1).astype('int')

        if dig_type == 'in':
            self.dig_in_mapping = df2.copy()
        else:
            self.dig_out_mapping = df2.copy()

        if os.path.isfile(self.h5_file):
            dio.h5io.write_digital_map_to_h5(self.h5_file, self.dig_in_mapping, dig_type)

    def _setup_channel_mapping(self, ports, channels, emg_port, emg_channels, shell=False):
        '''Creates electrode_mapping and emg_mapping DataFrames with columns:
        - Electrode
        - Port
        - Channel

        Parameters
        ----------
        ports : list of str, item corresponing to each channel
        channels : list of int, channels on each port
        emg_port : str
        emg_channels : list of int
        '''
        if emg_port is None:
            q = userIO.ask_user('Do you have an EMG?', shell=shell)
            if q==1:
                emg_port = userIO.select_from_list('Select EMG Port:',
                                                   ports, 'EMG Port',
                                                   shell=shell)
                emg_channels = userIO.select_from_list(
                    'Select EMG Channels:',
                    [y for x, y in
                     zip(ports, channels)
                     if x == emg_port],
                    title='EMG Channels',
                    multi_select=True, shell=shell)

        el_map, em_map = dio.params.flatten_channels(ports, channels,
                                                     emg_port, emg_channels)
        self.electrode_mapping = el_map
        self.emg_mapping = em_map
        if os.path.isfile(self.h5_file):
            dio.h5io.write_electrode_map_to_h5(self.h5_file, self.electrode_mapping)

    def edit_spike_array_params(self, shell=False):
        '''Edit spike array parameters and adjust dig_in_mapping accordingly

        Parameters
        ----------
        shell : bool, whether to use CLI or GUI
        '''
        if not hasattr(self, 'dig_in_mapping'):
            self.spike_array_params = None
            return

        sa = deepcopy(self.spike_array_params)
        tmp = userIO.fill_dict(sa, 'Spike Array Parameters\n(Times in ms)',
                               shell=shell)
        if tmp is None:
            return

        dim = self.dig_in_mapping
        dim['spike_array'] = False
        if tmp['dig_ins_to_use'] != ['']:
            tmp['dig_ins_to_use'] = [int(x) for x in tmp['dig_ins_to_use']]
            dim.loc[[x in tmp['dig_ins_to_use'] for x in dim.channel],
                    'spike_array'] = True

        dim['laser_channels'] = False
        if tmp['laser_channels'] != ['']:
            tmp['laser_channels'] = [int(x) for x in tmp['laser_channels']]
            dim.loc[[x in tmp['laser_channels'] for x in dim.channel],
                    'laser'] = True

        self.spike_array_params = tmp.copy()
        wt.write_params_to_json('spike_array_params',
                                self.root_dir, tmp)

    def edit_clustering_params(self, shell=False):
        '''Allows user interface for editing clustering parameters

        Parameters
        ----------
        shell : bool (optional)
            True if you want command-line interface, False for GUI (default)
        '''
        tmp = userIO.fill_dict(self.clustering_params,
                               'Clustering Parameters\n(Times in ms)',
                               shell=shell)
        if tmp:
            self.clustering_params = tmp
            wt.write_params_to_json('clustering_params', self.root_dir, tmp)

    def edit_psth_params(self, shell=False):
        '''Allows user interface for editing psth parameters

        Parameters
        ----------
        shell : bool (optional)
            True if you want command-line interface, False for GUI (default)
        '''
        tmp = userIO.fill_dict(self.psth_params,
                               'PSTH Parameters\n(Times in ms)',
                               shell=shell)
        if tmp:
            self.psth_params = tmp
            wt.write_params_to_json('psth_params', self.root_dir, tmp)

    def edit_pal_id_params(self, shell=False):
        '''Allows user interface for editing palatability/identity parameters

        Parameters
        ----------
        shell : bool (optional)
            True if you want command-line interface, False for GUI (default)
        '''
        tmp = userIO.fill_dict(self.pal_id_params,
                               'Palatability/Identity Parameters\n(Times in ms)',
                               shell=shell)
        if tmp:
            self.pal_id_params = tmp
            wt.write_params_to_json('pal_id_params', self.root_dir, tmp)

    def __str__(self):
        '''Put all information about dataset in string format

        Returns
        -------
        str : representation of dataset object
        '''
        out1 = super().__str__()
        out = [out1]
        out.append('\nObject creation date: '
                   + self.dataset_creation_date.strftime('%m/%d/%y'))

        if hasattr(self, 'raw_h5_file'):
            out.append('Deleted Raw h5 file: '+self.raw_h5_file)

        out.append('h5 File: '+self.h5_file)
        out.append('')

        out.append('--------------------')
        out.append('Processing Status')
        out.append('--------------------')
        out.append(pt.print_dict(self.process_status))
        out.append('')

        if not hasattr(self, 'rec_info'):
            return '\n'.join(out)

        info = self.rec_info

        out.append('--------------------')
        out.append('Recording Info')
        out.append('--------------------')
        out.append(pt.print_dict(self.rec_info))
        out.append('')

        out.append('--------------------')
        out.append('Electrodes')
        out.append('--------------------')
        out.append(pt.print_dataframe(self.electrode_mapping))
        out.append('')

        if hasattr(self, 'CAR_electrodes'):
            out.append('--------------------')
            out.append('CAR Groups')
            out.append('--------------------')
            headers = ['Group %i' % x for x in range(len(self.CAR_electrodes))]
            out.append(pt.print_list_table(self.CAR_electrodes, headers))
            out.append('')

        if not self.emg_mapping.empty:
            out.append('--------------------')
            out.append('EMG')
            out.append('--------------------')
            out.append(pt.print_dataframe(self.emg_mapping))
            out.append('')

        if info.get('dig_in'):
            out.append('--------------------')
            out.append('Digital Input')
            out.append('--------------------')
            out.append(pt.print_dataframe(self.dig_in_mapping))
            out.append('')

        if info.get('dig_out'):
            out.append('--------------------')
            out.append('Digital Output')
            out.append('--------------------')
            out.append(pt.print_dataframe(self.dig_out_mapping))
            out.append('')

        out.append('--------------------')
        out.append('Clustering Parameters')
        out.append('--------------------')
        out.append(pt.print_dict(self.clustering_params))
        out.append('')

        out.append('--------------------')
        out.append('Spike Array Parameters')
        out.append('--------------------')
        out.append(pt.print_dict(self.spike_array_params))
        out.append('')

        out.append('--------------------')
        out.append('PSTH Parameters')
        out.append('--------------------')
        out.append(pt.print_dict(self.psth_params))
        out.append('')

        out.append('--------------------')
        out.append('Palatability/Identity Parameters')
        out.append('--------------------')
        out.append(pt.print_dict(self.pal_id_params))
        out.append('')

        return '\n'.join(out)

    @Logger('Writing parameters to JSON')
    def _write_all_params_to_json(self):
        '''Writes all parameters to json files in analysis_params folder in the
        recording directory
        '''
        print('Writing all parameters to json file in analysis_params folder...')
        clustering_params = self.clustering_params
        spike_array_params = self.spike_array_params
        psth_params = self.psth_params
        pal_id_params = self.pal_id_params
        CAR_params = self.CAR_electrodes

        rec_dir = self.root_dir
        wt.write_params_to_json('clustering_params', rec_dir, clustering_params)
        wt.write_params_to_json('spike_array_params', rec_dir, spike_array_params)
        wt.write_params_to_json('psth_params', rec_dir, psth_params)
        wt.write_params_to_json('pal_id_params', rec_dir, pal_id_params)
        wt.write_params_to_json('CAR_params', rec_dir, CAR_params)

    @Logger('Extracting Data')
    def extract_data(self, filename=None, shell=False):
        '''Create hdf5 store for data and read in Intan .dat files. Also create
        subfolders for processing outputs

        Parameters
        ----------
        data_quality: {'clean', 'noisy'} (optional)
            Specifies quality of data for default clustering parameters
            associated. Should generally first process with clean (default)
            parameters and then try noisy after running blech_clust and
            checking if too many electrodes as cutoff too early
        '''
        if self.rec_info['file_type'] is None:
            raise ValueError('Unsupported recording type. Cannot extract yet.')

        if filename is None:
            filename = self.h5_file

        print('\nExtract Intan Data\n--------------------')
        # Create h5 file
        tmp = dio.h5io.create_empty_data_h5(filename, shell)
        if tmp is None:
            return

        # Create arrays for raw data in hdf5 store
        dio.h5io.create_hdf_arrays(filename, self.rec_info,
                                   self.electrode_mapping, self.emg_mapping)

        # Read in data to arrays
        dio.h5io.read_files_into_arrays(filename,
                                        self.rec_info,
                                        self.electrode_mapping,
                                        self.emg_mapping)

        # Write electrode and digital input mapping into h5 file
        # TODO: write EMG and digital output mapping into h5 file
        dio.h5io.write_electrode_map_to_h5(self.h5_file, self.electrode_mapping)
        if self.dig_in_mapping is not None:
            dio.h5io.write_digital_map_to_h5(self.h5_file, self.dig_in_mapping, 'in')

        if self.dig_out_mapping is not None:
            dio.h5io.write_digital_map_to_h5(self.h5_file, self.dig_in_mapping, 'out')

        # update status
        self.h5_file = filename
        self.process_status['extract_data'] = True
        self.save()
        print('\nData Extraction Complete\n--------------------')

    @Logger('Creating Trial List')
    def create_trial_list(self):
        '''Create lists of trials based on digital inputs and outputs and store
        to hdf5 store
        Can only be run after data extraction
        '''
        if self.rec_info.get('dig_in'):
            in_list = dio.h5io.create_trial_data_table(
                self.h5_file,
                self.dig_in_mapping,
                self.sampling_rate,
                'in')
            self.dig_in_trials = in_list
        else:
            print('No digital input data found')

        if self.rec_info.get('dig_out'):
            out_list = dio.h5io.create_trial_data_table(
                self.h5_file,
                self.dig_out_mapping,
                self.sampling_rate,
                'out')
            self.dig_out_trials = out_list
        else:
            print('No digital output data found')

        self.process_status['create_trial_list'] = True
        self.save()

    @Logger('Marking Dead Channels')
    def mark_dead_channels(self, dead_channels=None, shell=False):
        '''Plots small piece of raw traces and a metric to help identify dead
        channels. Once user marks channels as dead a new column is added to
        electrode mapping

        Parameters
        ----------
        dead_channels : list of int, optional
            if this is specified then nothing is plotted, those channels are
            simply marked as dead
        shell : bool, optional
        '''
        print('Marking dead channels\n----------')
        em = self.electrode_mapping.copy()
        if dead_channels is None:
            userIO.tell_user('Making traces figure for dead channel detection...',
                             shell=True)
            save_file = os.path.join(self.root_dir, 'Electrode_Traces.png')
            fig, ax = datplt.plot_traces_and_outliers(self.h5_file, save_file=save_file)
            if not shell:
                # Better to open figure outside of python since its a lot of
                # data on figure and matplotlib is slow
                subprocess.call(['xdg-open', save_file])
            else:
                userIO.tell_user('Saved figure of traces to %s for reference'
                                 % save_file, shell=shell)

            choice = userIO.select_from_list('Select dead channels:',
                                             em.Electrode.to_list(),
                                             'Dead Channel Selection',
                                             multi_select=True,
                                             shell=shell)
            dead_channels = list(map(int, choice))

        print('Marking eletrodes %s as dead.\n'
              'They will be excluded from common average referencing.'
              % dead_channels)
        em['dead'] = False
        em.loc[dead_channels, 'dead'] = True
        self.electrode_mapping = em
        if os.path.isfile(self.h5_file):
            dio.h5io.write_electrode_map_to_h5(self.h5_file, self.electrode_mapping)

        self.process_status['mark_dead_channels'] = True
        self.save()
        return dead_channels

    @Logger('Common Average Referencing')
    def common_average_reference(self):
        '''Define electrode groups and remove common average from  signals

        Parameters
        ----------
        num_groups : int (optional)
            number of CAR groups, if not provided
            there's a prompt
        '''
        if not hasattr(self, 'CAR_electrodes'):
            raise ValueError('CAR_electrodes not set')

        if not hasattr(self, 'electrode_mapping'):
            raise ValueError('electrode_mapping not set')

        car_electrodes = self.CAR_electrodes
        num_groups = len(car_electrodes)
        em = self.electrode_mapping.copy()

        if 'dead' in em.columns:
            dead_electrodes = em.Electrode[em.dead].to_list()
        else:
            dead_electrodes = []

        # Gather Common Average Reference Groups
        print('CAR Groups\n')
        headers = ['Group %i' % x for x in range(num_groups)]
        print(pt.print_list_table(car_electrodes, headers))

        # Reference each group
        for i, x in enumerate(car_electrodes):
            tmp = list(set(x) - set(dead_electrodes))
            dio.h5io.common_avg_reference(self.h5_file, tmp, i)

        # Compress and repack file
        dio.h5io.compress_and_repack(self.h5_file)

        self.process_status['common_average_reference'] = True
        self.save()

    @Logger('Running Spike Detection')
    def detect_spikes(self, data_quality=None, multi_process=True, n_cores=None):
        '''Run spike detection on each electrode. Prepares for clustering with
        BlechClust. Works for both single recording clustering or
        multi-recording clustering

        Parameters
        ----------
        data_quality : {'clean', 'noisy', None (default)}
            set if you want to change the data quality parameters for cutoff
            and spike detection before running clustering. These parameters are
            automatically set as "clean" during initial parameter setup
        n_cores : int (optional)
            number of cores to use for parallel processing. default is max-1.
        '''
        if data_quality:
            tmp = dio.params.load_params('clustering_params', self.root_dir,
                                         default_keyword=data_quality)
            if tmp:
                self.clustering_params = tmp
                wt.write_params_to_json('clustering_params', self.root_dir, tmp)
            else:
                raise ValueError('%s is not a valid data_quality preset. Must '
                                 'be "clean" or "noisy" or None.')

        print('\nRunning Spike Detection\n-------------------')
        print('Parameters\n%s' % pt.print_dict(self.clustering_params))

        # Create folders for saving things within recording dir
        data_dir = self.root_dir

        em = self.electrode_mapping
        if 'dead' in em.columns:
            electrodes = em.Electrode[em['dead'] == False].tolist()
        else:
            electrodes = em.Electrode.tolist()


        pbar = tqdm(total = len(electrodes))
        results = [(None, None, None)] * (max(electrodes)+1)
        def update_pbar(ans):
            if isinstance(ans, tuple) and ans[0] is not None:
                results[ans[0]] = ans
            else:
                print('Unexpected error when clustering an electrode')

            pbar.update()

        spike_detectors = [clust.SpikeDetection(data_dir, x,
                                                self.clustering_params)
                           for x in electrodes]
        if multi_process:
            if n_cores is None or n_cores > multiprocessing.cpu_count():
                n_cores = multiprocessing.cpu_count() - 1

            pool = multiprocessing.get_context('spawn').Pool(n_cores)
            for sd in spike_detectors:
                pool.apply_async(sd.run, callback=update_pbar)

            pool.close()
            pool.join()
        else:
            for sd in spike_detectors:
                res = sd.run()
                update_pbar(res)

        pbar.close()

        print('Electrode    Result    Cutoff (s)')
        cutoffs = {}
        clust_res = {}
        clustered = []
        for x, y, z in results:
            if x is None:
                continue

            clustered.append(x)
            print('  {:<13}{:<10}{}'.format(x, y, z))
            cutoffs[x] = z
            clust_res[x] = y

        print('1 - Sucess\n0 - No data or no spikes\n-1 - Error')

        em = self.electrode_mapping.copy()
        em['cutoff_time'] = em['Electrode'].map(cutoffs)
        em['clustering_result'] = em['Electrode'].map(clust_res)
        self.electrode_mapping = em.copy()
        self.process_status['spike_detection'] = True
        dio.h5io.write_electrode_map_to_h5(self.h5_file, em)
        self.save()
        print('Spike Detection Complete\n------------------')
        return results

    @Logger('Running Blech Clust')
    def blech_clust_run(self, data_quality=None, multi_process=True, n_cores=None, umap=False):
        '''Write clustering parameters to file and
        Run blech_process on each electrode using GNU parallel

        Parameters
        ----------
        data_quality : {'clean', 'noisy', None (default)}
            set if you want to change the data quality parameters for cutoff
            and spike detection before running clustering. These parameters are
            automatically set as "clean" during initial parameter setup
        accept_params : bool, False (default)
            set to True in order to skip popup confirmation of parameters when
            running
        '''
        if self.process_status['spike_detection'] == False:
            raise FileNotFoundError('Must run spike detection before clustering.')

        if data_quality:
            tmp = dio.params.load_params('clustering_params', self.root_dir,
                                         default_keyword=data_quality)
            if tmp:
                self.clustering_params = tmp
                wt.write_params_to_json('clustering_params', self.root_dir, tmp)
            else:
                raise ValueError('%s is not a valid data_quality preset. Must '
                                 'be "clean" or "noisy" or None.')

        print('\nRunning Blech Clust\n-------------------')
        print('Parameters\n%s' % pt.print_dict(self.clustering_params))

        # Get electrodes, throw out 'dead' electrodes
        em = self.electrode_mapping
        if 'dead' in em.columns:
            electrodes = em.Electrode[em['dead'] == False].tolist()
        else:
            electrodes = em.Electrode.tolist()


        pbar = tqdm(total = len(electrodes))
        def update_pbar(ans):
            pbar.update()

        errors = []
        def error_call(e):
            errors.append(e)

        if not umap:
            clust_objs = [clust.BlechClust(self.root_dir, x, params=self.clustering_params)
                          for x in electrodes]
        else:
            clust_objs = [clust.BlechClust(self.root_dir, x,
                                           params=self.clustering_params,
                                           data_transform=clust.UMAP_METRICS,
                                           n_pc=5)
                          for x in electrodes]

        if multi_process:
            if n_cores is None or n_cores > multiprocessing.cpu_count():
                n_cores = multiprocessing.cpu_count() - 1

            pool = multiprocessing.get_context('spawn').Pool(n_cores)
            for x in clust_objs:
                pool.apply_async(x.run, callback=update_pbar, error_callback=error_call)

            pool.close()
            pool.join()
        else:
            for x in clust_objs:
                res = x.run()
                update_pbar(res)

        pbar.close()

        self.process_status['spike_clustering'] = True
        self.process_status['cleanup_clustering'] = False
        dio.h5io.write_electrode_map_to_h5(self.h5_file, em)
        self.save()
        print('Clustering Complete\n------------------')
        if len(errors) > 0:
            print('Errors encountered:')
            print(errors)

    @Logger('Cleaning up clustering memory logs. Removing raw data and setting'
            'up hdf5 for unit sorting')
    def cleanup_clustering(self):
        '''Consolidates memory monitor files, removes raw and referenced data
        and setups up hdf5 store for sorted units data
        '''
        if self.process_status['cleanup_clustering']:
            return

        h5_file = dio.h5io.cleanup_clustering(self.root_dir)
        self.h5_file = h5_file
        self.process_status['cleanup_clustering'] = True
        self.save()

    def sort_spikes(self, electrode=None, shell=False):
        if electrode is None:
            electrode = userIO.get_user_input('Electrode #: ', shell=shell)
            if electrode is None or not electrode.isnumeric():
                return

            electrode = int(electrode)

        if not self.process_status['spike_clustering']:
            raise ValueError('Must run spike clustering first.')

        if not self.process_status['cleanup_clustering']:
            self.cleanup_clustering()

        sorter = clust.SpikeSorter(self.root_dir, electrode=electrode, shell=shell)
        if not shell:
            root, sorting_GUI = ssg.launch_sorter_GUI(sorter)
            return root, sorting_GUI
        else:
            # TODO: Make shell UI
            # TODO: Make sort by table
            print('No shell UI yet')
            return

        self.process_status['sort_units'] = True

    @Logger('Calculating Units Similarity')
    def units_similarity(self, similarity_cutoff=50, shell=False):
        if 'SSH_CONNECTION' in os.environ:
            shell= True

        metrics_dir = os.path.join(self.root_dir, 'sorted_unit_metrics')
        if not os.path.isdir(metrics_dir):
            raise ValueError('No sorted unit metrics found. Must sort units before calculating similarity')

        violation_file = os.path.join(metrics_dir,
                                      'units_similarity_violations.txt')
        violations, sim = ss.calc_units_similarity(self.h5_file,
                                                   self.sampling_rate,
                                                   similarity_cutoff,
                                                   violation_file)
        if len(violations) == 0:
            userIO.tell_user('No similarity violations found!', shell=shell)
            self.process_status['units_similarity'] = True
            return violations, sim

        out_str = ['Units Similarity Violations Found:']
        out_str.append('Unit_1    Unit_2    Similarity')
        for x,y in violations:
            u1 = dio.h5io.parse_unit_number(x)
            u2 = dio.h5io.parse_unit_number(y)
            out_str.append('   {:<10}{:<10}{}\n'.format(x, y, sim[u1][u2]))

        out_str.append('Delete units with dataset.delete_unit(N)')
        out_str = '\n'.join(out_str)
        userIO.tell_user(out_str, shell=shell)
        self.process_status['units_similarity'] = True
        self.save()
        return violations, sim

    @Logger('Deleting Unit')
    def delete_unit(self, unit_num, confirm=False, shell=False):
        if isinstance(unit_num, str):
            unit_num = dio.h5io.parse_unit_number(unit_num)

        if unit_num is None:
            print('No unit deleted')
            return

        if not confirm:
            q = userIO.ask_user('Are you sure you want to delete unit%03i?' % unit_num,
                                choices = ['No','Yes'], shell=shell)
        else:
            q = 1

        if q == 0:
            print('No unit deleted')
            return
        else:
            tmp = dio.h5io.delete_unit(self.root_dir, unit_num)
            if tmp is False:
                userIO.tell_user('Unit %i not found in dataset. No unit deleted'
                                 % unit_num, shell=shell)
            else:
                userIO.tell_user('Unit %i sucessfully deleted.' % unit_num,
                                 shell=shell)

        self.save()

    @Logger('Making Unit Arrays')
    def make_unit_arrays(self):
        '''Make spike arrays for each unit and store in hdf5 store
        '''
        params = self.spike_array_params

        print('Generating unit arrays with parameters:\n----------')
        print(pt.print_dict(params, tabs=1))
        ss.make_spike_arrays(self.h5_file, params)
        self.process_status['make_unit_arrays'] = True
        self.save()

    @Logger('Making Unit Plots')
    def make_unit_plots(self):
        '''Make waveform plots for each sorted unit
        '''
        unit_table = self.get_unit_table()
        save_dir = os.path.join(self.root_dir, 'unit_waveforms_plots')
        if os.path.isdir(save_dir):
            shutil.rmtree(save_dir)

        os.mkdir(save_dir)
        for i, row in unit_table.iterrows():
            datplt.make_unit_plots(self.root_dir, row['unit_name'], save_dir=save_dir)

        self.process_status['make_unit_plots'] = True
        self.save()

    @Logger('Making PSTH Arrays')
    def make_psth_arrays(self):
        '''Make smoothed firing rate traces for each unit/trial and store in
        hdf5 store
        '''
        params = self.psth_params
        dig_ins = self.dig_in_mapping.query('spike_array == True')
        for idx, row in dig_ins.iterrows():
            spike_analysis.make_psths_for_tastant(self.h5_file,
                                                  params['window_size'],
                                                  params['window_step'],
                                                  row['channel'])

        self.process_status['make_psth_arrays'] = True
        self.save()

    @Logger('Calculating Palatability/Identity Metrics')
    def palatability_calculate(self, shell=False):
        pal_analysis.palatability_identity_calculations(self.root_dir,
                                                        params=self.pal_id_params)
        self.process_status['palatability_calculate'] = True
        self.save()

    @Logger('Plotting Palatability/Identity Metrics')
    def palatability_plot(self, shell=False):
        pal_plt.plot_palatability_identity([self.root_dir], shell=shell)
        self.process_status['palatability_plot'] = True
        self.save()

    @Logger('Removing low-spiking units')
    def cleanup_lowSpiking_units(self, min_spikes=100):
        unit_table = self.get_unit_table()
        remove = []
        spike_count = []
        for unit in unit_table['unit_num']:
            waves, descrip, fs = dio.h5io.get_unit_waveforms(self.root_dir, unit)
            if waves.shape[0] < min_spikes:
                spike_count.append(waves.shape[0])
                remove.append(unit)

        for unit, count in zip(reversed(remove), reversed(spike_count)):
            print('Removing unit %i. Only %i spikes.' % (unit, count))
            userIO.tell_user('Removing unit %i. Only %i spikes.'
                             % (unit, count), shell=True)
            self.delete_unit(unit, confirm=True, shell=True)

        userIO.tell_user('Removed %i units for having less than %i spikes.'
                         % (len(remove), min_spikes), shell=True)

    def get_unit_table(self):
        '''Returns a pandas dataframe with sorted unit information

        Returns
        --------
        pandas.DataFrame with columns:
            unit_name, unit_num, electrode, single_unit,
            regular_spiking, fast_spiking
        '''
        unit_table = dio.h5io.get_unit_table(self.root_dir)
        return unit_table

    def circus_clust_run(self, shell=False):
        circ.prep_for_circus(self.root_dir, self.electrode_mapping,
                             self.data_name, self.sampling_rate)
        circ.start_the_show()

    def pre_process_for_clustering(self, shell=False, dead_channels=None):
        status = self.process_status
        if not status['initialize parameters']:
            self.initParams(shell=shell)

        if not status['extract_data']:
            self.extract_data(shell=True)

        if not status['create_trial_list']:
            self.create_trial_list()

        if not status['mark_dead_channels'] and dead_channels != False:
            self.mark_dead_channels(dead_channels=dead_channels, shell=shell)

        if not status['common_average_reference']:
            self.common_average_reference()

        if not status['spike_detection']:
            self.detect_spikes()

    def extract_and_circus_cluster(self, dead_channels=None, shell=True):
        print('Extracting Data...')
        self.extract_data()
        print('Marking dead channels...')
        self.mark_dead_channels(dead_channels, shell=shell)
        print('Common average referencing...')
        self.common_average_reference()
        print('Initiating circus clustering...')
        circus = circ.circus_clust(self.root_dir, self.data_name,
                                   self.sampling_rate, self.electrode_mapping)
        print('Preparing for circus...')
        circus.prep_for_circus()
        print('Starting circus clustering...')
        circus.start_the_show()
        print('Plotting cluster waveforms...')
        circus.plot_cluster_waveforms()

    def post_sorting(self):
        self.make_unit_plots()
        self.make_unit_arrays()
        self.units_similarity(shell=True)
        self.make_psth_arrays()

Ancestors

Class variables

var PROCESSING_STEPS

Methods

def blech_clust_run(self, data_quality=None, multi_process=True, n_cores=None, umap=False)

Write clustering parameters to file and Run blech_process on each electrode using GNU parallel

Parameters

data_quality : {'clean', 'noisy', None (default)}
set if you want to change the data quality parameters for cutoff and spike detection before running clustering. These parameters are automatically set as "clean" during initial parameter setup
accept_params : bool, False (default)
set to True in order to skip popup confirmation of parameters when running
Expand source code
@Logger('Running Blech Clust')
def blech_clust_run(self, data_quality=None, multi_process=True, n_cores=None, umap=False):
    '''Write clustering parameters to file and
    Run blech_process on each electrode using GNU parallel

    Parameters
    ----------
    data_quality : {'clean', 'noisy', None (default)}
        set if you want to change the data quality parameters for cutoff
        and spike detection before running clustering. These parameters are
        automatically set as "clean" during initial parameter setup
    accept_params : bool, False (default)
        set to True in order to skip popup confirmation of parameters when
        running
    '''
    if self.process_status['spike_detection'] == False:
        raise FileNotFoundError('Must run spike detection before clustering.')

    if data_quality:
        tmp = dio.params.load_params('clustering_params', self.root_dir,
                                     default_keyword=data_quality)
        if tmp:
            self.clustering_params = tmp
            wt.write_params_to_json('clustering_params', self.root_dir, tmp)
        else:
            raise ValueError('%s is not a valid data_quality preset. Must '
                             'be "clean" or "noisy" or None.')

    print('\nRunning Blech Clust\n-------------------')
    print('Parameters\n%s' % pt.print_dict(self.clustering_params))

    # Get electrodes, throw out 'dead' electrodes
    em = self.electrode_mapping
    if 'dead' in em.columns:
        electrodes = em.Electrode[em['dead'] == False].tolist()
    else:
        electrodes = em.Electrode.tolist()


    pbar = tqdm(total = len(electrodes))
    def update_pbar(ans):
        pbar.update()

    errors = []
    def error_call(e):
        errors.append(e)

    if not umap:
        clust_objs = [clust.BlechClust(self.root_dir, x, params=self.clustering_params)
                      for x in electrodes]
    else:
        clust_objs = [clust.BlechClust(self.root_dir, x,
                                       params=self.clustering_params,
                                       data_transform=clust.UMAP_METRICS,
                                       n_pc=5)
                      for x in electrodes]

    if multi_process:
        if n_cores is None or n_cores > multiprocessing.cpu_count():
            n_cores = multiprocessing.cpu_count() - 1

        pool = multiprocessing.get_context('spawn').Pool(n_cores)
        for x in clust_objs:
            pool.apply_async(x.run, callback=update_pbar, error_callback=error_call)

        pool.close()
        pool.join()
    else:
        for x in clust_objs:
            res = x.run()
            update_pbar(res)

    pbar.close()

    self.process_status['spike_clustering'] = True
    self.process_status['cleanup_clustering'] = False
    dio.h5io.write_electrode_map_to_h5(self.h5_file, em)
    self.save()
    print('Clustering Complete\n------------------')
    if len(errors) > 0:
        print('Errors encountered:')
        print(errors)
def circus_clust_run(self, shell=False)
Expand source code
def circus_clust_run(self, shell=False):
    circ.prep_for_circus(self.root_dir, self.electrode_mapping,
                         self.data_name, self.sampling_rate)
    circ.start_the_show()
def cleanup_clustering(self)

Consolidates memory monitor files, removes raw and referenced data and setups up hdf5 store for sorted units data

Expand source code
@Logger('Cleaning up clustering memory logs. Removing raw data and setting'
        'up hdf5 for unit sorting')
def cleanup_clustering(self):
    '''Consolidates memory monitor files, removes raw and referenced data
    and setups up hdf5 store for sorted units data
    '''
    if self.process_status['cleanup_clustering']:
        return

    h5_file = dio.h5io.cleanup_clustering(self.root_dir)
    self.h5_file = h5_file
    self.process_status['cleanup_clustering'] = True
    self.save()
def cleanup_lowSpiking_units(self, min_spikes=100)
Expand source code
@Logger('Removing low-spiking units')
def cleanup_lowSpiking_units(self, min_spikes=100):
    unit_table = self.get_unit_table()
    remove = []
    spike_count = []
    for unit in unit_table['unit_num']:
        waves, descrip, fs = dio.h5io.get_unit_waveforms(self.root_dir, unit)
        if waves.shape[0] < min_spikes:
            spike_count.append(waves.shape[0])
            remove.append(unit)

    for unit, count in zip(reversed(remove), reversed(spike_count)):
        print('Removing unit %i. Only %i spikes.' % (unit, count))
        userIO.tell_user('Removing unit %i. Only %i spikes.'
                         % (unit, count), shell=True)
        self.delete_unit(unit, confirm=True, shell=True)

    userIO.tell_user('Removed %i units for having less than %i spikes.'
                     % (len(remove), min_spikes), shell=True)
def common_average_reference(self)

Define electrode groups and remove common average from signals

Parameters

num_groups : int (optional)
number of CAR groups, if not provided there's a prompt
Expand source code
@Logger('Common Average Referencing')
def common_average_reference(self):
    '''Define electrode groups and remove common average from  signals

    Parameters
    ----------
    num_groups : int (optional)
        number of CAR groups, if not provided
        there's a prompt
    '''
    if not hasattr(self, 'CAR_electrodes'):
        raise ValueError('CAR_electrodes not set')

    if not hasattr(self, 'electrode_mapping'):
        raise ValueError('electrode_mapping not set')

    car_electrodes = self.CAR_electrodes
    num_groups = len(car_electrodes)
    em = self.electrode_mapping.copy()

    if 'dead' in em.columns:
        dead_electrodes = em.Electrode[em.dead].to_list()
    else:
        dead_electrodes = []

    # Gather Common Average Reference Groups
    print('CAR Groups\n')
    headers = ['Group %i' % x for x in range(num_groups)]
    print(pt.print_list_table(car_electrodes, headers))

    # Reference each group
    for i, x in enumerate(car_electrodes):
        tmp = list(set(x) - set(dead_electrodes))
        dio.h5io.common_avg_reference(self.h5_file, tmp, i)

    # Compress and repack file
    dio.h5io.compress_and_repack(self.h5_file)

    self.process_status['common_average_reference'] = True
    self.save()
def create_trial_list(self)

Create lists of trials based on digital inputs and outputs and store to hdf5 store Can only be run after data extraction

Expand source code
@Logger('Creating Trial List')
def create_trial_list(self):
    '''Create lists of trials based on digital inputs and outputs and store
    to hdf5 store
    Can only be run after data extraction
    '''
    if self.rec_info.get('dig_in'):
        in_list = dio.h5io.create_trial_data_table(
            self.h5_file,
            self.dig_in_mapping,
            self.sampling_rate,
            'in')
        self.dig_in_trials = in_list
    else:
        print('No digital input data found')

    if self.rec_info.get('dig_out'):
        out_list = dio.h5io.create_trial_data_table(
            self.h5_file,
            self.dig_out_mapping,
            self.sampling_rate,
            'out')
        self.dig_out_trials = out_list
    else:
        print('No digital output data found')

    self.process_status['create_trial_list'] = True
    self.save()
def delete_unit(self, unit_num, confirm=False, shell=False)
Expand source code
@Logger('Deleting Unit')
def delete_unit(self, unit_num, confirm=False, shell=False):
    if isinstance(unit_num, str):
        unit_num = dio.h5io.parse_unit_number(unit_num)

    if unit_num is None:
        print('No unit deleted')
        return

    if not confirm:
        q = userIO.ask_user('Are you sure you want to delete unit%03i?' % unit_num,
                            choices = ['No','Yes'], shell=shell)
    else:
        q = 1

    if q == 0:
        print('No unit deleted')
        return
    else:
        tmp = dio.h5io.delete_unit(self.root_dir, unit_num)
        if tmp is False:
            userIO.tell_user('Unit %i not found in dataset. No unit deleted'
                             % unit_num, shell=shell)
        else:
            userIO.tell_user('Unit %i sucessfully deleted.' % unit_num,
                             shell=shell)

    self.save()
def detect_spikes(self, data_quality=None, multi_process=True, n_cores=None)

Run spike detection on each electrode. Prepares for clustering with BlechClust. Works for both single recording clustering or multi-recording clustering

Parameters

data_quality : {'clean', 'noisy', None (default)}
set if you want to change the data quality parameters for cutoff and spike detection before running clustering. These parameters are automatically set as "clean" during initial parameter setup
n_cores : int (optional)
number of cores to use for parallel processing. default is max-1.
Expand source code
@Logger('Running Spike Detection')
def detect_spikes(self, data_quality=None, multi_process=True, n_cores=None):
    '''Run spike detection on each electrode. Prepares for clustering with
    BlechClust. Works for both single recording clustering or
    multi-recording clustering

    Parameters
    ----------
    data_quality : {'clean', 'noisy', None (default)}
        set if you want to change the data quality parameters for cutoff
        and spike detection before running clustering. These parameters are
        automatically set as "clean" during initial parameter setup
    n_cores : int (optional)
        number of cores to use for parallel processing. default is max-1.
    '''
    if data_quality:
        tmp = dio.params.load_params('clustering_params', self.root_dir,
                                     default_keyword=data_quality)
        if tmp:
            self.clustering_params = tmp
            wt.write_params_to_json('clustering_params', self.root_dir, tmp)
        else:
            raise ValueError('%s is not a valid data_quality preset. Must '
                             'be "clean" or "noisy" or None.')

    print('\nRunning Spike Detection\n-------------------')
    print('Parameters\n%s' % pt.print_dict(self.clustering_params))

    # Create folders for saving things within recording dir
    data_dir = self.root_dir

    em = self.electrode_mapping
    if 'dead' in em.columns:
        electrodes = em.Electrode[em['dead'] == False].tolist()
    else:
        electrodes = em.Electrode.tolist()


    pbar = tqdm(total = len(electrodes))
    results = [(None, None, None)] * (max(electrodes)+1)
    def update_pbar(ans):
        if isinstance(ans, tuple) and ans[0] is not None:
            results[ans[0]] = ans
        else:
            print('Unexpected error when clustering an electrode')

        pbar.update()

    spike_detectors = [clust.SpikeDetection(data_dir, x,
                                            self.clustering_params)
                       for x in electrodes]
    if multi_process:
        if n_cores is None or n_cores > multiprocessing.cpu_count():
            n_cores = multiprocessing.cpu_count() - 1

        pool = multiprocessing.get_context('spawn').Pool(n_cores)
        for sd in spike_detectors:
            pool.apply_async(sd.run, callback=update_pbar)

        pool.close()
        pool.join()
    else:
        for sd in spike_detectors:
            res = sd.run()
            update_pbar(res)

    pbar.close()

    print('Electrode    Result    Cutoff (s)')
    cutoffs = {}
    clust_res = {}
    clustered = []
    for x, y, z in results:
        if x is None:
            continue

        clustered.append(x)
        print('  {:<13}{:<10}{}'.format(x, y, z))
        cutoffs[x] = z
        clust_res[x] = y

    print('1 - Sucess\n0 - No data or no spikes\n-1 - Error')

    em = self.electrode_mapping.copy()
    em['cutoff_time'] = em['Electrode'].map(cutoffs)
    em['clustering_result'] = em['Electrode'].map(clust_res)
    self.electrode_mapping = em.copy()
    self.process_status['spike_detection'] = True
    dio.h5io.write_electrode_map_to_h5(self.h5_file, em)
    self.save()
    print('Spike Detection Complete\n------------------')
    return results
def edit_clustering_params(self, shell=False)

Allows user interface for editing clustering parameters

Parameters

shell : bool (optional)
True if you want command-line interface, False for GUI (default)
Expand source code
def edit_clustering_params(self, shell=False):
    '''Allows user interface for editing clustering parameters

    Parameters
    ----------
    shell : bool (optional)
        True if you want command-line interface, False for GUI (default)
    '''
    tmp = userIO.fill_dict(self.clustering_params,
                           'Clustering Parameters\n(Times in ms)',
                           shell=shell)
    if tmp:
        self.clustering_params = tmp
        wt.write_params_to_json('clustering_params', self.root_dir, tmp)
def edit_pal_id_params(self, shell=False)

Allows user interface for editing palatability/identity parameters

Parameters

shell : bool (optional)
True if you want command-line interface, False for GUI (default)
Expand source code
def edit_pal_id_params(self, shell=False):
    '''Allows user interface for editing palatability/identity parameters

    Parameters
    ----------
    shell : bool (optional)
        True if you want command-line interface, False for GUI (default)
    '''
    tmp = userIO.fill_dict(self.pal_id_params,
                           'Palatability/Identity Parameters\n(Times in ms)',
                           shell=shell)
    if tmp:
        self.pal_id_params = tmp
        wt.write_params_to_json('pal_id_params', self.root_dir, tmp)
def edit_psth_params(self, shell=False)

Allows user interface for editing psth parameters

Parameters

shell : bool (optional)
True if you want command-line interface, False for GUI (default)
Expand source code
def edit_psth_params(self, shell=False):
    '''Allows user interface for editing psth parameters

    Parameters
    ----------
    shell : bool (optional)
        True if you want command-line interface, False for GUI (default)
    '''
    tmp = userIO.fill_dict(self.psth_params,
                           'PSTH Parameters\n(Times in ms)',
                           shell=shell)
    if tmp:
        self.psth_params = tmp
        wt.write_params_to_json('psth_params', self.root_dir, tmp)
def edit_spike_array_params(self, shell=False)

Edit spike array parameters and adjust dig_in_mapping accordingly

Parameters

shell : bool, whether to use CLI or GUI
 
Expand source code
def edit_spike_array_params(self, shell=False):
    '''Edit spike array parameters and adjust dig_in_mapping accordingly

    Parameters
    ----------
    shell : bool, whether to use CLI or GUI
    '''
    if not hasattr(self, 'dig_in_mapping'):
        self.spike_array_params = None
        return

    sa = deepcopy(self.spike_array_params)
    tmp = userIO.fill_dict(sa, 'Spike Array Parameters\n(Times in ms)',
                           shell=shell)
    if tmp is None:
        return

    dim = self.dig_in_mapping
    dim['spike_array'] = False
    if tmp['dig_ins_to_use'] != ['']:
        tmp['dig_ins_to_use'] = [int(x) for x in tmp['dig_ins_to_use']]
        dim.loc[[x in tmp['dig_ins_to_use'] for x in dim.channel],
                'spike_array'] = True

    dim['laser_channels'] = False
    if tmp['laser_channels'] != ['']:
        tmp['laser_channels'] = [int(x) for x in tmp['laser_channels']]
        dim.loc[[x in tmp['laser_channels'] for x in dim.channel],
                'laser'] = True

    self.spike_array_params = tmp.copy()
    wt.write_params_to_json('spike_array_params',
                            self.root_dir, tmp)
def extract_and_circus_cluster(self, dead_channels=None, shell=True)
Expand source code
def extract_and_circus_cluster(self, dead_channels=None, shell=True):
    print('Extracting Data...')
    self.extract_data()
    print('Marking dead channels...')
    self.mark_dead_channels(dead_channels, shell=shell)
    print('Common average referencing...')
    self.common_average_reference()
    print('Initiating circus clustering...')
    circus = circ.circus_clust(self.root_dir, self.data_name,
                               self.sampling_rate, self.electrode_mapping)
    print('Preparing for circus...')
    circus.prep_for_circus()
    print('Starting circus clustering...')
    circus.start_the_show()
    print('Plotting cluster waveforms...')
    circus.plot_cluster_waveforms()
def extract_data(self, filename=None, shell=False)

Create hdf5 store for data and read in Intan .dat files. Also create subfolders for processing outputs

Parameters

data_quality : {'clean', 'noisy'} (optional)
Specifies quality of data for default clustering parameters associated. Should generally first process with clean (default) parameters and then try noisy after running blech_clust and checking if too many electrodes as cutoff too early
Expand source code
@Logger('Extracting Data')
def extract_data(self, filename=None, shell=False):
    '''Create hdf5 store for data and read in Intan .dat files. Also create
    subfolders for processing outputs

    Parameters
    ----------
    data_quality: {'clean', 'noisy'} (optional)
        Specifies quality of data for default clustering parameters
        associated. Should generally first process with clean (default)
        parameters and then try noisy after running blech_clust and
        checking if too many electrodes as cutoff too early
    '''
    if self.rec_info['file_type'] is None:
        raise ValueError('Unsupported recording type. Cannot extract yet.')

    if filename is None:
        filename = self.h5_file

    print('\nExtract Intan Data\n--------------------')
    # Create h5 file
    tmp = dio.h5io.create_empty_data_h5(filename, shell)
    if tmp is None:
        return

    # Create arrays for raw data in hdf5 store
    dio.h5io.create_hdf_arrays(filename, self.rec_info,
                               self.electrode_mapping, self.emg_mapping)

    # Read in data to arrays
    dio.h5io.read_files_into_arrays(filename,
                                    self.rec_info,
                                    self.electrode_mapping,
                                    self.emg_mapping)

    # Write electrode and digital input mapping into h5 file
    # TODO: write EMG and digital output mapping into h5 file
    dio.h5io.write_electrode_map_to_h5(self.h5_file, self.electrode_mapping)
    if self.dig_in_mapping is not None:
        dio.h5io.write_digital_map_to_h5(self.h5_file, self.dig_in_mapping, 'in')

    if self.dig_out_mapping is not None:
        dio.h5io.write_digital_map_to_h5(self.h5_file, self.dig_in_mapping, 'out')

    # update status
    self.h5_file = filename
    self.process_status['extract_data'] = True
    self.save()
    print('\nData Extraction Complete\n--------------------')
def get_unit_table(self)

Returns a pandas dataframe with sorted unit information

Returns

pandas.DataFrame with columns:
unit_name, unit_num, electrode, single_unit, regular_spiking, fast_spiking
Expand source code
def get_unit_table(self):
    '''Returns a pandas dataframe with sorted unit information

    Returns
    --------
    pandas.DataFrame with columns:
        unit_name, unit_num, electrode, single_unit,
        regular_spiking, fast_spiking
    '''
    unit_table = dio.h5io.get_unit_table(self.root_dir)
    return unit_table
def initParams(self, data_quality='clean', emg_port=None, emg_channels=None, car_keyword=None, car_group_areas=None, shell=False, dig_in_names=None, dig_out_names=None, accept_params=False)

Initalizes basic default analysis parameters and allows customization of parameters

Parameters (all optional)

data_quality : {'clean', 'noisy'} keyword defining which default set of parameters to use to detect headstage disconnection during clustering default is 'clean'. Best practice is to run blech_clust as 'clean' and re-run as 'noisy' if too many early cutoffs occurr emg_port : str Port ('A', 'B', 'C') of EMG, if there was an EMG. None (default) will query user. False indicates no EMG port and not to query user emg_channels : list of int channel or channels of EMGs on port specified default is None car_keyword : str Specifes default common average reference groups defaults are found in CAR_defaults.json Currently 'bilateral32' is only keyword available If left as None (default) user will be queries to select common average reference groups shell : bool False (default) for GUI. True for command-line interface dig_in_names : list of str Names of digital inputs. Must match number of digital inputs used in recording. None (default) queries user to name each dig_in dig_out_names : list of str Names of digital outputs. Must match number of digital outputs in recording. None (default) queries user to name each dig_out accept_params : bool True automatically accepts default parameters where possible, decreasing user queries False (default) will query user to confirm or edit parameters for clustering, spike array and psth creation and palatability/identity calculations

Expand source code
@Logger('Initializing Parameters')
def initParams(self, data_quality='clean', emg_port=None,
               emg_channels=None, car_keyword=None,
               car_group_areas=None,
               shell=False, dig_in_names=None,
               dig_out_names=None, accept_params=False):
    '''
    Initalizes basic default analysis parameters and allows customization
    of parameters

    Parameters (all optional)
    -------------------------
    data_quality : {'clean', 'noisy'}
        keyword defining which default set of parameters to use to detect
        headstage disconnection during clustering
        default is 'clean'. Best practice is to run blech_clust as 'clean'
        and re-run as 'noisy' if too many early cutoffs occurr
    emg_port : str
        Port ('A', 'B', 'C') of EMG, if there was an EMG. None (default)
        will query user. False indicates no EMG port and not to query user
    emg_channels : list of int
        channel or channels of EMGs on port specified
        default is None
    car_keyword : str
        Specifes default common average reference groups
        defaults are found in CAR_defaults.json
        Currently 'bilateral32' is only keyword available
        If left as None (default) user will be queries to select common
        average reference groups
    shell : bool
        False (default) for GUI. True for command-line interface
    dig_in_names : list of str
        Names of digital inputs. Must match number of digital inputs used
        in recording.
        None (default) queries user to name each dig_in
    dig_out_names : list of str
        Names of digital outputs. Must match number of digital outputs in
        recording.
        None (default) queries user to name each dig_out
    accept_params : bool
        True automatically accepts default parameters where possible,
        decreasing user queries
        False (default) will query user to confirm or edit parameters for
        clustering, spike array and psth creation and palatability/identity
        calculations
    '''
    # Get parameters from info.rhd
    file_dir = self.root_dir
    rec_info = dio.rawIO.read_rec_info(file_dir)
    ports = rec_info.pop('ports')
    channels = rec_info.pop('channels')
    sampling_rate = rec_info['amplifier_sampling_rate']
    self.rec_info = rec_info
    self.sampling_rate = sampling_rate

    # Get default parameters from files
    clustering_params = dio.params.load_params('clustering_params', file_dir,
                                               default_keyword=data_quality)
    spike_array_params = dio.params.load_params('spike_array_params', file_dir)
    psth_params = dio.params.load_params('psth_params', file_dir)
    pal_id_params = dio.params.load_params('pal_id_params', file_dir)
    spike_array_params['sampling_rate'] = sampling_rate
    clustering_params['file_dir'] = file_dir
    clustering_params['sampling_rate'] = sampling_rate

    # Setup digital input mapping
    if rec_info.get('dig_in'):
        self._setup_digital_mapping('in', dig_in_names, shell)
        dim = self.dig_in_mapping.copy()
        spike_array_params['laser_channels'] = dim.channel[dim['laser']].to_list()
        spike_array_params['dig_ins_to_use'] = dim.channel[dim['spike_array']].to_list()
    else:
        self.dig_in_mapping = None

    if rec_info.get('dig_out'):
        self._setup_digital_mapping('out', dig_out_names, shell)
        dom = self.dig_out_mapping.copy()
    else:
        self.dig_out_mapping = None

    # Setup electrode and emg mapping
    self._setup_channel_mapping(ports, channels, emg_port,
                                emg_channels, shell=shell)

    # Set CAR groups
    self._set_CAR_groups(group_keyword=car_keyword, group_areas=car_group_areas, shell=shell)

    # Confirm parameters
    self.spike_array_params = spike_array_params
    if not accept_params:
        conf = userIO.confirm_parameter_dict
        clustering_params = conf(clustering_params,
                                 'Clustering Parameters', shell=shell)
        self.edit_spike_array_params(shell=shell)
        psth_params = conf(psth_params,
                           'PSTH Parameters', shell=shell)
        pal_id_params = conf(pal_id_params,
                             'Palatability/Identity Parameters\n'
                             'Valid unit_type is Single, Multi or All',
                             shell=shell)

    # Store parameters
    self.clustering_params = clustering_params
    self.pal_id_params = pal_id_params
    self.psth_params = psth_params
    self._write_all_params_to_json()
    self.process_status['initialize parameters'] = True
    self.save()
def make_psth_arrays(self)

Make smoothed firing rate traces for each unit/trial and store in hdf5 store

Expand source code
@Logger('Making PSTH Arrays')
def make_psth_arrays(self):
    '''Make smoothed firing rate traces for each unit/trial and store in
    hdf5 store
    '''
    params = self.psth_params
    dig_ins = self.dig_in_mapping.query('spike_array == True')
    for idx, row in dig_ins.iterrows():
        spike_analysis.make_psths_for_tastant(self.h5_file,
                                              params['window_size'],
                                              params['window_step'],
                                              row['channel'])

    self.process_status['make_psth_arrays'] = True
    self.save()
def make_unit_arrays(self)

Make spike arrays for each unit and store in hdf5 store

Expand source code
@Logger('Making Unit Arrays')
def make_unit_arrays(self):
    '''Make spike arrays for each unit and store in hdf5 store
    '''
    params = self.spike_array_params

    print('Generating unit arrays with parameters:\n----------')
    print(pt.print_dict(params, tabs=1))
    ss.make_spike_arrays(self.h5_file, params)
    self.process_status['make_unit_arrays'] = True
    self.save()
def make_unit_plots(self)

Make waveform plots for each sorted unit

Expand source code
@Logger('Making Unit Plots')
def make_unit_plots(self):
    '''Make waveform plots for each sorted unit
    '''
    unit_table = self.get_unit_table()
    save_dir = os.path.join(self.root_dir, 'unit_waveforms_plots')
    if os.path.isdir(save_dir):
        shutil.rmtree(save_dir)

    os.mkdir(save_dir)
    for i, row in unit_table.iterrows():
        datplt.make_unit_plots(self.root_dir, row['unit_name'], save_dir=save_dir)

    self.process_status['make_unit_plots'] = True
    self.save()
def mark_dead_channels(self, dead_channels=None, shell=False)

Plots small piece of raw traces and a metric to help identify dead channels. Once user marks channels as dead a new column is added to electrode mapping

Parameters

dead_channels : list of int, optional
if this is specified then nothing is plotted, those channels are simply marked as dead
shell : bool, optional
 
Expand source code
@Logger('Marking Dead Channels')
def mark_dead_channels(self, dead_channels=None, shell=False):
    '''Plots small piece of raw traces and a metric to help identify dead
    channels. Once user marks channels as dead a new column is added to
    electrode mapping

    Parameters
    ----------
    dead_channels : list of int, optional
        if this is specified then nothing is plotted, those channels are
        simply marked as dead
    shell : bool, optional
    '''
    print('Marking dead channels\n----------')
    em = self.electrode_mapping.copy()
    if dead_channels is None:
        userIO.tell_user('Making traces figure for dead channel detection...',
                         shell=True)
        save_file = os.path.join(self.root_dir, 'Electrode_Traces.png')
        fig, ax = datplt.plot_traces_and_outliers(self.h5_file, save_file=save_file)
        if not shell:
            # Better to open figure outside of python since its a lot of
            # data on figure and matplotlib is slow
            subprocess.call(['xdg-open', save_file])
        else:
            userIO.tell_user('Saved figure of traces to %s for reference'
                             % save_file, shell=shell)

        choice = userIO.select_from_list('Select dead channels:',
                                         em.Electrode.to_list(),
                                         'Dead Channel Selection',
                                         multi_select=True,
                                         shell=shell)
        dead_channels = list(map(int, choice))

    print('Marking eletrodes %s as dead.\n'
          'They will be excluded from common average referencing.'
          % dead_channels)
    em['dead'] = False
    em.loc[dead_channels, 'dead'] = True
    self.electrode_mapping = em
    if os.path.isfile(self.h5_file):
        dio.h5io.write_electrode_map_to_h5(self.h5_file, self.electrode_mapping)

    self.process_status['mark_dead_channels'] = True
    self.save()
    return dead_channels
def palatability_calculate(self, shell=False)
Expand source code
@Logger('Calculating Palatability/Identity Metrics')
def palatability_calculate(self, shell=False):
    pal_analysis.palatability_identity_calculations(self.root_dir,
                                                    params=self.pal_id_params)
    self.process_status['palatability_calculate'] = True
    self.save()
def palatability_plot(self, shell=False)
Expand source code
@Logger('Plotting Palatability/Identity Metrics')
def palatability_plot(self, shell=False):
    pal_plt.plot_palatability_identity([self.root_dir], shell=shell)
    self.process_status['palatability_plot'] = True
    self.save()
def post_sorting(self)
Expand source code
def post_sorting(self):
    self.make_unit_plots()
    self.make_unit_arrays()
    self.units_similarity(shell=True)
    self.make_psth_arrays()
def pre_process_for_clustering(self, shell=False, dead_channels=None)
Expand source code
def pre_process_for_clustering(self, shell=False, dead_channels=None):
    status = self.process_status
    if not status['initialize parameters']:
        self.initParams(shell=shell)

    if not status['extract_data']:
        self.extract_data(shell=True)

    if not status['create_trial_list']:
        self.create_trial_list()

    if not status['mark_dead_channels'] and dead_channels != False:
        self.mark_dead_channels(dead_channels=dead_channels, shell=shell)

    if not status['common_average_reference']:
        self.common_average_reference()

    if not status['spike_detection']:
        self.detect_spikes()
def set_electrode_areas(self, areas)

sets the electrode area for each CAR group.

Parameters

areas : list of str
number of elements must match number of CAR groups

Throws

ValueError

Expand source code
@Logger('Re-labelling CAR group areas')
def set_electrode_areas(self, areas):
    '''sets the electrode area for each CAR group.

    Parameters
    ----------
    areas : list of str
        number of elements must match number of CAR groups

    Throws
    ------
    ValueError
    '''
    em = self.electrode_mapping.copy()
    if len(em['CAR_group'].unique()) != len(areas):
        raise ValueError('Number of items in areas must match number of CAR groups')

    em['areas'] = em['CAR_group'].apply(lambda x: areas[int(x)])
    self.electrode_mapping = em.copy()
    dio.h5io.write_electrode_map_to_h5(self.h5_file, self.electrode_mapping)
    self.save()
def sort_spikes(self, electrode=None, shell=False)
Expand source code
def sort_spikes(self, electrode=None, shell=False):
    if electrode is None:
        electrode = userIO.get_user_input('Electrode #: ', shell=shell)
        if electrode is None or not electrode.isnumeric():
            return

        electrode = int(electrode)

    if not self.process_status['spike_clustering']:
        raise ValueError('Must run spike clustering first.')

    if not self.process_status['cleanup_clustering']:
        self.cleanup_clustering()

    sorter = clust.SpikeSorter(self.root_dir, electrode=electrode, shell=shell)
    if not shell:
        root, sorting_GUI = ssg.launch_sorter_GUI(sorter)
        return root, sorting_GUI
    else:
        # TODO: Make shell UI
        # TODO: Make sort by table
        print('No shell UI yet')
        return

    self.process_status['sort_units'] = True
def units_similarity(self, similarity_cutoff=50, shell=False)
Expand source code
@Logger('Calculating Units Similarity')
def units_similarity(self, similarity_cutoff=50, shell=False):
    if 'SSH_CONNECTION' in os.environ:
        shell= True

    metrics_dir = os.path.join(self.root_dir, 'sorted_unit_metrics')
    if not os.path.isdir(metrics_dir):
        raise ValueError('No sorted unit metrics found. Must sort units before calculating similarity')

    violation_file = os.path.join(metrics_dir,
                                  'units_similarity_violations.txt')
    violations, sim = ss.calc_units_similarity(self.h5_file,
                                               self.sampling_rate,
                                               similarity_cutoff,
                                               violation_file)
    if len(violations) == 0:
        userIO.tell_user('No similarity violations found!', shell=shell)
        self.process_status['units_similarity'] = True
        return violations, sim

    out_str = ['Units Similarity Violations Found:']
    out_str.append('Unit_1    Unit_2    Similarity')
    for x,y in violations:
        u1 = dio.h5io.parse_unit_number(x)
        u2 = dio.h5io.parse_unit_number(y)
        out_str.append('   {:<10}{:<10}{}\n'.format(x, y, sim[u1][u2]))

    out_str.append('Delete units with dataset.delete_unit(N)')
    out_str = '\n'.join(out_str)
    userIO.tell_user(out_str, shell=shell)
    self.process_status['units_similarity'] = True
    self.save()
    return violations, sim