Module blechpy.analysis.poissonHMM_deprecated

Expand source code
import os
import math
import numpy as np
import itertools as it
import pylab as plt
import seaborn as sns
import pandas as pd
import multiprocessing as mp
import tables
#from scipy.spatial.distance import euclidean
from numba import njit
from blechpy.utils.particles import HMMInfoParticle
from blechpy import load_dataset
from blechpy.dio import h5io
from blechpy.plotting import hmm_plot as hmmplt
from joblib import Parallel, delayed, Memory
from appdirs import user_cache_dir
cachedir = user_cache_dir('blechpy')
memory = Memory(cachedir, verbose=0)



TEST_PARAMS = {'n_cells': 10, 'n_states': 4, 'state_seq_length': 5,
               'trial_time': 3.5, 'dt': 0.001, 'max_rate': 50, 'n_trials': 15,
               'min_state_dur': 0.05, 'noise': 0.01, 'baseline_dur': 1}

FACTORIAL_LOOKUP = np.array([math.factorial(x) for x in range(20)])


@njit
def fast_factorial(x):
    if x < len(FACTORIAL_LOOKUP):
        return FACTORIAL_LOOKUP[x]
    else:
        y = 1
        for i in range(1,x+1):
            y = y*i

        return y


@njit
def poisson(rate, n, dt):
    '''Gives probability of each neurons spike count assuming poisson spiking
    '''
    tmp = np.power(rate*dt, n) / np.array([fast_factorial(x) for x in n])
    tmp = tmp * np.exp(-rate*dt)
    return tmp


@njit
def forward(spikes, dt, PI, A, B):
    '''Run forward algorithm to compute alpha = P(Xt = i| o1...ot, pi)
    Gives the probabilities of being in a specific state at each time point
    given the past observations and initial probabilities

    Parameters
    ----------
    spikes : np.array
        N x T matrix of spike counts with each entry ((i,j)) holding the # of
        spikes from neuron i in timebine j
    nStates : int, # of hidden states predicted to have generate the spikes
    dt : float, timebin in seconds (i.e. 0.001)
    PI : np.array
        nStates x 1 vector of initial state probabilities
    A : np.array
        nStates x nStates state transmission matrix with each entry ((i,j))
        giving the probability of transitioning from state i to state j
    B : np.array
        N x nSates rate matrix. Each entry ((i,j)) gives this predicited rate
        of neuron i in state j

    Returns
    -------
    alpha : np.array
        nStates x T matrix of forward probabilites. Each entry (i,j) gives
        P(Xt = i | o1,...,oj, pi)
    norms : np.array
        1 x T vector of norm used to normalize alpha to be a probability
        distribution and also to scale the outputs of the backward algorithm.
        norms(t) = sum(alpha(:,t))
    '''
    nTimeSteps = spikes.shape[1]
    nStates = A.shape[0]

    # For each state, use the the initial state distribution and spike counts
    # to initialize alpha(:,1)
    row = np.array([PI[i] * np.prod(poisson(B[:,i], spikes[:,0], dt))
                    for i in range(nStates)])
    alpha = np.zeros((nStates, nTimeSteps))
    norms = [np.sum(row)]
    alpha[:, 0] = row/norms[0]
    for t in range(1, nTimeSteps):
        tmp = np.array([np.prod(poisson(B[:, s], spikes[:, t], dt)) *
                        np.sum(alpha[:, t-1] * A[:,s])
                        for s in range(nStates)])
        tmp_norm = np.sum(tmp)
        norms.append(tmp_norm)
        tmp = tmp / tmp_norm
        alpha[:, t] = tmp

    return alpha, norms


@njit
def backward(spikes, dt, A, B, norms):
    ''' Runs the backward algorithm to compute beta = P(ot+1...oT | Xt=s)
    Computes the probability of observing all future observations given the
    current state at each time point

    Paramters
    ---------
    spike : np.array, N x T matrix of spike counts
    nStates : int, # of hidden states predicted
    dt : float, timebin size in seconds
    A : np.array, nStates x nStates matrix of transition probabilities
    B : np.array, N x nStates matrix of estimated spike rates for each neuron

    Returns
    -------
    beta : np.array, nStates x T matrix of backward probabilities
    '''
    nTimeSteps = spikes.shape[1]
    nStates = A.shape[0]
    beta = np.zeros((nStates, nTimeSteps))
    beta[:, -1] = 1  # Initialize final beta to 1 for all states
    tStep = list(range(nTimeSteps-1))
    tStep.reverse()
    for t in tStep:
        for s in range(nStates):
            beta[s,t] = np.sum((beta[:, t+1] * A[s,:]) *
                               np.prod(poisson(B[:, s], spikes[:, t+1], dt)))

        beta[:, t] = beta[:, t] / norms[t+1]

    return beta


@njit
def baum_welch(spikes, dt, A, B, alpha, beta):
    nTimeSteps = spikes.shape[1]
    nStates = A.shape[0]
    gamma = np.zeros((nStates, nTimeSteps))
    epsilons = np.zeros((nStates, nStates, nTimeSteps-1))
    for t in range(nTimeSteps):
        if t < nTimeSteps-1:
            gamma[:, t] = (alpha[:, t] * beta[:, t]) / np.sum(alpha[:,t] * beta[:,t])
            epsilonNumerator = np.zeros((nStates, nStates))
            for si in range(nStates):
                for sj in range(nStates):
                    probs = np.prod(poisson(B[:,sj], spikes[:, t+1], dt))
                    epsilonNumerator[si, sj] = (alpha[si, t]*A[si, sj]*
                                                beta[sj, t]*probs)

            epsilons[:, :, t] = epsilonNumerator / np.sum(epsilonNumerator)

    return gamma, epsilons


def isNotConverged(oldPI, oldA, oldB, PI, A, B, thresh=1e-4):
    dPI = np.sqrt(np.sum(np.power(oldPI - PI, 2)))
    dA = np.sqrt(np.sum(np.power(oldA - A, 2)))
    dB = np.sqrt(np.sum(np.power(oldB - B, 2)))
    print('dPI = %f,  dA = %f,  dB = %f' % (dPI, dA, dB))
    if all([x < thresh for x in [dPI, dA, dB]]):
        return False
    else:
        return True

def poisson_viterbi(spikes, dt, PI, A, B):
    '''
    Parameters
    ----------
    spikes : np.array, Neuron X Time matrix of spike counts
    PI : np.array, nStates x 1 vector of initial state probabilities
    A : np.array, nStates X nStates matric of state transition probabilities
    B : np.array, Neuron X States matrix of estimated firing rates
    dt : float, time step size in seconds

    Returns
    -------
    bestPath : np.array
        1 x Time vector of states representing the most likely hidden state
        sequence
    maxPathLogProb : float
        Log probability of the most likely state sequence
    T1 : np.array
        State X Time matrix where each entry (i,j) gives the log probability of
        the the most likely path so far ending in state i that generates
        observations o1,..., oj
    T2: np.array
        State X Time matrix of back pointers where each entry (i,j) gives the
        state x(j-1) on the most likely path so far ending in state i
    '''
    if A.shape[0] != A.shape[1]:
        raise ValueError('Transition matrix is not square')

    nStates = A.shape[0]
    nCells, nTimeSteps = spikes.shape
    T1 = np.zeros((nStates, nTimeSteps))
    T2 = np.zeros((nStates, nTimeSteps))
    T1[:,1] = np.array([np.log(PI[i]) +
                        np.log(np.prod(poisson(B[:,i], spikes[:, 1], dt)))
                        for i in range(nStates)])
    for t, s in it.product(range(1,nTimeSteps), range(nStates)):
        probs = np.log(np.prod(poisson(B[:, s], spikes[:, t], dt)))
        vec2 = T1[:, t-1] + np.log(A[:,s])
        vec1 = vec2 + probs
        T1[s, t] = np.max(vec1)
        idx = np.argmax(vec2)
        T2[s, t] = idx

    bestPathEndState = np.argmax(T1[:, -1])
    maxPathLogProb = T1[idx, -1]
    bestPath = np.zeros((nTimeSteps,))
    bestPath[-1] = bestPathEndState
    tStep = list(range(nTimeSteps-1))
    tStep.reverse()
    for t in tStep:
        bestPath[t] = T2[int(bestPath[t+1]), t+1]

    return bestPath, maxPathLogProb, T1, T2


class TestData(object):
    def __init__(self, params=None):
        if params is None:
            params = TEST_PARAMS.copy()
            param_str = '\t'+'\n\t'.join(repr(params)[1:-1].split(', '))
            print('Using default parameters:\n%s' % param_str)

        self.params = params.copy()
        self.generate()

    def generate(self, params=None):
        print('-'*80)
        print('Simulating Data')
        print('-'*80)
        if params is not None:
            self.params.update(params)

        params = self.params
        param_str = '\t'+'\n\t'.join(repr(params)[1:-1].split(', '))
        print('Parameters:\n%s' % param_str)

        self._generate_ground_truth()
        self._generate_spike_trains()

    def _generate_ground_truth(self):
        print('Generating ground truth state sequence...')
        params = self.params
        nStates = params['n_states']
        seqLen = params['state_seq_length']
        minSeqDur = params['min_state_dur']
        baseline_dur = params['baseline_dur']
        maxFR = params['max_rate']
        nCells = params['n_cells']
        trialTime = params['trial_time']
        nTrials = params['n_trials']
        dt = params['dt']
        nTimeSteps = int(trialTime/dt)

        T = trialTime
        # Figure out a random state sequence and state durations
        stateSeq = np.random.randint(0, nStates, seqLen)
        stateSeq = np.array([0, *np.random.randint(0,nStates, seqLen-1)])
        stateDurs = np.zeros((nTrials, seqLen))
        for i in range(nTrials):
            tmp = np.abs(np.random.rand(seqLen-1))
            tmp = tmp * ((trialTime - baseline_dur) / np.sum(tmp))
            stateDurs[i, :] = np.array([baseline_dur, *tmp])

        # Make vector of state at each time point
        stateVec = np.zeros((nTrials, nTimeSteps))
        for trial in range(nTrials):
            t0 = 0
            for state, dur in zip(stateSeq, stateDurs[trial]):
                tn = int(dur/dt)
                stateVec[trial, t0:t0+tn] = state
                t0 += tn

        # Determine firing rates per neuron per state
        # For each neuron generate a mean firing rate and then draw state
        # firing rates from a normal distribution around that with 10Hz
        # variance
        mean_rates = np.random.rand(nCells, 1) * maxFR
        stateRates = np.zeros((nCells, nStates))
        for i, r in enumerate(mean_rates):
            stateRates[i, :] = np.array([r, *np.abs(np.random.normal(r, .5*r, nStates-1))])

        self.ground_truth = {'state_sequence': stateSeq,
                             'state_durations': stateDurs,
                             'firing_rates': stateRates,
                             'state_vectors': stateVec}

    def _generate_spike_trains(self):
        print('Generating new spike trains...')
        params = self.params
        nCells = params['n_cells']
        trialTime = params['trial_time']
        dt = params['dt']
        nTrials = params['n_trials']
        noise = params['noise']
        nTimeSteps = int(trialTime/dt)

        stateRates = self.ground_truth['firing_rates']
        stateVec = self.ground_truth['state_vectors']


        # Make spike arrays
        # Trial x Neuron x Time
        random_nums = np.abs(np.random.rand(nTrials, nCells, nTimeSteps))
        rate_arr = np.zeros((nTrials, nCells, nTimeSteps))
        for trial, cell, t in  it.product(range(nTrials), range(nCells), range(nTimeSteps)):
            state = int(stateVec[trial, t])
            mean_rate = stateRates[cell, state]
            # draw noisy rates from normal distrib with mean rate from ground
            # truth and width as noise*mean_rate
            r = np.random.normal(mean_rate, mean_rate*noise)
            rate_arr[trial, cell, t] = r

        spikes = (random_nums <= rate_arr *dt).astype('int')

        self.spike_trains = spikes

    def get_spike_trains(self):
        if not hasattr(self, 'spike_trains'):
            self._generate_spike_trains()

        return self.spike_trains

    def get_ground_truth(self):
        if not hasattr(self, 'ground_truth'):
            self._generate_ground_truth()

        return self.ground_truth

    def plot_state_rates(self, ax=None):
        fig, ax = plot_state_rates(self.ground_truth['firing_rates'], ax=ax)
        return fig, ax

    def plot_state_raster(self, ax=None):
        fig, ax = plot_state_raster(self.spike_trains,
                                    self.ground_truth['state_vectors'],
                                    self.params['dt'], ax=ax)
        return fig, ax


