Module blechpy.plotting.hmm_plot

Expand source code
import os
import numpy as np
import pandas as pd
from scipy.ndimage import gaussian_filter1d
import matplotlib
matplotlib.use('TkAgg')
import pylab as plt
import seaborn as sns



def get_sequence_windows(seq):
    t = 0
    out = []
    while t < len(seq):
        s = seq[t]
        tmp = np.where(seq[t:] != s)[0]
        if len(tmp) == 0:
            tmp = len(seq) - t
        else:
            tmp = tmp[0]

        out.append((t, tmp+t-1, s))
        t += tmp

    return out


def get_threshold_windows(trace, thresh=0.75):
    '''Returns list of tuples with start and stop time for windows where the
    given trace is above threshold. trace can be multiple rows. returns tuples
    in fashion (start_idx, stop_idx, row)
    '''
    out = []
    if len(trace.shape) == 1:
        trace = np.array([trace])

    n_rows, n_steps = trace.shape
    for i, row in enumerate(trace):
        t = 0
        while t < n_steps:
            if row[t] >= thresh:
                tmp = np.where(row[t:] < thresh)[0]
            else:
                tmp = np.where(row[t:] >= thresh)[0]

            if len(tmp) == 0:
                tmp = len(row) - t
            else:
                tmp = tmp[0]

            if row[t] >= thresh:
                out.append((t, tmp+t-1, i))

            t += tmp

    return out


def get_hmm_plot_colors(n_states):
    colors = [plt.cm.tab10(x) for x in np.linspace(0, 1, n_states)]
    return colors


def plot_raster(spikes, time=None, ax=None, y_min=0.05, y_max=0.95):
    '''Plot 2D spike raster

    Parameters
    ----------
    spikes : np.array
        2D matrix M x N where N is the number of time steps and in each bin is
        a 0 or 1, with 1 signifying the presence of a spike
    '''
    if not ax:
        _, ax = plt.gca()

    n_rows, n_steps = spikes.shape
    if time is None:
        time = np.arange(0, n_steps)

    y_steps = np.linspace(y_min, y_max, n_rows)
    for i, row in enumerate(spikes):
        idx = np.where(row == 1)[0]
        if len(idx) == 0:
            continue

        ax.scatter(time[idx], row[idx]*y_steps[i], color='black', marker='|')

    return ax


def make_hmm_raster(spikes, time=None, save_file=None):
    '''Create figure of spikes rasters with each trial on a seperate axis

    Parameters
    ----------
    spikes: np.array, Trials X Cells X Time array with 1s where spikes occur
    time: np.array, 1D time vector
    save_file: str, if provided figure is saved and not returned

    Returns
    -------
    plt.Figure, list of plt.Axes
    '''
    if len(spikes) == 2:
        spikes = np.array([spikes])

    n_trials, n_cells, n_steps = spikes.shape
    if time is None:
        time = np.arange(0, n_steps)

    fig, axes = plt.subplots(nrows=n_trials, figsize=(15, n_trials))
    y_step = np.linspace(0.05, 0.95, n_cells)
    for ax, trial in zip(axes, spikes):
        tmp = plot_raster(trial, time=time, ax=ax)

        for spine in ax.spines.values():
            spine.set_visible(False)

        ax.get_yaxis().set_visible(False)
        ax.get_xaxis().set_visible(False)
        if time[0] < 0:
            ax.axvline(0, color='red', linestyle='--', linewidth=3, alpha=0.8)

    axes[-1].get_xaxis().set_visible(True)
    tmp_ax = fig.add_subplot('111', frameon=False)
    tmp_ax.tick_params(labelcolor='none', top=False, bottom=False,
                       left=False, right=False)
    tmp_ax.set_ylabel('Trials')
    axes[-1].set_xlabel('Time')
    axes[-1].set_ylabel('Cells', fontsize=11)
    if save_file:
        fig.savefig(save_file)
        plt.close(fig)
        return
    else:
        return fig, axes


def plot_sequence(seq, time=None, ax=None, y_min=0, y_max=1, colors=None):
    if ax is None:
        _, ax = plt.gca()

    if time is None:
        time = np.arange(0, len(seq))

    nStates = np.max(seq)+1
    if colors is None:
        colors = [plt.cm.Set2(x) for x in np.linspace(0, 1, nStates)]

    seq_windows = get_sequence_windows(seq)
    handles = {}
    for win in seq_windows:
        t_vec = [time[win[0]], time[win[1]]]
        h = ax.fill_between(t_vec, [y_min, y_min], [y_max, y_max],
                            color=colors[int(win[2])], alpha=0.4)
        if  win[2] not in handles:
            handles[win[2]] = h

    leg_handles = [handles[k] for k in sorted(handles.keys())]
    leg_labels = ['State %i' % k for k in sorted(handles.keys())]
    return ax, leg_handles, leg_labels


