Module blechpy.analysis.blech_clustering

Expand source code
import os
import shutil
import numpy as np
import pandas as pd
import itertools as it
import umap
import pywt
from statsmodels.stats.diagnostic import lilliefors
from copy import deepcopy
from scipy.spatial.distance import mahalanobis
from scipy import linalg
from scipy.signal import find_peaks
from scipy.stats import sem
from sklearn.mixture import GaussianMixture
from sklearn.decomposition import PCA
from blechpy.utils import write_tools as wt, print_tools as pt, math_tools as mt, userIO
from blechpy.dio import h5io
from blechpy.analysis import clustering, spike_analysis as sas
from blechpy.plotting import data_plot as dplt
import datetime as dt


def detect_spikes(filt_el, spike_snapshot = [0.5, 1.0], fs = 30000.0):
    '''Detects spikes in the filtered electrode trace and return the waveforms
    and spike_times

    Parameters
    ----------
    filt_el : np.array, 1-D
        filtered electrode trace
    spike_snapshot : list
        2-elements, [ms before spike minimum, ms after spike minimum] 
        time around spike to snap as waveform
    fs : float, sampling rate in Hz

    Returns
    -------
    waves : np.array
        matrix of de-jittered, spike waveforms, upsampled by 10x, row for each spike
    times : np.array
        array of spike times in samples
    threshold: float
        spike detection threshold
    '''
    # get indices of spike snapshot, expand by .1 ms in each direction
    snapshot = np.arange(-(spike_snapshot[0]+0.1)*fs/1000,
                         1+(spike_snapshot[1]+0.1)*fs/1000).astype('int64')
    m = np.mean(filt_el)
    th = 5.0*np.median(np.abs(filt_el)/0.6745)
    pos = np.where(filt_el <= m-th)[0]
    consecutive = mt.group_consecutives(pos)

    waves = []
    times = []
    for idx in consecutive:
        minimum = idx[np.argmin(filt_el[idx])]
        spike_idx = minimum + snapshot
        if spike_idx[0] >= 0 and spike_idx[-1] < len(filt_el):
            waves.append(filt_el[spike_idx])
            times.append(minimum)

    if len(waves) == 0:
        return None, None

    waves_dj, times_dj = clustering.dejitter(np.array(waves), np.array(times), spike_snapshot, fs)
    return waves_dj, times_dj, m-th


def implement_pca(scaled_slices):
    pca = PCA()
    pca_slices = pca.fit_transform(scaled_slices)
    return pca_slices, pca.explained_variance_ratio_


def implement_umap(waves, n_pc=3, n_neighbors=30, min_dist=0.0):
    reducer = umap.UMAP(n_components=n_pc,
                        n_neighbors=n_neighbors,
                        min_dist=min_dist)
    return reducer.fit_transform(waves)


def implement_wavelet_transform(waves, n_pc=10):
    coeffs = pywt.wavedec(waves, 'haar', axis=1)
    all_coeffs = np.column_stack(coeffs)
    k_stats = np.zeros((all_coeffs.shape[1],))
    p_vals = np.ones((all_coeffs.shape[1],))
    for i, c in enumerate(all_coeffs.T):
        k_stats[i], p_vals[i] = lilliefors(c, dist='norm')

    idx = np.argsort(p_vals)
    return all_coeffs[:, idx[:n_pc]]


def compute_waveform_metrics(waves, n_pc=3, umap=False):
    '''Make clustering data array with columns:
         - amplitudes, energy, slope, pc1, pc2, pc3, etc
    Parameters
    ----------
    waves : np.array
        waveforms with a row for each spike waveform
    n_pc : int (optional)
        number of principal components to include in data array

    Returns
    -------
    np.array
    '''
    data = np.zeros((waves.shape[0], 3))
    for i, wave in enumerate(waves):
        data[i,0] = np.min(wave)
        data[i,1] = np.sqrt(np.sum(wave**2))/len(wave)
        peaks = find_peaks(wave)[0]
        minima = np.argmin(wave)
        if not any(peaks < minima):
            maxima = np.argmax(wave[:minima])
        else:
            maxima = max(peaks[np.where(peaks < minima)[0]])

        data[i,2] = (wave[minima]-wave[maxima])/(minima-maxima)

    # Scale waveforms to energy before running PCA
    if umap:
        pc_waves = implement_umap(waves, n_pc=n_pc)
    else:
        scaled_waves = scale_waveforms(waves, energy=data[:,1])
        pc_waves, _ = implement_pca(scaled_waves)

    data = np.hstack((data, pc_waves[:,:n_pc]))
    data_columns = ['amplitude', 'energy', 'spike_slope']
    data_columns.extend(['PC%i' % i for i in range(n_pc)])
    return data, data_columns


def get_waveform_amplitudes(waves):
    '''Returns array of waveform amplitudes

    Parameters
    ----------
    waves : np.array, matrix of waveforms, with row for each spike

    Returns
    -------
    np.array
    '''
    return np.min(waves,axis = 1)


def get_waveform_energy(waves):
    '''Returns array of waveform energies

    Parameters
    ----------
    waves : np.array, matrix of waveforms, with row for each spike

    Returns
    -------
    np.array
    '''
    energy = np.sqrt(np.sum(waves**2, axis=1))/waves.shape[1]
    return energy


def get_spike_slopes(waves):
    '''Returns array of spike slopes (initial downward slope of spike)

    Parameters
    ----------
    waves : np.array, matrix of waveforms, with row for each spike

    Returns
    -------
    np.array
    '''
    slopes = np.zeros((waves.shape[0],))
    for i, wave in enumerate(waves):
        peaks = find_peaks(wave)[0]
        minima = np.argmin(wave)
        if not any(peaks < minima):
            maxima = np.argmax(wave[:minima])
        else:
            maxima = max(peaks[np.where(peaks < minima)[0]])

        slopes[i] = (wave[minima]-wave[maxima])/(minima-maxima)

    return slopes


def get_ISI_and_violations(spike_times, fs, rec_map=None):
    '''returns array of ISIs in ms and # of 1ms and 2ms violations

    Parameters
    ----------
    spike_time  numpy.array
    fs : float, sampling rate in Hz
    rec_map : np.array (optional)
        if not passed, it is assumed all spike times are from same recording
        if passed, spike times are split into recordings and ISIs are computed
        per recording.
        If fs is different for each recording, fs should be a dict with keys as
        rec ids in rec_map

    Returns
    -------
    np.array : ISIs
    int : 1ms violations
    int : 2ms violations
    '''
    if rec_map is not None:
        if not isinstance(fs, dict):
            fs = dict.fromkeys(np.unique(rec_map), fs)

        ISIs = np.array([])
        violations1 = 0
        violations2 = 0
        for i in np.unique(rec_map):
            idx = np.where(rec_map == i)[0]
            tmp_isi, v1, v2 = get_ISI_and_violations(spike_times[idx], fs[i])
            violations1 += v1
            violations2 += v2
            ISIs = np.concatenate((ISIs, tmp_isi))

    else:
        fs = float(fs/1000.0)
        ISIs = np.ediff1d(np.sort(spike_times))/fs
        violations1 = np.sum(ISIs < 1.0)
        violations2 = np.sum(ISIs < 2.0)

    return ISIs, violations1, violations2


def scale_waveforms(waves, energy=None):
    '''Scales each waveform to its own energy

    Parameters
    ----------
    waves : np.array, matrix of waveforms, with row for each spike
    energy : np.array (optional)
        array of waveform energies, saves computation time

    Returns
    -------
    np.array
    '''
    if energy is None:
        energy = get_waveform_energy(waves)
    elif len(energy) != waves.shape[0]:
        raise ValueError(('Energies must correspond to each waveforms.'
                          'Different lengths are not allowed'))

    scaled_slices = np.zeros(waves.shape)
    for i, w in enumerate(zip(waves, energy)):
        scaled_slices[i] = w[0]/w[1]

    return scaled_slices


def get_mahalanobis_distances_to_cluster(data, model, clusters, target_cluster):
    '''computes mahalanobis distance from spikes in target_cluster to all clusters
    in GMM model

    Parameters
    ----------
    data : np.array, data used to train GMM
    model : fitted GMM model
    clusters : np.array, maps data points to clusters
    target_cluster : int, cluster for which to compute distances

    Returns
    -------
    np.array
    '''
    unique_clusters = np.unique(abs(clusters))
    out_distances = dict.fromkeys(unique_clusters)
    cluster_idx = np.where(clusters == target_cluster)[0]
    for other_cluster in unique_clusters:
        mahalanobis_dist = np.zeros((len(cluster_idx),))
        other_cluster_mean = model.means_[other_cluster, :]
        other_cluster_covar_I = linalg.inv(model.covariances_[other_cluster, :, :])
        for i, idx in enumerate(cluster_idx):
            mahalanobis_dist[i] = mahalanobis(data[idx, :],
                                              other_cluster_mean,
                                              other_cluster_covar_I)

        out_distances[other_cluster] = mahalanobis_dist

    return out_distances


def get_recording_cutoff(filt_el, sampling_rate, voltage_cutoff,
                         max_breach_rate, max_secs_above_cutoff,
                         max_mean_breach_rate_persec, **kwargs):
    breach_idx = np.where(filt_el > voltage_cutoff)[0]
    breach_rate = float(len(breach_idx)*int(sampling_rate))/len(filt_el)
    # truncate to nearest second and make 1 sec bins
    filt_el = filt_el[:int(sampling_rate)*int(len(filt_el)/sampling_rate)]
    test_el = np.reshape(filt_el, (-1, int(sampling_rate)))
    breaches_per_sec = [len(np.where(test_el[i] > voltage_cutoff)[0])
                        for i in range(len(test_el))]
    breaches_per_sec = np.array(breaches_per_sec)
    secs_above_cutoff = len(np.where(breaches_per_sec > 0)[0])
    if secs_above_cutoff == 0:
        mean_breach_rate_persec = 0
    else:
        mean_breach_rate_persec = np.mean(breaches_per_sec[np.where(breaches_per_sec > 0)[0]])

    # And if they all exceed the cutoffs, assume that the headstage fell off mid-experiment
    recording_cutoff = int(len(filt_el)/sampling_rate) # cutoff in seconds
    if (breach_rate >= max_breach_rate and
        secs_above_cutoff >= max_secs_above_cutoff and
        mean_breach_rate_persec >= max_mean_breach_rate_persec):
        # Find the first 1 second epoch where the number of cutoff breaches is
        # higher than the maximum allowed mean breach rate 
        recording_cutoff = np.where(breaches_per_sec > max_mean_breach_rate_persec)[0][0]
        # cutoff is still in seconds since 1 sec bins

    return recording_cutoff


def UMAP_METRICS(waves, n_pc):
    return compute_waveform_metrics(waves, n_pc, umap=True)


class SpikeDetection(object):
    '''Interface to manage spike detection and data extraction in preparation
    for GMM clustering. Intended to help create and access the neccessary
    files. If object will detect is file already exist to avoid re-creation
    unless overwrite is specified as True.
    '''

    def __init__(self, file_dir, electrode, params=None, overwrite=False):
        # Setup paths to files and directories needed
        self._file_dir = file_dir
        self._electrode = electrode
        self._out_dir = os.path.join(file_dir, 'spike_detection',
                                     'electrode_%i' % electrode)
        self._data_dir = os.path.join(self._out_dir, 'data')
        self._plot_dir = os.path.join(self._out_dir, 'plots')
        self._files = {'params': os.path.join(file_dir,'analysis_params', 'spike_detection_params.json'),
                       'spike_waveforms': os.path.join(self._data_dir, 'spike_waveforms.npy'),
                       'spike_times' : os.path.join(self._data_dir, 'spike_times.npy'),
                       'energy' : os.path.join(self._data_dir, 'energy.npy'),
                       'spike_amplitudes' : os.path.join(self._data_dir, 'spike_amplitudes.npy'),
                       'pca_waveforms' : os.path.join(self._data_dir, 'pca_waveforms.npy'),
                       'slopes' : os.path.join(self._data_dir, 'spike_slopes.npy'),
                       'recording_cutoff' : os.path.join(self._data_dir, 'cutoff_time.txt'),
                       'detection_threshold' : os.path.join(self._data_dir, 'detection_threshold.txt')}

        self._status = dict.fromkeys(self._files.keys(), False)
        self._referenced = True

        # Delete existing data if overwrite is True
        if overwrite and os.path.isdir(self._out_dir):
            shutil.rmtree(self._out_dir)

        # See what data already exists
        self._check_existing_files()

        # Make directories if needed
        if not os.path.isdir(self._out_dir):
            os.makedirs(self._out_dir)

        if not os.path.isdir(self._data_dir):
            os.makedirs(self._data_dir)

        if not os.path.isdir(self._plot_dir):
            os.makedirs(self._plot_dir)

        if not os.path.isdir(os.path.join(file_dir, 'analysis_params')):
            os.makedirs(os.path.join(file_dir, 'analysis_params'))

        # grab recording cutoff time if it already exists
        # cutoff should be in seconds
        self.recording_cutoff = None
        if os.path.isfile(self._files['recording_cutoff']):
            self._status['recording_cutoff'] = True
            with open(self._files['recording_cutoff'], 'r') as f:
                self.recording_cutoff = float(f.read())

        self.detection_threshold = None
        if os.path.isfile(self._files['detection_threshold']):
            self._status['detection_threshold'] = True
            with open(self._files['detection_threshold'], 'r') as f:
                self.detection_threshold = float(f.read())

        # Read in parameters
        # Parameters passed as an argument will overshadow parameters saved in file
        # Input parameters should be formatted as dataset.clustering_parameters
        if params is None and os.path.isfile(self._files['params']):
            self.params = wt.read_dict_from_json(self._files['params'])
        elif params is None:
            raise FileNotFoundError('params must be provided if spike_detection_params.json does not exist.')
        else:
            self.params = {}
            self.params['voltage_cutoff'] = params['data_params']['V_cutoff for disconnected headstage']
            self.params['max_breach_rate'] = params['data_params']['Max rate of cutoff breach per second']
            self.params['max_secs_above_cutoff'] = params['data_params']['Max allowed seconds with a breach']
            self.params['max_mean_breach_rate_persec'] = params['data_params']['Max allowed breaches per second']
            band_lower = params['bandpass_params']['Lower freq cutoff']
            band_upper = params['bandpass_params']['Upper freq cutoff']
            self.params['bandpass'] = [band_lower, band_upper]
            snapshot_pre = params['spike_snapshot']['Time before spike (ms)']
            snapshot_post = params['spike_snapshot']['Time after spike (ms)']
            self.params['spike_snapshot'] = [snapshot_pre, snapshot_post]
            self.params['sampling_rate'] = params['sampling_rate']
            # Write params to json file
            wt.write_dict_to_json(self.params, self._files['params'])
            self._status['params'] = True

    def _check_existing_files(self):
        '''Checks which files already exist and updates _status so as to avoid
        re-creation later
        '''
        for k, v in self._files.items():
            if os.path.isfile(v):
                self._status[k] = True
            else:
                self._status[k] = False

    def run(self):
        status = self._status
        file_dir = self._file_dir
        electrode = self._electrode
        params = self.params
        fs = params['sampling_rate']

        # Check if this even needs to be run
        if all(status.values()):
            return electrode, 1, self.recording_cutoff

        # Grab referenced electrode or raw if ref is not available
        ref_el = h5io.get_referenced_trace(file_dir, electrode)
        if ref_el is None:
            print('Could not find referenced data for electrode %i. Using raw.' % electrode)
            self._referenced = False
            ref_el = h5io.get_raw_trace(file_dir, electrode)
            if ref_el is None:
                raise KeyError('Neither referenced nor raw data found for electrode %i in %s' % (electrode, file_dir))

        # Filter electrode trace
        filt_el = clustering.get_filtered_electrode(ref_el, freq=params['bandpass'],
                                               sampling_rate = fs)
        del ref_el
        # Get recording cutoff
        if not status['recording_cutoff']:
            self.recording_cutoff = get_recording_cutoff(filt_el, **params)
            with open(self._files['recording_cutoff'], 'w') as f:
                f.write(str(self.recording_cutoff))

            status['recording_cutoff'] = True
            fn = os.path.join(self._plot_dir, 'cutoff_time.png')
            dplt.plot_recording_cutoff(filt_el, fs, self.recording_cutoff,
                                       out_file=fn)

        # Truncate electrode trace, deal with early cutoff (<60s)
        if self.recording_cutoff < 60:
            print('Immediate Cutoff for electrode %i...exiting' % electrode)
            return electrode, 0, self.recording_cutoff

        filt_el = filt_el[:int(self.recording_cutoff*fs)]

        if status['spike_waveforms'] and status['spike_times']:
            waves = np.load(self._files['spike_waveforms'])
            times = np.load(self._files['spike_times'])
        else:
            # Detect spikes and get dejittered times and waveforms
            # detect_spikes returns waveforms upsampled by 10x and times in units
            # of samples
            waves, times, threshold = detect_spikes(filt_el, params['spike_snapshot'], fs)
            self.detection_threshold = threshold
            if waves is None:
                print('No waveforms detected on electrode %i' % electrode)
                return electrode, 0, self.recording_cutoff

            # Save waveforms and times
            np.save(self._files['spike_waveforms'], waves)
            np.save(self._files['spike_times'], times)
            with open(self._files['detection_threshold'], 'w') as f:
                f.write(str(threshold))

            status['detection_threshold'] = True
            status['spike_waveforms'] = True
            status['spike_times'] = True

        # Get various metrics and scale waveforms
        if not status['spike_amplitudes']:
            amplitudes = get_waveform_amplitudes(waves)
            np.save(self._files['spike_amplitudes'], amplitudes)
            status['spike_amplitudes'] = True

        if not status['slopes']:
            slopes = get_spike_slopes(waves)
            np.save(self._files['slopes'], slopes)
            status['slopes'] = True

        if not status['energy']:
            energy = get_waveform_energy(waves)
            np.save(self._files['energy'], energy)
            status['energy'] = True
        else:
            energy=None

        # get pca of scaled waveforms
        if not status['pca_waveforms']:
            scaled_waves = scale_waveforms(waves, energy=energy)
            pca_waves, explained_variance_ratio = implement_pca(scaled_waves)

            # Plot explained variance
            fn = os.path.join(self._plot_dir, 'pca_variance.png')
            dplt.plot_explained_pca_variance(explained_variance_ratio,
                                             out_file = fn)

        return electrode, 1, self.recording_cutoff

    def get_spike_waveforms(self):
        '''Returns spike waveforms if they have been extracted, None otherwise
        Dejittered waveforms upsampled to 10 x sampling_rate

        Returns
        -------
        numpy.array
        '''
        if os.path.isfile(self._files['spike_waveforms']):
            return np.load(self._files['spike_waveforms'])
        else:
            return None

    def get_spike_times(self):
        '''Returns spike times if they have been extracted, None otherwise
        In units of samples.

        Returns
        -------
        numpy.array
        '''
        if os.path.isfile(self._files['spike_times']):
            return np.load(self._files['spike_times'])
        else:
            return None

    def get_energy(self):
        '''Returns spike energies if they have been extracted, None otherwise

        Returns
        -------
        numpy.array
        '''
        if os.path.isfile(self._files['energy']):
            return np.load(self._files['energy'])
        else:
            return None

    def get_spike_amplitudes(self):
        '''Returns spike amplitudes if they have been extracted, None otherwise

        Returns
        -------
        numpy.array
        '''
        if os.path.isfile(self._files['spike_amplitudes']):
            return np.load(self._files['spike_amplitudes'])
        else:
            return None

    def get_spike_slopes(self):
        '''Returns spike slopes if they have been extracted, None otherwise

        Returns
        -------
        numpy.array
        '''
        if os.path.isfile(self._files['slopes']):
            return np.load(self._files['slopes'])
        else:
            return None

    def get_pca_waveforms(self):
        '''Returns pca of sclaed spike waveforms if they have been extracted,
        None otherwise
        Dejittered waveforms upsampled to 10 x sampling_rate, scaled to energy
        and transformed via PCA

        Returns
        -------
        numpy.array
        '''
        if os.path.isfile(self._files['spike_waveforms']):
            return np.load(self._files['spike_waveforms'])
        else:
            return None

    def get_clustering_metrics(self, n_pc=3):
        '''Returns array of metrics to use for feature based clustering
        Row for each waveform with columns:
            - amplitude, energy, spike slope, PC1, PC2, etc
        '''
        amplitude = self.get_spike_amplitudes()
        energy = self.get_energy()
        slopes = self.get_spike_slopes()
        pca_waves = self.get_pca_waveforms()
        out = np.vstack((amplitude, energy, slopes)).T
        out = np.hstack((out, pca_waves[:,:n_pc]))
        return out

    def __str__(self):
        out = []
        out.append('SpikeDetection\n--------------')
        out.append('Recording Directory: %s' % self._file_dir)
        out.append('Electrode: %i' % self._electrode)
        out.append('Output Directory: %s' % self._out_dir)
        out.append('###################################\n')
        out.append('Status:')
        out.append(pt.print_dict(self._status))
        out.append('-------------------\n')
        out.append('Parameters:')
        out.append(pt.print_dict(self.params))
        out.append('-------------------\n')
        out.append('Data files:')
        out.append(pt.print_dict(self._files))
        return '\n'.join(out)