class PoissonHMM(object):
    '''Poisson implementation of Hidden Markov Model for fitting spike data
    from a neuronal population
    Author: Roshan Nanu
    Adpated from code by Ben Ballintyn
    '''
    def __init__(self, n_predicted_states, spikes, dt,
                 max_history=500, cost_window=0.25, set_data=None):
        if len(spikes.shape) == 2:
            spikes = np.array([spikes])

        self.data = spikes.astype('int32')
        self.dt = dt
        self._rate_data = None
        self.n_states = n_predicted_states
        self._cost_window = cost_window
        self._max_history = max_history
        self.cost = None
        self.BIC = None
        self.best_sequences = None
        self.max_log_prob = None
        self._rate_data = None
        self.history = None
        self._compute_data_rate_array()
        if set_data is None:
            self.randomize()
        else:
            self.fitted = set_data['fitted']
            self.initial_distribution = set_data['initial_distribution']
            self.transition = set_data['transition']
            self.emission = set_data['emission']
            self.iteration = 0
            self._update_cost()


    def randomize(self):
        nStates = self.n_states
        spikes = self.data
        dt = self.dt
        n_trials, n_cells, n_steps = spikes.shape
        total_time = n_steps * dt

        # Initialize transition matrix with high stay probability
        print('Randomizing')
        diag = np.abs(np.random.normal(.99, .01, nStates))
        A = np.abs(np.random.normal(0.01/(nStates-1), 0.01, (nStates, nStates)))
        for i in range(nStates):
            A[i, i] = diag[i]
            A[i,:] = A[i,:] / np.sum(A[i,:])

        # Initialize rate matrix ("Emission" matrix)
        spike_counts = np.sum(spikes, axis=2) / total_time
        mean_rates = np.mean(spike_counts, axis=0)
        std_rates = np.std(spike_counts, axis=0)
        B = np.vstack([np.abs(np.random.normal(x, y, nStates))
                       for x,y in zip(mean_rates, std_rates)])
        # B = np.random.rand(nCells, nStates)

        self.transition = A
        self.emission = B
        self.initial_distribution = np.ones((nStates,)) / nStates
        self.iteration = 0
        self.fitted = False
        self.history = None
        self._update_cost()

    def fit(self, spikes=None, dt=None, max_iter = 1000, convergence_thresh = 1e-4,
            parallel=False):
        '''using parallels for processing trials actually seems to slow down
        processing (with 15 trials). Might still be useful if there is a very
        large nubmer of trials
        '''
        if self.fitted:
            return

        if spikes is not None:
            spikes = spikes.astype('int32')
            self.data = spikes
            self.dt = dt
        else:
            spikes = self.data
            dt = self.dt

        while (not self.isConverged(convergence_thresh) and
               (self.iteration < max_iter)):
            self._step(spikes, dt, parallel=parallel)
            print('Iter #%i complete.' % self.iteration)

        self.fitted = True

    def _step(self, spikes, dt, parallel=False):
        if len(spikes.shape) == 2:
            spikes = np.array([spikes])

        nTrials, nCells, nTimeSteps = spikes.shape

        A = self.transition
        B = self.emission
        PI = self.initial_distribution
        nStates = self.n_states

        # For multiple trials need to cmpute gamma and epsilon for every trial
        # and then update
        gammas = np.zeros((nTrials, nStates, nTimeSteps))
        epsilons = np.zeros((nTrials, nStates, nStates, nTimeSteps-1))
        if parallel:
            def update(ans):
                idx = ans[0]
                gammas[idx, :, :] = ans[1]
                epsilons[idx, :, :, :] = ans[2]

            def error(ans):
                raise RuntimeError(ans)

            n_cores = mp.cpu_count() - 1
            pool = mp.get_context('spawn').Pool(n_cores)
            for i, trial in enumerate(spikes):
                pool.apply_async(wrap_baum_welch,
                                 (i, trial, dt, PI, A, B),
                                 callback=update, error_callback=error)

            pool.close()
            pool.join()
        else:
            for i, trial in enumerate(spikes):
                _, tmp_gamma, tmp_epsilons = wrap_baum_welch(i, trial, dt, PI, A, B)
                gammas[i, :, :] = tmp_gamma
                epsilons[i, :, :, :] = tmp_epsilons

        # Store old parameters for convergence check
        self.update_history()

        PI, A, B = compute_new_matrices(spikes, dt, gammas, epsilons)
        self.transition = A
        self.emission = B
        self.initial_distribution = PI
        self.iteration += 1
        self._update_cost()

    def update_history(self):
        A = self.transition
        B = self.emission
        PI = self.initial_distribution
        BIC = self.BIC
        cost = self.cost
        iteration = self.iteration

        if self.history is None:
            self.history = {}
            self.history['A'] = [A]
            self.history['B'] = [B]
            self.history['PI'] = [PI]
            self.history['iterations'] = [iteration]
            self.history['cost'] = [cost]
            self.history['BIC'] = [BIC]
        else:
            if iteration in self.history['iterations']:
                return self.history

            self.history['A'].append(A)
            self.history['B'].append(B)
            self.history['PI'].append(PI)
            self.history['iterations'].append(iteration)
            self.history['cost'].append(cost)
            self.history['BIC'].append(BIC)

        if len(self.history['iterations']) > self._max_history:
            nmax = self._max_history
            for k, v in self.history.items():
                self.history[k] = v[-nmax:]

        return self.history

    def isConverged(self, thresh):
        if self.history is None:
            return False

        oldPI = self.history['PI'][-1]
        oldA = self.history['A'][-1]
        oldB = self.history['B'][-1]
        oldCost = self.history['cost'][-1]

        PI = self.initial_distribution
        A = self.transition
        B = self.emission
        cost = self.cost

        dPI = np.sqrt(np.sum(np.power(oldPI - PI, 2)))
        dA = np.sqrt(np.sum(np.power(oldA - A, 2)))
        dB = np.sqrt(np.sum(np.power(oldB - B, 2)))
        dCost = cost-oldCost
        print('dPI = %f,  dA = %f,  dB = %f, dCost = %f, cost = %f'
              % (dPI, dA, dB, dCost, cost))

        # TODO: determine if this is reasonable
        # dB takes waaaaay longer to converge than the rest, i'm going to
        # double the thresh just for that
        dB = dB/2

        if not all([x < thresh for x in [dPI, dA, dB]]):
            return False
        else:
            return True

    def get_best_paths(self):
        if self.best_sequences is not None:
            return self.best_sequences, self.max_log_prob

        spikes = self.data
        dt = self.dt
        PI = self.initial_distribution
        A = self.transition
        B = self.emission

        bestPaths, pathProbs = compute_best_paths(spikes, dt, PI, A, B)
        self.best_sequences = bestPaths
        self.max_log_prob = np.max(pathProbs)
        return bestPaths, self.max_log_prob

    def get_forward_probabilities(self):
        alphas = []
        for trial in self.data:
            tmp, _ = forward(trial, self.dt, self.initial_distribution,
                             self.transition, self.emission)
            alphas.append(tmp)

        return np.array(alphas)

    def get_backward_probabilities(self):
        PI = self.initial_distribution
        A = self.transition
        B = self.emission
        betas = []
        for trial in self.data:
            alpha, norms = forward(trial, self.dt, PI, A, B)
            tmp = backward(trial, self.dt, A, B, norms)
            betas.append(tmp)

        return np.array(betas)

    def get_gamma_probabilities(self):
        PI = self.initial_distribution
        A = self.transition
        B = self.emission
        gammas = []
        for i, trial in enumerate(self.data):
            _, tmp, _ = wrap_baum_welch(i, trial, self.dt, PI, A, B)
            gammas.append(tmp)

        return np.array(gammas)

    def get_BIC(self):
        if self.BIC is not None:
            return self.BIC

        PI = self.initial_distribution
        A = self.transition
        B = self.emission
        BIC, bestPaths, max_log_prob = compute_BIC(self.data, self.dt, PI, A, B)
        self.BIC = BIC
        self.best_sequences = bestPaths
        self.max_log_prob = max_log_prob
        return BIC, bestPaths, max_log_prob

    def _compute_data_rate_array(self):
        if self._rate_data is not None:
            return self._rate_data

        win_size = self._cost_window
        rate_array = convert_spikes_to_rates(self.data, self.dt,
                                             win_size, step_size=win_size)
        self._rate_data = rate_array

    def _compute_predicted_rate_array(self):
        B = self.emission
        bestPaths, _ = self.get_best_paths()
        bestPaths = bestPaths.astype('int32')
        win_size = self._cost_window
        dt = self.dt
        mean_rates = generate_rate_array_from_state_seq(bestPaths, B,
                                                        dt, win_size,
                                                        step_size=win_size)
        return mean_rates

    def set_to_lowest_cost(self):
        hist = self.update_history()
        idx = np.argmin(hist['cost'])
        iteration = hist['iterations'][idx]
        self.roll_back(iteration)

    def set_to_lowest_BIC(self):
        hist = self.update_history()
        idx = np.argmin(hist['BIC'])
        iteration = hist['iterations'][idx]
        self.roll_back(iteration)

    def find_best_in_history(self):
        hist = self.update_history()
        PIs = hist['PI']
        As = hist['A']
        Bs = hist['B']
        iters = hist['iterations']
        BICs = hist['BIC']
        idx = np.argmin(BICs)
        out = {'PI': PIs[idx], 'A': As[idx], 'B': Bs[idx]}
        return out, iters[idx], BICs

    def roll_back(self, iteration):
        hist = self.history
        try:
            idx = hist['iterations'].index(iteration)
        except ValueError:
            raise ValueError('Iteration %i not found in history' % iteration)

        self.initial_distribution = hist['PI'][idx]
        self.transition = hist['A'][idx]
        self.emission = hist['B'][idx]
        self.iteration = iteration
        self._update_cost()

    def set_matrices(self, new_mats):
        self.initial_distribution = new_mats['PI']
        self.transition = new_mats['A']
        self.emission = new_mats['B']
        if 'iteration' in new_mats:
            self.iteration = new_mats['iteration']
        self._update_cost()

    def set_data(self, new_data, dt):
        self.data = new_data
        self.dt = dt
        self._compute_data_rate_array()
        self._update_cost()

    def plot_state_raster(self, ax=None, state_map=None):
        bestPaths, _ = self.get_best_paths()
        if state_map is not None:
            bestPaths = convert_path_state_numbers(bestPaths, state_map)

        data = self.data
        fig, ax = plot_state_raster(data, bestPaths, self.dt, ax=ax)
        return fig, ax

    def plot_state_rates(self, ax=None, state_map=None):
        rates = self.emission
        if state_map:
            idx = [state_map[k] for k in sorted(state_map.keys())]
            maxState = np.max(list(state_map.values()))
            newRates = np.zeros((rates.shape[0], maxState+1))
            for k, v in state_map.items():
                newRates[:, v] = rates[:, k]

            rates = newRates

        fig, ax = plot_state_rates(rates, ax=ax)
        return fig, ax

    def reorder_states(self, state_map):
        idx = [state_map[k] for k in sorted(state_map.keys())]
        PI = self.initial_distribution
        A = self.transition
        B = self.emission

        newPI = PI[idx]
        newB = B[:, idx]
        newA = np.zeros(A.shape)
        for x in range(A.shape[0]):
            for y in range(A.shape[1]):
                i = state_map[x]
                j = state_map[y]
                newA[i,j] = A[x,y]

        self.initial_distribution = newPI
        self.transition = newA
        self.emission = newB
        self._update_cost()

    def _update_cost(self):
        spikes = self.data
        win_size = self._cost_window
        dt = self.dt
        PI = self.initial_distribution
        A  = self.transition
        B  = self.emission
        true_rates = self._rate_data
        cost, BIC, bestPaths, maxLogProb = compute_hmm_cost(spikes, dt, PI, A, B,
                                                            win_size=win_size,
                                                            true_rates=true_rates)
        self.cost = cost
        self.BIC = BIC
        self.best_sequences = bestPaths
        self.max_log_prob = maxLogProb

    def get_cost(self):
        if self.cost is None:
            self._update_cost()

        return self.cost


def compute_BIC(spikes, dt, PI, A, B):
    bestPaths, maxLogProb = compute_best_paths(spikes, dt, PI, A, B)
    maxLogProb = np.max(maxLogProb)

    nParams = (A.shape[0]*(A.shape[1]-1) +
               (PI.shape[0]-1) +
               B.shape[0]*(B.shape[1]-1))

    nPts = spikes.shape[-1]
    BIC = -2 * maxLogProb + nParams * np.log(nPts)
    return BIC, bestPaths, maxLogProb


def compute_hmm_cost(spikes, dt, PI, A, B, win_size=0.25, true_rates=None):
    if true_rates is None:
        true_rates = convert_spikes_to_rates(spikes, dt, win_size)

    BIC, bestPaths, maxLogProb = compute_BIC(spikes, dt, PI, A, B)
    hmm_rates = generate_rate_array_from_state_seq(bestPaths, B, dt, win_size,
                                                   step_size=win_size)
    RMSE = compute_rate_rmse(true_rates, hmm_rates)
    return RMSE, BIC, bestPaths, maxLogProb


def compute_best_paths(spikes, dt, PI, A, B):
    if len(spikes.shape) == 2:
        spikes = np.array([spikes])

    nTrials, nCells, nTimeSteps = spikes.shape
    bestPaths = np.zeros((nTrials, nTimeSteps))-1
    pathProbs = np.zeros((nTrials,))

    for i, trial in enumerate(spikes):
        bestPaths[i,:], pathProbs[i], _, _ = poisson_viterbi(trial, dt, PI,
                                                             A, B)
    return bestPaths, pathProbs


def convert_path_state_numbers(paths, state_map):
    newPaths = np.zeros(paths.shape)
    for k,v in state_map.items():
        idx = np.where(paths == k)
        newPaths[idx] = v

    return newPaths


@njit
def compute_rate_rmse(rates1, rates2):
    # Compute RMSE per trial
    # Mean over trials
    n_trials, n_cells, n_steps = rates1.shape
    RMSE = np.zeros((n_trials,))
    for i in range(n_trials):
        t1 = rates1[i, :, :]
        t2 = rates2[i, :, :]
        # Compute RMSE from euclidean distances at each time point
        distances = np.zeros((n_steps,))
        for j in range(n_steps):
            distances[j] =  euclidean(t1[:,j], t2[:,j])

        RMSE[i] = np.sqrt(np.mean(np.power(distances,2)))

    return np.mean(RMSE)


def plot_state_raster(data, stateVec, dt, ax=None):
    if len(data.shape) == 2:
        data = np.array([data])

    nTrials, nCells, nTimeSteps = data.shape
    nStates = np.max(stateVec) +1

    gradient = np.array([0 + i/(nCells+1) for i in range(nCells)])
    time = np.arange(0, nTimeSteps * dt * 1000, dt * 1000)
    colors = [plt.cm.jet(i) for i in np.linspace(0,1,nStates)]

    if ax is None:
        fig, ax = plt.subplots()
    else:
        fig = ax.figure

    for trial, spikes in enumerate(data):
        path = stateVec[trial]
        for i, row in enumerate(spikes):
            idx = np.where(row == 1)[0]
            ax.scatter(time[idx], row[idx]*trial + gradient[i],
                       c=[colors[int(x)] for x in path[idx]], marker='|')

    return fig, ax

def plot_state_rates(rates, ax=None):
    if ax is None:
        fig, ax = plt.subplots()
    else:
        fig = ax.figure

    nCells, nStates = rates.shape
    df = pd.DataFrame(rates, columns=['state %i' % i for i in range(nStates)])
    df['cell'] = ['cell %i' % i for i in df.index]
    df = pd.melt(df, 'cell', ['state %i' % i for i in range(nStates)], 'state', 'rate')
    sns.barplot(x='state', y='rate', hue='cell', data=df,
                palette='muted', ax=ax)

    return fig, ax

def compare_hmm_to_truth(truth_dat, hmm, state_map=None):
    if state_map is None:
        state_map = match_states(truth_dat.ground_truth['firing_rates'], hmm.emission)

    fig, ax = plt.subplots(nrows=2, ncols=2, figsize=(15,10))
    truth_dat.plot_state_raster(ax=ax[0,0])
    truth_dat.plot_state_rates(ax=ax[1,0])
    hmm.plot_state_raster(ax=ax[0,1], state_map=state_map)
    hmm.plot_state_rates(ax=ax[1,1], state_map=state_map)
    ax[0,0].set_title('Ground Truth States')
    ax[0,1].set_title('HMM Best Decoded States')
    ax[1,0].get_legend().remove()
    ax[1,1].legend(loc='upper center', bbox_to_anchor=[-0.4, -0.6, 0.5, 0.5], ncol=5)

    # Compute edit distances, histogram, return mean and median % correct
    truePaths = truth_dat.ground_truth['state_vectors']
    bestPaths, _ = hmm.get_best_paths()
    if state_map is not None:
        bestPaths = convert_path_state_numbers(bestPaths, state_map)

    edit_distances = np.zeros((truePaths.shape[0],))
    pool = mp.Pool(mp.cpu_count())
    def update(ans):
        edit_distances[ans[0]] = ans[1]

    print('Computing edit distances...')
    for i, x in enumerate(zip(truePaths, bestPaths)):
        pool.apply_async(levenshtein_mp, (i, *x), callback=update)

    pool.close()
    pool.join()
    print('Done!')

    nPts = truePaths.shape[1]
    mean_correct = 100*(nPts - np.mean(edit_distances)) / nPts
    median_correct = 100*(nPts - np.median(edit_distances)) / nPts

    # Plot:
    #   - edit distance histogram
    #   - side-by-side trial comparison
    h = 0.25
    dt = hmm.dt
    time = np.arange(0, nPts * (dt*1000), dt*1000)  # time in ms
    fig2, ax2 = plt.subplots(ncols=2, figsize=(15,10))
    ax2[0].hist(100*(nPts-edit_distances)/nPts)
    ax2[0].set_xlabel('Percent Correct')
    ax2[0].set_ylabel('Trial Count')
    ax2[0].set_title('Percent Correct based on edit distance\n'
                     'Mean Correct: %0.1f%%, Median: %0.1f%%'
                     % (mean_correct, median_correct))

    maxState = int(np.max((bestPaths, truePaths)))
    colors = [plt.cm.Paired(x) for x in np.linspace(0, 1, (maxState+1)*2)]
    trueCol = [colors[x] for x in np.arange(0, (maxState+1)*2, 2)]
    hmmCol = [colors[x] for x in np.arange(1, (maxState+1)*2, 2)]
    leg = {}
    leg['hmm'] = {k: None for k in np.unique((bestPaths, truePaths))}
    leg['truth'] = {k: None for k in np.unique((bestPaths, truePaths))}
    for i, x in enumerate(zip(truePaths, bestPaths)):
        y = x[0]
        z = x[1]
        t = 0
        while(t  < nPts):
            s = int(y[t])
            next_t = np.where(y[t:] != s)[0]
            if len(next_t) == 0:
                next_t = nPts - t
            else:
                next_t = next_t[0]

            t_start = time[t]
            t_end = time[t+next_t-1]
            tmp = ax2[1].fill_between([t_start, t_end], [i, i], [i+h, i+h], color=trueCol[s])
            if leg['truth'][s] is None:
                leg['truth'][s] = tmp

            t += next_t

        t = 0
        while(t  < nPts):
            s = int(z[t])
            next_t = np.where(z[t:] != s)[0]
            if len(next_t) == 0:
                next_t = nPts - t
            else:
                next_t = next_t[0]

            t_start = time[t]
            t_end = time[t+next_t-1]
            tmp = ax2[1].fill_between([t_start, t_end], [i, i], [i-h, i-h], color=hmmCol[s])
            if leg['hmm'][s] is None:
                leg['hmm'][s] = tmp

            t += next_t

        # Write % correct next to line
        t_str = '%0.1f%%' % (100 * (nPts - edit_distances[i])/nPts)
        ax2[1].text(nPts+5, i-h, t_str)

    ax2[1].set_xlim((0, nPts+int(nPts/3)))
    ax2[1].set_xlabel('Time (ms)')
    ax2[1].set_title('State Sequences')
    handles = list(leg['truth'].values()) + list(leg['hmm'].values())
    labels = (['True State %i' % i for i in leg['truth'].keys()] +
              ['HMM State %i' % i for i in leg['hmm'].keys()])
    ax2[1].legend(handles, labels, shadow=True,
                  bbox_to_anchor=(0.78, 0.5, 0.5, 0.5))

    fig.show()
    fig2.show()
    return fig, ax, fig2, ax2


def wrap_baum_welch(trial_id, trial_dat, dt, PI, A, B):
    alpha, norms = forward(trial_dat, dt, PI, A, B)
    beta = backward(trial_dat, dt, A, B, norms)
    tmp_gamma, tmp_epsilons = baum_welch(trial_dat, dt, A, B, alpha, beta)
    return trial_id, tmp_gamma, tmp_epsilons


