Module blechpy.dio.hmmIO

Expand source code
import os
import tables
import numpy as np
import pandas as pd
from blechpy.utils.particles import HMMInfoParticle
from copy import deepcopy

def fix_hmm_overview(h5_file):
    '''made to add the area column to the hmm overview
    now adds the hmm_class column
    '''
    fix = True
    if not os.path.isfile(h5_file):
        return

    #with tables.open_file(h5_file, 'r') as hf5:
    #    if 'notes' not in hf5.root.data_overview.colnames:
    #        fix = True
 
    #if not fix:
    #    return

    print('Fixing data overview table in %s' % h5_file)
    with tables.open_file(h5_file, 'a') as hf5:
        for row in hf5.root.data_overview.where('log_likelihood == 0.'):
            hmm_id = row['hmm_id']
            h_str = 'hmm_%s' % hmm_id
            fit_LL = hf5.root[h_str]['fit_LL'][-1]
            row['log_likelihood'] = fit_LL
            row.update()

        hf5.root.data_overview.flush()
        hf5.flush()


        #if 'tmp_overview' in hf5.root:
        #    hf5.remove_node('/tmp_overview')

        #new_table = hf5.create_table('/', 'tmp_overview', HMMInfoParticle,
        #                             'Parameters and goodness-of-fit info for HMMs in file')
        #table = hf5.root.data_overview
        #columns = table.colnames
        #new_row = new_table.row
        #for row in table.iterrows():
        #    for x in columns:
        #        new_row[x] = row[x]
 
        #    new_row['notes'] = 'PI & A constrained'
        #    new_row.append()
 
        #new_table.flush()
        #hf5.move_node('/tmp_overview', '/', 'data_overview', overwrite=True)

#       # Now change state_sequences to best_sequences
#       nodes = [x for x in hf5.walk_nodes('/') if 'log_likelihood_hist' in x._v_pathname]
#       for x in nodes:
#           hf5.move_node(x._v_pathname, x._v_parent._v_pathname, 'max_log_prob')

        #hf5.flush()


def setup_hmm_hdf5(h5_file, infoParticle=HMMInfoParticle):
    if os.path.isfile(h5_file):
        return

    print('Initializing hdf5 store: %s' % h5_file)
    with tables.open_file(h5_file, 'a') as hf5:
        if 'data_overview' not in hf5.root:
            print('Initializing data_overview table in hdf5 store...')
            table = hf5.create_table('/', 'data_overview', infoParticle,
                                     'Parameters and goodness-of-fit info for HMMs in file')
            table.flush()


def read_hmm_from_hdf5(h5_file, hmm_id):
    with tables.open_file(h5_file, 'r') as hf5:
        h_str = 'hmm_%s' % hmm_id
        if h_str not in hf5.root or len(hf5.list_nodes('/'+h_str)) == 0:
            return None

        # print('Loading HMM %i from hdf5' % hmm_id)
        nodes = [x._v_name for x in hf5.list_nodes('/'+h_str)]
        tmp = hf5.root[h_str]
        stat_arrays = {}
        for k in nodes:
            stat_arrays[k] = tmp[k][:]

        PI = stat_arrays.pop('initial_distribution')
        A = stat_arrays.pop('transition')
        B = stat_arrays.pop('emission')

        rs = stat_arrays['row_id'].shape
        tmp = np.array([x.decode('utf-8') for x in stat_arrays['row_id'].ravel()])
        stat_arrays['row_id'] = tmp.reshape(rs)


        table = hf5.root.data_overview
        for row in table.where('hmm_id == id', condvars={'id':hmm_id}):
            params = {}
            for k in table.colnames:
                if table.coltypes[k] == 'string':
                    params[k] = row[k].decode('utf-8')
                    if '..' in params[k]:
                        params[k] = params[k].split('..')

                else:
                    params[k] = row[k]

            if isinstance(params['taste'], list):
                params['channel'] = list_channel_hash(params['channel'])

            return PI, A, B, stat_arrays, params
        else:
            raise ValueError('Parameters not found for hmm %i' % hmm_id)


