MDTerp module¶
MDTerp.utils.py – Auxiliary utility functions for MDTerp package.
dominant_feature(all_result_loc, n=0)
¶
Function summarizing MDTerp results for all the transitions present in the dataset.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
all_results_loc |
str |
Location to save MDTerp results. |
required |
importance_coverage |
float |
For a specific transition, sets a cutoff for the sum of the most important features in descending order. |
required |
Returns:
Type | Description |
---|---|
dict |
Dictionary with keys representing detected transitions. E.g., key '3_8' means transition between index 3 and index 8 according to the prob array. Values are lists representing feature importance using the length of the list equaling the number of features in the provided dataset. |
Source code in MDTerp/utils.py
def dominant_feature(all_result_loc: str, n: int = 0) -> dict:
"""
Function summarizing MDTerp results for all the transitions present in the dataset.
Args:
all_results_loc (str): Location to save MDTerp results.
importance_coverage (float): For a specific transition, sets a cutoff for the sum of the most important features in descending order.
Returns:
dict: Dictionary with keys representing detected transitions. E.g., key '3_8' means transition between index 3 and index 8 according to the prob array. Values are lists representing feature importance using the length of the list equaling the number of features in the provided dataset.
"""
with open(all_result_loc, 'rb') as f:
loaded_dict = pickle.load(f)
for ii in loaded_dict:
tmp_c = loaded_dict[ii][1]
loaded_dict[ii] = np.argsort(tmp_c)[::-1][n]
return loaded_dict
input_summary(logger, numeric_dict, angle_dict, sin_cos_dict, save_dir, np_data)
¶
Function for summarizing user-provided input data in Python Logger.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
logger |
Logger |
Logger object created using Python's built-in logging module. |
required |
numeric_dict |
dict |
Python dictionary, each key represents the name of a numeric feature (non-periodic). Values should be lists with a single element using the index of the corresponding numpy array in np_data. |
required |
angle_dict |
dict |
Python dictionary, each key represents the name of an angular feature in [-pi, pi]. Values should be lists with a single element with the index of the corresponding numpy array in np_data. |
required |
sin_cos_dict |
dict |
Python dictionary, each key represents the name of an angular feature. Values should be lists with two elements using the sine, cosine indices of the corresponding numpy array in np_data. |
required |
save_dir |
str |
Location to save MDTerp results. |
required |
np_data |
np.ndarray |
Numpy 2D array containing training data for the black-box model. Samples along rows and features along columns. |
required |
Returns:
Type | Description |
---|---|
None |
None |
Source code in MDTerp/utils.py
def input_summary(logger: Logger, numeric_dict: dict, angle_dict: dict, sin_cos_dict: dict, save_dir: str, np_data: np.ndarray) -> None:
"""
Function for summarizing user-provided input data in Python Logger.
Args:
logger (Logger): Logger object created using Python's built-in logging module.
numeric_dict (dict): Python dictionary, each key represents the name of a numeric feature (non-periodic). Values should be lists with a single element using the index of the corresponding numpy array in np_data.
angle_dict (dict): Python dictionary, each key represents the name of an angular feature in [-pi, pi]. Values should be lists with a single element with the index of the corresponding numpy array in np_data.
sin_cos_dict (dict): Python dictionary, each key represents the name of an angular feature. Values should be lists with two elements using the sine, cosine indices of the corresponding numpy array in np_data.
save_dir (str): Location to save MDTerp results.
np_data (np.ndarray): Numpy 2D array containing training data for the black-box model. Samples along rows and features along columns.
Returns:
None
"""
logger.info('MDTerp result location >>> ' + save_dir )
logger.info('Defined numeric features >>> ' + str(len(list(numeric_dict.keys()))) )
logger.info('Defined angle features >>> ' + str(len(list(angle_dict.keys()))) )
logger.info('Defined sin_cos features >>> ' + str(len(list(sin_cos_dict.keys()))) )
logger.info('Number of samples in blackbox model training data >>> ' + str(np_data.shape[0]) )
logger.info('Number of columns in blackbox model training data >>> ' + str(np_data.shape[1]) )
if np_data.shape[1] != len(list(numeric_dict.keys())) + len(list(angle_dict.keys())) + len(list(sin_cos_dict.keys()))*2:
logger.error('Assertion failure between provided feature dictionaries and input data!')
raise ValueError('Assertion failure between provided feature dictionaries and input data!')
logger.info(100*'-')
log_maker(save_dir)
¶
Function for creating a logger detailing MDTerp operations.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
save_dir |
str |
Location to save MDTerp results. |
required |
Returns:
Type | Description |
---|---|
Logger |
Logger object created using Python's built-in logging module. |
Source code in MDTerp/utils.py
def log_maker(save_dir: str) -> Logger:
"""
Function for creating a logger detailing MDTerp operations.
Args:
save_dir (str): Location to save MDTerp results.
Returns:
Logger: Logger object created using Python's built-in logging module.
"""
fmt = '%(asctime)s %(name)-15s %(levelname)-8s %(message)s'
datefmt='%m-%d-%y %H:%M:%S'
logging.basicConfig(level=logging.INFO,format=fmt,datefmt=datefmt,filename=save_dir+'/MDTerp_summary.log',filemode='w')
logger = logging.getLogger('initialization')
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
formatter = logging.Formatter(fmt,datefmt=datefmt)
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
logger.info(100*'-')
logger.info('Starting MDTerp...')
logger.info(100*'-')
return logger
picker_fn(prob, threshold, point_max)
¶
Function for picking points at the transition state ensemble. Uses provided data and metastable state probability from the black-box model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
prob |
np.ndarray |
Numpy 2D array containing metastable state prediction probabilities from the black-box model. Rows represent samples, and the number of columns represents the number of states. Each row should sum to 1. |
required |
threshold |
float |
Threshold for identifying if a sample belongs to a transition state predicted by the black-box model. If the metastable state probability > threshold for two different classes for a specific sample, it's suitable for analysis. |
required |
point_max |
int |
If too many suitable points exist for a specific transition (e.g., transition between metastable state 3 and 8), point_max sets the maximum number of points chosen for analysis. Points are chosen from a uniform distribution. |
required |
Returns:
Type | Description |
---|---|
dict |
Dictionary with keys representing detected transitions. E.g., key '3_8' means transition between index 3 and index 8 according to the prob array. Values represent chosen samples/rows in the provided dataset undergoing this transition. |
Source code in MDTerp/utils.py
def picker_fn(prob: np.ndarray, threshold: float, point_max: int) -> dict:
"""
Function for picking points at the transition state ensemble. Uses provided data and metastable state probability from the black-box model.
Args:
prob (np.ndarray): Numpy 2D array containing metastable state prediction probabilities from the black-box model. Rows represent samples, and the number of columns represents the number of states. Each row should sum to 1.
threshold (float): Threshold for identifying if a sample belongs to a transition state predicted by the black-box model. If the metastable state probability > threshold for two different classes for a specific sample, it's suitable for analysis.
point_max (int): If too many suitable points exist for a specific transition (e.g., transition between metastable state 3 and 8), point_max sets the maximum number of points chosen for analysis. Points are chosen from a uniform distribution.
Returns:
dict: Dictionary with keys representing detected transitions. E.g., key '3_8' means transition between index 3 and index 8 according to the prob array. Values represent chosen samples/rows in the provided dataset undergoing this transition.
"""
transition_dict = defaultdict(list)
for i in range(prob.shape[0]):
sorted_ind = np.sort(np.argsort(prob[i, :])[::-1][:2])
sorted_val = np.sort(prob[i, :])[::-1][:2]
if (sorted_val[0]>=threshold) and (sorted_val[1]>=threshold):
transition_dict[str(sorted_ind[0]) + '_' + str(sorted_ind[1])].append(i)
for i in transition_dict.keys():
transition_dict[i] = np.random.choice(transition_dict[i], size = min(point_max, len(transition_dict[i])), replace = False)
return transition_dict
transition_summary(all_result_loc, importance_coverage=0.8)
¶
Function summarizing MDTerp results for all the transitions present in the dataset.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
all_results_loc |
str |
Location to save MDTerp results. |
required |
importance_coverage |
float |
For a specific transition, sets a cutoff for the sum of the most important features in descending order. |
0.8 |
Returns:
Type | Description |
---|---|
dict |
Dictionary with keys representing detected transitions. E.g., key '3_8' means transition between index 3 and index 8 according to the prob array. Values are lists representing mean and standard deviations of the feature importance using the length of the list equaling the number of features in the provided dataset for that transition. |
Source code in MDTerp/utils.py
def transition_summary(all_result_loc: str, importance_coverage: float = 0.8) -> dict:
"""
Function summarizing MDTerp results for all the transitions present in the dataset.
Args:
all_results_loc (str): Location to save MDTerp results.
importance_coverage (float): For a specific transition, sets a cutoff for the sum of the most important features in descending order.
Returns:
dict: Dictionary with keys representing detected transitions. E.g., key '3_8' means transition between index 3 and index 8 according to the prob array. Values are lists representing mean and standard deviations of the feature importance using the length of the list equaling the number of features in the provided dataset for that transition.
"""
with open(all_result_loc, 'rb') as f:
loaded_dict = pickle.load(f)
# Save all the unique transitions
transitions = []
for ii in loaded_dict:
transitions.append(loaded_dict[ii][0])
# Save summary results for each transition
summary_imp = {}
for ii in np.unique(transitions):
summary_imp[ii] = []
for ii in loaded_dict:
summary_imp[loaded_dict[ii][0]].append(loaded_dict[ii][1])
for ii in summary_imp:
tmp_a = np.mean(summary_imp[ii], axis = 0)
# Normalize results for the transition
normalization = np.sum(tmp_a)
tmp_a = tmp_a/normalization
tmp_b = np.std(summary_imp[ii], axis = 0)/normalization
trim_args = np.argsort(tmp_a)[::-1]
trim_vals = np.sort(tmp_a)[::-1]
# Discard irrelevant features for each transition, based on the importance_coverage hyperparameter
cutoff_k = 0
current_coverage = 0
while current_coverage < importance_coverage:
try:
current_coverage += trim_vals[cutoff_k]
cutoff_k += 1
except:
break
tmp_a[trim_args[cutoff_k:]] = 0
tmp_b[trim_args[cutoff_k:]] = 0
summary_imp[ii] = [tmp_a, tmp_b]
return summary_imp