def compute_new_matrices(spikes, dt, gammas, epsilons):
    nTrials, nCells, nTimeSteps = spikes.shape
    minFR = 1/(nTimeSteps*dt)

    PI = np.sum(gammas, axis=0)[:,1] / nTrials
    Anumer = np.sum(np.sum(epsilons, axis=3), axis=0)
    Adenom = np.sum(np.sum(gammas[:,:,:-1], axis=2), axis=0)
    A = Anumer/Adenom
    A = A/np.sum(A, axis=1)
    Bnumer = np.sum(np.array([np.matmul(tmp_y, tmp_g.T)
                              for tmp_y, tmp_g in zip(spikes, gammas)]),
                    axis=0)
    Bdenom =  np.sum(np.sum(gammas, axis=2), axis=0)
    B = (Bnumer / Bdenom)/dt
    idx = np.where(B < minFR)[0]
    B[idx] = minFR

    return PI, A, B


def match_states(rates1, rates2):
    '''Takes 2 Cell X State firing rate matrices and determines which states
    are most similar. Returns dict mapping rates2 states to rates1 states
    '''
    distances = np.zeros((rates1.shape[1], rates2.shape[1]))
    for x, y in it.product(range(rates1.shape[1]), range(rates2.shape[1])):
        tmp = euclidean(rates1[:, x], rates2[:, y])
        distances[x, y] = tmp

    states = list(range(rates2.shape[1]))
    out = {}
    for i in range(rates2.shape[1]):
        s = np.argmin(distances[:,i])
        r = np.argmin(distances[s, :])
        if r == i and s in states:
            out[i] = s
            idx = np.where(states == s)[0]
            states.pop(int(idx))

    for i in range(rates2.shape[1]):
        if i not in out:
            s = np.argmin(distances[states, i])
            out[i] = states[s]

    return out


@njit
def levenshtein(seq1, seq2):
    ''' Computes edit distance between 2 sequences
    '''
    size_x = len(seq1) + 1
    size_y = len(seq2) + 1
    matrix = np.zeros ((size_x, size_y))
    for x in range(size_x):
        matrix [x, 0] = x

    for y in range(size_y):
        matrix [0, y] = y

    for x in range(1, size_x):
        for y in range(1, size_y):
            if seq1[x-1] == seq2[y-1]:
                matrix [x,y] = min(matrix[x-1, y] + 1, matrix[x-1, y-1],
                                   matrix[x, y-1] + 1)
            else:
                matrix [x,y] = min(matrix[x-1,y] + 1, matrix[x-1,y-1] + 1,
                                   matrix[x,y-1] + 1)

    return (matrix[size_x - 1, size_y - 1])

@njit
def levenshtein_mp(i, seq1, seq2):
    ''' Computes edit distance between 2 sequences
    '''
    size_x = len(seq1) + 1
    size_y = len(seq2) + 1
    matrix = np.zeros ((size_x, size_y))
    for x in range(size_x):
        matrix [x, 0] = x

    for y in range(size_y):
        matrix [0, y] = y

    for x in range(1, size_x):
        for y in range(1, size_y):
            if seq1[x-1] == seq2[y-1]:
                matrix [x,y] = min(matrix[x-1, y] + 1, matrix[x-1, y-1],
                                   matrix[x, y-1] + 1)
            else:
                matrix [x,y] = min(matrix[x-1,y] + 1, matrix[x-1,y-1] + 1,
                                   matrix[x,y-1] + 1)

    return i, matrix[size_x - 1, size_y - 1]


def fit_hmm_mp(nStates, spikes, dt, max_iter=1000, thresh=1e-4):
    hmm = PoissonHMM(nStates, spikes, dt)
    hmm.fit(max_iter=max_iter, convergence_thresh=thresh, parallel=False)
    return hmm


@njit
def convert_spikes_to_rates(spikes, dt, win_size, step_size=None):
    if step_size is None:
        step_size = win_size

    n_trials, n_cells, n_steps = spikes.shape
    n_pts = int(win_size/dt)
    n_step_pts = int(step_size/dt)
    win_starts = np.arange(0, n_steps, n_step_pts)
    out = np.zeros((n_trials, n_cells, len(win_starts)))
    for i, w in enumerate(win_starts):
        out[:, :, i] = np.sum(spikes[:, :, w:w+n_pts], axis=2) / win_size

    return out


@njit
def generate_rate_array_from_state_seq(bestPaths, B, dt, win_size,
                                       step_size=None):
    if not step_size:
        step_size = win_size

    n_trials, n_steps = bestPaths.shape
    n_cells, n_states = B.shape
    rates = np.zeros((n_trials, n_cells, n_steps))
    for j in range(n_trials):
        seq = bestPaths[j, :].astype(np.int64)
        rates[j, :, :] = B[:, seq]

    n_pts = int(win_size / dt)
    n_step_pts = int(step_size/dt)
    win_starts = np.arange(0, n_steps, n_step_pts)
    mean_rates = np.zeros((n_trials, n_cells, len(win_starts)))
    for i, w in enumerate(win_starts):
        mean_rates[:, :, i] = np.sum(rates[:, : , w:w+n_pts], axis=2) / n_pts

    return mean_rates


@njit
def euclidean(a, b):
    c = np.power(a-b,2)
    return np.sqrt(np.sum(c))


def rebin_spike_array(spikes, dt, time, new_dt):
    if spikes.ndim == 2:
        spikes = np.expand_dims(spikes,0)

    n_trials, n_cells, n_steps = spikes.shape
    n_bins = int(new_dt/dt)
    new_time = np.arange(time[0], time[-1], new_dt)
    new_spikes = np.zeros((n_trials, n_cells, len(new_time)))
    for i, w in enumerate(new_time):
        idx = np.where((time >= w) & (time < w+new_dt))[0]
        new_spikes[:,:,i] = np.sum(spikes[:,:,idx], axis=-1)

    return new_spikes, new_time


HMM_PARAMS = {'unit_type': 'single', 'dt': 0.001, 'threshold': 1e-4, 'max_iter': 1000,
              'time_start': 0, 'time_end': 2000, 'n_repeats': 3, 'n_states': 3}


class HmmHandler(object):
    def __init__(self, dat, params=None, save_dir=None):
        '''Takes a blechpy dataset object and fits HMMs for each tastant

        Parameters
        ----------
        dat: blechpy.dataset
        params: dict or list of dicts
            each dict must have fields:
                time_window: list of int, time window to cut around stimuli in ms
                convergence_thresh: float
                max_iter: int
                n_repeats: int
                unit_type: str, {'single', 'pyramidal', 'interneuron', 'all'}
                bin_size: time bin for spike array when fitting in seconds
                n_states: predicted number of states to fit
        '''
        if isinstance(params, dict):
            params = [params]

        if isinstance(dat, str):
            dat = load_dataset(dat)
            if dat is None:
                raise FileNotFoundError('No dataset.p file found given directory')

        if save_dir is None:
            save_dir = os.path.join(dat.root_dir,
                                    '%s_analysis' % dat.data_name)

        self._dataset = dat
        self.root_dir = dat.root_dir
        self.save_dir = save_dir
        self.h5_file = os.path.join(save_dir, '%s_HMM_Analysis.hdf5' % dat.data_name)
        dim = dat.dig_in_mapping.query('exclude==False')
        tastes = dim['name'].tolist()
        if params is None:
            # Load params and fitted models
            self.load_data()
        else:
            self.init_params(params)

        self.params = params

        if not os.path.isdir(save_dir):
            os.makedirs(save_dir)

        self.plot_dir = os.path.join(save_dir, 'HMM_Plots')
        if not os.path.isdir(self.plot_dir):
            os.makedirs(self.plot_dir)

        self._setup_hdf5()

    def init_params(self, params):
        dat = self._dataset
        dim = dat.dig_in_mapping.query('exclude == False')
        tastes = dim['name'].tolist()
        dim = dim.set_index('name')
        if not hasattr(dat, 'dig_in_trials'):
            dat.create_trial_list()

        trials = dat.dig_in_trials
        data_params = []
        fit_objs = []
        fit_params = []
        for i, X in enumerate(it.product(params,tastes)):
            p = X[0].copy()
            t = X[1]
            p['hmm_id'] = i
            p['taste'] = t
            p['channel'] = dim.loc[t, 'channel']
            unit_names = query_units(dat, p['unit_type'])
            p['n_cells'] = len(unit_names)
            p['n_trials'] = len(trials.query('name == @t'))

            data_params.append(p)
            # Make fit object for each repeat
            # During fitting compare HMM as ones with the same ID are returned
            for i in range(p['n_repeats']):
                hmmFit = HMMFit(dat.root_dir, p)
                fit_objs.append(hmmFit)
                fit_params.append(p)

        self._fit_objects = fit_objs
        self._data_params = data_params
        self._fit_params = fit_params
        self._fitted_models = dict.fromkeys([x['hmm_id'] for x in data_params])
        self.write_overview_to_hdf5()

    def load_data(self):
        h5_file = self.h5_file
        if not os.path.isfile(h5_file):
            raise ValueError('No params to load')

        rec_dir = self._dataset.root_dir
        params = []
        fit_objs = []
        fit_params = []
        fitted_models = {}
        with tables.open_file(h5_file, 'r') as hf5:
            table = hf5.root.data_overview
            col_names = table.colnames
            for row in table[:]:
                p = {}
                for k in col_names:
                    if table.coltypes[k] == 'string':
                        p[k] = row[k].decode('utf-8')
                    else:
                        p[k] = row[k]

                params.append(p)
                for i in range(p['n_repeats']):
                    hmmFit = HMMFit(rec_dir, p)
                    fit_objs.append(hmmFit)
                    fit_params.append(p)

        for p in params:
            hmm_id = p['hmm_id']
            fitted_models[hmm_id] = read_hmm_from_hdf5(h5_file, hmm_id, rec_dir)

        self._data_params = params
        self._fit_objects = fit_objs
        self._fitted_models = fitted_models
        self._fit_params = fit_params


    def write_overview_to_hdf5(self):
        params = self._data_params
        h5_file = self.h5_file
        if hasattr(self, '_fitted_models'):
            models = self._fitted_models
        else:
            models = dict.fromkeys([x['hmm_id']
                                    for x in data_params])
            self._fitted_models = models


        if not os.path.isfile(h5_file):
            self._setup_hdf5()

        print('Writing data overview table to hdf5...')
        with tables.open_file(h5_file, 'a') as hf5:
            table = hf5.root.data_overview
            # Clear old table
            table.remove_rows(start=0)

            # Add new rows
            for p in params:
                row = table.row
                for k, v in p.items():
                    row[k] = v

                if models[p['hmm_id']] is not None:
                   hmm = models[p['hmm_id']]
                   row['n_iterations'] =  hmm.iterations
                   row['BIC'] = hmm.BIC
                   row['cost'] = hmm.cost
                   row['converged'] = hmm.isConverged(p['threshold'])
                   row['fitted'] = hmm.fitted

                row.append()

            table.flush()
            hf5.flush()

        print('Done!')

    def _setup_hdf5(self):
        h5_file = self.h5_file

        with tables.open_file(h5_file, 'a') as hf5:
            # Taste -> PI, A, B, BIC, state_sequences, nStates, nCells, dt
            if not 'data_overview' in hf5.root:
                # Contains taste, channel, n_cells, n_trials, n_states, dt, BIC
                table = hf5.create_table('/', 'data_overview', HMMInfoParticle,
                                         'Basic info for each digital_input')
                table.flush()


            if hasattr(self, '_data_params') and self._data_params is not None:
                for p in self._data_params:
                    hmm_str = 'hmm_%i' % p['hmm_id']
                    if hmm_str not in hf5.root:
                        hf5.create_group('/', hmm_str, 'Data for HMM #%i' % p['hmm_id'])

            hf5.flush()

    def run(self, parallel=True):
        self.write_overview_to_hdf5()
        h5_file = self.h5_file
        rec_dir = self._dataset.root_dir
        fit_objs = self._fit_objects
        fit_params = self._fit_params
        self._fitted_models = dict.fromkeys([x['hmm_id'] for x in self._data_params])
        errors = []

        # def update(ans):
        #     hmm_id = ans[0]
        #     hmm = ans[1]
        #     if self._fitted_models[hmm_id] is not None:
        #         best_hmm = pick_best_hmm([HMMs[hmm_id], hmm])
        #         self._fitted_models[hmm_id] = best_hmm
        #         write_hmm_to_hdf5(h5_file, hmm_id, best_hmm)
        #         del hmm, best_hmm
        #     else:
        #         # Check history for lowest BIC
        #         self._fitted_models[hmm_id] = hmm.set_to_lowest_BIC()
        #         write_hmm_to_hdf5(h5_file, hmm_id, hmm)
        #         del hmm

        # def error_call(e):
        #     errors.append(e)

        # if parallel:
        #     n_cpu = np.min((mp.cpu_count()-1, len(fit_objs)))
        #     if n_cpu > 10:
        #         pool = mp.get_context('spawn').Pool(n_cpu)
        #     else:
        #         pool = mp.Pool(n_cpu)

        #     for f in fit_objs:
        #         pool.apply_async(f.run, callback=update, error_callback=error_call)

        #     pool.close()
        #     pool.join()
        # else:
        #     for f in fit_objs:
        #         try:
        #             ans = f.run()
        #             update(ans)
        #         except Exception as e:
        #             raise Exception(e)
        #             error_call(e)
        print('Running fittings')
        if parallel:
            n_cpu = np.min((mp.cpu_count()-1, len(fit_params)))
        else:
            n_cpu = 1

        results = Parallel(n_jobs=n_cpu, verbose=20)(delayed(hmm_fit_mp)(rec_dir, p) for p in fit_params)
        for hmm_id, hmm in zip(*results):
            if self._fitted_models[hmm_id] is None:
                self._fitted_models[hmm_id] = hmm
            else:
                new_hmm = pick_best_hmm([hmm, self._fitted_models[hmm_id]])
                self._fitted_models[hmm_id] = new_hmm

        self.write_overview_to_hdf5()
        self.save_fitted_models()
        # if len(errors) > 0:
        #     print('Encountered errors: ')
        #     for e in errors:
        #         print(e)

    def save_fitted_models(self):
        models = self._fitted_models
        for k, v in models.items():
            write_hmm_to_hdf5(self.h5_file, k, v)
            plot_dir = os.path.join(self.plot_dir, 'HMM_%i' % k)
            if not os.path.isdir(plot_dir):
                os.makedirs(plot_dir)

            ids = [x['hmm_id'] for x in self._data_params]
            idx = ids.index(k)
            params = self._data_params[idx]
            time_window = [params['time_start'], params['time_end']]
            hmmplt.plot_hmm_figures(v, time_window, save_dir=plot_dir)


@memory.cache
def get_hmm_spike_data(rec_dir, unit_type, channel, time_start=None, time_end=None, dt=None):
    units = query_units(rec_dir, unit_type)
    time, spike_array = h5io.get_spike_data(rec_dir, units, channel)
    curr_dt = np.unique(np.diff(time))[0] / 1000
    if dt is not None and curr_dt < dt:
        spike_array, time = rebin_spike_array(spike_array, curr_dt, time, dt)
    elif dt is not None and curr_dt > dt:
        raise ValueError('Cannot upsample spike array from %f ms '
                         'bins to %f ms bins' % (dt, curr_dt))
    else:
        dt = curr_dt

    if time_start and time_end:
        idx = np.where((time >= time_start) & (time < time_end))[0]
        time = time[idx]
        spike_array = spike_array[:, :, idx]

    return spike_array.astype('int32'), dt, time


def read_hmm_from_hdf5(h5_file, hmm_id, rec_dir):
    print('Loading HMM %i for hdf5' % hmm_id)
    with tables.open_file(h5_file, 'r') as hf5:
        h_str = 'hmm_%i' % hmm_id
        if h_str not in hf5.root or len(hf5.list_nodes('/'+h_str)) == 0:
            return None

        table = hf5.root.data_overview
        row = list(table.where('hmm_id == id', condvars={'id':hmm_id}))
        if len(row) == 0:
            raise ValueError('Parameters not found for hmm %i' % hmm_id)
        elif len(row) > 1:
            raise ValueError('Multiple parameters found for hmm %i' % hmm_id)

        row = row[0]
        units = query_units(rec_dir, row['unit_type'].decode('utf-8'))
        spikes, dt, time = get_spike_data(rec_dir, units, row['channel'],
                                          dt=row['dt'],
                                          time_start=row['time_start'],
                                          time_end=row['time_end'])
        tmp = hf5.root[h_str]
        mats = {'initial_distribution': tmp['initial_distribution'][:],
                'transition': tmp['transition'][:],
                'emission': tmp['emission'][:],
                'fitted': row['fitted']}
        hmm = PoissonHMM(row['n_states'], spikes, dt, set_data=mats)

    return hmm