def write_hmm_to_hdf5(h5_file, hmm, params):
    params = deepcopy(params)
    hmm_id = hmm.hmm_id
    if 'hmm_id' in params and hmm_id is None:
        hmm.hmm_id = hmm_id = params['hmm_id']
    elif 'hmm_id' in params and hmm_id != params['hmm_id']:
        raise ValueError('ID of HMM %i does not match ID in params %i'
                         % (hmm_id, params['hmm_id']))
    else:
        pass

    if not os.path.isfile(h5_file):
        setup_hmm_hdf5(h5_file)

    print('\n' + '='*80)
    print('Writing HMM %s to hdf5 file...' % hmm_id)
    print(params)
    print('PI: %s' % repr(hmm.initial_distribution.shape))
    print('A: %s' % repr(hmm.transition.shape))
    print('B: %s' % repr(hmm.emission.shape))
    with tables.open_file(h5_file, 'a') as hf5:
        if hmm_id is None:
            ids = hf5.root.data_overview.col('hmm_id')
            tmp = np.where(np.diff(ids) > 1)[0]
            if len(ids) == 0:
                hmm_id = 0
            elif len(tmp) == 0:
                hmm_id = np.max(ids) + 1
            else:
                hmm_id = ids[tmp[0]] + 1

            hmm.hmm_id = hmm_id
            params['hmm_id'] = hmm_id
            print('HMM assigned id #%i' % hmm_id)

        h_str = 'hmm_%s' % hmm_id
        if h_str in hf5.root:
            print('Deleting existing data for %s...' % h_str)
            hf5.remove_node('/', h_str, recursive=True)

        print('Writing new data for %s' % h_str)
        hf5.create_group('/', h_str, 'Data for HMM #%i' % hmm_id)
        hf5.create_array('/'+h_str, 'initial_distribution',
                         hmm.initial_distribution)
        hf5.create_array('/'+h_str, 'transition', hmm.transition)
        hf5.create_array('/'+h_str, 'emission', hmm.emission)
        for k, v in hmm.stat_arrays.items():
            if not isinstance(v, np.ndarray):
                tmp_v = np.array(v)
            else:
                tmp_v = v

            hf5.create_array('/'+h_str, k, tmp_v)

        table = hf5.root.data_overview
        edited_rows = 0
        for row in table.where('hmm_id == id', condvars={'id': hmm_id}):
            print('Editing existing row in data_overview with new values for HMM %s' % hmm_id)
            print('New iteration: %i, fit_LL: %.4E, BIC: %.3E' % (hmm.iteration, hmm.fit_LL, hmm.BIC))
            print('Old iteration: %i, fit_LL: %.4E, BIC: %.3E' % (row['n_iterations'], row['log_likelihood'], row['BIC']))
            row['BIC'] = hmm.BIC
            row['cost'] = hmm.cost
            row['converged'] = hmm.converged
            row['fitted'] = hmm.fitted
            row['max_log_prob'] = hmm.max_log_prob
            row['log_likelihood'] = hmm.fit_LL
            row['n_iterations'] = hmm.iteration
            row.update()
            edited_rows += 1

        if edited_rows == 0:
            print('Creating new row in data_overview for HMM %s' % hmm_id)
            row = table.row
            for k,v in params.items():
                if table.coltypes[k] == 'string' and isinstance(v, list):
                    row[k] = '..'.join(v)
                elif isinstance(v, list) and k == 'channel':
                    print('channels: %s' % str(v))
                    row[k] = hash_channel_list(v)
                    print('channel hash: %i' % row[k])
                elif not isinstance(v, list):
                    row[k] = v
                else:
                    raise ValueError('Something fucked up')


            row['BIC'] = hmm.BIC
            row['cost'] = hmm.cost
            row['converged'] = hmm.converged
            row['fitted'] = hmm.fitted
            row['max_log_prob'] = hmm.max_log_prob
            row['log_likelihood'] = hmm.fit_LL
            row['n_iterations'] = hmm.iteration
            row.append()

        table.flush()
        hf5.flush()

    print('='*80+'\n')