def plot_viterbi_paths(hmm, spikes, time=None, colors=None, axes=None, legend=True,
                       hmm_id=None, save_file=None):
    if not axes:
        fig, axes = make_hmm_raster(spikes, time=time)
    else:
        fig = axes[0].figure

    if legend:
        fig.subplots_adjust(right=0.9)  # To  make room for legend


    BIC = hmm.BIC
    paths = hmm.stat_arrays['best_sequences']
    n_trials, n_steps = paths.shape
    n_states = hmm.n_states
    if time is None:
        time = np.arange(0, n_steps)

    if not colors:
        colors = [plt.cm.Set2(x) for x in np.linspace(0,1, n_states)]

    handles = []
    labels = []
    for trial, ax in zip(paths, axes):
        _, tmp_handles, tmp_labels = plot_sequence(trial, time=time, ax=ax, colors=colors)
        for l, h in zip(tmp_labels, tmp_handles):
            if l not in labels:
                handles.append(h)
                labels.append(l)

        if time[0] != 0:
            ax.axvline(0, color='red', linestyle='--', linewidth=3, alpha=0.8)

    if legend:
        mid = int(n_trials/2)
        axes[mid].legend(handles, labels, loc='upper center',
                         bbox_to_anchor=(0.8, .5, .5, .5), shadow=True,
                         fontsize=14)

    axes[-1].set_xlabel('Time (ms)')
    title_str = 'Decoded HMM Sequences'
    if hmm_id:
        title_str += '\n%s' % hmm_id

    axes[0].set_title(title_str)
    if save_file:
        fig.savefig(save_file)
        plt.close(fig)
        return
    else:
        return fig, axes


def plot_probability_traces(traces, time=None, ax=None, colors=None, thresh=0.75,
                           smoothing=3):
    y_min=0
    y_max=1
    if ax is None:
        _, ax = plt.gca()

    n_states, n_steps = traces.shape
    if time is None:
        time = np.arange(0, n_steps)

    if not colors:
        colors = [plt.cm.Set2(x) for x in np.linspace(0, 1, n_states)]

    windows = get_threshold_windows(traces, thresh=thresh)
    handles = {}
    for win in windows:
        t_vec = [time[win[0]], time[win[1]]]
        h = ax.fill_between(t_vec, [y_min, y_min], [y_max, y_max],
                            color=colors[int(win[2])], alpha=0.4)
        if  win[2] not in handles:
            handles[win[2]] = h

    leg_handles = [handles[k] for k in sorted(handles.keys())]
    leg_labels = ['State %i' % k for k in sorted(handles.keys())]

    for line, col in zip(traces, colors):
        tmp = line
        if smoothing:
            tmp = gaussian_filter1d(tmp, smoothing)

        ax.plot(time, tmp, color=col, linewidth=2)

    return ax, leg_handles, leg_labels


def plot_forward_probs(hmm, spikes, dt, time=None, colors=None, axes=None, legend=True,
                       hmm_id=None, thresh=0.75, save_file=None):
    if not axes:
        fig, axes = make_hmm_raster(spikes, time=time)
    else:
        fig = axes[0].figure

    if legend:
        fig.subplots_adjust(right=0.9)  # To  make room for legend

    alphas, norms = hmm.get_forward_probabilities(spikes, dt)
    n_trials, n_states, n_steps = alphas.shape
    if time is None:
        time = np.arange(0, n_steps)

    if not colors:
        colors = [plt.cm.Set2(x) for x in np.linspace(0,1, n_states)]

    handles = []
    labels = []
    for trial, ax in zip(alphas, axes):
        _, tmp_handles, tmp_labels = plot_probability_traces(trial,time=time, ax=ax,
                                                             colors=colors, thresh=thresh)
        for l, h in zip(tmp_labels, tmp_handles):
            if l not in labels:
                handles.append(h)
                labels.append(l)

    if time[0] != 0:
        ax.axvline(0, color='red', linestyle='--', linewidth=3, alpha=0.8)

    if legend:
        mid = int(n_trials/2)
        axes[mid].legend(handles, labels, loc='upper center',
                         bbox_to_anchor=(0.8, .5, .5, .5), shadow=True,
                         fontsize=14)

    axes[-1].set_xlabel('Time (ms)')
    title_str = 'HMM Forward Probabilities'
    if hmm_id:
        title_str += '\n%s' % hmm_id

    axes[0].set_title(title_str)
    if save_file:
        fig.savefig(save_file)
        plt.close(fig)
        return
    else:
        return fig, axes