def write_hmm_to_hdf5(h5_file, hmm_id, hmm):
    h_str = 'hmm_%i' % hmm_id
    print('Writing HMM %i to hdf5 file...' % hmm_id)
    with tables.open_file(h5_file, 'a') as hf5:
        if h_str in hf5.root:
            hf5.remove_node('/', h_str, recursive=True)

        hf5.create_group('/', h_str, 'Data for HMM #%i' % hmm_id)
        hf5.create_array('/'+h_str, 'initial_distribution',
                         hmm.initial_distribution)
        hf5.create_array('/'+h_str, 'transition', hmm.transition)
        hf5.create_array('/'+h_str, 'emission', hmm.emission)

        best_paths, _ = hmm.get_best_paths()
        hf5.create_array('/'+h_str, 'state_sequences', best_paths)


def query_units(dat, unit_type):
    '''Returns the units names of all units in the dataset that match unit_type

    Parameters
    ----------
    dat : blechpy.dataset or str
        Can either be a dataset object or the str path to the recording
        directory containing that data .h5 object
    unit_type : str, {'single', 'pyramidal', 'interneuron', 'all'}
        determines whether to return 'single' units, 'pyramidal' (regular
        spiking single) units, 'interneuron' (fast spiking single) units, or
        'all' units

    Returns
    -------
        list of str : unit_names
    '''
    if isinstance(dat, str):
        units = h5io.get_unit_table(dat)
    else:
        units = dat.get_unit_table()

    u_str = unit_type.lower()
    q_str = ''
    if u_str == 'single':
        q_str = 'single_unit == True'
    elif u_str == 'pyramidal':
        q_str = 'single_unit == True and regular_spiking == True'
    elif u_str == 'interneuron':
        q_str = 'single_unit == True and fast_spiking == True'
    elif u_str == 'all':
        return units['unit_name'].tolist()
    else:
        raise ValueError('Invalid unit_type %s. Must be '
                         'single, pyramidal, interneuron or all' % u_str)

    return units.query(q_str)['unit_name'].tolist()


    # Parameters
    # hmm_id
    # taste
    # channel
    # n_cells
    # unit_type
    # n_trials
    # dt
    # threshold
    # time_start
    # time_end
    # n_repeats
    # n_states
    # n_iterations
    # BIC
    # cost
    # converged
    # fitted
    #
    # Extras: unit_names, rec_dir


class HMMFit(object):
    def __init__(self, rec_dir, params):
        self._rec_dir = rec_dir
        self._params = params

    def run(self, parallel=False):
        params = self._params
        spikes, dt, time = self.get_spike_data()
        hmm = PoissonHMM(params['n_states'], spikes, dt)
        hmm.fit(max_iter=params['max_iter'],
                convergence_thresh=params['threshold'],
                parallel=parallel)
        del spikes, dt, time
        return params['hmm_id'], hmm

    def get_spike_data(self):
        p = self._params
        units = query_units(self._rec_dir, p['unit_type'])
        # Get stored spike array, time is in ms, dt is usually 1 ms
        spike_array, dt, time = get_spike_data(self._rec_dir, units,
                                               p['channel'], dt=p['dt'],
                                               time_start=p['time_start'],
                                               time_end=p['time_end'])
        return spike_array, dt, time

def hmm_fit_mp(rec_dir, params):
    hmm_id = params['hmm_id']
    n_states = params['n_states']
    dt = params['dt']
    time_start = params['time_start']
    time_end = params['time_end']
    max_iter = params['max_iter']
    threshold = params['threshold']
    unit_type = params['unit_type']
    channel = params['channel']
    spikes, dt, time = get_hmm_spike_data(rec_dir, unit_type, channel,
                                          time_start=time_start,
                                          time_end=time_end, dt = dt)
    hmm = PoissonHMM(params['n_states'], spikes, dt)
    hmm.fit(max_iter=max_iter, convergence_thresh=threshold)
    return hmm_id, hmm


def get_spike_data(rec_dir, units, channel, dt=None, time_start=None, time_end=None):
    time, spike_array = h5io.get_spike_data(rec_dir, units, channel)
    curr_dt = np.unique(np.diff(time))[0] / 1000
    if dt is not None and curr_dt < dt:
        spike_array, time = rebin_spike_array(spike_array, curr_dt, time, dt)
    elif dt is not None and curr_dt > dt:
        raise ValueError('Cannot upsample spike array from %f ms '
                         'bins to %f ms bins' % (dt, curr_dt))
    else:
        dt = curr_dt

    if time_start and time_end:
        idx = np.where((time >= time_start) & (time < time_end))[0]
        time = time[idx]
        spike_array = spike_array[:, :, idx]

    return spike_array.astype('int32'), dt, time


def pick_best_hmm(HMMs):
    '''For each HMM it searches the history for the HMM with lowest BIC Then it
    compares HMMs. Those with same # of free parameters are compared by BIC
    Those with different # of free parameters (namely # of states) are compared
    by cost Best HMM is returned

    Parameters
    ----------
    HMMs : list of PoissonHmm objects

    Returns
    -------
    PoissonHmm
    '''
    # First optimize each HMMs and sort into groups based on # of states
    groups = {}
    for hmm in HMMs:
        hmm.set_to_lowest_BIC()
        if hmm.n_states in groups:
            groups[hmm.n_states].append(hmm)
        else:
            groups[hmm.n_states] = [hmm]

    best_per_state = {}
    for k, v in groups:
        BICs = np.array([x.get_BIC()[0] for x in v])
        idx = np.argmin(BICs)
        best_per_state[k] = v[idx]

    hmm_list = best_per_state.values()
    costs = np.array([x.get_cost() for x in hmm_list])
    idx = np.argmin(costs)
    return hmm_list[idx]

    # Compare HMMs with same number of states by BIC

Functions

def backward(spikes, dt, A, B, norms)

Runs the backward algorithm to compute beta = P(ot+1…oT | Xt=s) Computes the probability of observing all future observations given the current state at each time point

Paramters

spike : np.array, N x T matrix of spike counts nStates : int, # of hidden states predicted dt : float, timebin size in seconds A : np.array, nStates x nStates matrix of transition probabilities B : np.array, N x nStates matrix of estimated spike rates for each neuron

Returns

beta : np.array, nStates x T matrix of backward() probabilities
 
Expand source code
@njit
def backward(spikes, dt, A, B, norms):
    ''' Runs the backward algorithm to compute beta = P(ot+1...oT | Xt=s)
    Computes the probability of observing all future observations given the
    current state at each time point

    Paramters
    ---------
    spike : np.array, N x T matrix of spike counts
    nStates : int, # of hidden states predicted
    dt : float, timebin size in seconds
    A : np.array, nStates x nStates matrix of transition probabilities
    B : np.array, N x nStates matrix of estimated spike rates for each neuron

    Returns
    -------
    beta : np.array, nStates x T matrix of backward probabilities
    '''
    nTimeSteps = spikes.shape[1]
    nStates = A.shape[0]
    beta = np.zeros((nStates, nTimeSteps))
    beta[:, -1] = 1  # Initialize final beta to 1 for all states
    tStep = list(range(nTimeSteps-1))
    tStep.reverse()
    for t in tStep:
        for s in range(nStates):
            beta[s,t] = np.sum((beta[:, t+1] * A[s,:]) *
                               np.prod(poisson(B[:, s], spikes[:, t+1], dt)))

        beta[:, t] = beta[:, t] / norms[t+1]

    return beta
def baum_welch(spikes, dt, A, B, alpha, beta)
Expand source code
@njit
def baum_welch(spikes, dt, A, B, alpha, beta):
    nTimeSteps = spikes.shape[1]
    nStates = A.shape[0]
    gamma = np.zeros((nStates, nTimeSteps))
    epsilons = np.zeros((nStates, nStates, nTimeSteps-1))
    for t in range(nTimeSteps):
        if t < nTimeSteps-1:
            gamma[:, t] = (alpha[:, t] * beta[:, t]) / np.sum(alpha[:,t] * beta[:,t])
            epsilonNumerator = np.zeros((nStates, nStates))
            for si in range(nStates):
                for sj in range(nStates):
                    probs = np.prod(poisson(B[:,sj], spikes[:, t+1], dt))
                    epsilonNumerator[si, sj] = (alpha[si, t]*A[si, sj]*
                                                beta[sj, t]*probs)

            epsilons[:, :, t] = epsilonNumerator / np.sum(epsilonNumerator)

    return gamma, epsilons
def compare_hmm_to_truth(truth_dat, hmm, state_map=None)
Expand source code
def compare_hmm_to_truth(truth_dat, hmm, state_map=None):
    if state_map is None:
        state_map = match_states(truth_dat.ground_truth['firing_rates'], hmm.emission)

    fig, ax = plt.subplots(nrows=2, ncols=2, figsize=(15,10))
    truth_dat.plot_state_raster(ax=ax[0,0])
    truth_dat.plot_state_rates(ax=ax[1,0])
    hmm.plot_state_raster(ax=ax[0,1], state_map=state_map)
    hmm.plot_state_rates(ax=ax[1,1], state_map=state_map)
    ax[0,0].set_title('Ground Truth States')
    ax[0,1].set_title('HMM Best Decoded States')
    ax[1,0].get_legend().remove()
    ax[1,1].legend(loc='upper center', bbox_to_anchor=[-0.4, -0.6, 0.5, 0.5], ncol=5)

    # Compute edit distances, histogram, return mean and median % correct
    truePaths = truth_dat.ground_truth['state_vectors']
    bestPaths, _ = hmm.get_best_paths()
    if state_map is not None:
        bestPaths = convert_path_state_numbers(bestPaths, state_map)

    edit_distances = np.zeros((truePaths.shape[0],))
    pool = mp.Pool(mp.cpu_count())
    def update(ans):
        edit_distances[ans[0]] = ans[1]

    print('Computing edit distances...')
    for i, x in enumerate(zip(truePaths, bestPaths)):
        pool.apply_async(levenshtein_mp, (i, *x), callback=update)

    pool.close()
    pool.join()
    print('Done!')

    nPts = truePaths.shape[1]
    mean_correct = 100*(nPts - np.mean(edit_distances)) / nPts
    median_correct = 100*(nPts - np.median(edit_distances)) / nPts

    # Plot:
    #   - edit distance histogram
    #   - side-by-side trial comparison
    h = 0.25
    dt = hmm.dt
    time = np.arange(0, nPts * (dt*1000), dt*1000)  # time in ms
    fig2, ax2 = plt.subplots(ncols=2, figsize=(15,10))
    ax2[0].hist(100*(nPts-edit_distances)/nPts)
    ax2[0].set_xlabel('Percent Correct')
    ax2[0].set_ylabel('Trial Count')
    ax2[0].set_title('Percent Correct based on edit distance\n'
                     'Mean Correct: %0.1f%%, Median: %0.1f%%'
                     % (mean_correct, median_correct))

    maxState = int(np.max((bestPaths, truePaths)))
    colors = [plt.cm.Paired(x) for x in np.linspace(0, 1, (maxState+1)*2)]
    trueCol = [colors[x] for x in np.arange(0, (maxState+1)*2, 2)]
    hmmCol = [colors[x] for x in np.arange(1, (maxState+1)*2, 2)]
    leg = {}
    leg['hmm'] = {k: None for k in np.unique((bestPaths, truePaths))}
    leg['truth'] = {k: None for k in np.unique((bestPaths, truePaths))}
    for i, x in enumerate(zip(truePaths, bestPaths)):
        y = x[0]
        z = x[1]
        t = 0
        while(t  < nPts):
            s = int(y[t])
            next_t = np.where(y[t:] != s)[0]
            if len(next_t) == 0:
                next_t = nPts - t
            else:
                next_t = next_t[0]

            t_start = time[t]
            t_end = time[t+next_t-1]
            tmp = ax2[1].fill_between([t_start, t_end], [i, i], [i+h, i+h], color=trueCol[s])
            if leg['truth'][s] is None:
                leg['truth'][s] = tmp

            t += next_t

        t = 0
        while(t  < nPts):
            s = int(z[t])
            next_t = np.where(z[t:] != s)[0]
            if len(next_t) == 0:
                next_t = nPts - t
            else:
                next_t = next_t[0]

            t_start = time[t]
            t_end = time[t+next_t-1]
            tmp = ax2[1].fill_between([t_start, t_end], [i, i], [i-h, i-h], color=hmmCol[s])
            if leg['hmm'][s] is None:
                leg['hmm'][s] = tmp

            t += next_t

        # Write % correct next to line
        t_str = '%0.1f%%' % (100 * (nPts - edit_distances[i])/nPts)
        ax2[1].text(nPts+5, i-h, t_str)

    ax2[1].set_xlim((0, nPts+int(nPts/3)))
    ax2[1].set_xlabel('Time (ms)')
    ax2[1].set_title('State Sequences')
    handles = list(leg['truth'].values()) + list(leg['hmm'].values())
    labels = (['True State %i' % i for i in leg['truth'].keys()] +
              ['HMM State %i' % i for i in leg['hmm'].keys()])
    ax2[1].legend(handles, labels, shadow=True,
                  bbox_to_anchor=(0.78, 0.5, 0.5, 0.5))

    fig.show()
    fig2.show()
    return fig, ax, fig2, ax2
def compute_BIC(spikes, dt, PI, A, B)
Expand source code
def compute_BIC(spikes, dt, PI, A, B):
    bestPaths, maxLogProb = compute_best_paths(spikes, dt, PI, A, B)
    maxLogProb = np.max(maxLogProb)

    nParams = (A.shape[0]*(A.shape[1]-1) +
               (PI.shape[0]-1) +
               B.shape[0]*(B.shape[1]-1))

    nPts = spikes.shape[-1]
    BIC = -2 * maxLogProb + nParams * np.log(nPts)
    return BIC, bestPaths, maxLogProb
def compute_best_paths(spikes, dt, PI, A, B)
Expand source code
def compute_best_paths(spikes, dt, PI, A, B):
    if len(spikes.shape) == 2:
        spikes = np.array([spikes])

    nTrials, nCells, nTimeSteps = spikes.shape
    bestPaths = np.zeros((nTrials, nTimeSteps))-1
    pathProbs = np.zeros((nTrials,))

    for i, trial in enumerate(spikes):
        bestPaths[i,:], pathProbs[i], _, _ = poisson_viterbi(trial, dt, PI,
                                                             A, B)
    return bestPaths, pathProbs
def compute_hmm_cost(spikes, dt, PI, A, B, win_size=0.25, true_rates=None)
Expand source code
def compute_hmm_cost(spikes, dt, PI, A, B, win_size=0.25, true_rates=None):
    if true_rates is None:
        true_rates = convert_spikes_to_rates(spikes, dt, win_size)

    BIC, bestPaths, maxLogProb = compute_BIC(spikes, dt, PI, A, B)
    hmm_rates = generate_rate_array_from_state_seq(bestPaths, B, dt, win_size,
                                                   step_size=win_size)
    RMSE = compute_rate_rmse(true_rates, hmm_rates)
    return RMSE, BIC, bestPaths, maxLogProb
def compute_new_matrices(spikes, dt, gammas, epsilons)
Expand source code
def compute_new_matrices(spikes, dt, gammas, epsilons):
    nTrials, nCells, nTimeSteps = spikes.shape
    minFR = 1/(nTimeSteps*dt)

    PI = np.sum(gammas, axis=0)[:,1] / nTrials
    Anumer = np.sum(np.sum(epsilons, axis=3), axis=0)
    Adenom = np.sum(np.sum(gammas[:,:,:-1], axis=2), axis=0)
    A = Anumer/Adenom
    A = A/np.sum(A, axis=1)
    Bnumer = np.sum(np.array([np.matmul(tmp_y, tmp_g.T)
                              for tmp_y, tmp_g in zip(spikes, gammas)]),
                    axis=0)
    Bdenom =  np.sum(np.sum(gammas, axis=2), axis=0)
    B = (Bnumer / Bdenom)/dt
    idx = np.where(B < minFR)[0]
    B[idx] = minFR

    return PI, A, B
def compute_rate_rmse(rates1, rates2)
Expand source code
@njit
def compute_rate_rmse(rates1, rates2):
    # Compute RMSE per trial
    # Mean over trials
    n_trials, n_cells, n_steps = rates1.shape
    RMSE = np.zeros((n_trials,))
    for i in range(n_trials):
        t1 = rates1[i, :, :]
        t2 = rates2[i, :, :]
        # Compute RMSE from euclidean distances at each time point
        distances = np.zeros((n_steps,))
        for j in range(n_steps):
            distances[j] =  euclidean(t1[:,j], t2[:,j])

        RMSE[i] = np.sqrt(np.mean(np.power(distances,2)))

    return np.mean(RMSE)
def convert_path_state_numbers(paths, state_map)
Expand source code
def convert_path_state_numbers(paths, state_map):
    newPaths = np.zeros(paths.shape)
    for k,v in state_map.items():
        idx = np.where(paths == k)
        newPaths[idx] = v

    return newPaths