class BlechClust(object):
    def __init__(self, rec_dirs, electrode, out_dir=None, params=None,
                 overwrite=False, no_write=False, n_pc=3,
                 data_transform=compute_waveform_metrics):
        '''Recording directories should be ordered to make spike sorting easier later on
        '''
        if isinstance(rec_dirs, str):
            rec_dirs = [rec_dirs]

        rec_dirs = [x[:-1] if x.endswith(os.sep) else x for x in rec_dirs]
        self.rec_dirs = rec_dirs
        self.electrode = electrode
        self._data_transform = data_transform
        self._n_pc = n_pc
        if out_dir is None:
            if len(rec_dirs) > 1:
                top = os.path.dirname(rec_dirs[0])
                out_dir = os.path.join(top, 'BlechClust', 'electrode_%i' % electrode)
            else:
                out_dir = os.path.join(rec_dirs[0], 'BlechClust', 'electrode_%i' % electrode)

        if overwrite:
            shutil.rmtree(out_dir)

        # Make directories
        self.out_dir = out_dir
        self._plot_dir = os.path.join(out_dir, 'plots')
        self._data_dir = os.path.join(out_dir, 'clustering_results')
        if not os.path.isdir(out_dir):
            os.makedirs(out_dir)

        if not os.path.isdir(self._data_dir):
            os.mkdir(self._data_dir)

        if not os.path.isdir(self._plot_dir):
            os.mkdir(self._plot_dir)

        # Check files
        params_file = os.path.join(out_dir, 'BlechClust_params.json')
        map_file = os.path.join(self._data_dir, 'spike_id.npy')
        key_file = os.path.join(self._data_dir, 'rec_key.json')
        results_file = os.path.join(self._data_dir, 'clustering_results.json')
        self._files = {'params': params_file, 'spike_map': map_file,
                       'rec_key': key_file, 'clustering_results': results_file}
        self.params = params
        self._load_existsing_data()

        if self._rec_key is None and not no_write:
            # Create new rec key
            rec_key = {x:y for x,y in enumerate(self.rec_dirs)}
            self._rec_key = rec_key
            wt.write_dict_to_json(rec_key, self._files['rec_key'])
        elif self._rec_key is None:
            ValueError('Existing rec_key not found and no_write is enabled')

        # Check to see if spike detection is already completed on all recording directories
        spike_check = self._check_spike_detection()
        if not all(spike_check):
            invalid = [rec_dirs[i] for i, x in enumerate(spike_check) if x==False]
            error_str = '\n\t'.join(invalid)
            raise ValueError('Spike detection has not been run on:\n\t%s' % error_str)

    def _load_existsing_data(self):
        params = self.params
        file_check = self._check_existing_files()

        # Check params files and create if new params are passed
        if file_check['params']:
            self.params = wt.read_dict_from_json(self._files['params'])

        # Make new params or overwrite existing with passed params
        if params is None and not file_check['params']:
            raise ValueError(('Params file does not exists at %s. Must provide'
                              ' clustering parameters.') % self._files['params'])
        elif params is not None:
            self.params['max_clusters'] = params['clustering_params']['Max Number of Clusters']
            self.params['max_iterations'] = params['clustering_params']['Max Number of Iterations']
            self.params['threshold'] = params['clustering_params']['Convergence Criterion']
            self.params['num_restarts'] = params['clustering_params']['GMM random restarts']
            self.params['wf_amplitude_sd_cutoff'] = params['data_params']['Intra-cluster waveform amp SD cutoff']
            wt.write_dict_to_json(self.params, self._files['params'])

        # Deal with existing rec key
        if file_check['rec_key']:
            rec_dirs = self.rec_dirs
            rec_key = wt.read_dict_from_json(self._files['rec_key'])
            rec_key = {int(x): y for x,y in rec_key.items()}
            if len(rec_key) != len(rec_dirs):
                raise ValueError('Rec key does not match rec dirs')

            # Correct rec key in case rec_dir roots have changed
            for rd in rec_dirs:
                rn = os.path.basename(rd)
                dn = os.path.dirname(rd)
                kd = [(x, y) for x,y in rec_key.items() if rn in y]
                if len(kd) == 0:
                    raise ValueError('%s not found in rec_key' % rn)

                kd = kd[0]
                if kd[1] != rd:
                    rec_key[kd[0]] = rd

            inverted = {v:k for k,v in rec_key.items()}
            self.rec_dirs = sorted(self.rec_dirs, key=lambda i: inverted[i])
            self._rec_key = rec_key
        else:
            self._rec_key = None

        # Check is clustering has already been done, load results
        if file_check['clustering_results']:
            self.results = wt.read_pandas_from_table(self._files['clustering_results'])
            expected_results = np.arange(2, self.params['max_clusters'] + 1)
            if not all([x in self.results['clusters'] for x in expected_results]):
                self.clustered = False
            else:
                self.clustered = True

        else:
            self.results = None
            self.clustered = False

    def _check_existing_files(self):
        out = dict.fromkeys(self._files.keys(), False)
        for k,v in self._files.items():
            if os.path.isfile(v):
                out[k] = True

        return out

    def _check_spike_detection(self):
        '''Check to see if spike detection has been run on all recording directories
        '''
        out = []
        for rec in  self.rec_dirs:
            try:
                spike_detect = SpikeDetection(rec, self.electrode)
                if all(spike_detect._status):
                    out.append(True)
                else:
                    out.append(False)

            except FileNotFoundError:
                out.append(False)

        return out

    def run(self, n_pc=None, overwrite=False):
        if self.clustered and not overwrite:
            return True

        if n_pc is None:
            n_pc = self._n_pc

        GMM = ClusterGMM(self.params['max_iterations'],
                         self.params['num_restarts'], self.params['threshold'])

        # Collect data from all recordings
        waveforms, spike_times, spike_map, fs, offsets = self.get_spike_data()

        # Save array to map spikes and predictions back to original recordings
        np.save(self._files['spike_map'], spike_map)

        data, data_columns = self._data_transform(waveforms, n_pc)
        amplitudes = get_waveform_amplitudes(waveforms)

        # Run GMM for each number of clusters from 2 to max_clusters
        tested_clusters = np.arange(2, self.params['max_clusters']+1)
        clust_results = pd.DataFrame(columns=['clusters','converged',
                                              'BIC','spikes_per_cluster'],
                                     index=tested_clusters)
        for n_clust in tested_clusters:
            data_dir = os.path.join(self._data_dir, '%i_clusters' % n_clust)
            plot_dir = os.path.join(self._plot_dir, '%i_clusters' % n_clust)
            wave_plot_dir = os.path.join(self._plot_dir, '%i_clusters_waveforms_ISIs' % n_clust)
            bic_file = os.path.join(data_dir, 'bic.npy')
            pred_file = os.path.join(data_dir, 'predictions.npy')

            if os.path.isfile(bic_file) and os.path.isfile(pred_file) and not overwrite:
                bic = np.load(bic_file)
                predictions = np.load(pred_file)
                spikes_per_clust = [len(np.where(predictions == c)[0])
                                    for c in np.unique(predictions)]
                clust_results.loc[n_clust] = [n_clust, True, bic, spikes_per_clust]
                continue

            if not os.path.isdir(wave_plot_dir):
                os.makedirs(wave_plot_dir)

            if not os.path.isdir(data_dir):
                os.makedirs(data_dir)

            if not os.path.isdir(plot_dir):
                os.makedirs(plot_dir)

            model, predictions, bic = GMM.fit(data, n_clust)
            if model is None:
                clust_results.loc[n_clust] = [n_clust, bic, False, [0]]
                # Nothing converged
                continue

            # Go through each cluster and throw out any spikes too far from the
            # mean
            spikes_per_clust = []
            for c in range(n_clust):
                idx = np.where(predictions == c)[0]
                mean_amp = np.mean(amplitudes[idx])
                sd_amp = np.std(amplitudes[idx])
                cutoff_amp = mean_amp - (sd_amp * self.params['wf_amplitude_sd_cutoff'])
                rejected_idx = np.array([i for i in idx if amplitudes[i] <= cutoff_amp])
                if len(rejected_idx) > 0:
                    predictions[rejected_idx] = -1

                idx = np.where(predictions == c)[0]
                spikes_per_clust.append(len(idx))

                if len(idx) == 0:
                    continue

                # Plot waveforms and ISIs of cluster
                ISIs, violations_1ms, violations_2ms = get_ISI_and_violations(spike_times[idx], fs, spike_map[idx])
                cluster_waves = waveforms[idx]
                cluster_times = spike_times[idx]
                isi_fn = os.path.join(wave_plot_dir, 'Cluster%i_ISI.png' % c)
                wave_fn = os.path.join(wave_plot_dir, 'Cluster%i_waveforms.png' % c)
                title_str = ('Cluster%i\nviolations_1ms = %i, '
                             'violations_2ms = %i\n'
                             'Number of waveforms = %i' %
                             (c, violations_1ms, violations_2ms, len(idx)))
                dplt.plot_waveforms(cluster_waves, title=title_str, save_file=wave_fn)
                if len(ISIs) > 0:
                    dplt.plot_ISIs(ISIs, total_spikes=len(idx), save_file=isi_fn)


            clust_results.loc[n_clust] = [n_clust, True, bic, spikes_per_clust]

            # Plot feature pairs
            feature_pairs = it.combinations(list(range(data.shape[1])), 2)
            for f1, f2 in feature_pairs:
                fn = '%sVS%s.png' % (data_columns[f1], data_columns[f2])
                fn = os.path.join(plot_dir, fn)
                dplt.plot_cluster_features(data[:, [f1,f2]], predictions,
                                           x_label = data_columns[f1],
                                           y_label = data_columns[f2],
                                           save_file = fn)

            # For each cluster plot mahanalobis distances to all other clusters
            for c in range(n_clust):
                distances = get_mahalanobis_distances_to_cluster(data,  model,
                                                                 predictions, c)
                fn = os.path.join(plot_dir, 'Mahalanobis_cluster%i.png' % c)
                title = ('Mahalanobis distance of Cluster %i from all other clusters' % c)
                dplt.plot_mahalanobis_to_cluster(distances, title=title, save_file=fn)

            # Save data
            np.save(bic_file, bic)
            np.save(pred_file, predictions)

        # Save results table
        self.results = clust_results
        wt.write_pandas_to_table(clust_results,
                                 self._files['clustering_results'],
                                 overwrite=True)
        self.clustered = True
        return True

    def get_spike_data(self):
        # Collect data from all recordings
        tmp_waves = []
        tmp_times = []
        tmp_id = []
        fs = dict.fromkeys(self._rec_key.keys())
        offsets = dict.fromkeys(self._rec_key.keys())
        offset = 0
        for i in sorted(self._rec_key.keys()):
            rec = self._rec_key[i]
            spike_detect = SpikeDetection(rec, self.electrode)
            t = spike_detect.get_spike_times()
            fs[i] = spike_detect.params['sampling_rate']
            if t is None:
                offsets[i] = int(offset)
                offset = offset + 3*fs[i]
                continue

            tmp_waves.append(spike_detect.get_spike_waveforms())
            tmp_times.append(t)
            tmp_id.append(np.ones((t.shape[0],))*i)
            offsets[i] = int(offset)
            offset = offset + max(t) + 3*fs[i]

        waveforms = np.vstack(tmp_waves)
        spike_times = np.hstack(tmp_times)
        spike_map = np.hstack(tmp_id)

        # Double check that spike_map matches up with existing spike_map
        if os.path.isfile(self._files['spike_map']):
            orig_map = np.load(self._files['spike_map'])
            if len(orig_map) != len(spike_map):
                raise ValueError('Spike detection has changed, please re-cluster with overwrite=True')

        return waveforms, spike_times, spike_map, fs, offsets

    def get_clusters(self, solution_num, cluster_nums):
        if not isinstance(cluster_nums, list):
            cluster_nums = [cluster_nums]

        waveforms, times, spike_map, fs, offsets = self.get_spike_data()
        predictions = self.get_predictions(solution_num)
        out = []
        for c in cluster_nums:
            idx = np.where(predictions == c)[0]
            if len(idx)==0:
                continue

            tmp_clust = SpikeCluster('Cluster_%i' % c,
                                     self.electrode,
                                     solution_num,
                                     c,
                                     1,
                                     waveforms[idx],
                                     times[idx],
                                     spike_map[idx],
                                     self._rec_key.copy(),
                                     fs.copy(),
                                     offsets.copy(),
                                     manipulations='')
            out.append(tmp_clust)

        return out

    def get_predictions(self, n_clusters):
        fn = os.path.join(self._data_dir, '%i_clusters' % n_clusters,
                          'predictions.npy')
        if os.path.isfile(fn):
            return np.load(fn)
        else:
            return None


class ClusterGMM(object):
    def __init__(self, n_iters, n_restarts, thresh):
        self.params = {'iterations': n_iters,
                       'restarts': n_restarts,
                       'thresh': thresh}

    def fit(self, data, n_clusters):
        min_bic = None
        best_model = None
        if n_clusters is not None:
            self.params['clusters'] = n_clusters

        for i in range(self.params['restarts']):
            model = GaussianMixture(n_components = self.params['clusters'],
                                    covariance_type = 'full',
                                    tol = self.params['thresh'],
                                    random_state = i,
                                    max_iter = self.params['iterations'])
            model.fit(data)
            if model.converged_:
                new_bic = model.bic(data)
                if min_bic is None:
                    min_bic = model.bic(data)
                    best_model = model
                elif new_bic < min_bic:
                    best_model = model
                    min_bic = new_bic

        predictions = best_model.predict(data)
        self._model = best_model
        self._predictions = predictions
        self._bic = min_bic
        return best_model, predictions, min_bic