def plot_backward_probs(hmm, spikes, dt, time=None, colors=None, axes=None, legend=True,
                        hmm_id=None, thresh=0.75, save_file=None):
    if not axes:
        fig, axes = make_hmm_raster(spikes, time=time)
    else:
        fig = axes[0].figure

    if legend:
        fig.subplots_adjust(right=0.9)  # To  make room for legend

    betas = hmm.get_backward_probabilities(spikes, dt)
    n_trials, n_states, n_steps = betas.shape
    if time is None:
        time = np.arange(0, n_steps)

    if not colors:
        colors = [plt.cm.Set2(x) for x in np.linspace(0,1, n_states)]

    handles = []
    labels = []
    for trial, ax in zip(betas, axes):
        _, tmp_handles, tmp_labels = plot_probability_traces(trial,time=time, ax=ax,
                                                             colors=colors, thresh=thresh)
        for l, h in zip(tmp_labels, tmp_handles):
            if l not in labels:
                handles.append(h)
                labels.append(l)

    if time[0] != 0:
        ax.axvline(0, color='red', linestyle='--', linewidth=3, alpha=0.8)

    if legend:
        mid = int(n_trials/2)
        axes[mid].legend(handles, labels, loc='upper center',
                         bbox_to_anchor=(0.8, .5, .5, .5), shadow=True,
                         fontsize=14)

    axes[-1].set_xlabel('Time (ms)')
    title_str = 'HMM Backward Probabilities'
    if hmm_id:
        title_str += '\n%s' % hmm_id

    axes[0].set_title(title_str)
    if save_file:
        fig.savefig(save_file)
        plt.close(fig)
        return
    else:
        return fig, axes


def plot_gamma_probs(hmm, spikes=None, dt=None, time=None, colors=None, axes=None, legend=True,
                     hmm_id=None, thresh=0.75, save_file=None):
    if not axes:
        fig, axes = make_hmm_raster(spikes, time=time)
    else:
        fig = axes[0].figure

    if legend:
        fig.subplots_adjust(right=0.9)  # To  make room for legend

    gammas = hmm.stat_arrays['gamma_probabilities']
    if gammas == []:
        if spikes is None and dt is None:
            raise ValueError('Not enough info to compute gamma probabilities')

        gammas = hmm.get_gamma_probabilities(spikes, dt)

    n_trials, n_states, n_steps = gammas.shape
    if time is None:
        time = np.arange(0, n_steps)

    if not colors:
        colors = [plt.cm.Set2(x) for x in np.linspace(0,1, n_states)]

    handles = []
    labels = []
    for trial, ax in zip(gammas, axes):
        _, tmp_handles, tmp_labels = plot_probability_traces(trial,time=time, ax=ax,
                                                             colors=colors, thresh=thresh)
        for l, h in zip(tmp_labels, tmp_handles):
            if l not in labels:
                handles.append(h)
                labels.append(l)

    if time[0] != 0:
        ax.axvline(0, color='red', linestyle='--', linewidth=3, alpha=0.8)

    if legend:
        mid = int(n_trials/2)
        axes[mid].legend(handles, labels, loc='upper center',
                         bbox_to_anchor=(0.8, .5, .5, .5), shadow=True,
                         fontsize=14)

    axes[-1].set_xlabel('Time (ms)')
    title_str = 'HMM Gamma Probabilities'
    if hmm_id:
        title_str += '\n%s' % hmm_id

    axes[0].set_title(title_str)
    if save_file:
        fig.savefig(save_file)
        return
    else:
        return fig, axes


def plot_hmm_rates(rates, axes=None, colors=None):
    '''Make bar plot of spike rates for each cell and state in an HMM emission
    matrix

    Parameters
    ----------
    rates: np.array
        Cell X State matrix of firing rates
    '''
    n_cells, n_states = rates.shape
    if axes is None:
        _, axes = plt.subplot(ncols=n_states)

    if len(axes) < n_states:
        raise ValueError('Must provided enough axes to plot each state')

    if not colors:
        colors = [plt.cm.Set2(x) for x in np.linspace(0,1, n_states)]

    df = pd.DataFrame(rates, columns=['state %i' % i for i in range(n_states)])
    df['cell'] = ['cell %i' % i for i in df.index]
    df = pd.melt(df, 'cell', ['state %i' % i for i in range(n_states)], 'state', 'rate')
    x_max = 0
    for g, ax, col in zip(df.groupby('state'), axes, colors):
        sns.barplot(data=g[1], x='rate', y='cell',
                    color='black', ax=ax)
        ax.set_title(g[0])
        ax.set_ylabel('')
        ax.set_xlabel('')
        ax.set_facecolor(col)
        ax.patch.set_alpha(0.5)
        ax.set_yticklabels([])
        ax.tick_params(left=False)
        xl = ax.get_xlim()
        if xl[1] > x_max:
            x_max = xl[1]

        for spine in ax.spines.values():
            spine.set_visible(False)


    for ax in axes:
        ax.set_xlim([0, x_max])

    axes[0].set_yticklabels(['Cell %i' % i for i in range(n_cells)])
    mid = int(n_states/2)
    axes[mid].set_xlabel('Firing Rate (Hz)')
    return axes