def convert_spikes_to_rates(spikes, dt, win_size, step_size=None)
Expand source code
@njit
def convert_spikes_to_rates(spikes, dt, win_size, step_size=None):
    if step_size is None:
        step_size = win_size

    n_trials, n_cells, n_steps = spikes.shape
    n_pts = int(win_size/dt)
    n_step_pts = int(step_size/dt)
    win_starts = np.arange(0, n_steps, n_step_pts)
    out = np.zeros((n_trials, n_cells, len(win_starts)))
    for i, w in enumerate(win_starts):
        out[:, :, i] = np.sum(spikes[:, :, w:w+n_pts], axis=2) / win_size

    return out
def euclidean(a, b)
Expand source code
@njit
def euclidean(a, b):
    c = np.power(a-b,2)
    return np.sqrt(np.sum(c))
def fast_factorial(x)
Expand source code
@njit
def fast_factorial(x):
    if x < len(FACTORIAL_LOOKUP):
        return FACTORIAL_LOOKUP[x]
    else:
        y = 1
        for i in range(1,x+1):
            y = y*i

        return y
def fit_hmm_mp(nStates, spikes, dt, max_iter=1000, thresh=0.0001)
Expand source code
def fit_hmm_mp(nStates, spikes, dt, max_iter=1000, thresh=1e-4):
    hmm = PoissonHMM(nStates, spikes, dt)
    hmm.fit(max_iter=max_iter, convergence_thresh=thresh, parallel=False)
    return hmm
def forward(spikes, dt, PI, A, B)

Run forward algorithm to compute alpha = P(Xt = i| o1…ot, pi) Gives the probabilities of being in a specific state at each time point given the past observations and initial probabilities

Parameters

spikes : np.array
N x T matrix of spike counts with each entry ((i,j)) holding the # of spikes from neuron i in timebine j
nStates : int, # of hidden states predicted to have generate the spikes
 
dt : float, timebin in seconds (i.e. 0.001)
 
PI : np.array
nStates x 1 vector of initial state probabilities
A : np.array
nStates x nStates state transmission matrix with each entry ((i,j)) giving the probability of transitioning from state i to state j
B : np.array
N x nSates rate matrix. Each entry ((i,j)) gives this predicited rate of neuron i in state j

Returns

alpha : np.array
nStates x T matrix of forward probabilites. Each entry (i,j) gives P(Xt = i | o1,…,oj, pi)
norms : np.array
1 x T vector of norm used to normalize alpha to be a probability distribution and also to scale the outputs of the backward algorithm. norms(t) = sum(alpha(:,t))
Expand source code
@njit
def forward(spikes, dt, PI, A, B):
    '''Run forward algorithm to compute alpha = P(Xt = i| o1...ot, pi)
    Gives the probabilities of being in a specific state at each time point
    given the past observations and initial probabilities

    Parameters
    ----------
    spikes : np.array
        N x T matrix of spike counts with each entry ((i,j)) holding the # of
        spikes from neuron i in timebine j
    nStates : int, # of hidden states predicted to have generate the spikes
    dt : float, timebin in seconds (i.e. 0.001)
    PI : np.array
        nStates x 1 vector of initial state probabilities
    A : np.array
        nStates x nStates state transmission matrix with each entry ((i,j))
        giving the probability of transitioning from state i to state j
    B : np.array
        N x nSates rate matrix. Each entry ((i,j)) gives this predicited rate
        of neuron i in state j

    Returns
    -------
    alpha : np.array
        nStates x T matrix of forward probabilites. Each entry (i,j) gives
        P(Xt = i | o1,...,oj, pi)
    norms : np.array
        1 x T vector of norm used to normalize alpha to be a probability
        distribution and also to scale the outputs of the backward algorithm.
        norms(t) = sum(alpha(:,t))
    '''
    nTimeSteps = spikes.shape[1]
    nStates = A.shape[0]

    # For each state, use the the initial state distribution and spike counts
    # to initialize alpha(:,1)
    row = np.array([PI[i] * np.prod(poisson(B[:,i], spikes[:,0], dt))
                    for i in range(nStates)])
    alpha = np.zeros((nStates, nTimeSteps))
    norms = [np.sum(row)]
    alpha[:, 0] = row/norms[0]
    for t in range(1, nTimeSteps):
        tmp = np.array([np.prod(poisson(B[:, s], spikes[:, t], dt)) *
                        np.sum(alpha[:, t-1] * A[:,s])
                        for s in range(nStates)])
        tmp_norm = np.sum(tmp)
        norms.append(tmp_norm)
        tmp = tmp / tmp_norm
        alpha[:, t] = tmp

    return alpha, norms
def generate_rate_array_from_state_seq(bestPaths, B, dt, win_size, step_size=None)
Expand source code
@njit
def generate_rate_array_from_state_seq(bestPaths, B, dt, win_size,
                                       step_size=None):
    if not step_size:
        step_size = win_size

    n_trials, n_steps = bestPaths.shape
    n_cells, n_states = B.shape
    rates = np.zeros((n_trials, n_cells, n_steps))
    for j in range(n_trials):
        seq = bestPaths[j, :].astype(np.int64)
        rates[j, :, :] = B[:, seq]

    n_pts = int(win_size / dt)
    n_step_pts = int(step_size/dt)
    win_starts = np.arange(0, n_steps, n_step_pts)
    mean_rates = np.zeros((n_trials, n_cells, len(win_starts)))
    for i, w in enumerate(win_starts):
        mean_rates[:, :, i] = np.sum(rates[:, : , w:w+n_pts], axis=2) / n_pts

    return mean_rates
def get_hmm_spike_data(rec_dir, unit_type, channel, time_start=None, time_end=None, dt=None)
Expand source code
@memory.cache
def get_hmm_spike_data(rec_dir, unit_type, channel, time_start=None, time_end=None, dt=None):
    units = query_units(rec_dir, unit_type)
    time, spike_array = h5io.get_spike_data(rec_dir, units, channel)
    curr_dt = np.unique(np.diff(time))[0] / 1000
    if dt is not None and curr_dt < dt:
        spike_array, time = rebin_spike_array(spike_array, curr_dt, time, dt)
    elif dt is not None and curr_dt > dt:
        raise ValueError('Cannot upsample spike array from %f ms '
                         'bins to %f ms bins' % (dt, curr_dt))
    else:
        dt = curr_dt

    if time_start and time_end:
        idx = np.where((time >= time_start) & (time < time_end))[0]
        time = time[idx]
        spike_array = spike_array[:, :, idx]

    return spike_array.astype('int32'), dt, time
def get_spike_data(rec_dir, units, channel, dt=None, time_start=None, time_end=None)
Expand source code
def get_spike_data(rec_dir, units, channel, dt=None, time_start=None, time_end=None):
    time, spike_array = h5io.get_spike_data(rec_dir, units, channel)
    curr_dt = np.unique(np.diff(time))[0] / 1000
    if dt is not None and curr_dt < dt:
        spike_array, time = rebin_spike_array(spike_array, curr_dt, time, dt)
    elif dt is not None and curr_dt > dt:
        raise ValueError('Cannot upsample spike array from %f ms '
                         'bins to %f ms bins' % (dt, curr_dt))
    else:
        dt = curr_dt

    if time_start and time_end:
        idx = np.where((time >= time_start) & (time < time_end))[0]
        time = time[idx]
        spike_array = spike_array[:, :, idx]

    return spike_array.astype('int32'), dt, time
def hmm_fit_mp(rec_dir, params)
Expand source code
def hmm_fit_mp(rec_dir, params):
    hmm_id = params['hmm_id']
    n_states = params['n_states']
    dt = params['dt']
    time_start = params['time_start']
    time_end = params['time_end']
    max_iter = params['max_iter']
    threshold = params['threshold']
    unit_type = params['unit_type']
    channel = params['channel']
    spikes, dt, time = get_hmm_spike_data(rec_dir, unit_type, channel,
                                          time_start=time_start,
                                          time_end=time_end, dt = dt)
    hmm = PoissonHMM(params['n_states'], spikes, dt)
    hmm.fit(max_iter=max_iter, convergence_thresh=threshold)
    return hmm_id, hmm
def isNotConverged(oldPI, oldA, oldB, PI, A, B, thresh=0.0001)
Expand source code
def isNotConverged(oldPI, oldA, oldB, PI, A, B, thresh=1e-4):
    dPI = np.sqrt(np.sum(np.power(oldPI - PI, 2)))
    dA = np.sqrt(np.sum(np.power(oldA - A, 2)))
    dB = np.sqrt(np.sum(np.power(oldB - B, 2)))
    print('dPI = %f,  dA = %f,  dB = %f' % (dPI, dA, dB))
    if all([x < thresh for x in [dPI, dA, dB]]):
        return False
    else:
        return True
def levenshtein(seq1, seq2)

Computes edit distance between 2 sequences

Expand source code
@njit
def levenshtein(seq1, seq2):
    ''' Computes edit distance between 2 sequences
    '''
    size_x = len(seq1) + 1
    size_y = len(seq2) + 1
    matrix = np.zeros ((size_x, size_y))
    for x in range(size_x):
        matrix [x, 0] = x

    for y in range(size_y):
        matrix [0, y] = y

    for x in range(1, size_x):
        for y in range(1, size_y):
            if seq1[x-1] == seq2[y-1]:
                matrix [x,y] = min(matrix[x-1, y] + 1, matrix[x-1, y-1],
                                   matrix[x, y-1] + 1)
            else:
                matrix [x,y] = min(matrix[x-1,y] + 1, matrix[x-1,y-1] + 1,
                                   matrix[x,y-1] + 1)

    return (matrix[size_x - 1, size_y - 1])
def levenshtein_mp(i, seq1, seq2)

Computes edit distance between 2 sequences

Expand source code
@njit
def levenshtein_mp(i, seq1, seq2):
    ''' Computes edit distance between 2 sequences
    '''
    size_x = len(seq1) + 1
    size_y = len(seq2) + 1
    matrix = np.zeros ((size_x, size_y))
    for x in range(size_x):
        matrix [x, 0] = x

    for y in range(size_y):
        matrix [0, y] = y

    for x in range(1, size_x):
        for y in range(1, size_y):
            if seq1[x-1] == seq2[y-1]:
                matrix [x,y] = min(matrix[x-1, y] + 1, matrix[x-1, y-1],
                                   matrix[x, y-1] + 1)
            else:
                matrix [x,y] = min(matrix[x-1,y] + 1, matrix[x-1,y-1] + 1,
                                   matrix[x,y-1] + 1)

    return i, matrix[size_x - 1, size_y - 1]
def match_states(rates1, rates2)

Takes 2 Cell X State firing rate matrices and determines which states are most similar. Returns dict mapping rates2 states to rates1 states

Expand source code
def match_states(rates1, rates2):
    '''Takes 2 Cell X State firing rate matrices and determines which states
    are most similar. Returns dict mapping rates2 states to rates1 states
    '''
    distances = np.zeros((rates1.shape[1], rates2.shape[1]))
    for x, y in it.product(range(rates1.shape[1]), range(rates2.shape[1])):
        tmp = euclidean(rates1[:, x], rates2[:, y])
        distances[x, y] = tmp

    states = list(range(rates2.shape[1]))
    out = {}
    for i in range(rates2.shape[1]):
        s = np.argmin(distances[:,i])
        r = np.argmin(distances[s, :])
        if r == i and s in states:
            out[i] = s
            idx = np.where(states == s)[0]
            states.pop(int(idx))

    for i in range(rates2.shape[1]):
        if i not in out:
            s = np.argmin(distances[states, i])
            out[i] = states[s]

    return out
def pick_best_hmm(HMMs)

For each HMM it searches the history for the HMM with lowest BIC Then it compares HMMs. Those with same # of free parameters are compared by BIC Those with different # of free parameters (namely # of states) are compared by cost Best HMM is returned

Parameters

HMMs : list of PoissonHmm objects
 

Returns

PoissonHmm
 
Expand source code
def pick_best_hmm(HMMs):
    '''For each HMM it searches the history for the HMM with lowest BIC Then it
    compares HMMs. Those with same # of free parameters are compared by BIC
    Those with different # of free parameters (namely # of states) are compared
    by cost Best HMM is returned

    Parameters
    ----------
    HMMs : list of PoissonHmm objects

    Returns
    -------
    PoissonHmm
    '''
    # First optimize each HMMs and sort into groups based on # of states
    groups = {}
    for hmm in HMMs:
        hmm.set_to_lowest_BIC()
        if hmm.n_states in groups:
            groups[hmm.n_states].append(hmm)
        else:
            groups[hmm.n_states] = [hmm]

    best_per_state = {}
    for k, v in groups:
        BICs = np.array([x.get_BIC()[0] for x in v])
        idx = np.argmin(BICs)
        best_per_state[k] = v[idx]

    hmm_list = best_per_state.values()
    costs = np.array([x.get_cost() for x in hmm_list])
    idx = np.argmin(costs)
    return hmm_list[idx]
def plot_state_raster(data, stateVec, dt, ax=None)
Expand source code
def plot_state_raster(data, stateVec, dt, ax=None):
    if len(data.shape) == 2:
        data = np.array([data])

    nTrials, nCells, nTimeSteps = data.shape
    nStates = np.max(stateVec) +1

    gradient = np.array([0 + i/(nCells+1) for i in range(nCells)])
    time = np.arange(0, nTimeSteps * dt * 1000, dt * 1000)
    colors = [plt.cm.jet(i) for i in np.linspace(0,1,nStates)]

    if ax is None:
        fig, ax = plt.subplots()
    else:
        fig = ax.figure

    for trial, spikes in enumerate(data):
        path = stateVec[trial]
        for i, row in enumerate(spikes):
            idx = np.where(row == 1)[0]
            ax.scatter(time[idx], row[idx]*trial + gradient[i],
                       c=[colors[int(x)] for x in path[idx]], marker='|')

    return fig, ax
def plot_state_rates(rates, ax=None)
Expand source code
def plot_state_rates(rates, ax=None):
    if ax is None:
        fig, ax = plt.subplots()
    else:
        fig = ax.figure

    nCells, nStates = rates.shape
    df = pd.DataFrame(rates, columns=['state %i' % i for i in range(nStates)])
    df['cell'] = ['cell %i' % i for i in df.index]
    df = pd.melt(df, 'cell', ['state %i' % i for i in range(nStates)], 'state', 'rate')
    sns.barplot(x='state', y='rate', hue='cell', data=df,
                palette='muted', ax=ax)

    return fig, ax
def poisson(rate, n, dt)

Gives probability of each neurons spike count assuming poisson spiking

Expand source code
@njit
def poisson(rate, n, dt):
    '''Gives probability of each neurons spike count assuming poisson spiking
    '''
    tmp = np.power(rate*dt, n) / np.array([fast_factorial(x) for x in n])
    tmp = tmp * np.exp(-rate*dt)
    return tmp
def poisson_viterbi(spikes, dt, PI, A, B)

Parameters

spikes : np.array, Neuron X Time matrix of spike counts
 
PI : np.array, nStates x 1 vector of initial state probabilities
 
A : np.array, nStates X nStates matric of state transition probabilities
 
B : np.array, Neuron X States matrix of estimated firing rates
 
dt : float, time step size in seconds
 

Returns

bestPath : np.array
1 x Time vector of states representing the most likely hidden state sequence
maxPathLogProb : float
Log probability of the most likely state sequence
T1 : np.array
State X Time matrix where each entry (i,j) gives the log probability of the the most likely path so far ending in state i that generates observations o1,…, oj
T2 : np.array
State X Time matrix of back pointers where each entry (i,j) gives the state x(j-1) on the most likely path so far ending in state i
Expand source code
def poisson_viterbi(spikes, dt, PI, A, B):
    '''
    Parameters
    ----------
    spikes : np.array, Neuron X Time matrix of spike counts
    PI : np.array, nStates x 1 vector of initial state probabilities
    A : np.array, nStates X nStates matric of state transition probabilities
    B : np.array, Neuron X States matrix of estimated firing rates
    dt : float, time step size in seconds

    Returns
    -------
    bestPath : np.array
        1 x Time vector of states representing the most likely hidden state
        sequence
    maxPathLogProb : float
        Log probability of the most likely state sequence
    T1 : np.array
        State X Time matrix where each entry (i,j) gives the log probability of
        the the most likely path so far ending in state i that generates
        observations o1,..., oj
    T2: np.array
        State X Time matrix of back pointers where each entry (i,j) gives the
        state x(j-1) on the most likely path so far ending in state i
    '''
    if A.shape[0] != A.shape[1]:
        raise ValueError('Transition matrix is not square')

    nStates = A.shape[0]
    nCells, nTimeSteps = spikes.shape
    T1 = np.zeros((nStates, nTimeSteps))
    T2 = np.zeros((nStates, nTimeSteps))
    T1[:,1] = np.array([np.log(PI[i]) +
                        np.log(np.prod(poisson(B[:,i], spikes[:, 1], dt)))
                        for i in range(nStates)])
    for t, s in it.product(range(1,nTimeSteps), range(nStates)):
        probs = np.log(np.prod(poisson(B[:, s], spikes[:, t], dt)))
        vec2 = T1[:, t-1] + np.log(A[:,s])
        vec1 = vec2 + probs
        T1[s, t] = np.max(vec1)
        idx = np.argmax(vec2)
        T2[s, t] = idx

    bestPathEndState = np.argmax(T1[:, -1])
    maxPathLogProb = T1[idx, -1]
    bestPath = np.zeros((nTimeSteps,))
    bestPath[-1] = bestPathEndState
    tStep = list(range(nTimeSteps-1))
    tStep.reverse()
    for t in tStep:
        bestPath[t] = T2[int(bestPath[t+1]), t+1]

    return bestPath, maxPathLogProb, T1, T2