class SpikeSorter(object):
    def __init__(self, rec_dirs, electrode, clustering_dir=None, shell=False):
        if isinstance(rec_dirs, str):
            rec_dirs = [rec_dirs]

        rec_dirs = [x[:-1] if x.endswith(os.sep) else x for x in rec_dirs]
        self.rec_dirs = rec_dirs
        self.electrode = electrode
        if clustering_dir is None:
            if len(rec_dirs) > 1:
                top = os.path.dirname(rec_dirs[0])
                clustering_dir = os.path.join(top, 'BlechClust', 'electrode_%i' % electrode)
            else:
                clustering_dir = os.path.join(rec_dirs[0], 'BlechClust', 'electrode_%i' % electrode)

        self.clustering_dir = clustering_dir
        try:
            clust = BlechClust(rec_dirs, electrode, out_dir = clustering_dir, no_write=True)
        except FileNotFoundError:
            clust = None

        if clust is None or not clust.clustered:
            raise ValueError('Recordings have not been clustered yet.')

        # Match recording directory ordering to clustering
        self.rec_dirs = clust.rec_dirs
        self.clustering = clust
        self._current_solution = None
        self._active = None
        self._last_saved = None
        self._previous = None
        self._shell = shell
        self._split_results = None
        self._split_starter = None
        self._split_index = None
        self._last_umap_embedding = None
        self._last_action = None
        self._last_popped = None  # Dict of indices to clusters
        self._last_added = None  # List of indices

        thresh = []
        for rd in rec_dirs:
            sd = SpikeDetection(rd, electrode)
            thresh.append(sd.detection_threshold)

        self._detection_thresholds = thresh

    def undo(self):
        if self._last_action is None:
            return

        if self._last_action == 'save':
            self.undo_last_save()
            return

        # Remove last added
        for k in reversed(sorted(self._last_added)):
            self._active.pop(k)

        # Insert previous clusters
        for k in sorted(self._last_popped.keys()):
            self._active.insert(k, self._last_popped[k])

        # reset
        self._last_action = None
        self._last_popped = None
        self._last_added = None

    def set_active_clusters(self, solution_num):
        self._current_solution = solution_num
        cluster_nums = list(range(solution_num))
        clusters = self.clustering.get_clusters(solution_num, cluster_nums)
        if len(clusters) == 0:
            raise ValueError('Solution or clusters not found')

        self._active = clusters
        self._last_action = None
        self._last_popped = None
        self._last_added = None

    def save_clusters(self, target_clusters, single_unit, pyramidal, interneuron):
        '''Saves active clusters as cells, write them to the h5_files in the
        appropriate recording directories

        Parameters
        ----------
        target_clusters: list of int
            indicies of active clusters to save
        single_unit : list of bool
            elements in list must correspond to elements in active clusters
        pyramidal : list of bool
        interneuron : list of bool
        '''
        if self._active is None:
            return

        if any([i >= len(self._active) for i in target_clusters]):
            raise ValueError('Target cluster is out of range.')

        n_clusters = len(target_clusters)
        if (len(single_unit) != n_clusters or len(pyramidal) != n_clusters or
            len(interneuron) != n_clusters):
            raise ValueError('Length of input lists must match number of '
                             'active clusters. Expected %i' % n_clusters)

        self._last_action = 'save'
        self._last_popped = {i: self._active[i] for i in target_clusters}
        self._last_added = []
        clusters = [self._active[i] for i in target_clusters]
        rec_key = self.clustering._rec_key
        self._last_saved = dict.fromkeys(rec_key.keys(), None)

        for clust, single, pyr, intr in zip(clusters, single_unit,
                                            pyramidal, interneuron):
            for i, rec in rec_key.items():
                idx = np.where(clust['spike_map'] == i)[0]
                if len(idx) == 0:
                    continue

                waves = clust['spike_waveforms'][idx]
                times = clust['spike_times'][idx]
                unit_name = h5io.add_new_unit(rec, self.electrode, waves,
                                              times, single, pyr, intr)
                if self._last_saved[i] is None:
                    self._last_saved[i] = [unit_name]
                else:
                    self._last_saved[i].append(unit_name)

                metrics_dir = os.path.join(rec,'sorted_unit_metrics', unit_name)
                if not os.path.isdir(metrics_dir):
                    os.makedirs(metrics_dir)

                # Write cluster info to file
                print_clust = clust.copy()
                for k,v in clust.items():
                    if isinstance(v, np.ndarray):
                        print_clust.pop(k)

                print_clust.pop('rec_key')
                print_clust.pop('fs')
                clust_info_file = os.path.join(metrics_dir, 'cluster.info')
                with open(clust_info_file, 'a+') as log:
                    print('%s sorted on %s'
                          % (unit_name,
                             dt.datetime.today().strftime('%m/%d/%y %H:%M')),
                          file=log)
                    print('Cluster info: \n----------', file=log)
                    print(pt.print_dict(print_clust), file=log)
                    print('Saved metrics to %s' % metrics_dir, file=log)
                    print('--------------\n', file=log)

        userIO.tell_user('Target clusters successfully saved to recording '
                         'directories.', shell=True)
        self._active = [self._active[i] for i in range(len(self._active))
                        if i not in target_clusters]

    def undo_last_save(self):
        if self._last_saved is None:
            return

        rec_key = self.clustering._rec_key
        last_saved = self._last_saved
        for i, rec in rec_key.items():
            for unit in reversed(np.sort(last_saved[i])):
                h5io.delete_unit(rec, unit)

        for k in sorted(self._last_popped.keys()):
            self._active.insert(k, self._last_popped[k])

        self._last_saved = None
        self._last_popped = None
        self._last_added = None
        self._last_action = None

    def split_cluster(self, target_clust, n_iter, n_restart, thresh, n_clust,
                      store_split=False, umap=False):
        '''splits the target active cluster using a GMM
        '''
        if target_clust >= len(self._active):
            raise ValueError('Invalid target. Only %i active clusters' % len(self._active))

        cluster = self._active.pop(target_clust)
        self._split_starter = cluster
        self._split_index = target_clust

        try:
            GMM = ClusterGMM(n_iter, n_restart, thresh)
            waves = cluster['spike_waveforms']
            data, data_columns = compute_waveform_metrics(waves, umap=umap)
            model, predictions, bic = GMM.fit(data, n_clust)
            new_clusts = []
            for i in np.unique(predictions):
                idx = np.where(predictions == i)[0]
                edit_str = (cluster['manipulations'] + '\nSplit %s into %i '
                            'clusters. This is sub-cluster %i'
                            % (cluster['Cluster_Name'], n_clust, i))
                tmp_clust = SpikeCluster(cluster['Cluster_Name'] + '-%i' % i,
                                         cluster['electrode_num'],
                                         cluster['solution_num'],
                                         cluster['cluster_num'],
                                         cluster['cluster_id']*10+i,
                                         waves[idx],
                                         cluster['spike_times'][idx],
                                         cluster['spike_map'][idx],
                                         cluster['rec_key'].copy(),
                                         cluster['fs'].copy(),
                                         cluster['offsets'].copy(),
                                         manipulations=edit_str)
                new_clusts.append(tmp_clust)

            # Plot cluster and ask to choose which to keep
            figs = []
            for i, c in enumerate(new_clusts):
                _, viol_1ms, viol_2ms = get_ISI_and_violations(c['spike_times'], c['fs'], c['spike_map'])
                plot_title = ('Index: %i\n1ms violations: %i, 2ms violations: %i\n'
                              'Total Waveforms: %i'
                              % (i, viol_1ms, viol_2ms, len(c['spike_times'])))
                tmp_fig, _ = dplt.plot_waveforms(c['spike_waveforms'], title=plot_title)
                figs.append(tmp_fig)
                tmp_fig.show()

            f2 = dplt.plot_waveforms_pca([c['spike_waveforms'] for c in new_clusts])
            figs.append(f2)
            f2.show()
        except:
            # So cluster isn't lost with error
            self._active.insert(target_clust, cluster)
            self._split_starter = None
            self._split_index = None
            raise

        if store_split:
            self._split_results = new_clusts
            return new_clusts
        else:
            self._split_starter = None
            self._split_index = None
            selection_list = ['all'] + ['%i' % i for i in range(len(new_clusts))]
            prompt = 'Select split clusters to keep\nCancel to reset.'
            ans = userIO.select_from_list(prompt, selection_list,
                                          multi_select=True, shell=self._shell)
            if ans is None or 'all' in ans:
                print('Reset to before split')
                self._active.insert(target_clust, cluster)
            else:
                keepers = [new_clusts[int(i)] for i in ans]
                start_idx = len(self._active)
                self._last_added = list(range(start_idx, start_idx+len(keepers)))
                self._last_popped = {target_clust: cluster}
                self._last_action = 'split'
                self._active.extend(keepers)

            return True

    def set_split(self, choices):
        if self._split_starter is None:
            raise ValueError('Not split stored.')

        if len(choices) == 0:
            self._active.insert(self._split_index, self._split_starter)
        else:
            keepers = [self._split_results[i] for i in choices]
            start_idx = len(self._active)
            self._last_added = list(range(start_idx, start_idx+len(keepers)))
            self._last_popped = {self._split_index: self._split_starter}
            self._last_action = 'split'
            self._active.extend(keepers)

        self._split_index = None
        self._split_results = None
        self._split_starter = None

    def merge_clusters(self, target_clusters):
        if any([i >= len(self._active) for i in target_clusters]):
            raise ValueError('Target cluster is out of range.')

        new_clust = []
        self._last_popped = {}
        self._last_action = 'merge'
        self._last_added = []
        for c in target_clusters:
            self._last_popped[c] = self._active[c]

            if len(new_clust) == 0:
                new_clust = deepcopy(self._active[c])
                continue

            clust = self._active[c]
            sm1 = new_clust['spike_map']
            sm2 = clust['spike_map']
            st1 = new_clust['spike_times']
            st2 = clust['spike_times']
            sw1 = new_clust['spike_waveforms']
            sw2 = clust['spike_waveforms']

            spike_map = np.hstack((sm1, sm2))
            spike_times = np.hstack((st1, st2))
            spike_waveforms = np.vstack((sw1, sw2))

            # Re-order to spike_map
            idx = np.argsort(spike_map)
            spike_map = spike_map[idx]
            spike_times = spike_times[idx]
            spike_waveforms = spike_waveforms[idx]

            # Re-order so spike_times within a reocrding are in order
            times = []
            waves = []
            new_map = []
            for i in np.unique(spike_map):
                idx = np.where(spike_map == i)[0]
                st = spike_times[idx]
                sw = spike_waveforms[idx]
                sm = spike_map[idx]
                idx2 = np.argsort(st)
                st = st[idx2]
                sw = sw[idx2]
                sm = sm[idx2]
                times.append(st)
                waves.append(sw)
                new_map.append(sm)

            times = np.hstack(times)
            waves = np.vstack(waves)
            spike_map = np.hstack(new_map)
            del new_map, spike_times, spike_waveforms

            new_clust['spike_map'] = spike_map
            new_clust['spike_times'] = times
            new_clust['spike_waveforms'] = waves
            new_clust['manipulations'] += '\nMerged with %s.' % clust['Cluster_Name']
            new_clust['Cluster_Name'] += '+' + clust['Cluster_Name'].replace('Cluster_','')

        self._active = [self._active[i] for i in range(len(self._active))
                        if i not in target_clusters]

        self._last_added = [len(self._active)]
        self._active.append(new_clust)

    def discard_clusters(self, target_clusters):
        if isinstance(target_clusters, int):
            target_clusters = [target_clusters]

        if len(target_clusters) == 0:
            return

        self._last_action = 'discard'
        self._last_popped = {i: self._active[i] for i in target_clusters}
        self._last_added = []
        self._active = [self._active[i] for i in range(len(self._active))
                        if i not in target_clusters]

    def plot_clusters_waveforms(self, target_clusters):
        if len(target_clusters) == 0:
            return

        for i in target_clusters:
            c = self._active[i]
            isi, v1, v2 = get_ISI_and_violations(c['spike_times'], c['fs'], c['spike_map'])
            title = ('Index : %i\n1ms violations: %0.1f, 2ms violations: %0.1f'
                     '\ntotal waveforms: %i'
                     % (i, v1, v2, len(c['spike_waveforms'])))
            fig, ax = dplt.plot_waveforms(c['spike_waveforms'], title=title,
                                          threshold=self._detection_thresholds[0])
            fig.show()

    def split_by_rec(self, target_cluster):
        if isinstance(target_cluster, list) and len(target_cluster) != 1:
            return
        elif isinstance(target_cluster, list):
            target_cluster = target_clsuter[0]

        clust = self._active[target_cluster]
        sm = clust['spike_map']
        recs = np.unique(sm)
        if len(sm) == 1:
            return
        else:
            clust = self._active.pop(target_cluster)
            st = clust['spike_times']
            sw = clust['spike_waveforms']
            keepers = []
            for i in recs:
                idx = np.where(sm == i)[0]
                new_clust = deepcopy(clust)
                new_clust['spike_times'] = st[idx]
                new_clust['spike_waveforms'] = sw[idx, :]
                new_clust['cluster_id'] = clust['cluster_id']*10 + i
                new_clust['spike_map'] = sm[idx]
                new_clust['manipulations'] = '\nSplit by recording'
                keepers.append(new_clust)

        start_idx = len(self._active)
        self._last_added = list(range(start_idx, start_idx+len(keepers)))
        self._last_popped = {target_cluster: clust}
        self._last_action = 'split'
        self._active.extend(keepers)

    def plot_cluster_waveforms_by_rec(self, target_cluster):
        if isinstance(target_cluster, list) and len(target_cluster) != 1:
            return
        elif isinstance(target_cluster, list):
            target_cluster = target_cluster[0]

        c = self._active[target_cluster]
        sm = c['spike_map']
        for i in np.unique(sm):
            idx = np.where(sm==i)[0]
            waves = c['spike_waveforms'][idx, :]
            isi, v1, v2 = get_ISI_and_violations(c['spike_times'][idx], c['fs'][i])
            title = ('Index : %i, Rec: %i\n1ms violations: %0.1f, 2ms violations: %0.1f'
                     '\ntotal waveforms: %i'
                     % (target_cluster, i, v1, v2, len(waves)))
            fig, ax = dplt.plot_waveforms(waves, title=title)
            fig.show()

    def plot_cluster_waveforms_over_time(self, target_cluster, interval):
        if isinstance(target_cluster, list) and len(target_cluster) != 1:
            return
        elif isinstance(target_cluster, list):
            target_cluster = target_cluster[0]

        c = self._active[target_cluster]
        spike_times = c.get_spike_time_vector('s')
        start_times = np.arange(spike_times[0], spike_times[-1]+1, interval)
        if len(start_times) > 10:
            userIO.tell_user('This would open more than 10 figures, choose a larger interval')
            return

        if len(start_times) == 0:
            return

        for i, start_time in enumerate(start_times):
            idx = np.where((spike_times >= start_time) & (spike_times < start_time+interval))[0]
            waves = c['spike_waveforms'][idx,:]
            isi, v1, v2 = get_ISI_and_violations(c['spike_times'][idx], c['fs'][0])
            title = ('Index : %i, Rec: %i\n1ms violations: %0.1f, 2ms violations: %0.1f'
                     '\ntotal waveforms: %i'
                     % (target_cluster, i, v1, v2, len(waves)))
            fig, ax = dplt.plot_waveforms(waves, title=title,
                                          threshold=self._detection_thresholds[0])
            fig.show()

    def plot_clusters_pca(self, target_clusters):
        if len(target_clusters) == 0:
            return

        waves = [self._active[i]['spike_waveforms'] for i in target_clusters]
        fig = dplt.plot_waveforms_pca(waves, cluster_ids=target_clusters)
        fig.show()

    def plot_clusters_umap(self, target_clusters):
        if len(target_clusters) == 0:
            return

        waves = [self._active[i]['spike_waveforms'] for i in target_clusters]
        fig = dplt.plot_waveforms_umap(waves, cluster_ids=target_clusters)
        fig.show()

    def plot_clusters_wavelets(self, target_clusters):
        if len(target_clusters) == 0:
            return

        waves = [self._active[i]['spike_waveforms'] for i in target_clusters]
        fig, ax = dplt.plot_waveforms_wavelet_tranform(waves,
                                                       cluster_ids=target_clusters,
                                                       n_pc=4)
        fig.show()

    def plot_clusters_raster(self, target_clusters):
        if len(target_clusters) == 0:
            return

        clusters = [self._active[i] for i in target_clusters]
        spike_times = []
        spike_waves = []
        vlines = {}
        for c in clusters:
            # Adjust spike times by offset so recordings are not overlapping
            st = c.get_spike_time_vector(units='s')
            vlines = {i: c['offsets'][i] / c['fs'][i] for i in c['fs'].keys()}
            spike_times.append(st)
            spike_waves.append(c['spike_waveforms'])

        fig, ax = dplt.plot_spike_raster(spike_times, spike_waves, target_clusters)
        ax.set_xlabel('Time (s)')
        for x in vlines.values():
            ax.axvline(x, color='black', linewidth=2)

        fig.show()

    def plot_clusters_ISI(self, target_clusters):
        if len(target_clusters) == 0:
            return

        for i in target_clusters:
            cluster = self._active[i]
            isi, v1, v2 = get_ISI_and_violations(cluster['spike_times'],
                                                 cluster['fs'],
                                                 cluster['spike_map'])
            fig, ax = dplt.plot_ISIs(isi, total_spikes=len(cluster['spike_times']))
            title= ax.get_title()
            title = 'Index: %i\n%s' % (i, title)
            ax.set_title(title)
            fig.show()

    def plot_clusters_acorr(self, target_clusters):
        if len(target_clusters) == 0 or not all([x < len(self._active) for x in target_clusters]):
            return

        for i in target_clusters:
            cluster = self._active[i]
            acf, bin_centers, edges = sas.spike_time_acorr(cluster.get_spike_time_vector(units='ms'))
            fig, ax = dplt.plot_correlogram(acf, bin_centers, edges)
            title = 'Index: %i\nAutocorrelogram' % (i)
            ax.set_title(title)
            fig.show()

    def plot_clusters_xcorr(self, target_clusters):
        if len(target_clusters) == 0 or not all([x < len(self._active) for x in target_clusters]):
            return

        pairs = it.combinations(target_clusters, 2)
        for x, y in pairs:
            clust1 = self._active[x]
            clust2 = self._active[y]
            xcf, bin_centers, edges = sas.spike_time_xcorr(clust1.get_spike_time_vector(units='ms'),
                                                    clust2.get_spike_time_vector(units='ms'))
            fig, ax = dplt.plot_correlogram(xcf, bin_centers, edges)
            title = 'Cross-correlogram\n%i vs %i' % (x, y)
            ax.set_title(title)
            fig.show()

    def get_mean_waveform(self, target_cluster):
        '''Returns mean waveform of target_cluster in active clusters. Also
        returns St. Dev. of waveforms
        '''
        cluster = self._active[target_cluster]
        return cluster.get_mean_waveform()

    def get_possible_solutions(self):
        results = self.clustering.results.dropna()
        converged = list(results[results['converged']].index)
        return converged


class SpikeCluster(dict):
    def __init__(self, name, electrode, solution, cluster, cluster_id, waves, times,
                 spike_map, rec_key, fs={0: 30000}, offsets={0:0}, manipulations=''):
        # Confirm spike_map, rec_key, fs and offsets are all in sync
        rec_nums = np.unique(spike_map)
        if (not all([x in rec_key.keys() for x in rec_nums]) or
            not all([x in fs.keys() for x in rec_nums]) or
            not all([x in offsets.keys() for x in rec_nums])):
            raise ValueError('rec_key, fs and offsets must have entries for '
                             'each unique element of spike_map')

        # Confirm same number of waves, times and map entries
        if waves.shape[0] != len(times) or len(times) != len(spike_map):
            raise ValueError('Must have same number of waves, times and map entries')

        super(SpikeCluster, self).__init__(Cluster_Name=name,
                                           electrode_num=electrode,
                                           solution_num=solution,
                                           cluster_num = cluster,
                                           cluster_id=cluster_id,
                                           spike_waveforms=waves,
                                           spike_times=times,
                                           spike_map=spike_map,
                                           rec_key=rec_key,
                                           fs=fs,
                                           offsets=offsets,
                                           manipulations=manipulations)

    def delete_spikes(self, idx, msg=None):
        self['spike_waveforms'] = np.delete(self['spike_waveforms'],
                                               idx, axis=0)
        self['spike_times'] = np.delete(self['spike_times'], idx)
        self['spike_map'] = np.delete(self['spike_map'], idx)
        print('deleted %i spikes.' % len(idx))
        if msg is not None:
            self['manipulations'] += '/n' + msg + '\n-Removed %i spikes' % len(idx)

    def get_spike_time_vector(self, units='samples'):
        '''Return vector of all spike times with offsets added if multiple
        recordings are present

        Parameters
        ----------
        units : {'samples' (default), 'ms', 's'}, units for spike times returned

        Returns
        -------
        numpy.ndarray
        '''
        if units.lower() == 'ms':
            times = np.array([(a + self['offsets'][b]) / (self['fs'][b] / 1000)
                              for a, b in
                              zip(self['spike_times'].astype('float64'), self['spike_map'])])
        elif units.lower() == 's':
            times = np.array([(a + self['offsets'][b]) / self['fs'][b]
                              for a, b in
                              zip(self['spike_times'].astype('float64'), self['spike_map'])])
        elif units.lower() == 'samples':
            times = np.array([(a + self['offsets'][b]) for a, b in
                              zip(self['spike_times'], self['spike_map'])])
        else:
            raise ValueError('units must be either samples or ms')

        return times

    def __eq__(self, other):
        times1 = self.get_spike_time_vector(units='samples')
        times2 = other.get_spike_time_vector(units='samples')
        times1 = np.sort(times1)
        times2= np.sort(times2)
        return np.array_equal(times1, times2)

    def get_mean_waveform(self):
        '''Returns mean waveform of cluster. Also
        returns St. Dev. of waveforms and number of waveforms

        Returns
        -------
        np.ndarray, np.ndarray, int
        mean waveform, st. dev of waveforms, number of waveforms 
        '''
        mean_wave = np.mean(self['spike_waveforms'], axis=0)
        std_wave = np.std(self['spike_waveforms'], axis=0)
        n_waves = self['spike_waveforms'].shape[0]
        return mean_wave, std_wave, n_waves

    # TODO: Finish this section
    def _dist(self, other=None):
        if other is None:
            other = self

        pass

    def _add(self, other):
        pass

    def _subtract(self, other):
        pass

    def _divide(self, N, method='pca'):
        pass

    def _is_subcluster(self, other):
        '''less than?'''
        pass

    def _is_supercluster(self, other):
        ''' greater than?'''
        pass

Functions

def UMAP_METRICS(waves, n_pc)
Expand source code
def UMAP_METRICS(waves, n_pc):
    return compute_waveform_metrics(waves, n_pc, umap=True)
def compute_waveform_metrics(waves, n_pc=3, umap=False)

Make clustering data array with columns: - amplitudes, energy, slope, pc1, pc2, pc3, etc Parameters


waves : np.array
waveforms with a row for each spike waveform
n_pc : int (optional)
number of principal components to include in data array

Returns

np.array
 
