import numpy as np
import os, sys
import importlib
import itertools as itt
from itertools import repeat
from .potential_manager import Potential, Potential_NoMP, NN_Potential
from .imp_samp import *
[docs]
class ImpSampManager:
"""Imports and Wraps around the user-provided trial wfn and (optionally) the first and second derivatives.
Parallelized using multiprocessing, which is considered the default for pyvibdmc."""
def __init__(self,
trial_function,
trial_directory,
python_file,
pot_manager,
pass_timestep=False,
new_pool_num_cores=None,
deriv_function=None,
trial_kwargs=None,
deriv_kwargs=None):
self.trial_func = trial_function
self.trial_dir = trial_directory
self.python_file = python_file
self.deriv_func = deriv_function
self.trial_kwargs = trial_kwargs
self.deriv_kwargs = deriv_kwargs
self.pot_manager = pot_manager
self.pass_timestep = pass_timestep
self.nomp_pool_cores = new_pool_num_cores # Only when one wants to do multiprocessing importance sampling with noMP potential (like NN-DMC)
if self.pass_timestep:
self.ct = 0
self.trial_kwargs['timestep']=0
self.deriv_kwargs['timestep']=0
if isinstance(self.pot_manager, Potential):
self.pool = self.pot_manager.pool
self.num_cores = self.pot_manager.num_cores
self._reinit_pool()
elif (isinstance(self.pot_manager, Potential_NoMP) or isinstance(self.pot_manager,
NN_Potential)) and self.nomp_pool_cores is not None:
"""Really only for NN_Potential using multi-core imp samp"""
from multiprocessing import Pool
self.pool = Pool(self.nomp_pool_cores)
self.num_cores = self.nomp_pool_cores
self._reinit_pool()
def __getstate__(self):
"""Since pool is a variable inside this class, the object cannot be pickled + used for multiprocessing.
The solution is to use __getstate__/__setstate, which will delete the pool and pot_manager internally
when needed."""
self_dict = self.__dict__.copy()
del self_dict['pool']
del self_dict['pot_manager']
return self_dict
def __setstate__(self, state):
self.__dict__.update(state)
def _init_wfn_mp(self, chdir=False):
"""Import the python functions of the pool workers on the pool.
For when you have a Potential object. For simplicity, efficiency, and restrictiveness, the imp samp stuff
should be in the same directory as the potential energy callers."""
if chdir:
# For main process
cur_dir = os.getcwd()
os.chdir(self.trial_dir)
sys.path.insert(0, os.getcwd())
module = self.python_file.split(".")[0]
x = importlib.import_module(module)
self.trial_wfn = getattr(x, self.trial_func)
if self.deriv_func is None:
# bool for pyvibdmc sim code to do both derivs at once.
self.all_finite = True
self.derivs = ImpSamp.finite_diff
else: # Supplied derivatives, just import them
self.all_finite = False
self.derivs = getattr(x, self.deriv_func)
if chdir:
# For main process
os.chdir(cur_dir)
def _reinit_pool(self):
"""Imports the appropriate modules that are in the potential_manager directory"""
empt = [() for _ in range(self.num_cores)]
self._init_wfn_mp(chdir=True)
self.pot_manager.pool.starmap(self._init_wfn_mp, empt, chunksize=1)
[docs]
def call_trial(self, cds):
"""Get trial wave function using multiprocessing"""
cds = np.array_split(cds, self.pot_manager.num_cores)
if self.trial_kwargs is not None:
res = self.pool.starmap(self.trial_wfn, zip(cds, repeat(self.trial_kwargs, len(cds))))
else:
res = self.pool.map(self.trial_wfn, cds)
res = np.concatenate(res)
return res
[docs]
def call_trial_no_mp(self, cds):
"""For call_derivs (finite diff), get trial wave function.
Still used in the mp.pool context, just doesn't call pool itself"""
if self.trial_kwargs is None:
trial = self.trial_wfn(cds)
else:
trial = self.trial_wfn(cds, self.trial_kwargs)
return trial
[docs]
def call_derivs(self, cds):
"""For when derivatives are not supplied, call finite difference function.
This is still parallelized."""
cds = np.array_split(cds, self.num_cores)
if self.all_finite:
# Divide by trial wfn if finite difference
derivz, sderivz, trial_wfn = zip(*self.pool.starmap(self.derivs,
zip(cds, repeat(self.call_trial_no_mp, len(cds)))))
derivz = np.concatenate(derivz) / np.concatenate(trial_wfn)[:, np.newaxis, np.newaxis]
sderivz = np.concatenate(sderivz) / np.concatenate(trial_wfn)[:, np.newaxis, np.newaxis]
else:
if self.deriv_kwargs is None:
derivz, sderivz = zip(*self.pool.map(self.derivs, cds))
else:
derivz, sderivz = zip(*self.pool.starmap(self.derivs, zip(cds, repeat(self.deriv_kwargs, len(cds)))))
derivz = np.concatenate(derivz)
sderivz = np.concatenate(sderivz)
##Testing
# fderivz, fsderivz, trial_wfn = ImpSamp.finite_diff(np.concatenate(cds), trial_func=self.call_trial_no_mp)
# fderivz = fderivz / trial_wfn[:, np.newaxis, np.newaxis]
# fsderivz = fsderivz / trial_wfn[:, np.newaxis, np.newaxis]
# print('deriv:', np.average(fderivz-derivz))
# print('sderiv:', np.average(fsderivz-sderivz))
# print('hi')
if self.pass_timestep:
self.ct+=1
self.trial_kwargs['timestep'] = self.ct
self.deriv_kwargs['timestep'] = self.ct
return derivz, sderivz
[docs]
class ImpSampManager_NoMP:
"""Version of the manager that does not use any multiprocessing. If we ever evaluate the trial wfns with GPUs
this could be useful. Could also be useful if multiprocessing is incompatible with your workflow."""
def __init__(self,
trial_function,
trial_directory,
python_file,
chdir=False,
pass_timestep=False,
deriv_function=None,
trial_kwargs=None,
deriv_kwargs=None, ):
self.trial_fuc = trial_function
self.trial_dir = trial_directory
self.python_file = python_file
self.pass_timestep = pass_timestep
self.deriv_func = deriv_function
self.trial_kwargs = trial_kwargs
self.deriv_kwargs = deriv_kwargs
self.chdir = chdir
if self.pass_timestep:
self.ct = 0
self.trial_kwargs['timestep']=0
self.deriv_kwargs['timestep']=0
self._import_modz()
def _import_modz(self):
self._curdir = os.getcwd()
os.chdir(self.trial_dir)
sys.path.insert(0, os.getcwd())
module = self.python_file.split(".")[0]
x = importlib.import_module(module)
self.trial = getattr(x, self.trial_fuc)
if self.deriv_func is None:
self.all_finite = True
self.derivs = ImpSamp.finite_diff
else: # Supplied derivatives, just import them
self.all_finite = False
self.derivs = getattr(x, self.deriv_func)
os.chdir(self._curdir)
[docs]
def call_imp_func(self, func, cds, func_kwargs=None):
"""Convenience function for trial, deriv, and sderiv so I don't have to have triplicates of code"""
if self.chdir:
os.chdir(self.trial_dir)
if func_kwargs is None:
ret_val = func(cds)
else:
ret_val = func(cds, func_kwargs)
if self.chdir:
os.chdir(self._curdir)
return ret_val
[docs]
def call_trial(self, cds):
"""Call trial wave function."""
trial = self.call_imp_func(self.trial, cds, self.trial_kwargs)
return trial
[docs]
def call_derivs(self, cds):
"""For when derivatives are not supplied, call finite difference function.
Returns derivatives divided by psi already"""
if self.all_finite:
derivz, sderivz, trial_wfn = self.derivs(cds, trial_func=self.call_trial)
derivz = derivz / trial_wfn[:, np.newaxis, np.newaxis]
sderivz = sderivz / trial_wfn[:, np.newaxis, np.newaxis]
else:
derivz, sderivz = self.call_imp_func(self.derivs, cds, self.deriv_kwargs)
###Testing
# fderivz, fsderivz, trial_wfn = ImpSamp.finite_diff(cds, trial_func=self.call_trial)
# fderivz = fderivz / trial_wfn[:, np.newaxis, np.newaxis]
# fsderivz = fsderivz / trial_wfn[:, np.newaxis, np.newaxis]
# max_d = np.average(fderivz - derivz)
# max_sd = np.average(fsderivz - sderivz)
# print(f"Avg Psi: {max_d}")
# print(f"Avg 2Psi: {max_sd}")
###/Testing
if self.pass_timestep:
self.ct+=1
self.trial_kwargs['timestep'] = self.ct
self.deriv_kwargs['timestep'] = self.ct
return derivz, sderivz