def query_units(dat, unit_type)

Returns the units names of all units in the dataset that match unit_type

Parameters

dat : blechpy.dataset or str
Can either be a dataset object or the str path to the recording directory containing that data .h5 object
unit_type : str, {'single', 'pyramidal', 'interneuron', 'all'}
determines whether to return 'single' units, 'pyramidal' (regular spiking single) units, 'interneuron' (fast spiking single) units, or 'all' units

Returns

list of str : unit_names
Expand source code
def query_units(dat, unit_type):
    '''Returns the units names of all units in the dataset that match unit_type

    Parameters
    ----------
    dat : blechpy.dataset or str
        Can either be a dataset object or the str path to the recording
        directory containing that data .h5 object
    unit_type : str, {'single', 'pyramidal', 'interneuron', 'all'}
        determines whether to return 'single' units, 'pyramidal' (regular
        spiking single) units, 'interneuron' (fast spiking single) units, or
        'all' units

    Returns
    -------
        list of str : unit_names
    '''
    if isinstance(dat, str):
        units = h5io.get_unit_table(dat)
    else:
        units = dat.get_unit_table()

    u_str = unit_type.lower()
    q_str = ''
    if u_str == 'single':
        q_str = 'single_unit == True'
    elif u_str == 'pyramidal':
        q_str = 'single_unit == True and regular_spiking == True'
    elif u_str == 'interneuron':
        q_str = 'single_unit == True and fast_spiking == True'
    elif u_str == 'all':
        return units['unit_name'].tolist()
    else:
        raise ValueError('Invalid unit_type %s. Must be '
                         'single, pyramidal, interneuron or all' % u_str)

    return units.query(q_str)['unit_name'].tolist()
def read_hmm_from_hdf5(h5_file, hmm_id, rec_dir)
Expand source code
def read_hmm_from_hdf5(h5_file, hmm_id, rec_dir):
    print('Loading HMM %i for hdf5' % hmm_id)
    with tables.open_file(h5_file, 'r') as hf5:
        h_str = 'hmm_%i' % hmm_id
        if h_str not in hf5.root or len(hf5.list_nodes('/'+h_str)) == 0:
            return None

        table = hf5.root.data_overview
        row = list(table.where('hmm_id == id', condvars={'id':hmm_id}))
        if len(row) == 0:
            raise ValueError('Parameters not found for hmm %i' % hmm_id)
        elif len(row) > 1:
            raise ValueError('Multiple parameters found for hmm %i' % hmm_id)

        row = row[0]
        units = query_units(rec_dir, row['unit_type'].decode('utf-8'))
        spikes, dt, time = get_spike_data(rec_dir, units, row['channel'],
                                          dt=row['dt'],
                                          time_start=row['time_start'],
                                          time_end=row['time_end'])
        tmp = hf5.root[h_str]
        mats = {'initial_distribution': tmp['initial_distribution'][:],
                'transition': tmp['transition'][:],
                'emission': tmp['emission'][:],
                'fitted': row['fitted']}
        hmm = PoissonHMM(row['n_states'], spikes, dt, set_data=mats)

    return hmm
def rebin_spike_array(spikes, dt, time, new_dt)
Expand source code
def rebin_spike_array(spikes, dt, time, new_dt):
    if spikes.ndim == 2:
        spikes = np.expand_dims(spikes,0)

    n_trials, n_cells, n_steps = spikes.shape
    n_bins = int(new_dt/dt)
    new_time = np.arange(time[0], time[-1], new_dt)
    new_spikes = np.zeros((n_trials, n_cells, len(new_time)))
    for i, w in enumerate(new_time):
        idx = np.where((time >= w) & (time < w+new_dt))[0]
        new_spikes[:,:,i] = np.sum(spikes[:,:,idx], axis=-1)

    return new_spikes, new_time
def wrap_baum_welch(trial_id, trial_dat, dt, PI, A, B)
Expand source code
def wrap_baum_welch(trial_id, trial_dat, dt, PI, A, B):
    alpha, norms = forward(trial_dat, dt, PI, A, B)
    beta = backward(trial_dat, dt, A, B, norms)
    tmp_gamma, tmp_epsilons = baum_welch(trial_dat, dt, A, B, alpha, beta)
    return trial_id, tmp_gamma, tmp_epsilons
def write_hmm_to_hdf5(h5_file, hmm_id, hmm)
Expand source code
def write_hmm_to_hdf5(h5_file, hmm_id, hmm):
    h_str = 'hmm_%i' % hmm_id
    print('Writing HMM %i to hdf5 file...' % hmm_id)
    with tables.open_file(h5_file, 'a') as hf5:
        if h_str in hf5.root:
            hf5.remove_node('/', h_str, recursive=True)

        hf5.create_group('/', h_str, 'Data for HMM #%i' % hmm_id)
        hf5.create_array('/'+h_str, 'initial_distribution',
                         hmm.initial_distribution)
        hf5.create_array('/'+h_str, 'transition', hmm.transition)
        hf5.create_array('/'+h_str, 'emission', hmm.emission)

        best_paths, _ = hmm.get_best_paths()
        hf5.create_array('/'+h_str, 'state_sequences', best_paths)

Classes

class HMMFit (rec_dir, params)
Expand source code
class HMMFit(object):
    def __init__(self, rec_dir, params):
        self._rec_dir = rec_dir
        self._params = params

    def run(self, parallel=False):
        params = self._params
        spikes, dt, time = self.get_spike_data()
        hmm = PoissonHMM(params['n_states'], spikes, dt)
        hmm.fit(max_iter=params['max_iter'],
                convergence_thresh=params['threshold'],
                parallel=parallel)
        del spikes, dt, time
        return params['hmm_id'], hmm

    def get_spike_data(self):
        p = self._params
        units = query_units(self._rec_dir, p['unit_type'])
        # Get stored spike array, time is in ms, dt is usually 1 ms
        spike_array, dt, time = get_spike_data(self._rec_dir, units,
                                               p['channel'], dt=p['dt'],
                                               time_start=p['time_start'],
                                               time_end=p['time_end'])
        return spike_array, dt, time

Methods

def get_spike_data(self)
Expand source code
def get_spike_data(self):
    p = self._params
    units = query_units(self._rec_dir, p['unit_type'])
    # Get stored spike array, time is in ms, dt is usually 1 ms
    spike_array, dt, time = get_spike_data(self._rec_dir, units,
                                           p['channel'], dt=p['dt'],
                                           time_start=p['time_start'],
                                           time_end=p['time_end'])
    return spike_array, dt, time
def run(self, parallel=False)
Expand source code
def run(self, parallel=False):
    params = self._params
    spikes, dt, time = self.get_spike_data()
    hmm = PoissonHMM(params['n_states'], spikes, dt)
    hmm.fit(max_iter=params['max_iter'],
            convergence_thresh=params['threshold'],
            parallel=parallel)
    del spikes, dt, time
    return params['hmm_id'], hmm
class HmmHandler (dat, params=None, save_dir=None)

Takes a blechpy dataset object and fits HMMs for each tastant

Parameters

dat : blechpy.dataset
 
params : dict or list of dicts
each dict must have fields: time_window: list of int, time window to cut around stimuli in ms convergence_thresh: float max_iter: int n_repeats: int unit_type: str, {'single', 'pyramidal', 'interneuron', 'all'} bin_size: time bin for spike array when fitting in seconds n_states: predicted number of states to fit
Expand source code
class HmmHandler(object):
    def __init__(self, dat, params=None, save_dir=None):
        '''Takes a blechpy dataset object and fits HMMs for each tastant

        Parameters
        ----------
        dat: blechpy.dataset
        params: dict or list of dicts
            each dict must have fields:
                time_window: list of int, time window to cut around stimuli in ms
                convergence_thresh: float
                max_iter: int
                n_repeats: int
                unit_type: str, {'single', 'pyramidal', 'interneuron', 'all'}
                bin_size: time bin for spike array when fitting in seconds
                n_states: predicted number of states to fit
        '''
        if isinstance(params, dict):
            params = [params]

        if isinstance(dat, str):
            dat = load_dataset(dat)
            if dat is None:
                raise FileNotFoundError('No dataset.p file found given directory')

        if save_dir is None:
            save_dir = os.path.join(dat.root_dir,
                                    '%s_analysis' % dat.data_name)

        self._dataset = dat
        self.root_dir = dat.root_dir
        self.save_dir = save_dir
        self.h5_file = os.path.join(save_dir, '%s_HMM_Analysis.hdf5' % dat.data_name)
        dim = dat.dig_in_mapping.query('exclude==False')
        tastes = dim['name'].tolist()
        if params is None:
            # Load params and fitted models
            self.load_data()
        else:
            self.init_params(params)

        self.params = params

        if not os.path.isdir(save_dir):
            os.makedirs(save_dir)

        self.plot_dir = os.path.join(save_dir, 'HMM_Plots')
        if not os.path.isdir(self.plot_dir):
            os.makedirs(self.plot_dir)

        self._setup_hdf5()

    def init_params(self, params):
        dat = self._dataset
        dim = dat.dig_in_mapping.query('exclude == False')
        tastes = dim['name'].tolist()
        dim = dim.set_index('name')
        if not hasattr(dat, 'dig_in_trials'):
            dat.create_trial_list()

        trials = dat.dig_in_trials
        data_params = []
        fit_objs = []
        fit_params = []
        for i, X in enumerate(it.product(params,tastes)):
            p = X[0].copy()
            t = X[1]
            p['hmm_id'] = i
            p['taste'] = t
            p['channel'] = dim.loc[t, 'channel']
            unit_names = query_units(dat, p['unit_type'])
            p['n_cells'] = len(unit_names)
            p['n_trials'] = len(trials.query('name == @t'))

            data_params.append(p)
            # Make fit object for each repeat
            # During fitting compare HMM as ones with the same ID are returned
            for i in range(p['n_repeats']):
                hmmFit = HMMFit(dat.root_dir, p)
                fit_objs.append(hmmFit)
                fit_params.append(p)

        self._fit_objects = fit_objs
        self._data_params = data_params
        self._fit_params = fit_params
        self._fitted_models = dict.fromkeys([x['hmm_id'] for x in data_params])
        self.write_overview_to_hdf5()

    def load_data(self):
        h5_file = self.h5_file
        if not os.path.isfile(h5_file):
            raise ValueError('No params to load')

        rec_dir = self._dataset.root_dir
        params = []
        fit_objs = []
        fit_params = []
        fitted_models = {}
        with tables.open_file(h5_file, 'r') as hf5:
            table = hf5.root.data_overview
            col_names = table.colnames
            for row in table[:]:
                p = {}
                for k in col_names:
                    if table.coltypes[k] == 'string':
                        p[k] = row[k].decode('utf-8')
                    else:
                        p[k] = row[k]

                params.append(p)
                for i in range(p['n_repeats']):
                    hmmFit = HMMFit(rec_dir, p)
                    fit_objs.append(hmmFit)
                    fit_params.append(p)

        for p in params:
            hmm_id = p['hmm_id']
            fitted_models[hmm_id] = read_hmm_from_hdf5(h5_file, hmm_id, rec_dir)

        self._data_params = params
        self._fit_objects = fit_objs
        self._fitted_models = fitted_models
        self._fit_params = fit_params


    def write_overview_to_hdf5(self):
        params = self._data_params
        h5_file = self.h5_file
        if hasattr(self, '_fitted_models'):
            models = self._fitted_models
        else:
            models = dict.fromkeys([x['hmm_id']
                                    for x in data_params])
            self._fitted_models = models


        if not os.path.isfile(h5_file):
            self._setup_hdf5()

        print('Writing data overview table to hdf5...')
        with tables.open_file(h5_file, 'a') as hf5:
            table = hf5.root.data_overview
            # Clear old table
            table.remove_rows(start=0)

            # Add new rows
            for p in params:
                row = table.row
                for k, v in p.items():
                    row[k] = v

                if models[p['hmm_id']] is not None:
                   hmm = models[p['hmm_id']]
                   row['n_iterations'] =  hmm.iterations
                   row['BIC'] = hmm.BIC
                   row['cost'] = hmm.cost
                   row['converged'] = hmm.isConverged(p['threshold'])
                   row['fitted'] = hmm.fitted

                row.append()

            table.flush()
            hf5.flush()

        print('Done!')

    def _setup_hdf5(self):
        h5_file = self.h5_file

        with tables.open_file(h5_file, 'a') as hf5:
            # Taste -> PI, A, B, BIC, state_sequences, nStates, nCells, dt
            if not 'data_overview' in hf5.root:
                # Contains taste, channel, n_cells, n_trials, n_states, dt, BIC
                table = hf5.create_table('/', 'data_overview', HMMInfoParticle,
                                         'Basic info for each digital_input')
                table.flush()


            if hasattr(self, '_data_params') and self._data_params is not None:
                for p in self._data_params:
                    hmm_str = 'hmm_%i' % p['hmm_id']
                    if hmm_str not in hf5.root:
                        hf5.create_group('/', hmm_str, 'Data for HMM #%i' % p['hmm_id'])

            hf5.flush()

    def run(self, parallel=True):
        self.write_overview_to_hdf5()
        h5_file = self.h5_file
        rec_dir = self._dataset.root_dir
        fit_objs = self._fit_objects
        fit_params = self._fit_params
        self._fitted_models = dict.fromkeys([x['hmm_id'] for x in self._data_params])
        errors = []

        # def update(ans):
        #     hmm_id = ans[0]
        #     hmm = ans[1]
        #     if self._fitted_models[hmm_id] is not None:
        #         best_hmm = pick_best_hmm([HMMs[hmm_id], hmm])
        #         self._fitted_models[hmm_id] = best_hmm
        #         write_hmm_to_hdf5(h5_file, hmm_id, best_hmm)
        #         del hmm, best_hmm
        #     else:
        #         # Check history for lowest BIC
        #         self._fitted_models[hmm_id] = hmm.set_to_lowest_BIC()
        #         write_hmm_to_hdf5(h5_file, hmm_id, hmm)
        #         del hmm

        # def error_call(e):
        #     errors.append(e)

        # if parallel:
        #     n_cpu = np.min((mp.cpu_count()-1, len(fit_objs)))
        #     if n_cpu > 10:
        #         pool = mp.get_context('spawn').Pool(n_cpu)
        #     else:
        #         pool = mp.Pool(n_cpu)

        #     for f in fit_objs:
        #         pool.apply_async(f.run, callback=update, error_callback=error_call)

        #     pool.close()
        #     pool.join()
        # else:
        #     for f in fit_objs:
        #         try:
        #             ans = f.run()
        #             update(ans)
        #         except Exception as e:
        #             raise Exception(e)
        #             error_call(e)
        print('Running fittings')
        if parallel:
            n_cpu = np.min((mp.cpu_count()-1, len(fit_params)))
        else:
            n_cpu = 1

        results = Parallel(n_jobs=n_cpu, verbose=20)(delayed(hmm_fit_mp)(rec_dir, p) for p in fit_params)
        for hmm_id, hmm in zip(*results):
            if self._fitted_models[hmm_id] is None:
                self._fitted_models[hmm_id] = hmm
            else:
                new_hmm = pick_best_hmm([hmm, self._fitted_models[hmm_id]])
                self._fitted_models[hmm_id] = new_hmm

        self.write_overview_to_hdf5()
        self.save_fitted_models()
        # if len(errors) > 0:
        #     print('Encountered errors: ')
        #     for e in errors:
        #         print(e)

    def save_fitted_models(self):
        models = self._fitted_models
        for k, v in models.items():
            write_hmm_to_hdf5(self.h5_file, k, v)
            plot_dir = os.path.join(self.plot_dir, 'HMM_%i' % k)
            if not os.path.isdir(plot_dir):
                os.makedirs(plot_dir)

            ids = [x['hmm_id'] for x in self._data_params]
            idx = ids.index(k)
            params = self._data_params[idx]
            time_window = [params['time_start'], params['time_end']]
            hmmplt.plot_hmm_figures(v, time_window, save_dir=plot_dir)