Expand source code
def compute_waveform_metrics(waves, n_pc=3, umap=False):
    '''Make clustering data array with columns:
         - amplitudes, energy, slope, pc1, pc2, pc3, etc
    Parameters
    ----------
    waves : np.array
        waveforms with a row for each spike waveform
    n_pc : int (optional)
        number of principal components to include in data array

    Returns
    -------
    np.array
    '''
    data = np.zeros((waves.shape[0], 3))
    for i, wave in enumerate(waves):
        data[i,0] = np.min(wave)
        data[i,1] = np.sqrt(np.sum(wave**2))/len(wave)
        peaks = find_peaks(wave)[0]
        minima = np.argmin(wave)
        if not any(peaks < minima):
            maxima = np.argmax(wave[:minima])
        else:
            maxima = max(peaks[np.where(peaks < minima)[0]])

        data[i,2] = (wave[minima]-wave[maxima])/(minima-maxima)

    # Scale waveforms to energy before running PCA
    if umap:
        pc_waves = implement_umap(waves, n_pc=n_pc)
    else:
        scaled_waves = scale_waveforms(waves, energy=data[:,1])
        pc_waves, _ = implement_pca(scaled_waves)

    data = np.hstack((data, pc_waves[:,:n_pc]))
    data_columns = ['amplitude', 'energy', 'spike_slope']
    data_columns.extend(['PC%i' % i for i in range(n_pc)])
    return data, data_columns
def detect_spikes(filt_el, spike_snapshot=[0.5, 1.0], fs=30000.0)

Detects spikes in the filtered electrode trace and return the waveforms and spike_times

Parameters

filt_el : np.array, 1-D
filtered electrode trace
spike_snapshot : list
2-elements, [ms before spike minimum, ms after spike minimum] time around spike to snap as waveform
fs : float, sampling rate in Hz
 

Returns

waves : np.array
matrix of de-jittered, spike waveforms, upsampled by 10x, row for each spike
times : np.array
array of spike times in samples
threshold : float
spike detection threshold
Expand source code
def detect_spikes(filt_el, spike_snapshot = [0.5, 1.0], fs = 30000.0):
    '''Detects spikes in the filtered electrode trace and return the waveforms
    and spike_times

    Parameters
    ----------
    filt_el : np.array, 1-D
        filtered electrode trace
    spike_snapshot : list
        2-elements, [ms before spike minimum, ms after spike minimum] 
        time around spike to snap as waveform
    fs : float, sampling rate in Hz

    Returns
    -------
    waves : np.array
        matrix of de-jittered, spike waveforms, upsampled by 10x, row for each spike
    times : np.array
        array of spike times in samples
    threshold: float
        spike detection threshold
    '''
    # get indices of spike snapshot, expand by .1 ms in each direction
    snapshot = np.arange(-(spike_snapshot[0]+0.1)*fs/1000,
                         1+(spike_snapshot[1]+0.1)*fs/1000).astype('int64')
    m = np.mean(filt_el)
    th = 5.0*np.median(np.abs(filt_el)/0.6745)
    pos = np.where(filt_el <= m-th)[0]
    consecutive = mt.group_consecutives(pos)

    waves = []
    times = []
    for idx in consecutive:
        minimum = idx[np.argmin(filt_el[idx])]
        spike_idx = minimum + snapshot
        if spike_idx[0] >= 0 and spike_idx[-1] < len(filt_el):
            waves.append(filt_el[spike_idx])
            times.append(minimum)

    if len(waves) == 0:
        return None, None

    waves_dj, times_dj = clustering.dejitter(np.array(waves), np.array(times), spike_snapshot, fs)
    return waves_dj, times_dj, m-th
def get_ISI_and_violations(spike_times, fs, rec_map=None)

returns array of ISIs in ms and # of 1ms and 2ms violations

Parameters

spike_time numpy.array
fs : float, sampling rate in Hz
 
rec_map : np.array (optional)
if not passed, it is assumed all spike times are from same recording if passed, spike times are split into recordings and ISIs are computed per recording. If fs is different for each recording, fs should be a dict with keys as rec ids in rec_map

Returns

np.array : ISIs
 
int : 1ms violations
 
int : 2ms violations
 
Expand source code
def get_ISI_and_violations(spike_times, fs, rec_map=None):
    '''returns array of ISIs in ms and # of 1ms and 2ms violations

    Parameters
    ----------
    spike_time  numpy.array
    fs : float, sampling rate in Hz
    rec_map : np.array (optional)
        if not passed, it is assumed all spike times are from same recording
        if passed, spike times are split into recordings and ISIs are computed
        per recording.
        If fs is different for each recording, fs should be a dict with keys as
        rec ids in rec_map

    Returns
    -------
    np.array : ISIs
    int : 1ms violations
    int : 2ms violations
    '''
    if rec_map is not None:
        if not isinstance(fs, dict):
            fs = dict.fromkeys(np.unique(rec_map), fs)

        ISIs = np.array([])
        violations1 = 0
        violations2 = 0
        for i in np.unique(rec_map):
            idx = np.where(rec_map == i)[0]
            tmp_isi, v1, v2 = get_ISI_and_violations(spike_times[idx], fs[i])
            violations1 += v1
            violations2 += v2
            ISIs = np.concatenate((ISIs, tmp_isi))

    else:
        fs = float(fs/1000.0)
        ISIs = np.ediff1d(np.sort(spike_times))/fs
        violations1 = np.sum(ISIs < 1.0)
        violations2 = np.sum(ISIs < 2.0)

    return ISIs, violations1, violations2
def get_mahalanobis_distances_to_cluster(data, model, clusters, target_cluster)

computes mahalanobis distance from spikes in target_cluster to all clusters in GMM model

Parameters

data : np.array, data used to train GMM
 
model : fitted GMM model
 
clusters : np.array, maps data points to clusters
 
target_cluster : int, cluster for which to compute distances
 

Returns

np.array
 
Expand source code
def get_mahalanobis_distances_to_cluster(data, model, clusters, target_cluster):
    '''computes mahalanobis distance from spikes in target_cluster to all clusters
    in GMM model

    Parameters
    ----------
    data : np.array, data used to train GMM
    model : fitted GMM model
    clusters : np.array, maps data points to clusters
    target_cluster : int, cluster for which to compute distances

    Returns
    -------
    np.array
    '''
    unique_clusters = np.unique(abs(clusters))
    out_distances = dict.fromkeys(unique_clusters)
    cluster_idx = np.where(clusters == target_cluster)[0]
    for other_cluster in unique_clusters:
        mahalanobis_dist = np.zeros((len(cluster_idx),))
        other_cluster_mean = model.means_[other_cluster, :]
        other_cluster_covar_I = linalg.inv(model.covariances_[other_cluster, :, :])
        for i, idx in enumerate(cluster_idx):
            mahalanobis_dist[i] = mahalanobis(data[idx, :],
                                              other_cluster_mean,
                                              other_cluster_covar_I)

        out_distances[other_cluster] = mahalanobis_dist

    return out_distances
def get_recording_cutoff(filt_el, sampling_rate, voltage_cutoff, max_breach_rate, max_secs_above_cutoff, max_mean_breach_rate_persec, **kwargs)
Expand source code
def get_recording_cutoff(filt_el, sampling_rate, voltage_cutoff,
                         max_breach_rate, max_secs_above_cutoff,
                         max_mean_breach_rate_persec, **kwargs):
    breach_idx = np.where(filt_el > voltage_cutoff)[0]
    breach_rate = float(len(breach_idx)*int(sampling_rate))/len(filt_el)
    # truncate to nearest second and make 1 sec bins
    filt_el = filt_el[:int(sampling_rate)*int(len(filt_el)/sampling_rate)]
    test_el = np.reshape(filt_el, (-1, int(sampling_rate)))
    breaches_per_sec = [len(np.where(test_el[i] > voltage_cutoff)[0])
                        for i in range(len(test_el))]
    breaches_per_sec = np.array(breaches_per_sec)
    secs_above_cutoff = len(np.where(breaches_per_sec > 0)[0])
    if secs_above_cutoff == 0:
        mean_breach_rate_persec = 0
    else:
        mean_breach_rate_persec = np.mean(breaches_per_sec[np.where(breaches_per_sec > 0)[0]])

    # And if they all exceed the cutoffs, assume that the headstage fell off mid-experiment
    recording_cutoff = int(len(filt_el)/sampling_rate) # cutoff in seconds
    if (breach_rate >= max_breach_rate and
        secs_above_cutoff >= max_secs_above_cutoff and
        mean_breach_rate_persec >= max_mean_breach_rate_persec):
        # Find the first 1 second epoch where the number of cutoff breaches is
        # higher than the maximum allowed mean breach rate 
        recording_cutoff = np.where(breaches_per_sec > max_mean_breach_rate_persec)[0][0]
        # cutoff is still in seconds since 1 sec bins

    return recording_cutoff
def get_spike_slopes(waves)

Returns array of spike slopes (initial downward slope of spike)

Parameters

waves : np.array, matrix of waveforms, with row for each spike
 

Returns

np.array
 
Expand source code
def get_spike_slopes(waves):
    '''Returns array of spike slopes (initial downward slope of spike)

    Parameters
    ----------
    waves : np.array, matrix of waveforms, with row for each spike

    Returns
    -------
    np.array
    '''
    slopes = np.zeros((waves.shape[0],))
    for i, wave in enumerate(waves):
        peaks = find_peaks(wave)[0]
        minima = np.argmin(wave)
        if not any(peaks < minima):
            maxima = np.argmax(wave[:minima])
        else:
            maxima = max(peaks[np.where(peaks < minima)[0]])

        slopes[i] = (wave[minima]-wave[maxima])/(minima-maxima)

    return slopes
def get_waveform_amplitudes(waves)

Returns array of waveform amplitudes

Parameters

waves : np.array, matrix of waveforms, with row for each spike
 

Returns

np.array
 
Expand source code
def get_waveform_amplitudes(waves):
    '''Returns array of waveform amplitudes

    Parameters
    ----------
    waves : np.array, matrix of waveforms, with row for each spike

    Returns
    -------
    np.array
    '''
    return np.min(waves,axis = 1)
def get_waveform_energy(waves)

Returns array of waveform energies

Parameters

waves : np.array, matrix of waveforms, with row for each spike
 

Returns

np.array
 
Expand source code
def get_waveform_energy(waves):
    '''Returns array of waveform energies

    Parameters
    ----------
    waves : np.array, matrix of waveforms, with row for each spike

    Returns
    -------
    np.array
    '''
    energy = np.sqrt(np.sum(waves**2, axis=1))/waves.shape[1]
    return energy
def implement_pca(scaled_slices)
Expand source code
def implement_pca(scaled_slices):
    pca = PCA()
    pca_slices = pca.fit_transform(scaled_slices)
    return pca_slices, pca.explained_variance_ratio_
def implement_umap(waves, n_pc=3, n_neighbors=30, min_dist=0.0)
Expand source code
def implement_umap(waves, n_pc=3, n_neighbors=30, min_dist=0.0):
    reducer = umap.UMAP(n_components=n_pc,
                        n_neighbors=n_neighbors,
                        min_dist=min_dist)
    return reducer.fit_transform(waves)
def implement_wavelet_transform(waves, n_pc=10)
Expand source code
def implement_wavelet_transform(waves, n_pc=10):
    coeffs = pywt.wavedec(waves, 'haar', axis=1)
    all_coeffs = np.column_stack(coeffs)
    k_stats = np.zeros((all_coeffs.shape[1],))
    p_vals = np.ones((all_coeffs.shape[1],))
    for i, c in enumerate(all_coeffs.T):
        k_stats[i], p_vals[i] = lilliefors(c, dist='norm')

    idx = np.argsort(p_vals)
    return all_coeffs[:, idx[:n_pc]]
def scale_waveforms(waves, energy=None)

Scales each waveform to its own energy

Parameters

waves : np.array, matrix of waveforms, with row for each spike
 
energy : np.array (optional)
array of waveform energies, saves computation time

Returns

np.array
 
Expand source code
def scale_waveforms(waves, energy=None):
    '''Scales each waveform to its own energy

    Parameters
    ----------
    waves : np.array, matrix of waveforms, with row for each spike
    energy : np.array (optional)
        array of waveform energies, saves computation time

    Returns
    -------
    np.array
    '''
    if energy is None:
        energy = get_waveform_energy(waves)
    elif len(energy) != waves.shape[0]:
        raise ValueError(('Energies must correspond to each waveforms.'
                          'Different lengths are not allowed'))

    scaled_slices = np.zeros(waves.shape)
    for i, w in enumerate(zip(waves, energy)):
        scaled_slices[i] = w[0]/w[1]

    return scaled_slices

Classes

class BlechClust (rec_dirs, electrode, out_dir=None, params=None, overwrite=False, no_write=False, n_pc=3, data_transform=<function compute_waveform_metrics>)

Recording directories should be ordered to make spike sorting easier later on

Expand source code
class BlechClust(object):
    def __init__(self, rec_dirs, electrode, out_dir=None, params=None,
                 overwrite=False, no_write=False, n_pc=3,
                 data_transform=compute_waveform_metrics):
        '''Recording directories should be ordered to make spike sorting easier later on
        '''
        if isinstance(rec_dirs, str):
            rec_dirs = [rec_dirs]

        rec_dirs = [x[:-1] if x.endswith(os.sep) else x for x in rec_dirs]
        self.rec_dirs = rec_dirs
        self.electrode = electrode
        self._data_transform = data_transform
        self._n_pc = n_pc
        if out_dir is None:
            if len(rec_dirs) > 1:
                top = os.path.dirname(rec_dirs[0])
                out_dir = os.path.join(top, 'BlechClust', 'electrode_%i' % electrode)
            else:
                out_dir = os.path.join(rec_dirs[0], 'BlechClust', 'electrode_%i' % electrode)

        if overwrite:
            shutil.rmtree(out_dir)

        # Make directories
        self.out_dir = out_dir
        self._plot_dir = os.path.join(out_dir, 'plots')
        self._data_dir = os.path.join(out_dir, 'clustering_results')
        if not os.path.isdir(out_dir):
            os.makedirs(out_dir)

        if not os.path.isdir(self._data_dir):
            os.mkdir(self._data_dir)

        if not os.path.isdir(self._plot_dir):
            os.mkdir(self._plot_dir)

        # Check files
        params_file = os.path.join(out_dir, 'BlechClust_params.json')
        map_file = os.path.join(self._data_dir, 'spike_id.npy')
        key_file = os.path.join(self._data_dir, 'rec_key.json')
        results_file = os.path.join(self._data_dir, 'clustering_results.json')
        self._files = {'params': params_file, 'spike_map': map_file,
                       'rec_key': key_file, 'clustering_results': results_file}
        self.params = params
        self._load_existsing_data()

        if self._rec_key is None and not no_write:
            # Create new rec key
            rec_key = {x:y for x,y in enumerate(self.rec_dirs)}
            self._rec_key = rec_key
            wt.write_dict_to_json(rec_key, self._files['rec_key'])
        elif self._rec_key is None:
            ValueError('Existing rec_key not found and no_write is enabled')

        # Check to see if spike detection is already completed on all recording directories
        spike_check = self._check_spike_detection()
        if not all(spike_check):
            invalid = [rec_dirs[i] for i, x in enumerate(spike_check) if x==False]
            error_str = '\n\t'.join(invalid)
            raise ValueError('Spike detection has not been run on:\n\t%s' % error_str)

    def _load_existsing_data(self):
        params = self.params
        file_check = self._check_existing_files()

        # Check params files and create if new params are passed
        if file_check['params']:
            self.params = wt.read_dict_from_json(self._files['params'])

        # Make new params or overwrite existing with passed params
        if params is None and not file_check['params']:
            raise ValueError(('Params file does not exists at %s. Must provide'
                              ' clustering parameters.') % self._files['params'])
        elif params is not None:
            self.params['max_clusters'] = params['clustering_params']['Max Number of Clusters']
            self.params['max_iterations'] = params['clustering_params']['Max Number of Iterations']
            self.params['threshold'] = params['clustering_params']['Convergence Criterion']
            self.params['num_restarts'] = params['clustering_params']['GMM random restarts']
            self.params['wf_amplitude_sd_cutoff'] = params['data_params']['Intra-cluster waveform amp SD cutoff']
            wt.write_dict_to_json(self.params, self._files['params'])

        # Deal with existing rec key
        if file_check['rec_key']:
            rec_dirs = self.rec_dirs
            rec_key = wt.read_dict_from_json(self._files['rec_key'])
            rec_key = {int(x): y for x,y in rec_key.items()}
            if len(rec_key) != len(rec_dirs):
                raise ValueError('Rec key does not match rec dirs')

            # Correct rec key in case rec_dir roots have changed
            for rd in rec_dirs:
                rn = os.path.basename(rd)
                dn = os.path.dirname(rd)
                kd = [(x, y) for x,y in rec_key.items() if rn in y]
                if len(kd) == 0:
                    raise ValueError('%s not found in rec_key' % rn)

                kd = kd[0]
                if kd[1] != rd:
                    rec_key[kd[0]] = rd

            inverted = {v:k for k,v in rec_key.items()}
            self.rec_dirs = sorted(self.rec_dirs, key=lambda i: inverted[i])
            self._rec_key = rec_key
        else:
            self._rec_key = None

        # Check is clustering has already been done, load results
        if file_check['clustering_results']:
            self.results = wt.read_pandas_from_table(self._files['clustering_results'])
            expected_results = np.arange(2, self.params['max_clusters'] + 1)
            if not all([x in self.results['clusters'] for x in expected_results]):
                self.clustered = False
            else:
                self.clustered = True

        else:
            self.results = None
            self.clustered = False

    def _check_existing_files(self):
        out = dict.fromkeys(self._files.keys(), False)
        for k,v in self._files.items():
            if os.path.isfile(v):
                out[k] = True

        return out

    def _check_spike_detection(self):
        '''Check to see if spike detection has been run on all recording directories
        '''
        out = []
        for rec in  self.rec_dirs:
            try:
                spike_detect = SpikeDetection(rec, self.electrode)
                if all(spike_detect._status):
                    out.append(True)
                else:
                    out.append(False)

            except FileNotFoundError:
                out.append(False)

        return out

    def run(self, n_pc=None, overwrite=False):
        if self.clustered and not overwrite:
            return True

        if n_pc is None:
            n_pc = self._n_pc

        GMM = ClusterGMM(self.params['max_iterations'],
                         self.params['num_restarts'], self.params['threshold'])

        # Collect data from all recordings
        waveforms, spike_times, spike_map, fs, offsets = self.get_spike_data()

        # Save array to map spikes and predictions back to original recordings
        np.save(self._files['spike_map'], spike_map)

        data, data_columns = self._data_transform(waveforms, n_pc)
        amplitudes = get_waveform_amplitudes(waveforms)

        # Run GMM for each number of clusters from 2 to max_clusters
        tested_clusters = np.arange(2, self.params['max_clusters']+1)
        clust_results = pd.DataFrame(columns=['clusters','converged',
                                              'BIC','spikes_per_cluster'],
                                     index=tested_clusters)
        for n_clust in tested_clusters:
            data_dir = os.path.join(self._data_dir, '%i_clusters' % n_clust)
            plot_dir = os.path.join(self._plot_dir, '%i_clusters' % n_clust)
            wave_plot_dir = os.path.join(self._plot_dir, '%i_clusters_waveforms_ISIs' % n_clust)
            bic_file = os.path.join(data_dir, 'bic.npy')
            pred_file = os.path.join(data_dir, 'predictions.npy')

            if os.path.isfile(bic_file) and os.path.isfile(pred_file) and not overwrite:
                bic = np.load(bic_file)
                predictions = np.load(pred_file)
                spikes_per_clust = [len(np.where(predictions == c)[0])
                                    for c in np.unique(predictions)]
                clust_results.loc[n_clust] = [n_clust, True, bic, spikes_per_clust]
                continue

            if not os.path.isdir(wave_plot_dir):
                os.makedirs(wave_plot_dir)

            if not os.path.isdir(data_dir):
                os.makedirs(data_dir)

            if not os.path.isdir(plot_dir):
                os.makedirs(plot_dir)

            model, predictions, bic = GMM.fit(data, n_clust)
            if model is None:
                clust_results.loc[n_clust] = [n_clust, bic, False, [0]]
                # Nothing converged
                continue

            # Go through each cluster and throw out any spikes too far from the
            # mean
            spikes_per_clust = []
            for c in range(n_clust):
                idx = np.where(predictions == c)[0]
                mean_amp = np.mean(amplitudes[idx])
                sd_amp = np.std(amplitudes[idx])
                cutoff_amp = mean_amp - (sd_amp * self.params['wf_amplitude_sd_cutoff'])
                rejected_idx = np.array([i for i in idx if amplitudes[i] <= cutoff_amp])
                if len(rejected_idx) > 0:
                    predictions[rejected_idx] = -1

                idx = np.where(predictions == c)[0]
                spikes_per_clust.append(len(idx))

                if len(idx) == 0:
                    continue

                # Plot waveforms and ISIs of cluster
                ISIs, violations_1ms, violations_2ms = get_ISI_and_violations(spike_times[idx], fs, spike_map[idx])
                cluster_waves = waveforms[idx]
                cluster_times = spike_times[idx]
                isi_fn = os.path.join(wave_plot_dir, 'Cluster%i_ISI.png' % c)
                wave_fn = os.path.join(wave_plot_dir, 'Cluster%i_waveforms.png' % c)
                title_str = ('Cluster%i\nviolations_1ms = %i, '
                             'violations_2ms = %i\n'
                             'Number of waveforms = %i' %
                             (c, violations_1ms, violations_2ms, len(idx)))
                dplt.plot_waveforms(cluster_waves, title=title_str, save_file=wave_fn)
                if len(ISIs) > 0:
                    dplt.plot_ISIs(ISIs, total_spikes=len(idx), save_file=isi_fn)


            clust_results.loc[n_clust] = [n_clust, True, bic, spikes_per_clust]

            # Plot feature pairs
            feature_pairs = it.combinations(list(range(data.shape[1])), 2)
            for f1, f2 in feature_pairs:
                fn = '%sVS%s.png' % (data_columns[f1], data_columns[f2])
                fn = os.path.join(plot_dir, fn)
                dplt.plot_cluster_features(data[:, [f1,f2]], predictions,
                                           x_label = data_columns[f1],
                                           y_label = data_columns[f2],
                                           save_file = fn)

            # For each cluster plot mahanalobis distances to all other clusters
            for c in range(n_clust):
                distances = get_mahalanobis_distances_to_cluster(data,  model,
                                                                 predictions, c)
                fn = os.path.join(plot_dir, 'Mahalanobis_cluster%i.png' % c)
                title = ('Mahalanobis distance of Cluster %i from all other clusters' % c)
                dplt.plot_mahalanobis_to_cluster(distances, title=title, save_file=fn)

            # Save data
            np.save(bic_file, bic)
            np.save(pred_file, predictions)

        # Save results table
        self.results = clust_results
        wt.write_pandas_to_table(clust_results,
                                 self._files['clustering_results'],
                                 overwrite=True)
        self.clustered = True
        return True

    def get_spike_data(self):
        # Collect data from all recordings
        tmp_waves = []
        tmp_times = []
        tmp_id = []
        fs = dict.fromkeys(self._rec_key.keys())
        offsets = dict.fromkeys(self._rec_key.keys())
        offset = 0
        for i in sorted(self._rec_key.keys()):
            rec = self._rec_key[i]
            spike_detect = SpikeDetection(rec, self.electrode)
            t = spike_detect.get_spike_times()
            fs[i] = spike_detect.params['sampling_rate']
            if t is None:
                offsets[i] = int(offset)
                offset = offset + 3*fs[i]
                continue

            tmp_waves.append(spike_detect.get_spike_waveforms())
            tmp_times.append(t)
            tmp_id.append(np.ones((t.shape[0],))*i)
            offsets[i] = int(offset)
            offset = offset + max(t) + 3*fs[i]

        waveforms = np.vstack(tmp_waves)
        spike_times = np.hstack(tmp_times)
        spike_map = np.hstack(tmp_id)

        # Double check that spike_map matches up with existing spike_map
        if os.path.isfile(self._files['spike_map']):
            orig_map = np.load(self._files['spike_map'])
            if len(orig_map) != len(spike_map):
                raise ValueError('Spike detection has changed, please re-cluster with overwrite=True')

        return waveforms, spike_times, spike_map, fs, offsets

    def get_clusters(self, solution_num, cluster_nums):
        if not isinstance(cluster_nums, list):
            cluster_nums = [cluster_nums]

        waveforms, times, spike_map, fs, offsets = self.get_spike_data()
        predictions = self.get_predictions(solution_num)
        out = []
        for c in cluster_nums:
            idx = np.where(predictions == c)[0]
            if len(idx)==0:
                continue

            tmp_clust = SpikeCluster('Cluster_%i' % c,
                                     self.electrode,
                                     solution_num,
                                     c,
                                     1,
                                     waveforms[idx],
                                     times[idx],
                                     spike_map[idx],
                                     self._rec_key.copy(),
                                     fs.copy(),
                                     offsets.copy(),
                                     manipulations='')
            out.append(tmp_clust)

        return out

    def get_predictions(self, n_clusters):
        fn = os.path.join(self._data_dir, '%i_clusters' % n_clusters,
                          'predictions.npy')
        if os.path.isfile(fn):
            return np.load(fn)
        else:
            return None