def plot_hmm_transition(transition, ax=None):
    if not ax:
        _, ax = plt.gca()

    n_states = transition.shape[0]
    labels = ['State %i' % i for i in range(n_states)]
    sns.heatmap(transition, ax=ax, cmap='plasma', cbar=True, square=True,
                xticklabels=labels, yticklabels=labels, vmin=0, vmax=1,
                cbar_kws={'shrink': 0.5})
    ax.set_ylim((0, n_states))
    ax.set_title('Transition Probabilities')
    return ax


def plot_hmm_initial_probs(PI, ax=None):
    if not ax:
        _ , ax = plt.gca()

    n_states = PI.shape[0]
    labels = ['State %i' % i for i in range(n_states)]
    PI = np.expand_dims(PI, 1)
    sns.heatmap(PI, ax=ax, cmap='plasma', cbar=True,
                yticklabels=labels, vmin=0, vmax=1)
    ax.set_ylim((0, n_states))
    ax.set_title('Initial Probabilities')
    return ax


def plot_hmm_overview(hmm, colors=None, hmm_id=None, save_file=None):
    n_states = hmm.n_states
    if not colors:
        colors = get_hmm_plot_colors(n_states)

    PI = hmm.initial_distribution
    A = hmm.transition
    B = hmm.emission
    fig, axes = plt.subplots(nrows=2, ncols=np.max((n_states,2)), figsize=(20, 15))
    if n_states > 2:
        for ax in axes[0,1:-1]:
            ax.axis('off')

    plot_hmm_initial_probs(PI, ax=axes[0,0])
    plot_hmm_transition(A, ax=axes[0,-1])
    plot_hmm_rates(B, axes=axes[1,:], colors=colors)
    mid = int(n_states/2)
    axes[1, mid].set_xlabel('')
    tmp_ax = fig.add_subplot('111', frameon=False)
    tmp_ax.tick_params(labelcolor='none', top=False, bottom=False,
                       left=False, right=False)
    tmp_ax.set_xlabel('Firing Rate (Hz)')

    fig.subplots_adjust(top=0.9)
    title_str = 'Fitted HMM Parameters'
    if hmm_id:
        title_str += '\n%s' % hmm_id

    fig.suptitle(title_str)
    if save_file:
        fig.savefig(save_file)
        plt.close(fig)
        return
    else:
        return fig, axes


def plot_hmm_figures(hmm, spikes, dt, time, hmm_id=None, save_dir=None):
    colors = get_hmm_plot_colors(hmm.n_states)
    if hmm_id is None:
        hmm_id = hmm.hmm_id


    fig_names = ['sequences', 'forward_probabilities',
                 'backward_probabilities', 'gamma_probabilities', 'overview']
    if save_dir:
        files = {x : os.path.join(save_dir, '%s.png' % x) for x in fig_names}
    else:
        files = dict.fromkeys(fig_names, None)


    # Plot sequences
    print('Plotting Viterbi Decoded Paths...')
    plot_viterbi_paths(hmm, spikes, time=time, colors=colors,
                       hmm_id=hmm_id, save_file=files['sequences'])

    # Plot alphas
    print('Plotting Forward Probabilities...')
    plot_forward_probs(hmm, spikes, dt, time=time, colors=colors,
                       hmm_id=hmm_id, save_file=files['forward_probabilities'])
    # Plot betas
    print('Plotting Backward Probabilities...')
    plot_backward_probs(hmm, spikes, dt, time=time, colors=colors,
                        hmm_id=hmm_id, save_file=files['backward_probabilities'])

    # Plot gammas
    print('Plotting Gamma Probabilities...')
    plot_gamma_probs(hmm, spikes, dt, time=time, colors=colors,
                     hmm_id=hmm_id, save_file=files['gamma_probabilities'])

    # Plot stats: rate bar plots, transition heat map, initial probabilities
    print('Plotting HMM Overview...')
    plot_hmm_overview(hmm, colors=colors, save_file=files['overview'])

    print('Plotting Complete!')
    plt.close('all')

Functions

def get_hmm_plot_colors(n_states)
Expand source code
def get_hmm_plot_colors(n_states):
    colors = [plt.cm.tab10(x) for x in np.linspace(0, 1, n_states)]
    return colors
def get_sequence_windows(seq)
Expand source code
def get_sequence_windows(seq):
    t = 0
    out = []
    while t < len(seq):
        s = seq[t]
        tmp = np.where(seq[t:] != s)[0]
        if len(tmp) == 0:
            tmp = len(seq) - t
        else:
            tmp = tmp[0]

        out.append((t, tmp+t-1, s))
        t += tmp

    return out
def get_threshold_windows(trace, thresh=0.75)

Returns list of tuples with start and stop time for windows where the given trace is above threshold. trace can be multiple rows. returns tuples in fashion (start_idx, stop_idx, row)