Methods

def init_params(self, params)
Expand source code
def init_params(self, params):
    dat = self._dataset
    dim = dat.dig_in_mapping.query('exclude == False')
    tastes = dim['name'].tolist()
    dim = dim.set_index('name')
    if not hasattr(dat, 'dig_in_trials'):
        dat.create_trial_list()

    trials = dat.dig_in_trials
    data_params = []
    fit_objs = []
    fit_params = []
    for i, X in enumerate(it.product(params,tastes)):
        p = X[0].copy()
        t = X[1]
        p['hmm_id'] = i
        p['taste'] = t
        p['channel'] = dim.loc[t, 'channel']
        unit_names = query_units(dat, p['unit_type'])
        p['n_cells'] = len(unit_names)
        p['n_trials'] = len(trials.query('name == @t'))

        data_params.append(p)
        # Make fit object for each repeat
        # During fitting compare HMM as ones with the same ID are returned
        for i in range(p['n_repeats']):
            hmmFit = HMMFit(dat.root_dir, p)
            fit_objs.append(hmmFit)
            fit_params.append(p)

    self._fit_objects = fit_objs
    self._data_params = data_params
    self._fit_params = fit_params
    self._fitted_models = dict.fromkeys([x['hmm_id'] for x in data_params])
    self.write_overview_to_hdf5()
def load_data(self)
Expand source code
def load_data(self):
    h5_file = self.h5_file
    if not os.path.isfile(h5_file):
        raise ValueError('No params to load')

    rec_dir = self._dataset.root_dir
    params = []
    fit_objs = []
    fit_params = []
    fitted_models = {}
    with tables.open_file(h5_file, 'r') as hf5:
        table = hf5.root.data_overview
        col_names = table.colnames
        for row in table[:]:
            p = {}
            for k in col_names:
                if table.coltypes[k] == 'string':
                    p[k] = row[k].decode('utf-8')
                else:
                    p[k] = row[k]

            params.append(p)
            for i in range(p['n_repeats']):
                hmmFit = HMMFit(rec_dir, p)
                fit_objs.append(hmmFit)
                fit_params.append(p)

    for p in params:
        hmm_id = p['hmm_id']
        fitted_models[hmm_id] = read_hmm_from_hdf5(h5_file, hmm_id, rec_dir)

    self._data_params = params
    self._fit_objects = fit_objs
    self._fitted_models = fitted_models
    self._fit_params = fit_params
def run(self, parallel=True)
Expand source code
def run(self, parallel=True):
    self.write_overview_to_hdf5()
    h5_file = self.h5_file
    rec_dir = self._dataset.root_dir
    fit_objs = self._fit_objects
    fit_params = self._fit_params
    self._fitted_models = dict.fromkeys([x['hmm_id'] for x in self._data_params])
    errors = []

    # def update(ans):
    #     hmm_id = ans[0]
    #     hmm = ans[1]
    #     if self._fitted_models[hmm_id] is not None:
    #         best_hmm = pick_best_hmm([HMMs[hmm_id], hmm])
    #         self._fitted_models[hmm_id] = best_hmm
    #         write_hmm_to_hdf5(h5_file, hmm_id, best_hmm)
    #         del hmm, best_hmm
    #     else:
    #         # Check history for lowest BIC
    #         self._fitted_models[hmm_id] = hmm.set_to_lowest_BIC()
    #         write_hmm_to_hdf5(h5_file, hmm_id, hmm)
    #         del hmm

    # def error_call(e):
    #     errors.append(e)

    # if parallel:
    #     n_cpu = np.min((mp.cpu_count()-1, len(fit_objs)))
    #     if n_cpu > 10:
    #         pool = mp.get_context('spawn').Pool(n_cpu)
    #     else:
    #         pool = mp.Pool(n_cpu)

    #     for f in fit_objs:
    #         pool.apply_async(f.run, callback=update, error_callback=error_call)

    #     pool.close()
    #     pool.join()
    # else:
    #     for f in fit_objs:
    #         try:
    #             ans = f.run()
    #             update(ans)
    #         except Exception as e:
    #             raise Exception(e)
    #             error_call(e)
    print('Running fittings')
    if parallel:
        n_cpu = np.min((mp.cpu_count()-1, len(fit_params)))
    else:
        n_cpu = 1

    results = Parallel(n_jobs=n_cpu, verbose=20)(delayed(hmm_fit_mp)(rec_dir, p) for p in fit_params)
    for hmm_id, hmm in zip(*results):
        if self._fitted_models[hmm_id] is None:
            self._fitted_models[hmm_id] = hmm
        else:
            new_hmm = pick_best_hmm([hmm, self._fitted_models[hmm_id]])
            self._fitted_models[hmm_id] = new_hmm

    self.write_overview_to_hdf5()
    self.save_fitted_models()
def save_fitted_models(self)
Expand source code
def save_fitted_models(self):
    models = self._fitted_models
    for k, v in models.items():
        write_hmm_to_hdf5(self.h5_file, k, v)
        plot_dir = os.path.join(self.plot_dir, 'HMM_%i' % k)
        if not os.path.isdir(plot_dir):
            os.makedirs(plot_dir)

        ids = [x['hmm_id'] for x in self._data_params]
        idx = ids.index(k)
        params = self._data_params[idx]
        time_window = [params['time_start'], params['time_end']]
        hmmplt.plot_hmm_figures(v, time_window, save_dir=plot_dir)
def write_overview_to_hdf5(self)
Expand source code
def write_overview_to_hdf5(self):
    params = self._data_params
    h5_file = self.h5_file
    if hasattr(self, '_fitted_models'):
        models = self._fitted_models
    else:
        models = dict.fromkeys([x['hmm_id']
                                for x in data_params])
        self._fitted_models = models


    if not os.path.isfile(h5_file):
        self._setup_hdf5()

    print('Writing data overview table to hdf5...')
    with tables.open_file(h5_file, 'a') as hf5:
        table = hf5.root.data_overview
        # Clear old table
        table.remove_rows(start=0)

        # Add new rows
        for p in params:
            row = table.row
            for k, v in p.items():
                row[k] = v

            if models[p['hmm_id']] is not None:
               hmm = models[p['hmm_id']]
               row['n_iterations'] =  hmm.iterations
               row['BIC'] = hmm.BIC
               row['cost'] = hmm.cost
               row['converged'] = hmm.isConverged(p['threshold'])
               row['fitted'] = hmm.fitted

            row.append()

        table.flush()
        hf5.flush()

    print('Done!')
class PoissonHMM (n_predicted_states, spikes, dt, max_history=500, cost_window=0.25, set_data=None)

Poisson implementation of Hidden Markov Model for fitting spike data from a neuronal population Author: Roshan Nanu Adpated from code by Ben Ballintyn

Expand source code
class PoissonHMM(object):
    '''Poisson implementation of Hidden Markov Model for fitting spike data
    from a neuronal population
    Author: Roshan Nanu
    Adpated from code by Ben Ballintyn
    '''
    def __init__(self, n_predicted_states, spikes, dt,
                 max_history=500, cost_window=0.25, set_data=None):
        if len(spikes.shape) == 2:
            spikes = np.array([spikes])

        self.data = spikes.astype('int32')
        self.dt = dt
        self._rate_data = None
        self.n_states = n_predicted_states
        self._cost_window = cost_window
        self._max_history = max_history
        self.cost = None
        self.BIC = None
        self.best_sequences = None
        self.max_log_prob = None
        self._rate_data = None
        self.history = None
        self._compute_data_rate_array()
        if set_data is None:
            self.randomize()
        else:
            self.fitted = set_data['fitted']
            self.initial_distribution = set_data['initial_distribution']
            self.transition = set_data['transition']
            self.emission = set_data['emission']
            self.iteration = 0
            self._update_cost()


    def randomize(self):
        nStates = self.n_states
        spikes = self.data
        dt = self.dt
        n_trials, n_cells, n_steps = spikes.shape
        total_time = n_steps * dt

        # Initialize transition matrix with high stay probability
        print('Randomizing')
        diag = np.abs(np.random.normal(.99, .01, nStates))
        A = np.abs(np.random.normal(0.01/(nStates-1), 0.01, (nStates, nStates)))
        for i in range(nStates):
            A[i, i] = diag[i]
            A[i,:] = A[i,:] / np.sum(A[i,:])

        # Initialize rate matrix ("Emission" matrix)
        spike_counts = np.sum(spikes, axis=2) / total_time
        mean_rates = np.mean(spike_counts, axis=0)
        std_rates = np.std(spike_counts, axis=0)
        B = np.vstack([np.abs(np.random.normal(x, y, nStates))
                       for x,y in zip(mean_rates, std_rates)])
        # B = np.random.rand(nCells, nStates)

        self.transition = A
        self.emission = B
        self.initial_distribution = np.ones((nStates,)) / nStates
        self.iteration = 0
        self.fitted = False
        self.history = None
        self._update_cost()

    def fit(self, spikes=None, dt=None, max_iter = 1000, convergence_thresh = 1e-4,
            parallel=False):
        '''using parallels for processing trials actually seems to slow down
        processing (with 15 trials). Might still be useful if there is a very
        large nubmer of trials
        '''
        if self.fitted:
            return

        if spikes is not None:
            spikes = spikes.astype('int32')
            self.data = spikes
            self.dt = dt
        else:
            spikes = self.data
            dt = self.dt

        while (not self.isConverged(convergence_thresh) and
               (self.iteration < max_iter)):
            self._step(spikes, dt, parallel=parallel)
            print('Iter #%i complete.' % self.iteration)

        self.fitted = True

    def _step(self, spikes, dt, parallel=False):
        if len(spikes.shape) == 2:
            spikes = np.array([spikes])

        nTrials, nCells, nTimeSteps = spikes.shape

        A = self.transition
        B = self.emission
        PI = self.initial_distribution
        nStates = self.n_states

        # For multiple trials need to cmpute gamma and epsilon for every trial
        # and then update
        gammas = np.zeros((nTrials, nStates, nTimeSteps))
        epsilons = np.zeros((nTrials, nStates, nStates, nTimeSteps-1))
        if parallel:
            def update(ans):
                idx = ans[0]
                gammas[idx, :, :] = ans[1]
                epsilons[idx, :, :, :] = ans[2]

            def error(ans):
                raise RuntimeError(ans)

            n_cores = mp.cpu_count() - 1
            pool = mp.get_context('spawn').Pool(n_cores)
            for i, trial in enumerate(spikes):
                pool.apply_async(wrap_baum_welch,
                                 (i, trial, dt, PI, A, B),
                                 callback=update, error_callback=error)

            pool.close()
            pool.join()
        else:
            for i, trial in enumerate(spikes):
                _, tmp_gamma, tmp_epsilons = wrap_baum_welch(i, trial, dt, PI, A, B)
                gammas[i, :, :] = tmp_gamma
                epsilons[i, :, :, :] = tmp_epsilons

        # Store old parameters for convergence check
        self.update_history()

        PI, A, B = compute_new_matrices(spikes, dt, gammas, epsilons)
        self.transition = A
        self.emission = B
        self.initial_distribution = PI
        self.iteration += 1
        self._update_cost()

    def update_history(self):
        A = self.transition
        B = self.emission
        PI = self.initial_distribution
        BIC = self.BIC
        cost = self.cost
        iteration = self.iteration

        if self.history is None:
            self.history = {}
            self.history['A'] = [A]
            self.history['B'] = [B]
            self.history['PI'] = [PI]
            self.history['iterations'] = [iteration]
            self.history['cost'] = [cost]
            self.history['BIC'] = [BIC]
        else:
            if iteration in self.history['iterations']:
                return self.history

            self.history['A'].append(A)
            self.history['B'].append(B)
            self.history['PI'].append(PI)
            self.history['iterations'].append(iteration)
            self.history['cost'].append(cost)
            self.history['BIC'].append(BIC)

        if len(self.history['iterations']) > self._max_history:
            nmax = self._max_history
            for k, v in self.history.items():
                self.history[k] = v[-nmax:]

        return self.history

    def isConverged(self, thresh):
        if self.history is None:
            return False

        oldPI = self.history['PI'][-1]
        oldA = self.history['A'][-1]
        oldB = self.history['B'][-1]
        oldCost = self.history['cost'][-1]

        PI = self.initial_distribution
        A = self.transition
        B = self.emission
        cost = self.cost

        dPI = np.sqrt(np.sum(np.power(oldPI - PI, 2)))
        dA = np.sqrt(np.sum(np.power(oldA - A, 2)))
        dB = np.sqrt(np.sum(np.power(oldB - B, 2)))
        dCost = cost-oldCost
        print('dPI = %f,  dA = %f,  dB = %f, dCost = %f, cost = %f'
              % (dPI, dA, dB, dCost, cost))

        # TODO: determine if this is reasonable
        # dB takes waaaaay longer to converge than the rest, i'm going to
        # double the thresh just for that
        dB = dB/2

        if not all([x < thresh for x in [dPI, dA, dB]]):
            return False
        else:
            return True

    def get_best_paths(self):
        if self.best_sequences is not None:
            return self.best_sequences, self.max_log_prob

        spikes = self.data
        dt = self.dt
        PI = self.initial_distribution
        A = self.transition
        B = self.emission

        bestPaths, pathProbs = compute_best_paths(spikes, dt, PI, A, B)
        self.best_sequences = bestPaths
        self.max_log_prob = np.max(pathProbs)
        return bestPaths, self.max_log_prob

    def get_forward_probabilities(self):
        alphas = []
        for trial in self.data:
            tmp, _ = forward(trial, self.dt, self.initial_distribution,
                             self.transition, self.emission)
            alphas.append(tmp)

        return np.array(alphas)

    def get_backward_probabilities(self):
        PI = self.initial_distribution
        A = self.transition
        B = self.emission
        betas = []
        for trial in self.data:
            alpha, norms = forward(trial, self.dt, PI, A, B)
            tmp = backward(trial, self.dt, A, B, norms)
            betas.append(tmp)

        return np.array(betas)

    def get_gamma_probabilities(self):
        PI = self.initial_distribution
        A = self.transition
        B = self.emission
        gammas = []
        for i, trial in enumerate(self.data):
            _, tmp, _ = wrap_baum_welch(i, trial, self.dt, PI, A, B)
            gammas.append(tmp)

        return np.array(gammas)

    def get_BIC(self):
        if self.BIC is not None:
            return self.BIC

        PI = self.initial_distribution
        A = self.transition
        B = self.emission
        BIC, bestPaths, max_log_prob = compute_BIC(self.data, self.dt, PI, A, B)
        self.BIC = BIC
        self.best_sequences = bestPaths
        self.max_log_prob = max_log_prob
        return BIC, bestPaths, max_log_prob

    def _compute_data_rate_array(self):
        if self._rate_data is not None:
            return self._rate_data

        win_size = self._cost_window
        rate_array = convert_spikes_to_rates(self.data, self.dt,
                                             win_size, step_size=win_size)
        self._rate_data = rate_array

    def _compute_predicted_rate_array(self):
        B = self.emission
        bestPaths, _ = self.get_best_paths()
        bestPaths = bestPaths.astype('int32')
        win_size = self._cost_window
        dt = self.dt
        mean_rates = generate_rate_array_from_state_seq(bestPaths, B,
                                                        dt, win_size,
                                                        step_size=win_size)
        return mean_rates

    def set_to_lowest_cost(self):
        hist = self.update_history()
        idx = np.argmin(hist['cost'])
        iteration = hist['iterations'][idx]
        self.roll_back(iteration)

    def set_to_lowest_BIC(self):
        hist = self.update_history()
        idx = np.argmin(hist['BIC'])
        iteration = hist['iterations'][idx]
        self.roll_back(iteration)

    def find_best_in_history(self):
        hist = self.update_history()
        PIs = hist['PI']
        As = hist['A']
        Bs = hist['B']
        iters = hist['iterations']
        BICs = hist['BIC']
        idx = np.argmin(BICs)
        out = {'PI': PIs[idx], 'A': As[idx], 'B': Bs[idx]}
        return out, iters[idx], BICs

    def roll_back(self, iteration):
        hist = self.history
        try:
            idx = hist['iterations'].index(iteration)
        except ValueError:
            raise ValueError('Iteration %i not found in history' % iteration)

        self.initial_distribution = hist['PI'][idx]
        self.transition = hist['A'][idx]
        self.emission = hist['B'][idx]
        self.iteration = iteration
        self._update_cost()

    def set_matrices(self, new_mats):
        self.initial_distribution = new_mats['PI']
        self.transition = new_mats['A']
        self.emission = new_mats['B']
        if 'iteration' in new_mats:
            self.iteration = new_mats['iteration']
        self._update_cost()

    def set_data(self, new_data, dt):
        self.data = new_data
        self.dt = dt
        self._compute_data_rate_array()
        self._update_cost()

    def plot_state_raster(self, ax=None, state_map=None):
        bestPaths, _ = self.get_best_paths()
        if state_map is not None:
            bestPaths = convert_path_state_numbers(bestPaths, state_map)

        data = self.data
        fig, ax = plot_state_raster(data, bestPaths, self.dt, ax=ax)
        return fig, ax

    def plot_state_rates(self, ax=None, state_map=None):
        rates = self.emission
        if state_map:
            idx = [state_map[k] for k in sorted(state_map.keys())]
            maxState = np.max(list(state_map.values()))
            newRates = np.zeros((rates.shape[0], maxState+1))
            for k, v in state_map.items():
                newRates[:, v] = rates[:, k]

            rates = newRates

        fig, ax = plot_state_rates(rates, ax=ax)
        return fig, ax

    def reorder_states(self, state_map):
        idx = [state_map[k] for k in sorted(state_map.keys())]
        PI = self.initial_distribution
        A = self.transition
        B = self.emission

        newPI = PI[idx]
        newB = B[:, idx]
        newA = np.zeros(A.shape)
        for x in range(A.shape[0]):
            for y in range(A.shape[1]):
                i = state_map[x]
                j = state_map[y]
                newA[i,j] = A[x,y]

        self.initial_distribution = newPI
        self.transition = newA
        self.emission = newB
        self._update_cost()

    def _update_cost(self):
        spikes = self.data
        win_size = self._cost_window
        dt = self.dt
        PI = self.initial_distribution
        A  = self.transition
        B  = self.emission
        true_rates = self._rate_data
        cost, BIC, bestPaths, maxLogProb = compute_hmm_cost(spikes, dt, PI, A, B,
                                                            win_size=win_size,
                                                            true_rates=true_rates)
        self.cost = cost
        self.BIC = BIC
        self.best_sequences = bestPaths
        self.max_log_prob = maxLogProb

    def get_cost(self):
        if self.cost is None:
            self._update_cost()

        return self.cost