def delete_hmm_from_hdf5(h5_file, **kwargs):
    with tables.open_file(h5_file, 'a') as hf5:
        table = hf5.root.data_overview
        ids = []
        rmv = list(np.arange(len(table)))
        for k,v in kwargs.items():
            tmp = table[:][k]
            if isinstance(v, str):
                tmp = [x.decode('utf-8') for x in tmp]
                tmp = np.array(tmp)

            if v in tmp:
                idx = np.where(tmp == v)[0]
                ids.append(idx)

        for x in ids:
            rmv = [y for y in rmv if y in x]

        rmv.sort()
        for x in reversed(rmv):
            hmm_id = table[x]['hmm_id']
            h_str = 'hmm_%s' % hmm_id
            if h_str in hf5.root:
                print('Deleting existing data for %s...' % h_str)
                hf5.remove_node('/', h_str, recursive=True)
            else:
                print('HMM %s not found in hdf5.' % hmm_id)

            table.remove_rows(x, x+1)

        table.flush()
        hf5.flush()


def compare_hmm_params(p1, p2):
    compare_keys = ['taste', 'unit_type', 'dt', 'max_iter', 'time_start',
                    'time_end', 'n_states', 'n_trials', 'hmm_class', 'area', 'notes']
    for k in compare_keys:
        if p1[k] != p2[k]:
            return False

    return True


def get_hmm_h5(rec_dir):
    tmp = glob.glob(rec_dir + os.sep + '**' + os.sep + '*HMM_Analysis.hdf5', recursive=True)
    if len(tmp)>1:
        raise ValueError(str(tmp))

    return tmp[0]


def get_hmm_overview_from_hdf5(h5_file):
    with tables.open_file(h5_file, 'r') as hf5:
        table = hf5.root.data_overview
        ids = table[:]['hmm_id']

    params = []
    for i in ids:
        _, _, _, _, p = read_hmm_from_hdf5(h5_file, i)
        params.append(p)

    df = pd.DataFrame(params)

    return df


def hash_channel_list(channels):
    channels.insert(0, len(channels)) # gives elements and array and prevent leading 0 from dropping
    return ''.join([str(x) for x in channels])

def list_channel_hash(num):
    tmp = [int(x) for x in str(num)]
    return tmp[1:]

Functions

def compare_hmm_params(p1, p2)
Expand source code
def compare_hmm_params(p1, p2):
    compare_keys = ['taste', 'unit_type', 'dt', 'max_iter', 'time_start',
                    'time_end', 'n_states', 'n_trials', 'hmm_class', 'area', 'notes']
    for k in compare_keys:
        if p1[k] != p2[k]:
            return False

    return True
def delete_hmm_from_hdf5(h5_file, **kwargs)
Expand source code
def delete_hmm_from_hdf5(h5_file, **kwargs):
    with tables.open_file(h5_file, 'a') as hf5:
        table = hf5.root.data_overview
        ids = []
        rmv = list(np.arange(len(table)))
        for k,v in kwargs.items():
            tmp = table[:][k]
            if isinstance(v, str):
                tmp = [x.decode('utf-8') for x in tmp]
                tmp = np.array(tmp)

            if v in tmp:
                idx = np.where(tmp == v)[0]
                ids.append(idx)

        for x in ids:
            rmv = [y for y in rmv if y in x]

        rmv.sort()
        for x in reversed(rmv):
            hmm_id = table[x]['hmm_id']
            h_str = 'hmm_%s' % hmm_id
            if h_str in hf5.root:
                print('Deleting existing data for %s...' % h_str)
                hf5.remove_node('/', h_str, recursive=True)
            else:
                print('HMM %s not found in hdf5.' % hmm_id)

            table.remove_rows(x, x+1)

        table.flush()
        hf5.flush()
def fix_hmm_overview(h5_file)

made to add the area column to the hmm overview now adds the hmm_class column

