Module blechpy.analysis.clustering
Expand source code
import numpy as np
from scipy.signal import butter
from scipy.signal import filtfilt
from scipy.interpolate import interp1d
def get_filtered_electrode(data, freq = [300.0, 3000.0], sampling_rate = 30000.0):
el = data
m, n = butter(2, [2.0*freq[0]/sampling_rate, 2.0*freq[1]/sampling_rate], btype = 'bandpass')
filt_el = filtfilt(m, n, el)
return filt_el
def dejitter(slices, spike_times, spike_snapshot = [0.5, 1.0], sampling_rate = 30000.0):
'''Upsamples (by 10) and aligns spike waveforms to minima. Returns the
upsampled waveforms are correct spike_times
'''
x = np.arange(0,len(slices[0]),1)
xnew = np.arange(0,len(slices[0])-1,0.1)
# Calculate the number of samples to be sliced out around each spike's minimum
before = int((sampling_rate/1000.0)*(spike_snapshot[0]))
after = int((sampling_rate/1000.0)*(spike_snapshot[1]))
slices_dejittered = []
spike_times_dejittered = []
for i in range(len(slices)):
f = interp1d(x, slices[i])
# 10-fold interpolated spike
ynew = f(xnew)
orig_min = np.where(slices[i] == np.min(slices[i]))[0][0]
orig_min_time = x[orig_min] / (sampling_rate/1000)
minimum = np.where(ynew == np.min(ynew))[0][0]
min_time = xnew[minimum] / (sampling_rate/1000)
# Only accept spikes if the interpolated minimum has shifted by
# less than 1/10th of a ms (3 samples for a 30kHz recording, 30
# samples after interpolation)
if np.abs(min_time - orig_min_time) <= 0.1:
# If minimum is too close to the end for a full snapshot then toss out spike
if minimum + after*10 < len(ynew) and minimum - before*10 >= 0:
slices_dejittered.append(ynew[minimum - before*10 : minimum + after*10])
spike_times_dejittered.append(spike_times[i])
return np.array(slices_dejittered), np.array(spike_times_dejittered)
def get_waveforms(el_trace, spike_times, snapshot = [0.5, 1.0],
sampling_rate = 30000.0, bandpass=[300, 3000]):
'''Returns waveform slices based on the given spike_times (in samples)
'''
# Filter and extract waveforms
filt_el = get_filtered_electrode(el_trace, freq=bandpass,
sampling_rate=sampling_rate)
del el_trace
pre_pts = int((snapshot[0]+0.1) * (sampling_rate/1000))
post_pts = int((snapshot[1]+0.2) * (sampling_rate/1000))
slices = np.zeros((spike_times.shape[0], pre_pts+post_pts))
for i, st in enumerate(spike_times):
slices[i, :] = filt_el[st - pre_pts: st + post_pts]
slices_dj, times_dj = dejitter(slices, spike_times, snapshot, sampling_rate)
return slices_dj, sampling_rate*10
Functions
def dejitter(slices, spike_times, spike_snapshot=[0.5, 1.0], sampling_rate=30000.0)
-
Upsamples (by 10) and aligns spike waveforms to minima. Returns the upsampled waveforms are correct spike_times
Expand source code
def dejitter(slices, spike_times, spike_snapshot = [0.5, 1.0], sampling_rate = 30000.0): '''Upsamples (by 10) and aligns spike waveforms to minima. Returns the upsampled waveforms are correct spike_times ''' x = np.arange(0,len(slices[0]),1) xnew = np.arange(0,len(slices[0])-1,0.1) # Calculate the number of samples to be sliced out around each spike's minimum before = int((sampling_rate/1000.0)*(spike_snapshot[0])) after = int((sampling_rate/1000.0)*(spike_snapshot[1])) slices_dejittered = [] spike_times_dejittered = [] for i in range(len(slices)): f = interp1d(x, slices[i]) # 10-fold interpolated spike ynew = f(xnew) orig_min = np.where(slices[i] == np.min(slices[i]))[0][0] orig_min_time = x[orig_min] / (sampling_rate/1000) minimum = np.where(ynew == np.min(ynew))[0][0] min_time = xnew[minimum] / (sampling_rate/1000) # Only accept spikes if the interpolated minimum has shifted by # less than 1/10th of a ms (3 samples for a 30kHz recording, 30 # samples after interpolation) if np.abs(min_time - orig_min_time) <= 0.1: # If minimum is too close to the end for a full snapshot then toss out spike if minimum + after*10 < len(ynew) and minimum - before*10 >= 0: slices_dejittered.append(ynew[minimum - before*10 : minimum + after*10]) spike_times_dejittered.append(spike_times[i]) return np.array(slices_dejittered), np.array(spike_times_dejittered)
def get_filtered_electrode(data, freq=[300.0, 3000.0], sampling_rate=30000.0)
-
Expand source code
def get_filtered_electrode(data, freq = [300.0, 3000.0], sampling_rate = 30000.0): el = data m, n = butter(2, [2.0*freq[0]/sampling_rate, 2.0*freq[1]/sampling_rate], btype = 'bandpass') filt_el = filtfilt(m, n, el) return filt_el
def get_waveforms(el_trace, spike_times, snapshot=[0.5, 1.0], sampling_rate=30000.0, bandpass=[300, 3000])
-
Returns waveform slices based on the given spike_times (in samples)
Expand source code
def get_waveforms(el_trace, spike_times, snapshot = [0.5, 1.0], sampling_rate = 30000.0, bandpass=[300, 3000]): '''Returns waveform slices based on the given spike_times (in samples) ''' # Filter and extract waveforms filt_el = get_filtered_electrode(el_trace, freq=bandpass, sampling_rate=sampling_rate) del el_trace pre_pts = int((snapshot[0]+0.1) * (sampling_rate/1000)) post_pts = int((snapshot[1]+0.2) * (sampling_rate/1000)) slices = np.zeros((spike_times.shape[0], pre_pts+post_pts)) for i, st in enumerate(spike_times): slices[i, :] = filt_el[st - pre_pts: st + post_pts] slices_dj, times_dj = dejitter(slices, spike_times, snapshot, sampling_rate) return slices_dj, sampling_rate*10