Methods

def get_clusters(self, solution_num, cluster_nums)
Expand source code
def get_clusters(self, solution_num, cluster_nums):
    if not isinstance(cluster_nums, list):
        cluster_nums = [cluster_nums]

    waveforms, times, spike_map, fs, offsets = self.get_spike_data()
    predictions = self.get_predictions(solution_num)
    out = []
    for c in cluster_nums:
        idx = np.where(predictions == c)[0]
        if len(idx)==0:
            continue

        tmp_clust = SpikeCluster('Cluster_%i' % c,
                                 self.electrode,
                                 solution_num,
                                 c,
                                 1,
                                 waveforms[idx],
                                 times[idx],
                                 spike_map[idx],
                                 self._rec_key.copy(),
                                 fs.copy(),
                                 offsets.copy(),
                                 manipulations='')
        out.append(tmp_clust)

    return out
def get_predictions(self, n_clusters)
Expand source code
def get_predictions(self, n_clusters):
    fn = os.path.join(self._data_dir, '%i_clusters' % n_clusters,
                      'predictions.npy')
    if os.path.isfile(fn):
        return np.load(fn)
    else:
        return None
def get_spike_data(self)
Expand source code
def get_spike_data(self):
    # Collect data from all recordings
    tmp_waves = []
    tmp_times = []
    tmp_id = []
    fs = dict.fromkeys(self._rec_key.keys())
    offsets = dict.fromkeys(self._rec_key.keys())
    offset = 0
    for i in sorted(self._rec_key.keys()):
        rec = self._rec_key[i]
        spike_detect = SpikeDetection(rec, self.electrode)
        t = spike_detect.get_spike_times()
        fs[i] = spike_detect.params['sampling_rate']
        if t is None:
            offsets[i] = int(offset)
            offset = offset + 3*fs[i]
            continue

        tmp_waves.append(spike_detect.get_spike_waveforms())
        tmp_times.append(t)
        tmp_id.append(np.ones((t.shape[0],))*i)
        offsets[i] = int(offset)
        offset = offset + max(t) + 3*fs[i]

    waveforms = np.vstack(tmp_waves)
    spike_times = np.hstack(tmp_times)
    spike_map = np.hstack(tmp_id)

    # Double check that spike_map matches up with existing spike_map
    if os.path.isfile(self._files['spike_map']):
        orig_map = np.load(self._files['spike_map'])
        if len(orig_map) != len(spike_map):
            raise ValueError('Spike detection has changed, please re-cluster with overwrite=True')

    return waveforms, spike_times, spike_map, fs, offsets
def run(self, n_pc=None, overwrite=False)
Expand source code
def run(self, n_pc=None, overwrite=False):
    if self.clustered and not overwrite:
        return True

    if n_pc is None:
        n_pc = self._n_pc

    GMM = ClusterGMM(self.params['max_iterations'],
                     self.params['num_restarts'], self.params['threshold'])

    # Collect data from all recordings
    waveforms, spike_times, spike_map, fs, offsets = self.get_spike_data()

    # Save array to map spikes and predictions back to original recordings
    np.save(self._files['spike_map'], spike_map)

    data, data_columns = self._data_transform(waveforms, n_pc)
    amplitudes = get_waveform_amplitudes(waveforms)

    # Run GMM for each number of clusters from 2 to max_clusters
    tested_clusters = np.arange(2, self.params['max_clusters']+1)
    clust_results = pd.DataFrame(columns=['clusters','converged',
                                          'BIC','spikes_per_cluster'],
                                 index=tested_clusters)
    for n_clust in tested_clusters:
        data_dir = os.path.join(self._data_dir, '%i_clusters' % n_clust)
        plot_dir = os.path.join(self._plot_dir, '%i_clusters' % n_clust)
        wave_plot_dir = os.path.join(self._plot_dir, '%i_clusters_waveforms_ISIs' % n_clust)
        bic_file = os.path.join(data_dir, 'bic.npy')
        pred_file = os.path.join(data_dir, 'predictions.npy')

        if os.path.isfile(bic_file) and os.path.isfile(pred_file) and not overwrite:
            bic = np.load(bic_file)
            predictions = np.load(pred_file)
            spikes_per_clust = [len(np.where(predictions == c)[0])
                                for c in np.unique(predictions)]
            clust_results.loc[n_clust] = [n_clust, True, bic, spikes_per_clust]
            continue

        if not os.path.isdir(wave_plot_dir):
            os.makedirs(wave_plot_dir)

        if not os.path.isdir(data_dir):
            os.makedirs(data_dir)

        if not os.path.isdir(plot_dir):
            os.makedirs(plot_dir)

        model, predictions, bic = GMM.fit(data, n_clust)
        if model is None:
            clust_results.loc[n_clust] = [n_clust, bic, False, [0]]
            # Nothing converged
            continue

        # Go through each cluster and throw out any spikes too far from the
        # mean
        spikes_per_clust = []
        for c in range(n_clust):
            idx = np.where(predictions == c)[0]
            mean_amp = np.mean(amplitudes[idx])
            sd_amp = np.std(amplitudes[idx])
            cutoff_amp = mean_amp - (sd_amp * self.params['wf_amplitude_sd_cutoff'])
            rejected_idx = np.array([i for i in idx if amplitudes[i] <= cutoff_amp])
            if len(rejected_idx) > 0:
                predictions[rejected_idx] = -1

            idx = np.where(predictions == c)[0]
            spikes_per_clust.append(len(idx))

            if len(idx) == 0:
                continue

            # Plot waveforms and ISIs of cluster
            ISIs, violations_1ms, violations_2ms = get_ISI_and_violations(spike_times[idx], fs, spike_map[idx])
            cluster_waves = waveforms[idx]
            cluster_times = spike_times[idx]
            isi_fn = os.path.join(wave_plot_dir, 'Cluster%i_ISI.png' % c)
            wave_fn = os.path.join(wave_plot_dir, 'Cluster%i_waveforms.png' % c)
            title_str = ('Cluster%i\nviolations_1ms = %i, '
                         'violations_2ms = %i\n'
                         'Number of waveforms = %i' %
                         (c, violations_1ms, violations_2ms, len(idx)))
            dplt.plot_waveforms(cluster_waves, title=title_str, save_file=wave_fn)
            if len(ISIs) > 0:
                dplt.plot_ISIs(ISIs, total_spikes=len(idx), save_file=isi_fn)


        clust_results.loc[n_clust] = [n_clust, True, bic, spikes_per_clust]

        # Plot feature pairs
        feature_pairs = it.combinations(list(range(data.shape[1])), 2)
        for f1, f2 in feature_pairs:
            fn = '%sVS%s.png' % (data_columns[f1], data_columns[f2])
            fn = os.path.join(plot_dir, fn)
            dplt.plot_cluster_features(data[:, [f1,f2]], predictions,
                                       x_label = data_columns[f1],
                                       y_label = data_columns[f2],
                                       save_file = fn)

        # For each cluster plot mahanalobis distances to all other clusters
        for c in range(n_clust):
            distances = get_mahalanobis_distances_to_cluster(data,  model,
                                                             predictions, c)
            fn = os.path.join(plot_dir, 'Mahalanobis_cluster%i.png' % c)
            title = ('Mahalanobis distance of Cluster %i from all other clusters' % c)
            dplt.plot_mahalanobis_to_cluster(distances, title=title, save_file=fn)

        # Save data
        np.save(bic_file, bic)
        np.save(pred_file, predictions)

    # Save results table
    self.results = clust_results
    wt.write_pandas_to_table(clust_results,
                             self._files['clustering_results'],
                             overwrite=True)
    self.clustered = True
    return True
class ClusterGMM (n_iters, n_restarts, thresh)
Expand source code
class ClusterGMM(object):
    def __init__(self, n_iters, n_restarts, thresh):
        self.params = {'iterations': n_iters,
                       'restarts': n_restarts,
                       'thresh': thresh}

    def fit(self, data, n_clusters):
        min_bic = None
        best_model = None
        if n_clusters is not None:
            self.params['clusters'] = n_clusters

        for i in range(self.params['restarts']):
            model = GaussianMixture(n_components = self.params['clusters'],
                                    covariance_type = 'full',
                                    tol = self.params['thresh'],
                                    random_state = i,
                                    max_iter = self.params['iterations'])
            model.fit(data)
            if model.converged_:
                new_bic = model.bic(data)
                if min_bic is None:
                    min_bic = model.bic(data)
                    best_model = model
                elif new_bic < min_bic:
                    best_model = model
                    min_bic = new_bic

        predictions = best_model.predict(data)
        self._model = best_model
        self._predictions = predictions
        self._bic = min_bic
        return best_model, predictions, min_bic

Methods

def fit(self, data, n_clusters)
Expand source code
def fit(self, data, n_clusters):
    min_bic = None
    best_model = None
    if n_clusters is not None:
        self.params['clusters'] = n_clusters

    for i in range(self.params['restarts']):
        model = GaussianMixture(n_components = self.params['clusters'],
                                covariance_type = 'full',
                                tol = self.params['thresh'],
                                random_state = i,
                                max_iter = self.params['iterations'])
        model.fit(data)
        if model.converged_:
            new_bic = model.bic(data)
            if min_bic is None:
                min_bic = model.bic(data)
                best_model = model
            elif new_bic < min_bic:
                best_model = model
                min_bic = new_bic

    predictions = best_model.predict(data)
    self._model = best_model
    self._predictions = predictions
    self._bic = min_bic
    return best_model, predictions, min_bic
class SpikeCluster (name, electrode, solution, cluster, cluster_id, waves, times, spike_map, rec_key, fs={0: 30000}, offsets={0: 0}, manipulations='')

dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object's (key, value) pairs dict(iterable) -> new dictionary initialized as if via: d = {} for k, v in iterable: d[k] = v dict(**kwargs) -> new dictionary initialized with the name=value pairs in the keyword argument list. For example: dict(one=1, two=2)

Expand source code
class SpikeCluster(dict):
    def __init__(self, name, electrode, solution, cluster, cluster_id, waves, times,
                 spike_map, rec_key, fs={0: 30000}, offsets={0:0}, manipulations=''):
        # Confirm spike_map, rec_key, fs and offsets are all in sync
        rec_nums = np.unique(spike_map)
        if (not all([x in rec_key.keys() for x in rec_nums]) or
            not all([x in fs.keys() for x in rec_nums]) or
            not all([x in offsets.keys() for x in rec_nums])):
            raise ValueError('rec_key, fs and offsets must have entries for '
                             'each unique element of spike_map')

        # Confirm same number of waves, times and map entries
        if waves.shape[0] != len(times) or len(times) != len(spike_map):
            raise ValueError('Must have same number of waves, times and map entries')

        super(SpikeCluster, self).__init__(Cluster_Name=name,
                                           electrode_num=electrode,
                                           solution_num=solution,
                                           cluster_num = cluster,
                                           cluster_id=cluster_id,
                                           spike_waveforms=waves,
                                           spike_times=times,
                                           spike_map=spike_map,
                                           rec_key=rec_key,
                                           fs=fs,
                                           offsets=offsets,
                                           manipulations=manipulations)

    def delete_spikes(self, idx, msg=None):
        self['spike_waveforms'] = np.delete(self['spike_waveforms'],
                                               idx, axis=0)
        self['spike_times'] = np.delete(self['spike_times'], idx)
        self['spike_map'] = np.delete(self['spike_map'], idx)
        print('deleted %i spikes.' % len(idx))
        if msg is not None:
            self['manipulations'] += '/n' + msg + '\n-Removed %i spikes' % len(idx)

    def get_spike_time_vector(self, units='samples'):
        '''Return vector of all spike times with offsets added if multiple
        recordings are present

        Parameters
        ----------
        units : {'samples' (default), 'ms', 's'}, units for spike times returned

        Returns
        -------
        numpy.ndarray
        '''
        if units.lower() == 'ms':
            times = np.array([(a + self['offsets'][b]) / (self['fs'][b] / 1000)
                              for a, b in
                              zip(self['spike_times'].astype('float64'), self['spike_map'])])
        elif units.lower() == 's':
            times = np.array([(a + self['offsets'][b]) / self['fs'][b]
                              for a, b in
                              zip(self['spike_times'].astype('float64'), self['spike_map'])])
        elif units.lower() == 'samples':
            times = np.array([(a + self['offsets'][b]) for a, b in
                              zip(self['spike_times'], self['spike_map'])])
        else:
            raise ValueError('units must be either samples or ms')

        return times

    def __eq__(self, other):
        times1 = self.get_spike_time_vector(units='samples')
        times2 = other.get_spike_time_vector(units='samples')
        times1 = np.sort(times1)
        times2= np.sort(times2)
        return np.array_equal(times1, times2)

    def get_mean_waveform(self):
        '''Returns mean waveform of cluster. Also
        returns St. Dev. of waveforms and number of waveforms

        Returns
        -------
        np.ndarray, np.ndarray, int
        mean waveform, st. dev of waveforms, number of waveforms 
        '''
        mean_wave = np.mean(self['spike_waveforms'], axis=0)
        std_wave = np.std(self['spike_waveforms'], axis=0)
        n_waves = self['spike_waveforms'].shape[0]
        return mean_wave, std_wave, n_waves

    # TODO: Finish this section
    def _dist(self, other=None):
        if other is None:
            other = self

        pass

    def _add(self, other):
        pass

    def _subtract(self, other):
        pass

    def _divide(self, N, method='pca'):
        pass

    def _is_subcluster(self, other):
        '''less than?'''
        pass

    def _is_supercluster(self, other):
        ''' greater than?'''
        pass

Ancestors

  • builtins.dict

Methods