Expand source code
def fix_hmm_overview(h5_file):
    '''made to add the area column to the hmm overview
    now adds the hmm_class column
    '''
    fix = True
    if not os.path.isfile(h5_file):
        return

    #with tables.open_file(h5_file, 'r') as hf5:
    #    if 'notes' not in hf5.root.data_overview.colnames:
    #        fix = True
 
    #if not fix:
    #    return

    print('Fixing data overview table in %s' % h5_file)
    with tables.open_file(h5_file, 'a') as hf5:
        for row in hf5.root.data_overview.where('log_likelihood == 0.'):
            hmm_id = row['hmm_id']
            h_str = 'hmm_%s' % hmm_id
            fit_LL = hf5.root[h_str]['fit_LL'][-1]
            row['log_likelihood'] = fit_LL
            row.update()

        hf5.root.data_overview.flush()
        hf5.flush()
def get_hmm_h5(rec_dir)
Expand source code
def get_hmm_h5(rec_dir):
    tmp = glob.glob(rec_dir + os.sep + '**' + os.sep + '*HMM_Analysis.hdf5', recursive=True)
    if len(tmp)>1:
        raise ValueError(str(tmp))

    return tmp[0]
def get_hmm_overview_from_hdf5(h5_file)
Expand source code
def get_hmm_overview_from_hdf5(h5_file):
    with tables.open_file(h5_file, 'r') as hf5:
        table = hf5.root.data_overview
        ids = table[:]['hmm_id']

    params = []
    for i in ids:
        _, _, _, _, p = read_hmm_from_hdf5(h5_file, i)
        params.append(p)

    df = pd.DataFrame(params)

    return df
def hash_channel_list(channels)
Expand source code
def hash_channel_list(channels):
    channels.insert(0, len(channels)) # gives elements and array and prevent leading 0 from dropping
    return ''.join([str(x) for x in channels])
def list_channel_hash(num)
Expand source code
def list_channel_hash(num):
    tmp = [int(x) for x in str(num)]
    return tmp[1:]
def read_hmm_from_hdf5(h5_file, hmm_id)
Expand source code
def read_hmm_from_hdf5(h5_file, hmm_id):
    with tables.open_file(h5_file, 'r') as hf5:
        h_str = 'hmm_%s' % hmm_id
        if h_str not in hf5.root or len(hf5.list_nodes('/'+h_str)) == 0:
            return None

        # print('Loading HMM %i from hdf5' % hmm_id)
        nodes = [x._v_name for x in hf5.list_nodes('/'+h_str)]
        tmp = hf5.root[h_str]
        stat_arrays = {}
        for k in nodes:
            stat_arrays[k] = tmp[k][:]

        PI = stat_arrays.pop('initial_distribution')
        A = stat_arrays.pop('transition')
        B = stat_arrays.pop('emission')

        rs = stat_arrays['row_id'].shape
        tmp = np.array([x.decode('utf-8') for x in stat_arrays['row_id'].ravel()])
        stat_arrays['row_id'] = tmp.reshape(rs)


        table = hf5.root.data_overview
        for row in table.where('hmm_id == id', condvars={'id':hmm_id}):
            params = {}
            for k in table.colnames:
                if table.coltypes[k] == 'string':
                    params[k] = row[k].decode('utf-8')
                    if '..' in params[k]:
                        params[k] = params[k].split('..')

                else:
                    params[k] = row[k]

            if isinstance(params['taste'], list):
                params['channel'] = list_channel_hash(params['channel'])

            return PI, A, B, stat_arrays, params
        else:
            raise ValueError('Parameters not found for hmm %i' % hmm_id)
def setup_hmm_hdf5(h5_file, infoParticle=tables.description.HMMInfoParticle)
Expand source code
def setup_hmm_hdf5(h5_file, infoParticle=HMMInfoParticle):
    if os.path.isfile(h5_file):
        return

    print('Initializing hdf5 store: %s' % h5_file)
    with tables.open_file(h5_file, 'a') as hf5:
        if 'data_overview' not in hf5.root:
            print('Initializing data_overview table in hdf5 store...')
            table = hf5.create_table('/', 'data_overview', infoParticle,
                                     'Parameters and goodness-of-fit info for HMMs in file')
            table.flush()