Expand source code
def get_threshold_windows(trace, thresh=0.75):
    '''Returns list of tuples with start and stop time for windows where the
    given trace is above threshold. trace can be multiple rows. returns tuples
    in fashion (start_idx, stop_idx, row)
    '''
    out = []
    if len(trace.shape) == 1:
        trace = np.array([trace])

    n_rows, n_steps = trace.shape
    for i, row in enumerate(trace):
        t = 0
        while t < n_steps:
            if row[t] >= thresh:
                tmp = np.where(row[t:] < thresh)[0]
            else:
                tmp = np.where(row[t:] >= thresh)[0]

            if len(tmp) == 0:
                tmp = len(row) - t
            else:
                tmp = tmp[0]

            if row[t] >= thresh:
                out.append((t, tmp+t-1, i))

            t += tmp

    return out
def make_hmm_raster(spikes, time=None, save_file=None)

Create figure of spikes rasters with each trial on a seperate axis

Parameters

spikes : np.array, Trials X Cells X Time array with 1s where spikes occur
 
time : np.array, 1D time vector
 
save_file : str, if provided figure is saved and not returned
 

Returns

plt.Figure, list of plt.Axes
 
Expand source code
def make_hmm_raster(spikes, time=None, save_file=None):
    '''Create figure of spikes rasters with each trial on a seperate axis

    Parameters
    ----------
    spikes: np.array, Trials X Cells X Time array with 1s where spikes occur
    time: np.array, 1D time vector
    save_file: str, if provided figure is saved and not returned

    Returns
    -------
    plt.Figure, list of plt.Axes
    '''
    if len(spikes) == 2:
        spikes = np.array([spikes])

    n_trials, n_cells, n_steps = spikes.shape
    if time is None:
        time = np.arange(0, n_steps)

    fig, axes = plt.subplots(nrows=n_trials, figsize=(15, n_trials))
    y_step = np.linspace(0.05, 0.95, n_cells)
    for ax, trial in zip(axes, spikes):
        tmp = plot_raster(trial, time=time, ax=ax)

        for spine in ax.spines.values():
            spine.set_visible(False)

        ax.get_yaxis().set_visible(False)
        ax.get_xaxis().set_visible(False)
        if time[0] < 0:
            ax.axvline(0, color='red', linestyle='--', linewidth=3, alpha=0.8)

    axes[-1].get_xaxis().set_visible(True)
    tmp_ax = fig.add_subplot('111', frameon=False)
    tmp_ax.tick_params(labelcolor='none', top=False, bottom=False,
                       left=False, right=False)
    tmp_ax.set_ylabel('Trials')
    axes[-1].set_xlabel('Time')
    axes[-1].set_ylabel('Cells', fontsize=11)
    if save_file:
        fig.savefig(save_file)
        plt.close(fig)
        return
    else:
        return fig, axes
def plot_backward_probs(hmm, spikes, dt, time=None, colors=None, axes=None, legend=True, hmm_id=None, thresh=0.75, save_file=None)
Expand source code
def plot_backward_probs(hmm, spikes, dt, time=None, colors=None, axes=None, legend=True,
                        hmm_id=None, thresh=0.75, save_file=None):
    if not axes:
        fig, axes = make_hmm_raster(spikes, time=time)
    else:
        fig = axes[0].figure

    if legend:
        fig.subplots_adjust(right=0.9)  # To  make room for legend

    betas = hmm.get_backward_probabilities(spikes, dt)
    n_trials, n_states, n_steps = betas.shape
    if time is None:
        time = np.arange(0, n_steps)

    if not colors:
        colors = [plt.cm.Set2(x) for x in np.linspace(0,1, n_states)]

    handles = []
    labels = []
    for trial, ax in zip(betas, axes):
        _, tmp_handles, tmp_labels = plot_probability_traces(trial,time=time, ax=ax,
                                                             colors=colors, thresh=thresh)
        for l, h in zip(tmp_labels, tmp_handles):
            if l not in labels:
                handles.append(h)
                labels.append(l)

    if time[0] != 0:
        ax.axvline(0, color='red', linestyle='--', linewidth=3, alpha=0.8)

    if legend:
        mid = int(n_trials/2)
        axes[mid].legend(handles, labels, loc='upper center',
                         bbox_to_anchor=(0.8, .5, .5, .5), shadow=True,
                         fontsize=14)

    axes[-1].set_xlabel('Time (ms)')
    title_str = 'HMM Backward Probabilities'
    if hmm_id:
        title_str += '\n%s' % hmm_id

    axes[0].set_title(title_str)
    if save_file:
        fig.savefig(save_file)
        plt.close(fig)
        return
    else:
        return fig, axes