def delete_spikes(self, idx, msg=None)
Expand source code
def delete_spikes(self, idx, msg=None):
    self['spike_waveforms'] = np.delete(self['spike_waveforms'],
                                           idx, axis=0)
    self['spike_times'] = np.delete(self['spike_times'], idx)
    self['spike_map'] = np.delete(self['spike_map'], idx)
    print('deleted %i spikes.' % len(idx))
    if msg is not None:
        self['manipulations'] += '/n' + msg + '\n-Removed %i spikes' % len(idx)
def get_mean_waveform(self)

Returns mean waveform of cluster. Also returns St. Dev. of waveforms and number of waveforms

Returns

np.ndarray, np.ndarray, int
 
mean waveform, st. dev of waveforms, number of waveforms
 
Expand source code
def get_mean_waveform(self):
    '''Returns mean waveform of cluster. Also
    returns St. Dev. of waveforms and number of waveforms

    Returns
    -------
    np.ndarray, np.ndarray, int
    mean waveform, st. dev of waveforms, number of waveforms 
    '''
    mean_wave = np.mean(self['spike_waveforms'], axis=0)
    std_wave = np.std(self['spike_waveforms'], axis=0)
    n_waves = self['spike_waveforms'].shape[0]
    return mean_wave, std_wave, n_waves
def get_spike_time_vector(self, units='samples')

Return vector of all spike times with offsets added if multiple recordings are present

Parameters

units : {'samples' (default), 'ms', 's'}, units for spike times returned
 

Returns

numpy.ndarray
 
Expand source code
def get_spike_time_vector(self, units='samples'):
    '''Return vector of all spike times with offsets added if multiple
    recordings are present

    Parameters
    ----------
    units : {'samples' (default), 'ms', 's'}, units for spike times returned

    Returns
    -------
    numpy.ndarray
    '''
    if units.lower() == 'ms':
        times = np.array([(a + self['offsets'][b]) / (self['fs'][b] / 1000)
                          for a, b in
                          zip(self['spike_times'].astype('float64'), self['spike_map'])])
    elif units.lower() == 's':
        times = np.array([(a + self['offsets'][b]) / self['fs'][b]
                          for a, b in
                          zip(self['spike_times'].astype('float64'), self['spike_map'])])
    elif units.lower() == 'samples':
        times = np.array([(a + self['offsets'][b]) for a, b in
                          zip(self['spike_times'], self['spike_map'])])
    else:
        raise ValueError('units must be either samples or ms')

    return times
class SpikeDetection (file_dir, electrode, params=None, overwrite=False)

Interface to manage spike detection and data extraction in preparation for GMM clustering. Intended to help create and access the neccessary files. If object will detect is file already exist to avoid re-creation unless overwrite is specified as True.

Expand source code
class SpikeDetection(object):
    '''Interface to manage spike detection and data extraction in preparation
    for GMM clustering. Intended to help create and access the neccessary
    files. If object will detect is file already exist to avoid re-creation
    unless overwrite is specified as True.
    '''

    def __init__(self, file_dir, electrode, params=None, overwrite=False):
        # Setup paths to files and directories needed
        self._file_dir = file_dir
        self._electrode = electrode
        self._out_dir = os.path.join(file_dir, 'spike_detection',
                                     'electrode_%i' % electrode)
        self._data_dir = os.path.join(self._out_dir, 'data')
        self._plot_dir = os.path.join(self._out_dir, 'plots')
        self._files = {'params': os.path.join(file_dir,'analysis_params', 'spike_detection_params.json'),
                       'spike_waveforms': os.path.join(self._data_dir, 'spike_waveforms.npy'),
                       'spike_times' : os.path.join(self._data_dir, 'spike_times.npy'),
                       'energy' : os.path.join(self._data_dir, 'energy.npy'),
                       'spike_amplitudes' : os.path.join(self._data_dir, 'spike_amplitudes.npy'),
                       'pca_waveforms' : os.path.join(self._data_dir, 'pca_waveforms.npy'),
                       'slopes' : os.path.join(self._data_dir, 'spike_slopes.npy'),
                       'recording_cutoff' : os.path.join(self._data_dir, 'cutoff_time.txt'),
                       'detection_threshold' : os.path.join(self._data_dir, 'detection_threshold.txt')}

        self._status = dict.fromkeys(self._files.keys(), False)
        self._referenced = True

        # Delete existing data if overwrite is True
        if overwrite and os.path.isdir(self._out_dir):
            shutil.rmtree(self._out_dir)

        # See what data already exists
        self._check_existing_files()

        # Make directories if needed
        if not os.path.isdir(self._out_dir):
            os.makedirs(self._out_dir)

        if not os.path.isdir(self._data_dir):
            os.makedirs(self._data_dir)

        if not os.path.isdir(self._plot_dir):
            os.makedirs(self._plot_dir)

        if not os.path.isdir(os.path.join(file_dir, 'analysis_params')):
            os.makedirs(os.path.join(file_dir, 'analysis_params'))

        # grab recording cutoff time if it already exists
        # cutoff should be in seconds
        self.recording_cutoff = None
        if os.path.isfile(self._files['recording_cutoff']):
            self._status['recording_cutoff'] = True
            with open(self._files['recording_cutoff'], 'r') as f:
                self.recording_cutoff = float(f.read())

        self.detection_threshold = None
        if os.path.isfile(self._files['detection_threshold']):
            self._status['detection_threshold'] = True
            with open(self._files['detection_threshold'], 'r') as f:
                self.detection_threshold = float(f.read())

        # Read in parameters
        # Parameters passed as an argument will overshadow parameters saved in file
        # Input parameters should be formatted as dataset.clustering_parameters
        if params is None and os.path.isfile(self._files['params']):
            self.params = wt.read_dict_from_json(self._files['params'])
        elif params is None:
            raise FileNotFoundError('params must be provided if spike_detection_params.json does not exist.')
        else:
            self.params = {}
            self.params['voltage_cutoff'] = params['data_params']['V_cutoff for disconnected headstage']
            self.params['max_breach_rate'] = params['data_params']['Max rate of cutoff breach per second']
            self.params['max_secs_above_cutoff'] = params['data_params']['Max allowed seconds with a breach']
            self.params['max_mean_breach_rate_persec'] = params['data_params']['Max allowed breaches per second']
            band_lower = params['bandpass_params']['Lower freq cutoff']
            band_upper = params['bandpass_params']['Upper freq cutoff']
            self.params['bandpass'] = [band_lower, band_upper]
            snapshot_pre = params['spike_snapshot']['Time before spike (ms)']
            snapshot_post = params['spike_snapshot']['Time after spike (ms)']
            self.params['spike_snapshot'] = [snapshot_pre, snapshot_post]
            self.params['sampling_rate'] = params['sampling_rate']
            # Write params to json file
            wt.write_dict_to_json(self.params, self._files['params'])
            self._status['params'] = True

    def _check_existing_files(self):
        '''Checks which files already exist and updates _status so as to avoid
        re-creation later
        '''
        for k, v in self._files.items():
            if os.path.isfile(v):
                self._status[k] = True
            else:
                self._status[k] = False

    def run(self):
        status = self._status
        file_dir = self._file_dir
        electrode = self._electrode
        params = self.params
        fs = params['sampling_rate']

        # Check if this even needs to be run
        if all(status.values()):
            return electrode, 1, self.recording_cutoff

        # Grab referenced electrode or raw if ref is not available
        ref_el = h5io.get_referenced_trace(file_dir, electrode)
        if ref_el is None:
            print('Could not find referenced data for electrode %i. Using raw.' % electrode)
            self._referenced = False
            ref_el = h5io.get_raw_trace(file_dir, electrode)
            if ref_el is None:
                raise KeyError('Neither referenced nor raw data found for electrode %i in %s' % (electrode, file_dir))

        # Filter electrode trace
        filt_el = clustering.get_filtered_electrode(ref_el, freq=params['bandpass'],
                                               sampling_rate = fs)
        del ref_el
        # Get recording cutoff
        if not status['recording_cutoff']:
            self.recording_cutoff = get_recording_cutoff(filt_el, **params)
            with open(self._files['recording_cutoff'], 'w') as f:
                f.write(str(self.recording_cutoff))

            status['recording_cutoff'] = True
            fn = os.path.join(self._plot_dir, 'cutoff_time.png')
            dplt.plot_recording_cutoff(filt_el, fs, self.recording_cutoff,
                                       out_file=fn)

        # Truncate electrode trace, deal with early cutoff (<60s)
        if self.recording_cutoff < 60:
            print('Immediate Cutoff for electrode %i...exiting' % electrode)
            return electrode, 0, self.recording_cutoff

        filt_el = filt_el[:int(self.recording_cutoff*fs)]

        if status['spike_waveforms'] and status['spike_times']:
            waves = np.load(self._files['spike_waveforms'])
            times = np.load(self._files['spike_times'])
        else:
            # Detect spikes and get dejittered times and waveforms
            # detect_spikes returns waveforms upsampled by 10x and times in units
            # of samples
            waves, times, threshold = detect_spikes(filt_el, params['spike_snapshot'], fs)
            self.detection_threshold = threshold
            if waves is None:
                print('No waveforms detected on electrode %i' % electrode)
                return electrode, 0, self.recording_cutoff

            # Save waveforms and times
            np.save(self._files['spike_waveforms'], waves)
            np.save(self._files['spike_times'], times)
            with open(self._files['detection_threshold'], 'w') as f:
                f.write(str(threshold))

            status['detection_threshold'] = True
            status['spike_waveforms'] = True
            status['spike_times'] = True

        # Get various metrics and scale waveforms
        if not status['spike_amplitudes']:
            amplitudes = get_waveform_amplitudes(waves)
            np.save(self._files['spike_amplitudes'], amplitudes)
            status['spike_amplitudes'] = True

        if not status['slopes']:
            slopes = get_spike_slopes(waves)
            np.save(self._files['slopes'], slopes)
            status['slopes'] = True

        if not status['energy']:
            energy = get_waveform_energy(waves)
            np.save(self._files['energy'], energy)
            status['energy'] = True
        else:
            energy=None

        # get pca of scaled waveforms
        if not status['pca_waveforms']:
            scaled_waves = scale_waveforms(waves, energy=energy)
            pca_waves, explained_variance_ratio = implement_pca(scaled_waves)

            # Plot explained variance
            fn = os.path.join(self._plot_dir, 'pca_variance.png')
            dplt.plot_explained_pca_variance(explained_variance_ratio,
                                             out_file = fn)

        return electrode, 1, self.recording_cutoff

    def get_spike_waveforms(self):
        '''Returns spike waveforms if they have been extracted, None otherwise
        Dejittered waveforms upsampled to 10 x sampling_rate

        Returns
        -------
        numpy.array
        '''
        if os.path.isfile(self._files['spike_waveforms']):
            return np.load(self._files['spike_waveforms'])
        else:
            return None

    def get_spike_times(self):
        '''Returns spike times if they have been extracted, None otherwise
        In units of samples.

        Returns
        -------
        numpy.array
        '''
        if os.path.isfile(self._files['spike_times']):
            return np.load(self._files['spike_times'])
        else:
            return None

    def get_energy(self):
        '''Returns spike energies if they have been extracted, None otherwise

        Returns
        -------
        numpy.array
        '''
        if os.path.isfile(self._files['energy']):
            return np.load(self._files['energy'])
        else:
            return None

    def get_spike_amplitudes(self):
        '''Returns spike amplitudes if they have been extracted, None otherwise

        Returns
        -------
        numpy.array
        '''
        if os.path.isfile(self._files['spike_amplitudes']):
            return np.load(self._files['spike_amplitudes'])
        else:
            return None

    def get_spike_slopes(self):
        '''Returns spike slopes if they have been extracted, None otherwise

        Returns
        -------
        numpy.array
        '''
        if os.path.isfile(self._files['slopes']):
            return np.load(self._files['slopes'])
        else:
            return None

    def get_pca_waveforms(self):
        '''Returns pca of sclaed spike waveforms if they have been extracted,
        None otherwise
        Dejittered waveforms upsampled to 10 x sampling_rate, scaled to energy
        and transformed via PCA

        Returns
        -------
        numpy.array
        '''
        if os.path.isfile(self._files['spike_waveforms']):
            return np.load(self._files['spike_waveforms'])
        else:
            return None

    def get_clustering_metrics(self, n_pc=3):
        '''Returns array of metrics to use for feature based clustering
        Row for each waveform with columns:
            - amplitude, energy, spike slope, PC1, PC2, etc
        '''
        amplitude = self.get_spike_amplitudes()
        energy = self.get_energy()
        slopes = self.get_spike_slopes()
        pca_waves = self.get_pca_waveforms()
        out = np.vstack((amplitude, energy, slopes)).T
        out = np.hstack((out, pca_waves[:,:n_pc]))
        return out

    def __str__(self):
        out = []
        out.append('SpikeDetection\n--------------')
        out.append('Recording Directory: %s' % self._file_dir)
        out.append('Electrode: %i' % self._electrode)
        out.append('Output Directory: %s' % self._out_dir)
        out.append('###################################\n')
        out.append('Status:')
        out.append(pt.print_dict(self._status))
        out.append('-------------------\n')
        out.append('Parameters:')
        out.append(pt.print_dict(self.params))
        out.append('-------------------\n')
        out.append('Data files:')
        out.append(pt.print_dict(self._files))
        return '\n'.join(out)

Methods

def get_clustering_metrics(self, n_pc=3)

Returns array of metrics to use for feature based clustering Row for each waveform with columns: - amplitude, energy, spike slope, PC1, PC2, etc

Expand source code
def get_clustering_metrics(self, n_pc=3):
    '''Returns array of metrics to use for feature based clustering
    Row for each waveform with columns:
        - amplitude, energy, spike slope, PC1, PC2, etc
    '''
    amplitude = self.get_spike_amplitudes()
    energy = self.get_energy()
    slopes = self.get_spike_slopes()
    pca_waves = self.get_pca_waveforms()
    out = np.vstack((amplitude, energy, slopes)).T
    out = np.hstack((out, pca_waves[:,:n_pc]))
    return out
def get_energy(self)

Returns spike energies if they have been extracted, None otherwise

Returns

numpy.array
 
Expand source code
def get_energy(self):
    '''Returns spike energies if they have been extracted, None otherwise

    Returns
    -------
    numpy.array
    '''
    if os.path.isfile(self._files['energy']):
        return np.load(self._files['energy'])
    else:
        return None
def get_pca_waveforms(self)

Returns pca of sclaed spike waveforms if they have been extracted, None otherwise Dejittered waveforms upsampled to 10 x sampling_rate, scaled to energy and transformed via PCA

Returns

numpy.array
 
Expand source code
def get_pca_waveforms(self):
    '''Returns pca of sclaed spike waveforms if they have been extracted,
    None otherwise
    Dejittered waveforms upsampled to 10 x sampling_rate, scaled to energy
    and transformed via PCA

    Returns
    -------
    numpy.array
    '''
    if os.path.isfile(self._files['spike_waveforms']):
        return np.load(self._files['spike_waveforms'])
    else:
        return None
def get_spike_amplitudes(self)

Returns spike amplitudes if they have been extracted, None otherwise

Returns

numpy.array
 
Expand source code
def get_spike_amplitudes(self):
    '''Returns spike amplitudes if they have been extracted, None otherwise

    Returns
    -------
    numpy.array
    '''
    if os.path.isfile(self._files['spike_amplitudes']):
        return np.load(self._files['spike_amplitudes'])
    else:
        return None
def get_spike_slopes(self)

Returns spike slopes if they have been extracted, None otherwise

Returns

numpy.array
 
Expand source code
def get_spike_slopes(self):
    '''Returns spike slopes if they have been extracted, None otherwise

    Returns
    -------
    numpy.array
    '''
    if os.path.isfile(self._files['slopes']):
        return np.load(self._files['slopes'])
    else:
        return None
def get_spike_times(self)

Returns spike times if they have been extracted, None otherwise In units of samples.

Returns

numpy.array
 
Expand source code
def get_spike_times(self):
    '''Returns spike times if they have been extracted, None otherwise
    In units of samples.

    Returns
    -------
    numpy.array
    '''
    if os.path.isfile(self._files['spike_times']):
        return np.load(self._files['spike_times'])
    else:
        return None
def get_spike_waveforms(self)

Returns spike waveforms if they have been extracted, None otherwise Dejittered waveforms upsampled to 10 x sampling_rate

Returns

numpy.array
 
Expand source code
def get_spike_waveforms(self):
    '''Returns spike waveforms if they have been extracted, None otherwise
    Dejittered waveforms upsampled to 10 x sampling_rate

    Returns
    -------
    numpy.array
    '''
    if os.path.isfile(self._files['spike_waveforms']):
        return np.load(self._files['spike_waveforms'])
    else:
        return None
def run(self)
Expand source code
def run(self):
    status = self._status
    file_dir = self._file_dir
    electrode = self._electrode
    params = self.params
    fs = params['sampling_rate']

    # Check if this even needs to be run
    if all(status.values()):
        return electrode, 1, self.recording_cutoff

    # Grab referenced electrode or raw if ref is not available
    ref_el = h5io.get_referenced_trace(file_dir, electrode)
    if ref_el is None:
        print('Could not find referenced data for electrode %i. Using raw.' % electrode)
        self._referenced = False
        ref_el = h5io.get_raw_trace(file_dir, electrode)
        if ref_el is None:
            raise KeyError('Neither referenced nor raw data found for electrode %i in %s' % (electrode, file_dir))

    # Filter electrode trace
    filt_el = clustering.get_filtered_electrode(ref_el, freq=params['bandpass'],
                                           sampling_rate = fs)
    del ref_el
    # Get recording cutoff
    if not status['recording_cutoff']:
        self.recording_cutoff = get_recording_cutoff(filt_el, **params)
        with open(self._files['recording_cutoff'], 'w') as f:
            f.write(str(self.recording_cutoff))

        status['recording_cutoff'] = True
        fn = os.path.join(self._plot_dir, 'cutoff_time.png')
        dplt.plot_recording_cutoff(filt_el, fs, self.recording_cutoff,
                                   out_file=fn)

    # Truncate electrode trace, deal with early cutoff (<60s)
    if self.recording_cutoff < 60:
        print('Immediate Cutoff for electrode %i...exiting' % electrode)
        return electrode, 0, self.recording_cutoff

    filt_el = filt_el[:int(self.recording_cutoff*fs)]

    if status['spike_waveforms'] and status['spike_times']:
        waves = np.load(self._files['spike_waveforms'])
        times = np.load(self._files['spike_times'])
    else:
        # Detect spikes and get dejittered times and waveforms
        # detect_spikes returns waveforms upsampled by 10x and times in units
        # of samples
        waves, times, threshold = detect_spikes(filt_el, params['spike_snapshot'], fs)
        self.detection_threshold = threshold
        if waves is None:
            print('No waveforms detected on electrode %i' % electrode)
            return electrode, 0, self.recording_cutoff

        # Save waveforms and times
        np.save(self._files['spike_waveforms'], waves)
        np.save(self._files['spike_times'], times)
        with open(self._files['detection_threshold'], 'w') as f:
            f.write(str(threshold))

        status['detection_threshold'] = True
        status['spike_waveforms'] = True
        status['spike_times'] = True

    # Get various metrics and scale waveforms
    if not status['spike_amplitudes']:
        amplitudes = get_waveform_amplitudes(waves)
        np.save(self._files['spike_amplitudes'], amplitudes)
        status['spike_amplitudes'] = True

    if not status['slopes']:
        slopes = get_spike_slopes(waves)
        np.save(self._files['slopes'], slopes)
        status['slopes'] = True

    if not status['energy']:
        energy = get_waveform_energy(waves)
        np.save(self._files['energy'], energy)
        status['energy'] = True
    else:
        energy=None

    # get pca of scaled waveforms
    if not status['pca_waveforms']:
        scaled_waves = scale_waveforms(waves, energy=energy)
        pca_waves, explained_variance_ratio = implement_pca(scaled_waves)

        # Plot explained variance
        fn = os.path.join(self._plot_dir, 'pca_variance.png')
        dplt.plot_explained_pca_variance(explained_variance_ratio,
                                         out_file = fn)

    return electrode, 1, self.recording_cutoff
