Module blechpy.analysis.held_unit_analysis
Expand source code
import pandas as pd
import numpy as np
from scipy.spatial.distance import cdist
from sklearn.decomposition import PCA
from blechpy.analysis import spike_analysis as sas
from blechpy.datastructures.objects import load_dataset
from blechpy.dio import h5io
from blechpy.utils import print_tools as pt, userIO
import os
def calc_J1(wf_day1, wf_day2):
# Get the mean PCA waveforms on days 1 and 2
day1_mean = np.mean(wf_day1, axis=0)
day2_mean = np.mean(wf_day2, axis=0)
# Get the Euclidean distances of each day from its daily mean
day1_dists = cdist(wf_day1, day1_mean.reshape((-1, 3)), metric='euclidean')
day2_dists = cdist(wf_day2, day2_mean.reshape((-1, 3)), metric='euclidean')
# Sum up the distances to get J1
J1 = np.sum(day1_dists) + np.sum(day2_dists)
return J1
def calc_J2(wf_day1, wf_day2):
# Get the mean PCA waveforms on days 1 and 2
day1_mean = np.mean(wf_day1, axis=0)
day2_mean = np.mean(wf_day2, axis=0)
# Get the overall inter-day mean
overall_mean = np.mean(np.concatenate((wf_day1, wf_day2), axis=0), axis=0)
# Get the distances of the daily means from the inter-day mean
dist1 = cdist(day1_mean.reshape((-1, 3)), overall_mean.reshape((-1, 3)))
dist2 = cdist(day2_mean.reshape((-1, 3)), overall_mean.reshape((-1, 3)))
# Multiply the distances by the number of points on both days and sum to
# get J2
J2 = wf_day1.shape[0]*np.sum(dist1) + wf_day2.shape[0]*np.sum(dist2)
return J2
def calc_J3(wf_day1, wf_day2):
'''Calculate J3 value between 2 sets of PCA waveforms
Parameters
----------
wf_day1 : numpy.array
PCA waveforms for a single unit from session 1
wf_day2 : numpy.array
PCA waveforms for a single unit from session 2
Returns
-------
J3 : float
'''
J1 = calc_J1(wf_day1, wf_day2)
J2 = calc_J2(wf_day1, wf_day2)
J3 = J2 / J1
return J3
def get_intra_J3(rec_dirs, raw_waves=False):
print('\n----------\nComputing Intra J3s\n----------\n')
# Go through each recording directory and compute intra_J3 array
intra_J3 = []
for rd in rec_dirs:
print('Processing single units in %s...' % rd)
unit_names = h5io.get_unit_names(rd)
for un in unit_names:
print(' Computing for %s...' % un)
if raw_waves:
waves, descrip, fs = h5io.get_raw_unit_waveforms(rd, un)
else:
waves, descrip, fs = h5io.get_unit_waveforms(rd, un)
if descrip['single_unit'] == 1:
pca = PCA(n_components=3)
pca.fit(waves)
pca_waves = pca.transform(waves)
idx1 = int(waves.shape[0] * (1.0 / 3.0))
idx2 = int(waves.shape[0] * (2.0 / 3.0))
tmp_J3 = calc_J3(pca_waves[:idx1, :],
pca_waves[idx2:, :])
intra_J3.append(tmp_J3)
print('Done!\n==========')
return intra_J3
def find_held_units(rec_dirs, percent_criterion=95, rec_names=None, raw_waves=False):
# TODO: if any rec is 'one file per signal type' create tmp_raw.hdf5 and
# delete after detection is finished
userIO.tell_user('Computing intra recording J3 values...', shell=True)
intra_J3 = get_intra_J3(rec_dirs)
if rec_names is None:
rec_names = [os.path.basename(x) for x in rec_dirs]
rec_labels = {x: y for x, y in zip(rec_names, rec_dirs)}
print('\n----------\nComputing Inter J3s\n----------\n')
rec_pairs = [(rec_names[i], rec_names[i+1])
for i in range(len(rec_names)-1)]
held_df = pd.DataFrame(columns=['unit', 'electrode', 'single_unit',
'unit_type', *rec_names, 'J3'])
# Go through each pair of directories and computer inter_J3 between
# units. If the inter_J3 values is below the percentile_criterion of
# the intra_j3 array then mark units as held. Only compare the same
# type of single units on the same electrode
inter_J3 = []
for rec1, rec2 in rec_pairs:
rd1 = rec_labels.get(rec1)
rd2 = rec_labels.get(rec2)
h5_file1 = h5io.get_h5_filename(rd1)
h5_file2 = h5io.get_h5_filename(rd2)
print('Comparing %s vs %s' % (rec1, rec2))
found_cells = []
unit_names1 = h5io.get_unit_names(rd1)
unit_names2 = h5io.get_unit_names(rd2)
for unit1 in unit_names1:
if raw_waves:
wf1, descrip1, fs1 = h5io.get_raw_unit_waveforms(rd1, unit1)
else:
wf1, descrip1, fs1 = h5io.get_unit_waveforms(rd1, unit1)
electrode = descrip1['electrode_number']
single_unit = bool(descrip1['single_unit'])
unit_type = h5io.read_unit_description(descrip1)
if descrip1['single_unit'] == 1:
for unit2 in unit_names2:
if raw_waves:
wf2, descrip2, fs2 = \
h5io.get_raw_unit_waveforms(rd2, unit2,
required_descrip=descrip1)
else:
wf2, descrip2, fs2 = h5io.get_unit_waveforms(rd2, unit2,
required_descrip=descrip1)
if descrip1 == descrip2 and wf2 is not None:
print('Comparing %s %s vs %s %s' %
(rec1, unit1, rec2, unit2))
userIO.tell_user('Comparing %s %s vs %s %s' %
(rec1, unit1, rec2, unit2), shell=True)
if fs1 > fs2:
wf1 = sas.interpolate_waves(wf1, fs1,
fs2)
elif fs1 < fs2:
wf2 = sas.interpolate_waves(wf2, fs2,
fs1)
pca = PCA(n_components=3)
pca.fit(np.concatenate((wf1, wf2), axis=0))
pca_wf1 = pca.transform(wf1)
pca_wf2 = pca.transform(wf2)
J3 = calc_J3(pca_wf1, pca_wf2)
inter_J3.append(J3)
if J3 <= np.percentile(intra_J3,
percent_criterion):
print('Detected held unit:\n %s %s and %s %s'
% (rec1, unit1, rec2, unit2))
userIO.tell_user('Detected held unit:\n %s %s and %s %s'
% (rec1, unit1, rec2, unit2), shell=True)
found_cells.append((h5io.parse_unit_number(unit1),
h5io.parse_unit_number(unit2),
J3, single_unit, unit_type))
found_cells = np.array(found_cells)
userIO.tell_user('\n-----\n%s vs %s\n-----' % (rec1, rec2), shell=True)
userIO.tell_user(str(found_cells)+'\n', shell=True)
userIO.tell_user('Resolving duplicates...', shell=True)
found_cells = resolve_duplicate_matches(found_cells)
userIO.tell_user('Results:\n%s\n' % str(found_cells), shell=True)
for i, row in enumerate(found_cells):
if held_df.empty:
uL = 'A'
else:
uL = held_df['unit'].iloc[-1]
uL = pt.get_next_letter(uL)
unit1 = 'unit%03d' % int(row[0])
unit2 = 'unit%03d' % int(row[1])
j3 = row[2]
idx1 = np.where(held_df[rec1] == unit1)[0]
idx2 = np.where(held_df[rec2] == unit2)[0]
if row[3] == 'True':
single_unit = True
else:
single_unit = False
if idx1.size == 0 and idx2.size == 0:
tmp = {'unit': uL,
'single_unit': single_unit,
'unit_type': row[4],
rec1: unit1,
rec2: unit2,
'J3': [float(j3)]}
held_df = held_df.append(tmp, ignore_index=True)
elif idx1.size != 0 and idx2.size != 0:
userIO.tell_user('WTF...', shell=True)
continue
elif idx1.size != 0:
held_df[rec2].iloc[idx1[0]] = unit2
held_df['J3'].iloc[idx1[0]].append(float(j3))
else:
held_df[rec1].iloc[idx2[0]] = unit1
held_df['J3'].iloc[idx2[0]].append(float(j3))
return held_df, intra_J3, inter_J3
def resolve_duplicate_matches(found_cells):
if len(found_cells) == 0:
return found_cells
unique_units = np.unique(found_cells[:,0])
new_found = []
for unit in unique_units:
idx = np.where(found_cells[:,0] == unit)[0]
if len(idx) == 1:
new_found.append(found_cells[idx,:])
continue
min_j3 = np.argmin(found_cells[idx,2])
new_found.append(found_cells[idx[min_j3],:])
found = np.vstack(new_found)
go_back = []
new_found = []
for unit in np.unique(found[:,1]):
idx = np.where(found[:,1] == unit)[0]
if len(idx) == 1:
new_found.append(found[idx,:])
continue
min_j3 = np.argmin(found[idx,2])
i = idx[min_j3]
idx = np.delete(idx, min_j3)
new_found.append(found[i, :])
go_back.append(found[idx, :])
for row in go_back:
idx = np.where((found_cells[:,0] == row[0][0]) & (found_cells[:,1] != row[0][1]))[0]
if len(idx) == 1:
new_found.append(found_cells[idx,:])
continue
elif len(idx) == 0:
continue
min_j3 = np.argmin(found_cells[idx, 2])
new_found.append(found_cells[idx[min_j3],:])
out = np.vstack(new_found)
uni = True
for unit in np.unique(out[:,0]):
idx = np.where(out[:,0] == unit)[0]
if len(idx) > 1:
uni = False
break
for unit in np.unique(out[:,1]):
idx = np.where(out[:,1] == unit)[0]
if len(idx) > 1:
uni = False
break
# Sort
a = [int(x) for x in out[:,0]]
idx = np.argsort(a)
out = out[idx,:]
if uni:
return out
else:
print('Duplicates still found. Re-running')
print(out)
return resolve_duplicate_matches(out)
### Delete after here
def get_response_change(unit_name, rec1, unit1,
din1, rec2, unit2, din2,
bin_size=250, bin_step=25, norm_func=None):
'''Uses the spike arrays to compute the change in
firing rate of the response to the tastant.
Parameters
----------
unit_name : str, name of held unit
rec1 : str, path to recording directory 1
unit1: str, name of unit in rec1
din1 : int, number of din to use from rec1
rec2 : str, path to recording directory 2
unit2: str, name of unit in rec2
din2 : int, number of din to use from rec2
bin_size : int, default=250
width of bins in units of time vector saved in hf5 spike_trains
usually ms
bin_step : int, default=25
step size to take from one bin to the next in same units (usually ms)
norm_func: function (optional)
function with which to normalize the firing rates before getting difference
must take inputs (time_vector, firing_rate_array) where time_vector is
1D numpy.array and firing_rate_array is a Trial x Time numpy.array
Must return a numpy.array with same size as firing rate array
Returns
-------
difference_of_means : numpy.array
SEM : numpy.array, standard error of the mean difference
'''
# Get metadata
dat1 = load_dataset(rec1)
dat2 = load_dataset(rec2)
# Get data from hf5 files
time1, spikes1 = dio.h5io.get_spike_data(rec1, unit1, din1)
time2, spikes2 = dio.h5io.get_spike_data(rec2, unit2, din2)
# Get Firing Rates
bin_time1, fr1 = sas.get_binned_firing_rate(time1, spikes1, bin_size, bin_step)
bin_time2, fr2 = sas.get_binned_firing_rate(time2, spike2, bin_size, bin_step)
if not np.array_equal(bin_time1, bin_time2):
raise ValueError('Time of spike trains is not aligned')
# Normalize firing rates
if norm_func:
fr1 = norm_func(bin_time1, fr1)
fr2 = norm_fun(bin_time2, fr2)
difference_of_mean, SEM = sas.get_mean_difference(fr1, fr2, axis=0)
return difference_of_mean, SEM, bin_time1
Functions
def calc_J1(wf_day1, wf_day2)
-
Expand source code
def calc_J1(wf_day1, wf_day2): # Get the mean PCA waveforms on days 1 and 2 day1_mean = np.mean(wf_day1, axis=0) day2_mean = np.mean(wf_day2, axis=0) # Get the Euclidean distances of each day from its daily mean day1_dists = cdist(wf_day1, day1_mean.reshape((-1, 3)), metric='euclidean') day2_dists = cdist(wf_day2, day2_mean.reshape((-1, 3)), metric='euclidean') # Sum up the distances to get J1 J1 = np.sum(day1_dists) + np.sum(day2_dists) return J1
def calc_J2(wf_day1, wf_day2)
-
Expand source code
def calc_J2(wf_day1, wf_day2): # Get the mean PCA waveforms on days 1 and 2 day1_mean = np.mean(wf_day1, axis=0) day2_mean = np.mean(wf_day2, axis=0) # Get the overall inter-day mean overall_mean = np.mean(np.concatenate((wf_day1, wf_day2), axis=0), axis=0) # Get the distances of the daily means from the inter-day mean dist1 = cdist(day1_mean.reshape((-1, 3)), overall_mean.reshape((-1, 3))) dist2 = cdist(day2_mean.reshape((-1, 3)), overall_mean.reshape((-1, 3))) # Multiply the distances by the number of points on both days and sum to # get J2 J2 = wf_day1.shape[0]*np.sum(dist1) + wf_day2.shape[0]*np.sum(dist2) return J2
def calc_J3(wf_day1, wf_day2)
-
Calculate J3 value between 2 sets of PCA waveforms
Parameters
wf_day1
:numpy.array
- PCA waveforms for a single unit from session 1
wf_day2
:numpy.array
- PCA waveforms for a single unit from session 2
Returns
J3
:float
Expand source code
def calc_J3(wf_day1, wf_day2): '''Calculate J3 value between 2 sets of PCA waveforms Parameters ---------- wf_day1 : numpy.array PCA waveforms for a single unit from session 1 wf_day2 : numpy.array PCA waveforms for a single unit from session 2 Returns ------- J3 : float ''' J1 = calc_J1(wf_day1, wf_day2) J2 = calc_J2(wf_day1, wf_day2) J3 = J2 / J1 return J3
def find_held_units(rec_dirs, percent_criterion=95, rec_names=None, raw_waves=False)
-
Expand source code
def find_held_units(rec_dirs, percent_criterion=95, rec_names=None, raw_waves=False): # TODO: if any rec is 'one file per signal type' create tmp_raw.hdf5 and # delete after detection is finished userIO.tell_user('Computing intra recording J3 values...', shell=True) intra_J3 = get_intra_J3(rec_dirs) if rec_names is None: rec_names = [os.path.basename(x) for x in rec_dirs] rec_labels = {x: y for x, y in zip(rec_names, rec_dirs)} print('\n----------\nComputing Inter J3s\n----------\n') rec_pairs = [(rec_names[i], rec_names[i+1]) for i in range(len(rec_names)-1)] held_df = pd.DataFrame(columns=['unit', 'electrode', 'single_unit', 'unit_type', *rec_names, 'J3']) # Go through each pair of directories and computer inter_J3 between # units. If the inter_J3 values is below the percentile_criterion of # the intra_j3 array then mark units as held. Only compare the same # type of single units on the same electrode inter_J3 = [] for rec1, rec2 in rec_pairs: rd1 = rec_labels.get(rec1) rd2 = rec_labels.get(rec2) h5_file1 = h5io.get_h5_filename(rd1) h5_file2 = h5io.get_h5_filename(rd2) print('Comparing %s vs %s' % (rec1, rec2)) found_cells = [] unit_names1 = h5io.get_unit_names(rd1) unit_names2 = h5io.get_unit_names(rd2) for unit1 in unit_names1: if raw_waves: wf1, descrip1, fs1 = h5io.get_raw_unit_waveforms(rd1, unit1) else: wf1, descrip1, fs1 = h5io.get_unit_waveforms(rd1, unit1) electrode = descrip1['electrode_number'] single_unit = bool(descrip1['single_unit']) unit_type = h5io.read_unit_description(descrip1) if descrip1['single_unit'] == 1: for unit2 in unit_names2: if raw_waves: wf2, descrip2, fs2 = \ h5io.get_raw_unit_waveforms(rd2, unit2, required_descrip=descrip1) else: wf2, descrip2, fs2 = h5io.get_unit_waveforms(rd2, unit2, required_descrip=descrip1) if descrip1 == descrip2 and wf2 is not None: print('Comparing %s %s vs %s %s' % (rec1, unit1, rec2, unit2)) userIO.tell_user('Comparing %s %s vs %s %s' % (rec1, unit1, rec2, unit2), shell=True) if fs1 > fs2: wf1 = sas.interpolate_waves(wf1, fs1, fs2) elif fs1 < fs2: wf2 = sas.interpolate_waves(wf2, fs2, fs1) pca = PCA(n_components=3) pca.fit(np.concatenate((wf1, wf2), axis=0)) pca_wf1 = pca.transform(wf1) pca_wf2 = pca.transform(wf2) J3 = calc_J3(pca_wf1, pca_wf2) inter_J3.append(J3) if J3 <= np.percentile(intra_J3, percent_criterion): print('Detected held unit:\n %s %s and %s %s' % (rec1, unit1, rec2, unit2)) userIO.tell_user('Detected held unit:\n %s %s and %s %s' % (rec1, unit1, rec2, unit2), shell=True) found_cells.append((h5io.parse_unit_number(unit1), h5io.parse_unit_number(unit2), J3, single_unit, unit_type)) found_cells = np.array(found_cells) userIO.tell_user('\n-----\n%s vs %s\n-----' % (rec1, rec2), shell=True) userIO.tell_user(str(found_cells)+'\n', shell=True) userIO.tell_user('Resolving duplicates...', shell=True) found_cells = resolve_duplicate_matches(found_cells) userIO.tell_user('Results:\n%s\n' % str(found_cells), shell=True) for i, row in enumerate(found_cells): if held_df.empty: uL = 'A' else: uL = held_df['unit'].iloc[-1] uL = pt.get_next_letter(uL) unit1 = 'unit%03d' % int(row[0]) unit2 = 'unit%03d' % int(row[1]) j3 = row[2] idx1 = np.where(held_df[rec1] == unit1)[0] idx2 = np.where(held_df[rec2] == unit2)[0] if row[3] == 'True': single_unit = True else: single_unit = False if idx1.size == 0 and idx2.size == 0: tmp = {'unit': uL, 'single_unit': single_unit, 'unit_type': row[4], rec1: unit1, rec2: unit2, 'J3': [float(j3)]} held_df = held_df.append(tmp, ignore_index=True) elif idx1.size != 0 and idx2.size != 0: userIO.tell_user('WTF...', shell=True) continue elif idx1.size != 0: held_df[rec2].iloc[idx1[0]] = unit2 held_df['J3'].iloc[idx1[0]].append(float(j3)) else: held_df[rec1].iloc[idx2[0]] = unit1 held_df['J3'].iloc[idx2[0]].append(float(j3)) return held_df, intra_J3, inter_J3
def get_intra_J3(rec_dirs, raw_waves=False)
-
Expand source code
def get_intra_J3(rec_dirs, raw_waves=False): print('\n----------\nComputing Intra J3s\n----------\n') # Go through each recording directory and compute intra_J3 array intra_J3 = [] for rd in rec_dirs: print('Processing single units in %s...' % rd) unit_names = h5io.get_unit_names(rd) for un in unit_names: print(' Computing for %s...' % un) if raw_waves: waves, descrip, fs = h5io.get_raw_unit_waveforms(rd, un) else: waves, descrip, fs = h5io.get_unit_waveforms(rd, un) if descrip['single_unit'] == 1: pca = PCA(n_components=3) pca.fit(waves) pca_waves = pca.transform(waves) idx1 = int(waves.shape[0] * (1.0 / 3.0)) idx2 = int(waves.shape[0] * (2.0 / 3.0)) tmp_J3 = calc_J3(pca_waves[:idx1, :], pca_waves[idx2:, :]) intra_J3.append(tmp_J3) print('Done!\n==========') return intra_J3
def get_response_change(unit_name, rec1, unit1, din1, rec2, unit2, din2, bin_size=250, bin_step=25, norm_func=None)
-
Uses the spike arrays to compute the change in firing rate of the response to the tastant.
Parameters
unit_name
:str, name
ofheld unit
rec1
:str, path to recording directory 1
unit1
:str, name
ofunit in rec1
din1
:int, number
ofdin to use from rec1
rec2
:str, path to recording directory 2
unit2
:str, name
ofunit in rec2
din2
:int, number
ofdin to use from rec2
bin_size
:int
, default=250
- width of bins in units of time vector saved in hf5 spike_trains usually ms
bin_step
:int
, default=25
- step size to take from one bin to the next in same units (usually ms)
norm_func
:function (optional)
- function with which to normalize the firing rates before getting difference must take inputs (time_vector, firing_rate_array) where time_vector is 1D numpy.array and firing_rate_array is a Trial x Time numpy.array Must return a numpy.array with same size as firing rate array
Returns
difference_of_means
:numpy.array
SEM
:numpy.array, standard error
ofthe mean difference
Expand source code
def get_response_change(unit_name, rec1, unit1, din1, rec2, unit2, din2, bin_size=250, bin_step=25, norm_func=None): '''Uses the spike arrays to compute the change in firing rate of the response to the tastant. Parameters ---------- unit_name : str, name of held unit rec1 : str, path to recording directory 1 unit1: str, name of unit in rec1 din1 : int, number of din to use from rec1 rec2 : str, path to recording directory 2 unit2: str, name of unit in rec2 din2 : int, number of din to use from rec2 bin_size : int, default=250 width of bins in units of time vector saved in hf5 spike_trains usually ms bin_step : int, default=25 step size to take from one bin to the next in same units (usually ms) norm_func: function (optional) function with which to normalize the firing rates before getting difference must take inputs (time_vector, firing_rate_array) where time_vector is 1D numpy.array and firing_rate_array is a Trial x Time numpy.array Must return a numpy.array with same size as firing rate array Returns ------- difference_of_means : numpy.array SEM : numpy.array, standard error of the mean difference ''' # Get metadata dat1 = load_dataset(rec1) dat2 = load_dataset(rec2) # Get data from hf5 files time1, spikes1 = dio.h5io.get_spike_data(rec1, unit1, din1) time2, spikes2 = dio.h5io.get_spike_data(rec2, unit2, din2) # Get Firing Rates bin_time1, fr1 = sas.get_binned_firing_rate(time1, spikes1, bin_size, bin_step) bin_time2, fr2 = sas.get_binned_firing_rate(time2, spike2, bin_size, bin_step) if not np.array_equal(bin_time1, bin_time2): raise ValueError('Time of spike trains is not aligned') # Normalize firing rates if norm_func: fr1 = norm_func(bin_time1, fr1) fr2 = norm_fun(bin_time2, fr2) difference_of_mean, SEM = sas.get_mean_difference(fr1, fr2, axis=0) return difference_of_mean, SEM, bin_time1
def resolve_duplicate_matches(found_cells)
-
Expand source code
def resolve_duplicate_matches(found_cells): if len(found_cells) == 0: return found_cells unique_units = np.unique(found_cells[:,0]) new_found = [] for unit in unique_units: idx = np.where(found_cells[:,0] == unit)[0] if len(idx) == 1: new_found.append(found_cells[idx,:]) continue min_j3 = np.argmin(found_cells[idx,2]) new_found.append(found_cells[idx[min_j3],:]) found = np.vstack(new_found) go_back = [] new_found = [] for unit in np.unique(found[:,1]): idx = np.where(found[:,1] == unit)[0] if len(idx) == 1: new_found.append(found[idx,:]) continue min_j3 = np.argmin(found[idx,2]) i = idx[min_j3] idx = np.delete(idx, min_j3) new_found.append(found[i, :]) go_back.append(found[idx, :]) for row in go_back: idx = np.where((found_cells[:,0] == row[0][0]) & (found_cells[:,1] != row[0][1]))[0] if len(idx) == 1: new_found.append(found_cells[idx,:]) continue elif len(idx) == 0: continue min_j3 = np.argmin(found_cells[idx, 2]) new_found.append(found_cells[idx[min_j3],:]) out = np.vstack(new_found) uni = True for unit in np.unique(out[:,0]): idx = np.where(out[:,0] == unit)[0] if len(idx) > 1: uni = False break for unit in np.unique(out[:,1]): idx = np.where(out[:,1] == unit)[0] if len(idx) > 1: uni = False break # Sort a = [int(x) for x in out[:,0]] idx = np.argsort(a) out = out[idx,:] if uni: return out else: print('Duplicates still found. Re-running') print(out) return resolve_duplicate_matches(out)