def plot_forward_probs(hmm, spikes, dt, time=None, colors=None, axes=None, legend=True, hmm_id=None, thresh=0.75, save_file=None)
Expand source code
def plot_forward_probs(hmm, spikes, dt, time=None, colors=None, axes=None, legend=True,
                       hmm_id=None, thresh=0.75, save_file=None):
    if not axes:
        fig, axes = make_hmm_raster(spikes, time=time)
    else:
        fig = axes[0].figure

    if legend:
        fig.subplots_adjust(right=0.9)  # To  make room for legend

    alphas, norms = hmm.get_forward_probabilities(spikes, dt)
    n_trials, n_states, n_steps = alphas.shape
    if time is None:
        time = np.arange(0, n_steps)

    if not colors:
        colors = [plt.cm.Set2(x) for x in np.linspace(0,1, n_states)]

    handles = []
    labels = []
    for trial, ax in zip(alphas, axes):
        _, tmp_handles, tmp_labels = plot_probability_traces(trial,time=time, ax=ax,
                                                             colors=colors, thresh=thresh)
        for l, h in zip(tmp_labels, tmp_handles):
            if l not in labels:
                handles.append(h)
                labels.append(l)

    if time[0] != 0:
        ax.axvline(0, color='red', linestyle='--', linewidth=3, alpha=0.8)

    if legend:
        mid = int(n_trials/2)
        axes[mid].legend(handles, labels, loc='upper center',
                         bbox_to_anchor=(0.8, .5, .5, .5), shadow=True,
                         fontsize=14)

    axes[-1].set_xlabel('Time (ms)')
    title_str = 'HMM Forward Probabilities'
    if hmm_id:
        title_str += '\n%s' % hmm_id

    axes[0].set_title(title_str)
    if save_file:
        fig.savefig(save_file)
        plt.close(fig)
        return
    else:
        return fig, axes
def plot_gamma_probs(hmm, spikes=None, dt=None, time=None, colors=None, axes=None, legend=True, hmm_id=None, thresh=0.75, save_file=None)
Expand source code
def plot_gamma_probs(hmm, spikes=None, dt=None, time=None, colors=None, axes=None, legend=True,
                     hmm_id=None, thresh=0.75, save_file=None):
    if not axes:
        fig, axes = make_hmm_raster(spikes, time=time)
    else:
        fig = axes[0].figure

    if legend:
        fig.subplots_adjust(right=0.9)  # To  make room for legend

    gammas = hmm.stat_arrays['gamma_probabilities']
    if gammas == []:
        if spikes is None and dt is None:
            raise ValueError('Not enough info to compute gamma probabilities')

        gammas = hmm.get_gamma_probabilities(spikes, dt)

    n_trials, n_states, n_steps = gammas.shape
    if time is None:
        time = np.arange(0, n_steps)

    if not colors:
        colors = [plt.cm.Set2(x) for x in np.linspace(0,1, n_states)]

    handles = []
    labels = []
    for trial, ax in zip(gammas, axes):
        _, tmp_handles, tmp_labels = plot_probability_traces(trial,time=time, ax=ax,
                                                             colors=colors, thresh=thresh)
        for l, h in zip(tmp_labels, tmp_handles):
            if l not in labels:
                handles.append(h)
                labels.append(l)

    if time[0] != 0:
        ax.axvline(0, color='red', linestyle='--', linewidth=3, alpha=0.8)

    if legend:
        mid = int(n_trials/2)
        axes[mid].legend(handles, labels, loc='upper center',
                         bbox_to_anchor=(0.8, .5, .5, .5), shadow=True,
                         fontsize=14)

    axes[-1].set_xlabel('Time (ms)')
    title_str = 'HMM Gamma Probabilities'
    if hmm_id:
        title_str += '\n%s' % hmm_id

    axes[0].set_title(title_str)
    if save_file:
        fig.savefig(save_file)
        return
    else:
        return fig, axes
def plot_hmm_figures(hmm, spikes, dt, time, hmm_id=None, save_dir=None)
Expand source code
def plot_hmm_figures(hmm, spikes, dt, time, hmm_id=None, save_dir=None):
    colors = get_hmm_plot_colors(hmm.n_states)
    if hmm_id is None:
        hmm_id = hmm.hmm_id


    fig_names = ['sequences', 'forward_probabilities',
                 'backward_probabilities', 'gamma_probabilities', 'overview']
    if save_dir:
        files = {x : os.path.join(save_dir, '%s.png' % x) for x in fig_names}
    else:
        files = dict.fromkeys(fig_names, None)


    # Plot sequences
    print('Plotting Viterbi Decoded Paths...')
    plot_viterbi_paths(hmm, spikes, time=time, colors=colors,
                       hmm_id=hmm_id, save_file=files['sequences'])

    # Plot alphas
    print('Plotting Forward Probabilities...')
    plot_forward_probs(hmm, spikes, dt, time=time, colors=colors,
                       hmm_id=hmm_id, save_file=files['forward_probabilities'])
    # Plot betas
    print('Plotting Backward Probabilities...')
    plot_backward_probs(hmm, spikes, dt, time=time, colors=colors,
                        hmm_id=hmm_id, save_file=files['backward_probabilities'])

    # Plot gammas
    print('Plotting Gamma Probabilities...')
    plot_gamma_probs(hmm, spikes, dt, time=time, colors=colors,
                     hmm_id=hmm_id, save_file=files['gamma_probabilities'])

    # Plot stats: rate bar plots, transition heat map, initial probabilities
    print('Plotting HMM Overview...')
    plot_hmm_overview(hmm, colors=colors, save_file=files['overview'])

    print('Plotting Complete!')
    plt.close('all')