class SpikeSorter (rec_dirs, electrode, clustering_dir=None, shell=False)
Expand source code
class SpikeSorter(object):
    def __init__(self, rec_dirs, electrode, clustering_dir=None, shell=False):
        if isinstance(rec_dirs, str):
            rec_dirs = [rec_dirs]

        rec_dirs = [x[:-1] if x.endswith(os.sep) else x for x in rec_dirs]
        self.rec_dirs = rec_dirs
        self.electrode = electrode
        if clustering_dir is None:
            if len(rec_dirs) > 1:
                top = os.path.dirname(rec_dirs[0])
                clustering_dir = os.path.join(top, 'BlechClust', 'electrode_%i' % electrode)
            else:
                clustering_dir = os.path.join(rec_dirs[0], 'BlechClust', 'electrode_%i' % electrode)

        self.clustering_dir = clustering_dir
        try:
            clust = BlechClust(rec_dirs, electrode, out_dir = clustering_dir, no_write=True)
        except FileNotFoundError:
            clust = None

        if clust is None or not clust.clustered:
            raise ValueError('Recordings have not been clustered yet.')

        # Match recording directory ordering to clustering
        self.rec_dirs = clust.rec_dirs
        self.clustering = clust
        self._current_solution = None
        self._active = None
        self._last_saved = None
        self._previous = None
        self._shell = shell
        self._split_results = None
        self._split_starter = None
        self._split_index = None
        self._last_umap_embedding = None
        self._last_action = None
        self._last_popped = None  # Dict of indices to clusters
        self._last_added = None  # List of indices

        thresh = []
        for rd in rec_dirs:
            sd = SpikeDetection(rd, electrode)
            thresh.append(sd.detection_threshold)

        self._detection_thresholds = thresh

    def undo(self):
        if self._last_action is None:
            return

        if self._last_action == 'save':
            self.undo_last_save()
            return

        # Remove last added
        for k in reversed(sorted(self._last_added)):
            self._active.pop(k)

        # Insert previous clusters
        for k in sorted(self._last_popped.keys()):
            self._active.insert(k, self._last_popped[k])

        # reset
        self._last_action = None
        self._last_popped = None
        self._last_added = None

    def set_active_clusters(self, solution_num):
        self._current_solution = solution_num
        cluster_nums = list(range(solution_num))
        clusters = self.clustering.get_clusters(solution_num, cluster_nums)
        if len(clusters) == 0:
            raise ValueError('Solution or clusters not found')

        self._active = clusters
        self._last_action = None
        self._last_popped = None
        self._last_added = None

    def save_clusters(self, target_clusters, single_unit, pyramidal, interneuron):
        '''Saves active clusters as cells, write them to the h5_files in the
        appropriate recording directories

        Parameters
        ----------
        target_clusters: list of int
            indicies of active clusters to save
        single_unit : list of bool
            elements in list must correspond to elements in active clusters
        pyramidal : list of bool
        interneuron : list of bool
        '''
        if self._active is None:
            return

        if any([i >= len(self._active) for i in target_clusters]):
            raise ValueError('Target cluster is out of range.')

        n_clusters = len(target_clusters)
        if (len(single_unit) != n_clusters or len(pyramidal) != n_clusters or
            len(interneuron) != n_clusters):
            raise ValueError('Length of input lists must match number of '
                             'active clusters. Expected %i' % n_clusters)

        self._last_action = 'save'
        self._last_popped = {i: self._active[i] for i in target_clusters}
        self._last_added = []
        clusters = [self._active[i] for i in target_clusters]
        rec_key = self.clustering._rec_key
        self._last_saved = dict.fromkeys(rec_key.keys(), None)

        for clust, single, pyr, intr in zip(clusters, single_unit,
                                            pyramidal, interneuron):
            for i, rec in rec_key.items():
                idx = np.where(clust['spike_map'] == i)[0]
                if len(idx) == 0:
                    continue

                waves = clust['spike_waveforms'][idx]
                times = clust['spike_times'][idx]
                unit_name = h5io.add_new_unit(rec, self.electrode, waves,
                                              times, single, pyr, intr)
                if self._last_saved[i] is None:
                    self._last_saved[i] = [unit_name]
                else:
                    self._last_saved[i].append(unit_name)

                metrics_dir = os.path.join(rec,'sorted_unit_metrics', unit_name)
                if not os.path.isdir(metrics_dir):
                    os.makedirs(metrics_dir)

                # Write cluster info to file
                print_clust = clust.copy()
                for k,v in clust.items():
                    if isinstance(v, np.ndarray):
                        print_clust.pop(k)

                print_clust.pop('rec_key')
                print_clust.pop('fs')
                clust_info_file = os.path.join(metrics_dir, 'cluster.info')
                with open(clust_info_file, 'a+') as log:
                    print('%s sorted on %s'
                          % (unit_name,
                             dt.datetime.today().strftime('%m/%d/%y %H:%M')),
                          file=log)
                    print('Cluster info: \n----------', file=log)
                    print(pt.print_dict(print_clust), file=log)
                    print('Saved metrics to %s' % metrics_dir, file=log)
                    print('--------------\n', file=log)

        userIO.tell_user('Target clusters successfully saved to recording '
                         'directories.', shell=True)
        self._active = [self._active[i] for i in range(len(self._active))
                        if i not in target_clusters]

    def undo_last_save(self):
        if self._last_saved is None:
            return

        rec_key = self.clustering._rec_key
        last_saved = self._last_saved
        for i, rec in rec_key.items():
            for unit in reversed(np.sort(last_saved[i])):
                h5io.delete_unit(rec, unit)

        for k in sorted(self._last_popped.keys()):
            self._active.insert(k, self._last_popped[k])

        self._last_saved = None
        self._last_popped = None
        self._last_added = None
        self._last_action = None

    def split_cluster(self, target_clust, n_iter, n_restart, thresh, n_clust,
                      store_split=False, umap=False):
        '''splits the target active cluster using a GMM
        '''
        if target_clust >= len(self._active):
            raise ValueError('Invalid target. Only %i active clusters' % len(self._active))

        cluster = self._active.pop(target_clust)
        self._split_starter = cluster
        self._split_index = target_clust

        try:
            GMM = ClusterGMM(n_iter, n_restart, thresh)
            waves = cluster['spike_waveforms']
            data, data_columns = compute_waveform_metrics(waves, umap=umap)
            model, predictions, bic = GMM.fit(data, n_clust)
            new_clusts = []
            for i in np.unique(predictions):
                idx = np.where(predictions == i)[0]
                edit_str = (cluster['manipulations'] + '\nSplit %s into %i '
                            'clusters. This is sub-cluster %i'
                            % (cluster['Cluster_Name'], n_clust, i))
                tmp_clust = SpikeCluster(cluster['Cluster_Name'] + '-%i' % i,
                                         cluster['electrode_num'],
                                         cluster['solution_num'],
                                         cluster['cluster_num'],
                                         cluster['cluster_id']*10+i,
                                         waves[idx],
                                         cluster['spike_times'][idx],
                                         cluster['spike_map'][idx],
                                         cluster['rec_key'].copy(),
                                         cluster['fs'].copy(),
                                         cluster['offsets'].copy(),
                                         manipulations=edit_str)
                new_clusts.append(tmp_clust)

            # Plot cluster and ask to choose which to keep
            figs = []
            for i, c in enumerate(new_clusts):
                _, viol_1ms, viol_2ms = get_ISI_and_violations(c['spike_times'], c['fs'], c['spike_map'])
                plot_title = ('Index: %i\n1ms violations: %i, 2ms violations: %i\n'
                              'Total Waveforms: %i'
                              % (i, viol_1ms, viol_2ms, len(c['spike_times'])))
                tmp_fig, _ = dplt.plot_waveforms(c['spike_waveforms'], title=plot_title)
                figs.append(tmp_fig)
                tmp_fig.show()

            f2 = dplt.plot_waveforms_pca([c['spike_waveforms'] for c in new_clusts])
            figs.append(f2)
            f2.show()
        except:
            # So cluster isn't lost with error
            self._active.insert(target_clust, cluster)
            self._split_starter = None
            self._split_index = None
            raise

        if store_split:
            self._split_results = new_clusts
            return new_clusts
        else:
            self._split_starter = None
            self._split_index = None
            selection_list = ['all'] + ['%i' % i for i in range(len(new_clusts))]
            prompt = 'Select split clusters to keep\nCancel to reset.'
            ans = userIO.select_from_list(prompt, selection_list,
                                          multi_select=True, shell=self._shell)
            if ans is None or 'all' in ans:
                print('Reset to before split')
                self._active.insert(target_clust, cluster)
            else:
                keepers = [new_clusts[int(i)] for i in ans]
                start_idx = len(self._active)
                self._last_added = list(range(start_idx, start_idx+len(keepers)))
                self._last_popped = {target_clust: cluster}
                self._last_action = 'split'
                self._active.extend(keepers)

            return True

    def set_split(self, choices):
        if self._split_starter is None:
            raise ValueError('Not split stored.')

        if len(choices) == 0:
            self._active.insert(self._split_index, self._split_starter)
        else:
            keepers = [self._split_results[i] for i in choices]
            start_idx = len(self._active)
            self._last_added = list(range(start_idx, start_idx+len(keepers)))
            self._last_popped = {self._split_index: self._split_starter}
            self._last_action = 'split'
            self._active.extend(keepers)

        self._split_index = None
        self._split_results = None
        self._split_starter = None

    def merge_clusters(self, target_clusters):
        if any([i >= len(self._active) for i in target_clusters]):
            raise ValueError('Target cluster is out of range.')

        new_clust = []
        self._last_popped = {}
        self._last_action = 'merge'
        self._last_added = []
        for c in target_clusters:
            self._last_popped[c] = self._active[c]

            if len(new_clust) == 0:
                new_clust = deepcopy(self._active[c])
                continue

            clust = self._active[c]
            sm1 = new_clust['spike_map']
            sm2 = clust['spike_map']
            st1 = new_clust['spike_times']
            st2 = clust['spike_times']
            sw1 = new_clust['spike_waveforms']
            sw2 = clust['spike_waveforms']

            spike_map = np.hstack((sm1, sm2))
            spike_times = np.hstack((st1, st2))
            spike_waveforms = np.vstack((sw1, sw2))

            # Re-order to spike_map
            idx = np.argsort(spike_map)
            spike_map = spike_map[idx]
            spike_times = spike_times[idx]
            spike_waveforms = spike_waveforms[idx]

            # Re-order so spike_times within a reocrding are in order
            times = []
            waves = []
            new_map = []
            for i in np.unique(spike_map):
                idx = np.where(spike_map == i)[0]
                st = spike_times[idx]
                sw = spike_waveforms[idx]
                sm = spike_map[idx]
                idx2 = np.argsort(st)
                st = st[idx2]
                sw = sw[idx2]
                sm = sm[idx2]
                times.append(st)
                waves.append(sw)
                new_map.append(sm)

            times = np.hstack(times)
            waves = np.vstack(waves)
            spike_map = np.hstack(new_map)
            del new_map, spike_times, spike_waveforms

            new_clust['spike_map'] = spike_map
            new_clust['spike_times'] = times
            new_clust['spike_waveforms'] = waves
            new_clust['manipulations'] += '\nMerged with %s.' % clust['Cluster_Name']
            new_clust['Cluster_Name'] += '+' + clust['Cluster_Name'].replace('Cluster_','')

        self._active = [self._active[i] for i in range(len(self._active))
                        if i not in target_clusters]

        self._last_added = [len(self._active)]
        self._active.append(new_clust)

    def discard_clusters(self, target_clusters):
        if isinstance(target_clusters, int):
            target_clusters = [target_clusters]

        if len(target_clusters) == 0:
            return

        self._last_action = 'discard'
        self._last_popped = {i: self._active[i] for i in target_clusters}
        self._last_added = []
        self._active = [self._active[i] for i in range(len(self._active))
                        if i not in target_clusters]

    def plot_clusters_waveforms(self, target_clusters):
        if len(target_clusters) == 0:
            return

        for i in target_clusters:
            c = self._active[i]
            isi, v1, v2 = get_ISI_and_violations(c['spike_times'], c['fs'], c['spike_map'])
            title = ('Index : %i\n1ms violations: %0.1f, 2ms violations: %0.1f'
                     '\ntotal waveforms: %i'
                     % (i, v1, v2, len(c['spike_waveforms'])))
            fig, ax = dplt.plot_waveforms(c['spike_waveforms'], title=title,
                                          threshold=self._detection_thresholds[0])
            fig.show()

    def split_by_rec(self, target_cluster):
        if isinstance(target_cluster, list) and len(target_cluster) != 1:
            return
        elif isinstance(target_cluster, list):
            target_cluster = target_clsuter[0]

        clust = self._active[target_cluster]
        sm = clust['spike_map']
        recs = np.unique(sm)
        if len(sm) == 1:
            return
        else:
            clust = self._active.pop(target_cluster)
            st = clust['spike_times']
            sw = clust['spike_waveforms']
            keepers = []
            for i in recs:
                idx = np.where(sm == i)[0]
                new_clust = deepcopy(clust)
                new_clust['spike_times'] = st[idx]
                new_clust['spike_waveforms'] = sw[idx, :]
                new_clust['cluster_id'] = clust['cluster_id']*10 + i
                new_clust['spike_map'] = sm[idx]
                new_clust['manipulations'] = '\nSplit by recording'
                keepers.append(new_clust)

        start_idx = len(self._active)
        self._last_added = list(range(start_idx, start_idx+len(keepers)))
        self._last_popped = {target_cluster: clust}
        self._last_action = 'split'
        self._active.extend(keepers)

    def plot_cluster_waveforms_by_rec(self, target_cluster):
        if isinstance(target_cluster, list) and len(target_cluster) != 1:
            return
        elif isinstance(target_cluster, list):
            target_cluster = target_cluster[0]

        c = self._active[target_cluster]
        sm = c['spike_map']
        for i in np.unique(sm):
            idx = np.where(sm==i)[0]
            waves = c['spike_waveforms'][idx, :]
            isi, v1, v2 = get_ISI_and_violations(c['spike_times'][idx], c['fs'][i])
            title = ('Index : %i, Rec: %i\n1ms violations: %0.1f, 2ms violations: %0.1f'
                     '\ntotal waveforms: %i'
                     % (target_cluster, i, v1, v2, len(waves)))
            fig, ax = dplt.plot_waveforms(waves, title=title)
            fig.show()

    def plot_cluster_waveforms_over_time(self, target_cluster, interval):
        if isinstance(target_cluster, list) and len(target_cluster) != 1:
            return
        elif isinstance(target_cluster, list):
            target_cluster = target_cluster[0]

        c = self._active[target_cluster]
        spike_times = c.get_spike_time_vector('s')
        start_times = np.arange(spike_times[0], spike_times[-1]+1, interval)
        if len(start_times) > 10:
            userIO.tell_user('This would open more than 10 figures, choose a larger interval')
            return

        if len(start_times) == 0:
            return

        for i, start_time in enumerate(start_times):
            idx = np.where((spike_times >= start_time) & (spike_times < start_time+interval))[0]
            waves = c['spike_waveforms'][idx,:]
            isi, v1, v2 = get_ISI_and_violations(c['spike_times'][idx], c['fs'][0])
            title = ('Index : %i, Rec: %i\n1ms violations: %0.1f, 2ms violations: %0.1f'
                     '\ntotal waveforms: %i'
                     % (target_cluster, i, v1, v2, len(waves)))
            fig, ax = dplt.plot_waveforms(waves, title=title,
                                          threshold=self._detection_thresholds[0])
            fig.show()

    def plot_clusters_pca(self, target_clusters):
        if len(target_clusters) == 0:
            return

        waves = [self._active[i]['spike_waveforms'] for i in target_clusters]
        fig = dplt.plot_waveforms_pca(waves, cluster_ids=target_clusters)
        fig.show()

    def plot_clusters_umap(self, target_clusters):
        if len(target_clusters) == 0:
            return

        waves = [self._active[i]['spike_waveforms'] for i in target_clusters]
        fig = dplt.plot_waveforms_umap(waves, cluster_ids=target_clusters)
        fig.show()

    def plot_clusters_wavelets(self, target_clusters):
        if len(target_clusters) == 0:
            return

        waves = [self._active[i]['spike_waveforms'] for i in target_clusters]
        fig, ax = dplt.plot_waveforms_wavelet_tranform(waves,
                                                       cluster_ids=target_clusters,
                                                       n_pc=4)
        fig.show()

    def plot_clusters_raster(self, target_clusters):
        if len(target_clusters) == 0:
            return

        clusters = [self._active[i] for i in target_clusters]
        spike_times = []
        spike_waves = []
        vlines = {}
        for c in clusters:
            # Adjust spike times by offset so recordings are not overlapping
            st = c.get_spike_time_vector(units='s')
            vlines = {i: c['offsets'][i] / c['fs'][i] for i in c['fs'].keys()}
            spike_times.append(st)
            spike_waves.append(c['spike_waveforms'])

        fig, ax = dplt.plot_spike_raster(spike_times, spike_waves, target_clusters)
        ax.set_xlabel('Time (s)')
        for x in vlines.values():
            ax.axvline(x, color='black', linewidth=2)

        fig.show()

    def plot_clusters_ISI(self, target_clusters):
        if len(target_clusters) == 0:
            return

        for i in target_clusters:
            cluster = self._active[i]
            isi, v1, v2 = get_ISI_and_violations(cluster['spike_times'],
                                                 cluster['fs'],
                                                 cluster['spike_map'])
            fig, ax = dplt.plot_ISIs(isi, total_spikes=len(cluster['spike_times']))
            title= ax.get_title()
            title = 'Index: %i\n%s' % (i, title)
            ax.set_title(title)
            fig.show()

    def plot_clusters_acorr(self, target_clusters):
        if len(target_clusters) == 0 or not all([x < len(self._active) for x in target_clusters]):
            return

        for i in target_clusters:
            cluster = self._active[i]
            acf, bin_centers, edges = sas.spike_time_acorr(cluster.get_spike_time_vector(units='ms'))
            fig, ax = dplt.plot_correlogram(acf, bin_centers, edges)
            title = 'Index: %i\nAutocorrelogram' % (i)
            ax.set_title(title)
            fig.show()

    def plot_clusters_xcorr(self, target_clusters):
        if len(target_clusters) == 0 or not all([x < len(self._active) for x in target_clusters]):
            return

        pairs = it.combinations(target_clusters, 2)
        for x, y in pairs:
            clust1 = self._active[x]
            clust2 = self._active[y]
            xcf, bin_centers, edges = sas.spike_time_xcorr(clust1.get_spike_time_vector(units='ms'),
                                                    clust2.get_spike_time_vector(units='ms'))
            fig, ax = dplt.plot_correlogram(xcf, bin_centers, edges)
            title = 'Cross-correlogram\n%i vs %i' % (x, y)
            ax.set_title(title)
            fig.show()

    def get_mean_waveform(self, target_cluster):
        '''Returns mean waveform of target_cluster in active clusters. Also
        returns St. Dev. of waveforms
        '''
        cluster = self._active[target_cluster]
        return cluster.get_mean_waveform()

    def get_possible_solutions(self):
        results = self.clustering.results.dropna()
        converged = list(results[results['converged']].index)
        return converged

