diff --git a/pymc/database/__init__.py b/pymc/database/__init__.py index 5012e662a0..e271192fd3 100755 --- a/pymc/database/__init__.py +++ b/pymc/database/__init__.py @@ -32,7 +32,7 @@ """ -__modules__ = ['no_trace', 'txt', 'ram', 'pickle', 'sqlite', 'hdf5', 'hdf52', "__test_import__"] +__modules__ = ['no_trace', 'txt', 'ram', 'pickle', 'sqlite', 'hdf5', 'hdf5ea', "__test_import__"] from . import no_trace from . import txt @@ -49,4 +49,9 @@ except ImportError: pass +try: + from . import hdf5ea +except ImportError: + pass + diff --git a/pymc/database/hdf5ea.py b/pymc/database/hdf5ea.py new file mode 100644 index 0000000000..75ae14301d --- /dev/null +++ b/pymc/database/hdf5ea.py @@ -0,0 +1,339 @@ +"""HDF5 database module. + +Store the traces in an HDF5 array using pytables. + + +Implementation Notes +-------------------- + +This version only supports numeric objects, and stores them in +extentable HDF5 arrays. This allows the implementation to handle very +large data vectors. + + +Additional Dependencies +----------------------- + * HDF5 version 1.6.5, required by pytables. + * pytables version 2 and up. + +""" + +import os +import sys +import traceback +import warnings + +import numpy as np +import pymc +import tables + +from pymc.database import base, pickle +from pymc import six + +__all__ = ['Trace', 'Database', 'load'] + +warn_tally = """ +Error tallying %s, will not try to tally it again this chain. +Did you make all the same variables and step methods tallyable +as were tallyable last time you used the database file? + +Error: + +%s""" + + +############################################################################### + +class Trace(base.Trace): + """HDF5 trace.""" + + + def tally(self, chain): + """Adds current value to trace.""" + + arr = np.asarray(self._getfunc()) + arr = arr.reshape((1,) + arr.shape) + self.db._arrays[chain, self.name].append(arr) + + + # def __getitem__(self, index): + # """Mimic NumPy indexing for arrays.""" + # chain = self._chain + + # if chain is not None: + # tables = [self.db._gettables()[chain],] + # else: + # tables = self.db._gettables() + + # out = [] + # for table in tables: + # out.append(table.col(self.name)) + + # if np.isscalar(chain): + # return out[0][index] + # else: + # return np.hstack(out)[index] + + + def gettrace(self, burn=0, thin=1, chain=-1, slicing=None): + """Return the trace (last by default). + + :Parameters: + burn : integer + The number of transient steps to skip. + thin : integer + Keep one in thin. + chain : integer + The index of the chain to fetch. If None, return all chains. The + default is to return the last chain. + slicing : slice object + A slice overriding burn and thin assignement. + """ + + # XXX: handle chain == None case properly + + if chain is None: + chain = -1 + chain = self.db.chains[chain] + + arr = self.db._arrays[chain, self.name] + + if slicing is not None: + burn, stop, thin = slicing.start, slicing.stop, slicing.step + + if slicing is None or stop is None: + stop = arr.nrows + return np.asarray(arr.read(start=burn, stop=stop, step=thin)) + + __call__ = gettrace + + # def length(self, chain=-1): + # """Return the length of the trace. + + # :Parameters: + # chain : int or None + # The chain index. If None, returns the combined length of all chains. + # """ + # if chain is not None: + # tables = [self.db._gettables()[chain],] + # else: + # tables = self.db._gettables() + + # n = np.asarray([table.nrows for table in tables]) + # return n.sum() + + +############################################################################### + +class Database(pickle.Database): + """HDF5 database. + + Create an HDF5 file .h5. Each chain is stored in a group, + and the stochastics and deterministics are stored as extendable + arrays in each group. + """ + + + def __init__(self, dbname, dbmode='a', + dbcomplevel=0, dbcomplib='zlib', + **kwds): + """Create an HDF5 database instance, where samples are stored + in extendable arrays. + + :Parameters: + dbname : string + Name of the hdf5 file. + dbmode : {'a', 'w', 'r'} + File mode: 'a': append, 'w': overwrite, 'r': read-only. + dbcomplevel : integer (0-9) + Compression level, 0: no compression. + dbcomplib : string + Compression library (zlib, bzip2, lzo) + + :Notes: + * zlib has a good compression ratio, although somewhat slow, and + reasonably fast decompression. + * lzo is a fast compression library offering however a low compression + ratio. + * bzip2 has an excellent compression ratio but requires more CPU. + """ + + self.__name__ = 'hdf5ea' + self.__Trace__ = Trace + + self.dbname = dbname + self.mode = dbmode + + db_exists = os.path.exists(self.dbname) + self._h5file = tables.openFile(self.dbname, self.mode) + + default_filter = tables.Filters(complevel=dbcomplevel, complib=dbcomplib) + if self.mode =='r' or (self.mode=='a' and db_exists): + self.filter = getattr(self._h5file, 'filters', default_filter) + else: + self.filter = default_filter + + self.trace_names = [] + self._traces = {} + # self._states = {} + self._chains = {} + self._arrays = {} + + # load existing data + existing_chains = [ gr for gr in self._h5file.listNodes("/") + if gr._v_name[:5] == 'chain' ] + + for chain in existing_chains: + nchain = int(chain._v_name[5:]) + self._chains[nchain] = chain + + names = [] + for array in chain._f_listNodes(): + name = array._v_name + self._arrays[nchain, name] = array + + if name not in self._traces: + self._traces[name] = Trace(name, db=self) + + names.append(name) + + self.trace_names.append(names) + + + @property + def chains(self): + return range(len(self._chains)) + + + @property + def nchains(self): + return len(self._chains) + + + # def connect_model(self, model): + # """Link the Database to the Model instance. + + # In case a new database is created from scratch, ``connect_model`` + # creates Trace objects for all tallyable pymc objects defined in + # `model`. + + # If the database is being loaded from an existing file, ``connect_model`` + # restore the objects trace to their stored value. + + # :Parameters: + # model : pymc.Model instance + # An instance holding the pymc objects defining a statistical + # model (stochastics, deterministics, data, ...) + # """ + + # # Changed this to allow non-Model models. -AP + # if isinstance(model, pymc.Model): + # self.model = model + # else: + # raise AttributeError('Not a Model instance.') + + # # Restore the state of the Model from an existing Database. + # # The `load` method will have already created the Trace objects. + # if hasattr(self, '_state_'): + # names = set() + # for morenames in self.trace_names: + # names.update(morenames) + # for name, fun in six.iteritems(model._funs_to_tally): + # if name in self._traces: + # self._traces[name]._getfunc = fun + # names.remove(name) + # if len(names) > 0: + # raise RuntimeError("Some objects from the database" + # + "have not been assigned a getfunc: %s" + # % ', '.join(names)) + + + def _initialize(self, funs_to_tally, length): + """Create a group named ``chain#`` to store all data for this chain.""" + + chain = self.nchains + self._chains[chain] = self._h5file.createGroup( + '/', 'chain%d' % chain, 'chain #%d' % chain) + + for name, fun in six.iteritems(funs_to_tally): + + arr = np.asarray(fun()) + + assert arr.dtype != np.dtype('object') + + array = self._h5file.createEArray( + self._chains[chain], name, + tables.Atom.from_dtype(arr.dtype), (0,) + arr.shape, + filters=self.filter) + + self._arrays[chain, name] = array + self._traces[name] = Trace(name, getfunc=fun, db=self) + self._traces[name]._initialize(self.chains, length) + + self.trace_names.append(funs_to_tally.keys()) + + + def tally(self, chain=-1): + chain = self.chains[chain] + for name in self.trace_names[chain]: + try: + self._traces[name].tally(chain) + self._arrays[chain, name].flush() + except: + cls, inst, tb = sys.exc_info() + warnings.warn(warn_tally + % (name, ''.join(traceback.format_exception(cls, inst, tb)))) + self.trace_names[chain].remove(name) + + + + # def savestate(self, state, chain=-1): + # """Store a dictionnary containing the state of the Model and its + # StepMethods.""" + + # chain = self.chains[chain] + # if chain in self._states: + # self._states[chain] = state + # else: + # s = self._h5file.createVLArray(chain,'_state_',tables.ObjectAtom(),title='The saved state of the sampler',filters=self.filter) + # s.append(state) + # self._h5file.flush() + + + # def getstate(self, chain=-1): + # if len(self._chains)==0: + # return {} + # elif hasattr(self._chains[chain],'_state_'): + # if len(self._chains[chain]._state_)>0: + # return self._chains[chain]._state_[0] + # else: + # return {} + # else: + # return {} + + + def _finalize(self, chain=-1): + self._h5file.flush() + + def close(self): + self._h5file.close() + + + + +def load(dbname, dbmode='a'): + """Load an existing hdf5 database. + + Return a Database instance. + + :Parameters: + filename : string + Name of the hdf5 database to open. + mode : 'a', 'r' + File mode : 'a': append, 'r': read-only. + """ + if dbmode == 'w': + raise AttributeError("dbmode='w' not allowed for load.") + db = Database(dbname, dbmode=dbmode) + + return db