def plot_hmm_initial_probs(PI, ax=None)
Expand source code
def plot_hmm_initial_probs(PI, ax=None):
    if not ax:
        _ , ax = plt.gca()

    n_states = PI.shape[0]
    labels = ['State %i' % i for i in range(n_states)]
    PI = np.expand_dims(PI, 1)
    sns.heatmap(PI, ax=ax, cmap='plasma', cbar=True,
                yticklabels=labels, vmin=0, vmax=1)
    ax.set_ylim((0, n_states))
    ax.set_title('Initial Probabilities')
    return ax
def plot_hmm_overview(hmm, colors=None, hmm_id=None, save_file=None)
Expand source code
def plot_hmm_overview(hmm, colors=None, hmm_id=None, save_file=None):
    n_states = hmm.n_states
    if not colors:
        colors = get_hmm_plot_colors(n_states)

    PI = hmm.initial_distribution
    A = hmm.transition
    B = hmm.emission
    fig, axes = plt.subplots(nrows=2, ncols=np.max((n_states,2)), figsize=(20, 15))
    if n_states > 2:
        for ax in axes[0,1:-1]:
            ax.axis('off')

    plot_hmm_initial_probs(PI, ax=axes[0,0])
    plot_hmm_transition(A, ax=axes[0,-1])
    plot_hmm_rates(B, axes=axes[1,:], colors=colors)
    mid = int(n_states/2)
    axes[1, mid].set_xlabel('')
    tmp_ax = fig.add_subplot('111', frameon=False)
    tmp_ax.tick_params(labelcolor='none', top=False, bottom=False,
                       left=False, right=False)
    tmp_ax.set_xlabel('Firing Rate (Hz)')

    fig.subplots_adjust(top=0.9)
    title_str = 'Fitted HMM Parameters'
    if hmm_id:
        title_str += '\n%s' % hmm_id

    fig.suptitle(title_str)
    if save_file:
        fig.savefig(save_file)
        plt.close(fig)
        return
    else:
        return fig, axes
def plot_hmm_rates(rates, axes=None, colors=None)

Make bar plot of spike rates for each cell and state in an HMM emission matrix

Parameters

rates : np.array
Cell X State matrix of firing rates
Expand source code
def plot_hmm_rates(rates, axes=None, colors=None):
    '''Make bar plot of spike rates for each cell and state in an HMM emission
    matrix

    Parameters
    ----------
    rates: np.array
        Cell X State matrix of firing rates
    '''
    n_cells, n_states = rates.shape
    if axes is None:
        _, axes = plt.subplot(ncols=n_states)

    if len(axes) < n_states:
        raise ValueError('Must provided enough axes to plot each state')

    if not colors:
        colors = [plt.cm.Set2(x) for x in np.linspace(0,1, n_states)]

    df = pd.DataFrame(rates, columns=['state %i' % i for i in range(n_states)])
    df['cell'] = ['cell %i' % i for i in df.index]
    df = pd.melt(df, 'cell', ['state %i' % i for i in range(n_states)], 'state', 'rate')
    x_max = 0
    for g, ax, col in zip(df.groupby('state'), axes, colors):
        sns.barplot(data=g[1], x='rate', y='cell',
                    color='black', ax=ax)
        ax.set_title(g[0])
        ax.set_ylabel('')
        ax.set_xlabel('')
        ax.set_facecolor(col)
        ax.patch.set_alpha(0.5)
        ax.set_yticklabels([])
        ax.tick_params(left=False)
        xl = ax.get_xlim()
        if xl[1] > x_max:
            x_max = xl[1]

        for spine in ax.spines.values():
            spine.set_visible(False)


    for ax in axes:
        ax.set_xlim([0, x_max])

    axes[0].set_yticklabels(['Cell %i' % i for i in range(n_cells)])
    mid = int(n_states/2)
    axes[mid].set_xlabel('Firing Rate (Hz)')
    return axes
def plot_hmm_transition(transition, ax=None)
Expand source code
def plot_hmm_transition(transition, ax=None):
    if not ax:
        _, ax = plt.gca()

    n_states = transition.shape[0]
    labels = ['State %i' % i for i in range(n_states)]
    sns.heatmap(transition, ax=ax, cmap='plasma', cbar=True, square=True,
                xticklabels=labels, yticklabels=labels, vmin=0, vmax=1,
                cbar_kws={'shrink': 0.5})
    ax.set_ylim((0, n_states))
    ax.set_title('Transition Probabilities')
    return ax