def write_hmm_to_hdf5(h5_file, hmm, params)
Expand source code
def write_hmm_to_hdf5(h5_file, hmm, params):
    params = deepcopy(params)
    hmm_id = hmm.hmm_id
    if 'hmm_id' in params and hmm_id is None:
        hmm.hmm_id = hmm_id = params['hmm_id']
    elif 'hmm_id' in params and hmm_id != params['hmm_id']:
        raise ValueError('ID of HMM %i does not match ID in params %i'
                         % (hmm_id, params['hmm_id']))
    else:
        pass

    if not os.path.isfile(h5_file):
        setup_hmm_hdf5(h5_file)

    print('\n' + '='*80)
    print('Writing HMM %s to hdf5 file...' % hmm_id)
    print(params)
    print('PI: %s' % repr(hmm.initial_distribution.shape))
    print('A: %s' % repr(hmm.transition.shape))
    print('B: %s' % repr(hmm.emission.shape))
    with tables.open_file(h5_file, 'a') as hf5:
        if hmm_id is None:
            ids = hf5.root.data_overview.col('hmm_id')
            tmp = np.where(np.diff(ids) > 1)[0]
            if len(ids) == 0:
                hmm_id = 0
            elif len(tmp) == 0:
                hmm_id = np.max(ids) + 1
            else:
                hmm_id = ids[tmp[0]] + 1

            hmm.hmm_id = hmm_id
            params['hmm_id'] = hmm_id
            print('HMM assigned id #%i' % hmm_id)

        h_str = 'hmm_%s' % hmm_id
        if h_str in hf5.root:
            print('Deleting existing data for %s...' % h_str)
            hf5.remove_node('/', h_str, recursive=True)

        print('Writing new data for %s' % h_str)
        hf5.create_group('/', h_str, 'Data for HMM #%i' % hmm_id)
        hf5.create_array('/'+h_str, 'initial_distribution',
                         hmm.initial_distribution)
        hf5.create_array('/'+h_str, 'transition', hmm.transition)
        hf5.create_array('/'+h_str, 'emission', hmm.emission)
        for k, v in hmm.stat_arrays.items():
            if not isinstance(v, np.ndarray):
                tmp_v = np.array(v)
            else:
                tmp_v = v

            hf5.create_array('/'+h_str, k, tmp_v)

        table = hf5.root.data_overview
        edited_rows = 0
        for row in table.where('hmm_id == id', condvars={'id': hmm_id}):
            print('Editing existing row in data_overview with new values for HMM %s' % hmm_id)
            print('New iteration: %i, fit_LL: %.4E, BIC: %.3E' % (hmm.iteration, hmm.fit_LL, hmm.BIC))
            print('Old iteration: %i, fit_LL: %.4E, BIC: %.3E' % (row['n_iterations'], row['log_likelihood'], row['BIC']))
            row['BIC'] = hmm.BIC
            row['cost'] = hmm.cost
            row['converged'] = hmm.converged
            row['fitted'] = hmm.fitted
            row['max_log_prob'] = hmm.max_log_prob
            row['log_likelihood'] = hmm.fit_LL
            row['n_iterations'] = hmm.iteration
            row.update()
            edited_rows += 1

        if edited_rows == 0:
            print('Creating new row in data_overview for HMM %s' % hmm_id)
            row = table.row
            for k,v in params.items():
                if table.coltypes[k] == 'string' and isinstance(v, list):
                    row[k] = '..'.join(v)
                elif isinstance(v, list) and k == 'channel':
                    print('channels: %s' % str(v))
                    row[k] = hash_channel_list(v)
                    print('channel hash: %i' % row[k])
                elif not isinstance(v, list):
                    row[k] = v
                else:
                    raise ValueError('Something fucked up')


            row['BIC'] = hmm.BIC
            row['cost'] = hmm.cost
            row['converged'] = hmm.converged
            row['fitted'] = hmm.fitted
            row['max_log_prob'] = hmm.max_log_prob
            row['log_likelihood'] = hmm.fit_LL
            row['n_iterations'] = hmm.iteration
            row.append()

        table.flush()
        hf5.flush()

    print('='*80+'\n')