Methods

def discard_clusters(self, target_clusters)
Expand source code
def discard_clusters(self, target_clusters):
    if isinstance(target_clusters, int):
        target_clusters = [target_clusters]

    if len(target_clusters) == 0:
        return

    self._last_action = 'discard'
    self._last_popped = {i: self._active[i] for i in target_clusters}
    self._last_added = []
    self._active = [self._active[i] for i in range(len(self._active))
                    if i not in target_clusters]
def get_mean_waveform(self, target_cluster)

Returns mean waveform of target_cluster in active clusters. Also returns St. Dev. of waveforms

Expand source code
def get_mean_waveform(self, target_cluster):
    '''Returns mean waveform of target_cluster in active clusters. Also
    returns St. Dev. of waveforms
    '''
    cluster = self._active[target_cluster]
    return cluster.get_mean_waveform()
def get_possible_solutions(self)
Expand source code
def get_possible_solutions(self):
    results = self.clustering.results.dropna()
    converged = list(results[results['converged']].index)
    return converged
def merge_clusters(self, target_clusters)
Expand source code
def merge_clusters(self, target_clusters):
    if any([i >= len(self._active) for i in target_clusters]):
        raise ValueError('Target cluster is out of range.')

    new_clust = []
    self._last_popped = {}
    self._last_action = 'merge'
    self._last_added = []
    for c in target_clusters:
        self._last_popped[c] = self._active[c]

        if len(new_clust) == 0:
            new_clust = deepcopy(self._active[c])
            continue

        clust = self._active[c]
        sm1 = new_clust['spike_map']
        sm2 = clust['spike_map']
        st1 = new_clust['spike_times']
        st2 = clust['spike_times']
        sw1 = new_clust['spike_waveforms']
        sw2 = clust['spike_waveforms']

        spike_map = np.hstack((sm1, sm2))
        spike_times = np.hstack((st1, st2))
        spike_waveforms = np.vstack((sw1, sw2))

        # Re-order to spike_map
        idx = np.argsort(spike_map)
        spike_map = spike_map[idx]
        spike_times = spike_times[idx]
        spike_waveforms = spike_waveforms[idx]

        # Re-order so spike_times within a reocrding are in order
        times = []
        waves = []
        new_map = []
        for i in np.unique(spike_map):
            idx = np.where(spike_map == i)[0]
            st = spike_times[idx]
            sw = spike_waveforms[idx]
            sm = spike_map[idx]
            idx2 = np.argsort(st)
            st = st[idx2]
            sw = sw[idx2]
            sm = sm[idx2]
            times.append(st)
            waves.append(sw)
            new_map.append(sm)

        times = np.hstack(times)
        waves = np.vstack(waves)
        spike_map = np.hstack(new_map)
        del new_map, spike_times, spike_waveforms

        new_clust['spike_map'] = spike_map
        new_clust['spike_times'] = times
        new_clust['spike_waveforms'] = waves
        new_clust['manipulations'] += '\nMerged with %s.' % clust['Cluster_Name']
        new_clust['Cluster_Name'] += '+' + clust['Cluster_Name'].replace('Cluster_','')

    self._active = [self._active[i] for i in range(len(self._active))
                    if i not in target_clusters]

    self._last_added = [len(self._active)]
    self._active.append(new_clust)
def plot_cluster_waveforms_by_rec(self, target_cluster)
Expand source code
def plot_cluster_waveforms_by_rec(self, target_cluster):
    if isinstance(target_cluster, list) and len(target_cluster) != 1:
        return
    elif isinstance(target_cluster, list):
        target_cluster = target_cluster[0]

    c = self._active[target_cluster]
    sm = c['spike_map']
    for i in np.unique(sm):
        idx = np.where(sm==i)[0]
        waves = c['spike_waveforms'][idx, :]
        isi, v1, v2 = get_ISI_and_violations(c['spike_times'][idx], c['fs'][i])
        title = ('Index : %i, Rec: %i\n1ms violations: %0.1f, 2ms violations: %0.1f'
                 '\ntotal waveforms: %i'
                 % (target_cluster, i, v1, v2, len(waves)))
        fig, ax = dplt.plot_waveforms(waves, title=title)
        fig.show()
def plot_cluster_waveforms_over_time(self, target_cluster, interval)
Expand source code
def plot_cluster_waveforms_over_time(self, target_cluster, interval):
    if isinstance(target_cluster, list) and len(target_cluster) != 1:
        return
    elif isinstance(target_cluster, list):
        target_cluster = target_cluster[0]

    c = self._active[target_cluster]
    spike_times = c.get_spike_time_vector('s')
    start_times = np.arange(spike_times[0], spike_times[-1]+1, interval)
    if len(start_times) > 10:
        userIO.tell_user('This would open more than 10 figures, choose a larger interval')
        return

    if len(start_times) == 0:
        return

    for i, start_time in enumerate(start_times):
        idx = np.where((spike_times >= start_time) & (spike_times < start_time+interval))[0]
        waves = c['spike_waveforms'][idx,:]
        isi, v1, v2 = get_ISI_and_violations(c['spike_times'][idx], c['fs'][0])
        title = ('Index : %i, Rec: %i\n1ms violations: %0.1f, 2ms violations: %0.1f'
                 '\ntotal waveforms: %i'
                 % (target_cluster, i, v1, v2, len(waves)))
        fig, ax = dplt.plot_waveforms(waves, title=title,
                                      threshold=self._detection_thresholds[0])
        fig.show()
def plot_clusters_ISI(self, target_clusters)
Expand source code
def plot_clusters_ISI(self, target_clusters):
    if len(target_clusters) == 0:
        return

    for i in target_clusters:
        cluster = self._active[i]
        isi, v1, v2 = get_ISI_and_violations(cluster['spike_times'],
                                             cluster['fs'],
                                             cluster['spike_map'])
        fig, ax = dplt.plot_ISIs(isi, total_spikes=len(cluster['spike_times']))
        title= ax.get_title()
        title = 'Index: %i\n%s' % (i, title)
        ax.set_title(title)
        fig.show()
def plot_clusters_acorr(self, target_clusters)
Expand source code
def plot_clusters_acorr(self, target_clusters):
    if len(target_clusters) == 0 or not all([x < len(self._active) for x in target_clusters]):
        return

    for i in target_clusters:
        cluster = self._active[i]
        acf, bin_centers, edges = sas.spike_time_acorr(cluster.get_spike_time_vector(units='ms'))
        fig, ax = dplt.plot_correlogram(acf, bin_centers, edges)
        title = 'Index: %i\nAutocorrelogram' % (i)
        ax.set_title(title)
        fig.show()
def plot_clusters_pca(self, target_clusters)
Expand source code
def plot_clusters_pca(self, target_clusters):
    if len(target_clusters) == 0:
        return

    waves = [self._active[i]['spike_waveforms'] for i in target_clusters]
    fig = dplt.plot_waveforms_pca(waves, cluster_ids=target_clusters)
    fig.show()
def plot_clusters_raster(self, target_clusters)
Expand source code
def plot_clusters_raster(self, target_clusters):
    if len(target_clusters) == 0:
        return

    clusters = [self._active[i] for i in target_clusters]
    spike_times = []
    spike_waves = []
    vlines = {}
    for c in clusters:
        # Adjust spike times by offset so recordings are not overlapping
        st = c.get_spike_time_vector(units='s')
        vlines = {i: c['offsets'][i] / c['fs'][i] for i in c['fs'].keys()}
        spike_times.append(st)
        spike_waves.append(c['spike_waveforms'])

    fig, ax = dplt.plot_spike_raster(spike_times, spike_waves, target_clusters)
    ax.set_xlabel('Time (s)')
    for x in vlines.values():
        ax.axvline(x, color='black', linewidth=2)

    fig.show()
def plot_clusters_umap(self, target_clusters)
Expand source code
def plot_clusters_umap(self, target_clusters):
    if len(target_clusters) == 0:
        return

    waves = [self._active[i]['spike_waveforms'] for i in target_clusters]
    fig = dplt.plot_waveforms_umap(waves, cluster_ids=target_clusters)
    fig.show()
def plot_clusters_waveforms(self, target_clusters)
Expand source code
def plot_clusters_waveforms(self, target_clusters):
    if len(target_clusters) == 0:
        return

    for i in target_clusters:
        c = self._active[i]
        isi, v1, v2 = get_ISI_and_violations(c['spike_times'], c['fs'], c['spike_map'])
        title = ('Index : %i\n1ms violations: %0.1f, 2ms violations: %0.1f'
                 '\ntotal waveforms: %i'
                 % (i, v1, v2, len(c['spike_waveforms'])))
        fig, ax = dplt.plot_waveforms(c['spike_waveforms'], title=title,
                                      threshold=self._detection_thresholds[0])
        fig.show()
def plot_clusters_wavelets(self, target_clusters)
Expand source code
def plot_clusters_wavelets(self, target_clusters):
    if len(target_clusters) == 0:
        return

    waves = [self._active[i]['spike_waveforms'] for i in target_clusters]
    fig, ax = dplt.plot_waveforms_wavelet_tranform(waves,
                                                   cluster_ids=target_clusters,
                                                   n_pc=4)
    fig.show()
def plot_clusters_xcorr(self, target_clusters)
Expand source code
def plot_clusters_xcorr(self, target_clusters):
    if len(target_clusters) == 0 or not all([x < len(self._active) for x in target_clusters]):
        return

    pairs = it.combinations(target_clusters, 2)
    for x, y in pairs:
        clust1 = self._active[x]
        clust2 = self._active[y]
        xcf, bin_centers, edges = sas.spike_time_xcorr(clust1.get_spike_time_vector(units='ms'),
                                                clust2.get_spike_time_vector(units='ms'))
        fig, ax = dplt.plot_correlogram(xcf, bin_centers, edges)
        title = 'Cross-correlogram\n%i vs %i' % (x, y)
        ax.set_title(title)
        fig.show()
def save_clusters(self, target_clusters, single_unit, pyramidal, interneuron)

Saves active clusters as cells, write them to the h5_files in the appropriate recording directories

Parameters

target_clusters : list of int
indicies of active clusters to save
single_unit : list of bool
elements in list must correspond to elements in active clusters
pyramidal : list of bool
 
interneuron : list of bool
 
Expand source code
def save_clusters(self, target_clusters, single_unit, pyramidal, interneuron):
    '''Saves active clusters as cells, write them to the h5_files in the
    appropriate recording directories

    Parameters
    ----------
    target_clusters: list of int
        indicies of active clusters to save
    single_unit : list of bool
        elements in list must correspond to elements in active clusters
    pyramidal : list of bool
    interneuron : list of bool
    '''
    if self._active is None:
        return

    if any([i >= len(self._active) for i in target_clusters]):
        raise ValueError('Target cluster is out of range.')

    n_clusters = len(target_clusters)
    if (len(single_unit) != n_clusters or len(pyramidal) != n_clusters or
        len(interneuron) != n_clusters):
        raise ValueError('Length of input lists must match number of '
                         'active clusters. Expected %i' % n_clusters)

    self._last_action = 'save'
    self._last_popped = {i: self._active[i] for i in target_clusters}
    self._last_added = []
    clusters = [self._active[i] for i in target_clusters]
    rec_key = self.clustering._rec_key
    self._last_saved = dict.fromkeys(rec_key.keys(), None)

    for clust, single, pyr, intr in zip(clusters, single_unit,
                                        pyramidal, interneuron):
        for i, rec in rec_key.items():
            idx = np.where(clust['spike_map'] == i)[0]
            if len(idx) == 0:
                continue

            waves = clust['spike_waveforms'][idx]
            times = clust['spike_times'][idx]
            unit_name = h5io.add_new_unit(rec, self.electrode, waves,
                                          times, single, pyr, intr)
            if self._last_saved[i] is None:
                self._last_saved[i] = [unit_name]
            else:
                self._last_saved[i].append(unit_name)

            metrics_dir = os.path.join(rec,'sorted_unit_metrics', unit_name)
            if not os.path.isdir(metrics_dir):
                os.makedirs(metrics_dir)

            # Write cluster info to file
            print_clust = clust.copy()
            for k,v in clust.items():
                if isinstance(v, np.ndarray):
                    print_clust.pop(k)

            print_clust.pop('rec_key')
            print_clust.pop('fs')
            clust_info_file = os.path.join(metrics_dir, 'cluster.info')
            with open(clust_info_file, 'a+') as log:
                print('%s sorted on %s'
                      % (unit_name,
                         dt.datetime.today().strftime('%m/%d/%y %H:%M')),
                      file=log)
                print('Cluster info: \n----------', file=log)
                print(pt.print_dict(print_clust), file=log)
                print('Saved metrics to %s' % metrics_dir, file=log)
                print('--------------\n', file=log)

    userIO.tell_user('Target clusters successfully saved to recording '
                     'directories.', shell=True)
    self._active = [self._active[i] for i in range(len(self._active))
                    if i not in target_clusters]
def set_active_clusters(self, solution_num)
Expand source code
def set_active_clusters(self, solution_num):
    self._current_solution = solution_num
    cluster_nums = list(range(solution_num))
    clusters = self.clustering.get_clusters(solution_num, cluster_nums)
    if len(clusters) == 0:
        raise ValueError('Solution or clusters not found')

    self._active = clusters
    self._last_action = None
    self._last_popped = None
    self._last_added = None
def set_split(self, choices)
Expand source code
def set_split(self, choices):
    if self._split_starter is None:
        raise ValueError('Not split stored.')

    if len(choices) == 0:
        self._active.insert(self._split_index, self._split_starter)
    else:
        keepers = [self._split_results[i] for i in choices]
        start_idx = len(self._active)
        self._last_added = list(range(start_idx, start_idx+len(keepers)))
        self._last_popped = {self._split_index: self._split_starter}
        self._last_action = 'split'
        self._active.extend(keepers)

    self._split_index = None
    self._split_results = None
    self._split_starter = None
def split_by_rec(self, target_cluster)
Expand source code
def split_by_rec(self, target_cluster):
    if isinstance(target_cluster, list) and len(target_cluster) != 1:
        return
    elif isinstance(target_cluster, list):
        target_cluster = target_clsuter[0]

    clust = self._active[target_cluster]
    sm = clust['spike_map']
    recs = np.unique(sm)
    if len(sm) == 1:
        return
    else:
        clust = self._active.pop(target_cluster)
        st = clust['spike_times']
        sw = clust['spike_waveforms']
        keepers = []
        for i in recs:
            idx = np.where(sm == i)[0]
            new_clust = deepcopy(clust)
            new_clust['spike_times'] = st[idx]
            new_clust['spike_waveforms'] = sw[idx, :]
            new_clust['cluster_id'] = clust['cluster_id']*10 + i
            new_clust['spike_map'] = sm[idx]
            new_clust['manipulations'] = '\nSplit by recording'
            keepers.append(new_clust)

    start_idx = len(self._active)
    self._last_added = list(range(start_idx, start_idx+len(keepers)))
    self._last_popped = {target_cluster: clust}
    self._last_action = 'split'
    self._active.extend(keepers)
def split_cluster(self, target_clust, n_iter, n_restart, thresh, n_clust, store_split=False, umap=False)

splits the target active cluster using a GMM

Expand source code
def split_cluster(self, target_clust, n_iter, n_restart, thresh, n_clust,
                  store_split=False, umap=False):
    '''splits the target active cluster using a GMM
    '''
    if target_clust >= len(self._active):
        raise ValueError('Invalid target. Only %i active clusters' % len(self._active))

    cluster = self._active.pop(target_clust)
    self._split_starter = cluster
    self._split_index = target_clust

    try:
        GMM = ClusterGMM(n_iter, n_restart, thresh)
        waves = cluster['spike_waveforms']
        data, data_columns = compute_waveform_metrics(waves, umap=umap)
        model, predictions, bic = GMM.fit(data, n_clust)
        new_clusts = []
        for i in np.unique(predictions):
            idx = np.where(predictions == i)[0]
            edit_str = (cluster['manipulations'] + '\nSplit %s into %i '
                        'clusters. This is sub-cluster %i'
                        % (cluster['Cluster_Name'], n_clust, i))
            tmp_clust = SpikeCluster(cluster['Cluster_Name'] + '-%i' % i,
                                     cluster['electrode_num'],
                                     cluster['solution_num'],
                                     cluster['cluster_num'],
                                     cluster['cluster_id']*10+i,
                                     waves[idx],
                                     cluster['spike_times'][idx],
                                     cluster['spike_map'][idx],
                                     cluster['rec_key'].copy(),
                                     cluster['fs'].copy(),
                                     cluster['offsets'].copy(),
                                     manipulations=edit_str)
            new_clusts.append(tmp_clust)

        # Plot cluster and ask to choose which to keep
        figs = []
        for i, c in enumerate(new_clusts):
            _, viol_1ms, viol_2ms = get_ISI_and_violations(c['spike_times'], c['fs'], c['spike_map'])
            plot_title = ('Index: %i\n1ms violations: %i, 2ms violations: %i\n'
                          'Total Waveforms: %i'
                          % (i, viol_1ms, viol_2ms, len(c['spike_times'])))
            tmp_fig, _ = dplt.plot_waveforms(c['spike_waveforms'], title=plot_title)
            figs.append(tmp_fig)
            tmp_fig.show()

        f2 = dplt.plot_waveforms_pca([c['spike_waveforms'] for c in new_clusts])
        figs.append(f2)
        f2.show()
    except:
        # So cluster isn't lost with error
        self._active.insert(target_clust, cluster)
        self._split_starter = None
        self._split_index = None
        raise

    if store_split:
        self._split_results = new_clusts
        return new_clusts
    else:
        self._split_starter = None
        self._split_index = None
        selection_list = ['all'] + ['%i' % i for i in range(len(new_clusts))]
        prompt = 'Select split clusters to keep\nCancel to reset.'
        ans = userIO.select_from_list(prompt, selection_list,
                                      multi_select=True, shell=self._shell)
        if ans is None or 'all' in ans:
            print('Reset to before split')
            self._active.insert(target_clust, cluster)
        else:
            keepers = [new_clusts[int(i)] for i in ans]
            start_idx = len(self._active)
            self._last_added = list(range(start_idx, start_idx+len(keepers)))
            self._last_popped = {target_clust: cluster}
            self._last_action = 'split'
            self._active.extend(keepers)

        return True
def undo(self)
Expand source code
def undo(self):
    if self._last_action is None:
        return

    if self._last_action == 'save':
        self.undo_last_save()
        return

    # Remove last added
    for k in reversed(sorted(self._last_added)):
        self._active.pop(k)

    # Insert previous clusters
    for k in sorted(self._last_popped.keys()):
        self._active.insert(k, self._last_popped[k])

    # reset
    self._last_action = None
    self._last_popped = None
    self._last_added = None
def undo_last_save(self)
Expand source code
def undo_last_save(self):
    if self._last_saved is None:
        return

    rec_key = self.clustering._rec_key
    last_saved = self._last_saved
    for i, rec in rec_key.items():
        for unit in reversed(np.sort(last_saved[i])):
            h5io.delete_unit(rec, unit)

    for k in sorted(self._last_popped.keys()):
        self._active.insert(k, self._last_popped[k])

    self._last_saved = None
    self._last_popped = None
    self._last_added = None
    self._last_action = None