def plot_probability_traces(traces, time=None, ax=None, colors=None, thresh=0.75, smoothing=3)
Expand source code
def plot_probability_traces(traces, time=None, ax=None, colors=None, thresh=0.75,
                           smoothing=3):
    y_min=0
    y_max=1
    if ax is None:
        _, ax = plt.gca()

    n_states, n_steps = traces.shape
    if time is None:
        time = np.arange(0, n_steps)

    if not colors:
        colors = [plt.cm.Set2(x) for x in np.linspace(0, 1, n_states)]

    windows = get_threshold_windows(traces, thresh=thresh)
    handles = {}
    for win in windows:
        t_vec = [time[win[0]], time[win[1]]]
        h = ax.fill_between(t_vec, [y_min, y_min], [y_max, y_max],
                            color=colors[int(win[2])], alpha=0.4)
        if  win[2] not in handles:
            handles[win[2]] = h

    leg_handles = [handles[k] for k in sorted(handles.keys())]
    leg_labels = ['State %i' % k for k in sorted(handles.keys())]

    for line, col in zip(traces, colors):
        tmp = line
        if smoothing:
            tmp = gaussian_filter1d(tmp, smoothing)

        ax.plot(time, tmp, color=col, linewidth=2)

    return ax, leg_handles, leg_labels
def plot_raster(spikes, time=None, ax=None, y_min=0.05, y_max=0.95)

Plot 2D spike raster

Parameters

spikes : np.array
2D matrix M x N where N is the number of time steps and in each bin is a 0 or 1, with 1 signifying the presence of a spike
Expand source code
def plot_raster(spikes, time=None, ax=None, y_min=0.05, y_max=0.95):
    '''Plot 2D spike raster

    Parameters
    ----------
    spikes : np.array
        2D matrix M x N where N is the number of time steps and in each bin is
        a 0 or 1, with 1 signifying the presence of a spike
    '''
    if not ax:
        _, ax = plt.gca()

    n_rows, n_steps = spikes.shape
    if time is None:
        time = np.arange(0, n_steps)

    y_steps = np.linspace(y_min, y_max, n_rows)
    for i, row in enumerate(spikes):
        idx = np.where(row == 1)[0]
        if len(idx) == 0:
            continue

        ax.scatter(time[idx], row[idx]*y_steps[i], color='black', marker='|')

    return ax
def plot_sequence(seq, time=None, ax=None, y_min=0, y_max=1, colors=None)
Expand source code
def plot_sequence(seq, time=None, ax=None, y_min=0, y_max=1, colors=None):
    if ax is None:
        _, ax = plt.gca()

    if time is None:
        time = np.arange(0, len(seq))

    nStates = np.max(seq)+1
    if colors is None:
        colors = [plt.cm.Set2(x) for x in np.linspace(0, 1, nStates)]

    seq_windows = get_sequence_windows(seq)
    handles = {}
    for win in seq_windows:
        t_vec = [time[win[0]], time[win[1]]]
        h = ax.fill_between(t_vec, [y_min, y_min], [y_max, y_max],
                            color=colors[int(win[2])], alpha=0.4)
        if  win[2] not in handles:
            handles[win[2]] = h

    leg_handles = [handles[k] for k in sorted(handles.keys())]
    leg_labels = ['State %i' % k for k in sorted(handles.keys())]
    return ax, leg_handles, leg_labels
def plot_viterbi_paths(hmm, spikes, time=None, colors=None, axes=None, legend=True, hmm_id=None, save_file=None)
Expand source code
def plot_viterbi_paths(hmm, spikes, time=None, colors=None, axes=None, legend=True,
                       hmm_id=None, save_file=None):
    if not axes:
        fig, axes = make_hmm_raster(spikes, time=time)
    else:
        fig = axes[0].figure

    if legend:
        fig.subplots_adjust(right=0.9)  # To  make room for legend


    BIC = hmm.BIC
    paths = hmm.stat_arrays['best_sequences']
    n_trials, n_steps = paths.shape
    n_states = hmm.n_states
    if time is None:
        time = np.arange(0, n_steps)

    if not colors:
        colors = [plt.cm.Set2(x) for x in np.linspace(0,1, n_states)]

    handles = []
    labels = []
    for trial, ax in zip(paths, axes):
        _, tmp_handles, tmp_labels = plot_sequence(trial, time=time, ax=ax, colors=colors)
        for l, h in zip(tmp_labels, tmp_handles):
            if l not in labels:
                handles.append(h)
                labels.append(l)

        if time[0] != 0:
            ax.axvline(0, color='red', linestyle='--', linewidth=3, alpha=0.8)

    if legend:
        mid = int(n_trials/2)
        axes[mid].legend(handles, labels, loc='upper center',
                         bbox_to_anchor=(0.8, .5, .5, .5), shadow=True,
                         fontsize=14)

    axes[-1].set_xlabel('Time (ms)')
    title_str = 'Decoded HMM Sequences'
    if hmm_id:
        title_str += '\n%s' % hmm_id

    axes[0].set_title(title_str)
    if save_file:
        fig.savefig(save_file)
        plt.close(fig)
        return
    else:
        return fig, axes