Methods

def find_best_in_history(self)
Expand source code
def find_best_in_history(self):
    hist = self.update_history()
    PIs = hist['PI']
    As = hist['A']
    Bs = hist['B']
    iters = hist['iterations']
    BICs = hist['BIC']
    idx = np.argmin(BICs)
    out = {'PI': PIs[idx], 'A': As[idx], 'B': Bs[idx]}
    return out, iters[idx], BICs
def fit(self, spikes=None, dt=None, max_iter=1000, convergence_thresh=0.0001, parallel=False)

using parallels for processing trials actually seems to slow down processing (with 15 trials). Might still be useful if there is a very large nubmer of trials

Expand source code
def fit(self, spikes=None, dt=None, max_iter = 1000, convergence_thresh = 1e-4,
        parallel=False):
    '''using parallels for processing trials actually seems to slow down
    processing (with 15 trials). Might still be useful if there is a very
    large nubmer of trials
    '''
    if self.fitted:
        return

    if spikes is not None:
        spikes = spikes.astype('int32')
        self.data = spikes
        self.dt = dt
    else:
        spikes = self.data
        dt = self.dt

    while (not self.isConverged(convergence_thresh) and
           (self.iteration < max_iter)):
        self._step(spikes, dt, parallel=parallel)
        print('Iter #%i complete.' % self.iteration)

    self.fitted = True
def get_BIC(self)
Expand source code
def get_BIC(self):
    if self.BIC is not None:
        return self.BIC

    PI = self.initial_distribution
    A = self.transition
    B = self.emission
    BIC, bestPaths, max_log_prob = compute_BIC(self.data, self.dt, PI, A, B)
    self.BIC = BIC
    self.best_sequences = bestPaths
    self.max_log_prob = max_log_prob
    return BIC, bestPaths, max_log_prob
def get_backward_probabilities(self)
Expand source code
def get_backward_probabilities(self):
    PI = self.initial_distribution
    A = self.transition
    B = self.emission
    betas = []
    for trial in self.data:
        alpha, norms = forward(trial, self.dt, PI, A, B)
        tmp = backward(trial, self.dt, A, B, norms)
        betas.append(tmp)

    return np.array(betas)
def get_best_paths(self)
Expand source code
def get_best_paths(self):
    if self.best_sequences is not None:
        return self.best_sequences, self.max_log_prob

    spikes = self.data
    dt = self.dt
    PI = self.initial_distribution
    A = self.transition
    B = self.emission

    bestPaths, pathProbs = compute_best_paths(spikes, dt, PI, A, B)
    self.best_sequences = bestPaths
    self.max_log_prob = np.max(pathProbs)
    return bestPaths, self.max_log_prob
def get_cost(self)
Expand source code
def get_cost(self):
    if self.cost is None:
        self._update_cost()

    return self.cost
def get_forward_probabilities(self)
Expand source code
def get_forward_probabilities(self):
    alphas = []
    for trial in self.data:
        tmp, _ = forward(trial, self.dt, self.initial_distribution,
                         self.transition, self.emission)
        alphas.append(tmp)

    return np.array(alphas)
def get_gamma_probabilities(self)
Expand source code
def get_gamma_probabilities(self):
    PI = self.initial_distribution
    A = self.transition
    B = self.emission
    gammas = []
    for i, trial in enumerate(self.data):
        _, tmp, _ = wrap_baum_welch(i, trial, self.dt, PI, A, B)
        gammas.append(tmp)

    return np.array(gammas)
def isConverged(self, thresh)
Expand source code
def isConverged(self, thresh):
    if self.history is None:
        return False

    oldPI = self.history['PI'][-1]
    oldA = self.history['A'][-1]
    oldB = self.history['B'][-1]
    oldCost = self.history['cost'][-1]

    PI = self.initial_distribution
    A = self.transition
    B = self.emission
    cost = self.cost

    dPI = np.sqrt(np.sum(np.power(oldPI - PI, 2)))
    dA = np.sqrt(np.sum(np.power(oldA - A, 2)))
    dB = np.sqrt(np.sum(np.power(oldB - B, 2)))
    dCost = cost-oldCost
    print('dPI = %f,  dA = %f,  dB = %f, dCost = %f, cost = %f'
          % (dPI, dA, dB, dCost, cost))

    # TODO: determine if this is reasonable
    # dB takes waaaaay longer to converge than the rest, i'm going to
    # double the thresh just for that
    dB = dB/2

    if not all([x < thresh for x in [dPI, dA, dB]]):
        return False
    else:
        return True
def plot_state_raster(self, ax=None, state_map=None)
Expand source code
def plot_state_raster(self, ax=None, state_map=None):
    bestPaths, _ = self.get_best_paths()
    if state_map is not None:
        bestPaths = convert_path_state_numbers(bestPaths, state_map)

    data = self.data
    fig, ax = plot_state_raster(data, bestPaths, self.dt, ax=ax)
    return fig, ax
def plot_state_rates(self, ax=None, state_map=None)
Expand source code
def plot_state_rates(self, ax=None, state_map=None):
    rates = self.emission
    if state_map:
        idx = [state_map[k] for k in sorted(state_map.keys())]
        maxState = np.max(list(state_map.values()))
        newRates = np.zeros((rates.shape[0], maxState+1))
        for k, v in state_map.items():
            newRates[:, v] = rates[:, k]

        rates = newRates

    fig, ax = plot_state_rates(rates, ax=ax)
    return fig, ax
def randomize(self)
Expand source code
def randomize(self):
    nStates = self.n_states
    spikes = self.data
    dt = self.dt
    n_trials, n_cells, n_steps = spikes.shape
    total_time = n_steps * dt

    # Initialize transition matrix with high stay probability
    print('Randomizing')
    diag = np.abs(np.random.normal(.99, .01, nStates))
    A = np.abs(np.random.normal(0.01/(nStates-1), 0.01, (nStates, nStates)))
    for i in range(nStates):
        A[i, i] = diag[i]
        A[i,:] = A[i,:] / np.sum(A[i,:])

    # Initialize rate matrix ("Emission" matrix)
    spike_counts = np.sum(spikes, axis=2) / total_time
    mean_rates = np.mean(spike_counts, axis=0)
    std_rates = np.std(spike_counts, axis=0)
    B = np.vstack([np.abs(np.random.normal(x, y, nStates))
                   for x,y in zip(mean_rates, std_rates)])
    # B = np.random.rand(nCells, nStates)

    self.transition = A
    self.emission = B
    self.initial_distribution = np.ones((nStates,)) / nStates
    self.iteration = 0
    self.fitted = False
    self.history = None
    self._update_cost()
def reorder_states(self, state_map)
Expand source code
def reorder_states(self, state_map):
    idx = [state_map[k] for k in sorted(state_map.keys())]
    PI = self.initial_distribution
    A = self.transition
    B = self.emission

    newPI = PI[idx]
    newB = B[:, idx]
    newA = np.zeros(A.shape)
    for x in range(A.shape[0]):
        for y in range(A.shape[1]):
            i = state_map[x]
            j = state_map[y]
            newA[i,j] = A[x,y]

    self.initial_distribution = newPI
    self.transition = newA
    self.emission = newB
    self._update_cost()
def roll_back(self, iteration)
Expand source code
def roll_back(self, iteration):
    hist = self.history
    try:
        idx = hist['iterations'].index(iteration)
    except ValueError:
        raise ValueError('Iteration %i not found in history' % iteration)

    self.initial_distribution = hist['PI'][idx]
    self.transition = hist['A'][idx]
    self.emission = hist['B'][idx]
    self.iteration = iteration
    self._update_cost()
def set_data(self, new_data, dt)
Expand source code
def set_data(self, new_data, dt):
    self.data = new_data
    self.dt = dt
    self._compute_data_rate_array()
    self._update_cost()
def set_matrices(self, new_mats)
Expand source code
def set_matrices(self, new_mats):
    self.initial_distribution = new_mats['PI']
    self.transition = new_mats['A']
    self.emission = new_mats['B']
    if 'iteration' in new_mats:
        self.iteration = new_mats['iteration']
    self._update_cost()
def set_to_lowest_BIC(self)
Expand source code
def set_to_lowest_BIC(self):
    hist = self.update_history()
    idx = np.argmin(hist['BIC'])
    iteration = hist['iterations'][idx]
    self.roll_back(iteration)
def set_to_lowest_cost(self)
Expand source code
def set_to_lowest_cost(self):
    hist = self.update_history()
    idx = np.argmin(hist['cost'])
    iteration = hist['iterations'][idx]
    self.roll_back(iteration)
def update_history(self)
Expand source code
def update_history(self):
    A = self.transition
    B = self.emission
    PI = self.initial_distribution
    BIC = self.BIC
    cost = self.cost
    iteration = self.iteration

    if self.history is None:
        self.history = {}
        self.history['A'] = [A]
        self.history['B'] = [B]
        self.history['PI'] = [PI]
        self.history['iterations'] = [iteration]
        self.history['cost'] = [cost]
        self.history['BIC'] = [BIC]
    else:
        if iteration in self.history['iterations']:
            return self.history

        self.history['A'].append(A)
        self.history['B'].append(B)
        self.history['PI'].append(PI)
        self.history['iterations'].append(iteration)
        self.history['cost'].append(cost)
        self.history['BIC'].append(BIC)

    if len(self.history['iterations']) > self._max_history:
        nmax = self._max_history
        for k, v in self.history.items():
            self.history[k] = v[-nmax:]

    return self.history
class TestData (params=None)
Expand source code
class TestData(object):
    def __init__(self, params=None):
        if params is None:
            params = TEST_PARAMS.copy()
            param_str = '\t'+'\n\t'.join(repr(params)[1:-1].split(', '))
            print('Using default parameters:\n%s' % param_str)

        self.params = params.copy()
        self.generate()

    def generate(self, params=None):
        print('-'*80)
        print('Simulating Data')
        print('-'*80)
        if params is not None:
            self.params.update(params)

        params = self.params
        param_str = '\t'+'\n\t'.join(repr(params)[1:-1].split(', '))
        print('Parameters:\n%s' % param_str)

        self._generate_ground_truth()
        self._generate_spike_trains()

    def _generate_ground_truth(self):
        print('Generating ground truth state sequence...')
        params = self.params
        nStates = params['n_states']
        seqLen = params['state_seq_length']
        minSeqDur = params['min_state_dur']
        baseline_dur = params['baseline_dur']
        maxFR = params['max_rate']
        nCells = params['n_cells']
        trialTime = params['trial_time']
        nTrials = params['n_trials']
        dt = params['dt']
        nTimeSteps = int(trialTime/dt)

        T = trialTime
        # Figure out a random state sequence and state durations
        stateSeq = np.random.randint(0, nStates, seqLen)
        stateSeq = np.array([0, *np.random.randint(0,nStates, seqLen-1)])
        stateDurs = np.zeros((nTrials, seqLen))
        for i in range(nTrials):
            tmp = np.abs(np.random.rand(seqLen-1))
            tmp = tmp * ((trialTime - baseline_dur) / np.sum(tmp))
            stateDurs[i, :] = np.array([baseline_dur, *tmp])

        # Make vector of state at each time point
        stateVec = np.zeros((nTrials, nTimeSteps))
        for trial in range(nTrials):
            t0 = 0
            for state, dur in zip(stateSeq, stateDurs[trial]):
                tn = int(dur/dt)
                stateVec[trial, t0:t0+tn] = state
                t0 += tn

        # Determine firing rates per neuron per state
        # For each neuron generate a mean firing rate and then draw state
        # firing rates from a normal distribution around that with 10Hz
        # variance
        mean_rates = np.random.rand(nCells, 1) * maxFR
        stateRates = np.zeros((nCells, nStates))
        for i, r in enumerate(mean_rates):
            stateRates[i, :] = np.array([r, *np.abs(np.random.normal(r, .5*r, nStates-1))])

        self.ground_truth = {'state_sequence': stateSeq,
                             'state_durations': stateDurs,
                             'firing_rates': stateRates,
                             'state_vectors': stateVec}

    def _generate_spike_trains(self):
        print('Generating new spike trains...')
        params = self.params
        nCells = params['n_cells']
        trialTime = params['trial_time']
        dt = params['dt']
        nTrials = params['n_trials']
        noise = params['noise']
        nTimeSteps = int(trialTime/dt)

        stateRates = self.ground_truth['firing_rates']
        stateVec = self.ground_truth['state_vectors']


        # Make spike arrays
        # Trial x Neuron x Time
        random_nums = np.abs(np.random.rand(nTrials, nCells, nTimeSteps))
        rate_arr = np.zeros((nTrials, nCells, nTimeSteps))
        for trial, cell, t in  it.product(range(nTrials), range(nCells), range(nTimeSteps)):
            state = int(stateVec[trial, t])
            mean_rate = stateRates[cell, state]
            # draw noisy rates from normal distrib with mean rate from ground
            # truth and width as noise*mean_rate
            r = np.random.normal(mean_rate, mean_rate*noise)
            rate_arr[trial, cell, t] = r

        spikes = (random_nums <= rate_arr *dt).astype('int')

        self.spike_trains = spikes

    def get_spike_trains(self):
        if not hasattr(self, 'spike_trains'):
            self._generate_spike_trains()

        return self.spike_trains

    def get_ground_truth(self):
        if not hasattr(self, 'ground_truth'):
            self._generate_ground_truth()

        return self.ground_truth

    def plot_state_rates(self, ax=None):
        fig, ax = plot_state_rates(self.ground_truth['firing_rates'], ax=ax)
        return fig, ax

    def plot_state_raster(self, ax=None):
        fig, ax = plot_state_raster(self.spike_trains,
                                    self.ground_truth['state_vectors'],
                                    self.params['dt'], ax=ax)
        return fig, ax

Methods

def generate(self, params=None)
Expand source code
def generate(self, params=None):
    print('-'*80)
    print('Simulating Data')
    print('-'*80)
    if params is not None:
        self.params.update(params)

    params = self.params
    param_str = '\t'+'\n\t'.join(repr(params)[1:-1].split(', '))
    print('Parameters:\n%s' % param_str)

    self._generate_ground_truth()
    self._generate_spike_trains()
def get_ground_truth(self)
Expand source code
def get_ground_truth(self):
    if not hasattr(self, 'ground_truth'):
        self._generate_ground_truth()

    return self.ground_truth
def get_spike_trains(self)
Expand source code
def get_spike_trains(self):
    if not hasattr(self, 'spike_trains'):
        self._generate_spike_trains()

    return self.spike_trains
def plot_state_raster(self, ax=None)
Expand source code
def plot_state_raster(self, ax=None):
    fig, ax = plot_state_raster(self.spike_trains,
                                self.ground_truth['state_vectors'],
                                self.params['dt'], ax=ax)
    return fig, ax
def plot_state_rates(self, ax=None)
Expand source code
def plot_state_rates(self, ax=None):
    fig, ax = plot_state_rates(self.ground_truth['firing_rates'], ax=ax)
    return fig, ax