MDTerp module¶
MDTerp.base.run - Main class for implementing MDTERP.
Source code in MDTerp/base.py
class run:
"""
MDTerp.base.run - Main class for implementing MDTERP.
"""
def __init__(self, np_data: np.ndarray, model_function_loc: str, numeric_dict: dict = {}, angle_dict: dict = {}, sin_cos_dict:dict = {}, save_dir: str = './results/', prob_threshold: float = 0.48, point_max: int = 50, num_samples: int = 10000, cutoff: int = 15, seed: int = 0, unfaithfulness_threshold: float = 0.01, periodicity_upper: float = np.pi, periodicity_lower: float = -np.pi, alpha: float = 1.0) -> None:
"""
Constructor for the MDTerp.base.run class.
Args:
np_data (np.ndarray): Black-box training data.
model_function_loc (str): Location of a human-readable file containing two functions called 'load_model()', and 'run_model()'. 'load_model' must not take any arguments and should return the black-box model. 'run_model' must be a function that takes two arguments: the model and data, and returns metastable state probabilities. Go to https://shams-mehdi.github.io/MDTerp/docs/examples/ for example files.
numeric_dict (dict): Python dictionary, each key represents the name of a numeric feature (non-periodic). Values should be lists with a single element with the index of the corresponding numpy array in np_data.
angle_dict (dict): Python dictionary, each key represents the name of an angular feature in [-pi, pi]. Values should be lists with a single element with the index of the corresponding numpy array in np_data.
sin_cos_dict (dict): Python dictionary, each key represents the name of an angular feature. Values should be lists with two elements using the sine, cosine indices of the corresponding numpy array in np_data.
save_dir (str): Location to save MDTerp results.
prob_threshold (float): Threshold for identifying if a sample belongs to a transition state predicted by the black-box model. If metastable state probability > threshold for two different classes for a specific sample, it's suitable for analysis (Default: 0.48).
point_max (int): If too many suitable samples exist for a specific transition (e.g., transition between metastable state 3 and 8), point_max sets the maximum number of points chosen for further analysis. Points are chosen from a uniform distribution (Default: 50).
num_samples (int): Size of the perturbed neighborhood (Default: 10000). Ad hoc rule: should be proportional to the square root of the number of features.
cutoff (int): Maximum number of features kept for the final round of MDTerp and forward feature selection (use to improve compute time: when too many features are in the dataset and a priori it is known it is unlikely that more features than set by cutoff will be relevant).
seed (int): Random seed.
unf_threshold (float): Hyperparameter that sets a lower limit on surrogate model unfaithafulness (U). Forward feature selection ends when unfaithfulness reaches lower than this threshold.
periodicity_upper (float): Sets periodicity of the angular features (Default: numpy.pi).
periodicity_lower (float): Sets periodicity of the angular features (Default: -numpy.py).
alpha (float): L2 norm of Ridge regression (Default: 1.0).
Returns:
None
"""
# Initialization
os.makedirs(save_dir, exist_ok = True)
logger = log_maker(save_dir)
input_summary(logger, numeric_dict, angle_dict, sin_cos_dict, save_dir, np_data)
# Load black-box model
logger.info('Loading blackbox model from file >>> ' + model_function_loc)
with open(model_function_loc, 'r') as file:
func_code = file.read()
local_ns = {}
exec(func_code, globals(), local_ns)
model = local_ns["load_model"]()
logger.info("Model loaded!")
# Identify transition states for given/training dataset
state_probabilities = local_ns["run_model"](model, np_data)
points = picker_fn(state_probabilities, prob_threshold, point_max)
logger.info("Number of state transitions detected >>> " + str(len(list(points.keys()))))
logger.info("Probability threshold, maximum number of points per transition >>> " + str(prob_threshold) + ", " + str(point_max) )
if len(list(points.keys())) == 0:
logger.info("No transition detected. Check hyperparamters!")
raise ValueError("No transition detected. Check hyperparameters!")
logger.info(100*'-')
# Loop over all the transitions
importance_master = {}
for transition in points:
logger.info("Starting transition >>> " + transition)
for point in range(len(points[transition])):
index = points[transition][point]
feature_type_indices, indices_names = generate_neighborhood(save_dir, numeric_dict, angle_dict, sin_cos_dict, np_data, index, seed, num_samples, np.array([]), periodicity_upper, periodicity_lower)
state_probabilities2 = local_ns["run_model"](model, np.load(save_dir + 'DATA/make_prediction.npy'))
TERP_dat = np.load(save_dir + 'DATA/TERP_dat.npy')
selected_features = init_model(TERP_dat, state_probabilities2, cutoff, feature_type_indices, seed, alpha)
generate_neighborhood(save_dir, numeric_dict, angle_dict, sin_cos_dict, np_data, index, seed, num_samples, selected_features, periodicity_upper, periodicity_lower)
state_probabilities3 = local_ns["run_model"](model, np.load(save_dir + 'DATA_2/make_prediction.npy'))
TERP_dat = np.load(save_dir + 'DATA_2/TERP_dat.npy')
importance_0 = final_model(TERP_dat, state_probabilities3, unfaithfulness_threshold, feature_type_indices, selected_features, seed)
importance = make_result(feature_type_indices, indices_names, importance_0)
importance_master[index] = [transition, importance]
logger.info("Completed generating " + str(point + 1) + "/" + str(len(points[transition])) + " results!" + " First round features kept >>> " + str(len(selected_features)) + ", Final round features kept >>> " + str(np.nonzero(importance)[0].shape[0]))
logger.info(100*'_')
np.save(save_dir + 'MDTerp_feature_names.npy', indices_names)
with open(save_dir + 'MDTerp_results_all.pkl', 'wb') as f:
pickle.dump(importance_master, f)
logger.info("Feature names saved at >>> " + save_dir + 'MDTerp_feature_names.npy')
logger.info("All results saved at >>> " + save_dir + 'MDTerp_results_all.pkl')
shutil.rmtree(save_dir + 'DATA')
shutil.rmtree(save_dir + 'DATA_2')
logger.info("Completed!!!")
# Flush and close logger
for handler in logger.handlers[:]:
handler.close()
logger.removeHandler(handler)
__init__(self, np_data, model_function_loc, numeric_dict={}, angle_dict={}, sin_cos_dict={}, save_dir='./results/', prob_threshold=0.48, point_max=50, num_samples=10000, cutoff=15, seed=0, unfaithfulness_threshold=0.01, periodicity_upper=3.141592653589793, periodicity_lower=-3.141592653589793, alpha=1.0)
special
¶
Constructor for the MDTerp.base.run class.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
np_data |
np.ndarray |
Black-box training data. |
required |
model_function_loc |
str |
Location of a human-readable file containing two functions called 'load_model()', and 'run_model()'. 'load_model' must not take any arguments and should return the black-box model. 'run_model' must be a function that takes two arguments: the model and data, and returns metastable state probabilities. Go to https://shams-mehdi.github.io/MDTerp/docs/examples/ for example files. |
required |
numeric_dict |
dict |
Python dictionary, each key represents the name of a numeric feature (non-periodic). Values should be lists with a single element with the index of the corresponding numpy array in np_data. |
{} |
angle_dict |
dict |
Python dictionary, each key represents the name of an angular feature in [-pi, pi]. Values should be lists with a single element with the index of the corresponding numpy array in np_data. |
{} |
sin_cos_dict |
dict |
Python dictionary, each key represents the name of an angular feature. Values should be lists with two elements using the sine, cosine indices of the corresponding numpy array in np_data. |
{} |
save_dir |
str |
Location to save MDTerp results. |
'./results/' |
prob_threshold |
float |
Threshold for identifying if a sample belongs to a transition state predicted by the black-box model. If metastable state probability > threshold for two different classes for a specific sample, it's suitable for analysis (Default: 0.48). |
0.48 |
point_max |
int |
If too many suitable samples exist for a specific transition (e.g., transition between metastable state 3 and 8), point_max sets the maximum number of points chosen for further analysis. Points are chosen from a uniform distribution (Default: 50). |
50 |
num_samples |
int |
Size of the perturbed neighborhood (Default: 10000). Ad hoc rule: should be proportional to the square root of the number of features. |
10000 |
cutoff |
int |
Maximum number of features kept for the final round of MDTerp and forward feature selection (use to improve compute time: when too many features are in the dataset and a priori it is known it is unlikely that more features than set by cutoff will be relevant). |
15 |
seed |
int |
Random seed. |
0 |
unf_threshold |
float |
Hyperparameter that sets a lower limit on surrogate model unfaithafulness (U). Forward feature selection ends when unfaithfulness reaches lower than this threshold. |
required |
periodicity_upper |
float |
Sets periodicity of the angular features (Default: numpy.pi). |
3.141592653589793 |
periodicity_lower |
float |
Sets periodicity of the angular features (Default: -numpy.py). |
-3.141592653589793 |
alpha |
float |
L2 norm of Ridge regression (Default: 1.0). |
1.0 |
Returns:
Type | Description |
---|---|
None |
None |
Source code in MDTerp/base.py
def __init__(self, np_data: np.ndarray, model_function_loc: str, numeric_dict: dict = {}, angle_dict: dict = {}, sin_cos_dict:dict = {}, save_dir: str = './results/', prob_threshold: float = 0.48, point_max: int = 50, num_samples: int = 10000, cutoff: int = 15, seed: int = 0, unfaithfulness_threshold: float = 0.01, periodicity_upper: float = np.pi, periodicity_lower: float = -np.pi, alpha: float = 1.0) -> None:
"""
Constructor for the MDTerp.base.run class.
Args:
np_data (np.ndarray): Black-box training data.
model_function_loc (str): Location of a human-readable file containing two functions called 'load_model()', and 'run_model()'. 'load_model' must not take any arguments and should return the black-box model. 'run_model' must be a function that takes two arguments: the model and data, and returns metastable state probabilities. Go to https://shams-mehdi.github.io/MDTerp/docs/examples/ for example files.
numeric_dict (dict): Python dictionary, each key represents the name of a numeric feature (non-periodic). Values should be lists with a single element with the index of the corresponding numpy array in np_data.
angle_dict (dict): Python dictionary, each key represents the name of an angular feature in [-pi, pi]. Values should be lists with a single element with the index of the corresponding numpy array in np_data.
sin_cos_dict (dict): Python dictionary, each key represents the name of an angular feature. Values should be lists with two elements using the sine, cosine indices of the corresponding numpy array in np_data.
save_dir (str): Location to save MDTerp results.
prob_threshold (float): Threshold for identifying if a sample belongs to a transition state predicted by the black-box model. If metastable state probability > threshold for two different classes for a specific sample, it's suitable for analysis (Default: 0.48).
point_max (int): If too many suitable samples exist for a specific transition (e.g., transition between metastable state 3 and 8), point_max sets the maximum number of points chosen for further analysis. Points are chosen from a uniform distribution (Default: 50).
num_samples (int): Size of the perturbed neighborhood (Default: 10000). Ad hoc rule: should be proportional to the square root of the number of features.
cutoff (int): Maximum number of features kept for the final round of MDTerp and forward feature selection (use to improve compute time: when too many features are in the dataset and a priori it is known it is unlikely that more features than set by cutoff will be relevant).
seed (int): Random seed.
unf_threshold (float): Hyperparameter that sets a lower limit on surrogate model unfaithafulness (U). Forward feature selection ends when unfaithfulness reaches lower than this threshold.
periodicity_upper (float): Sets periodicity of the angular features (Default: numpy.pi).
periodicity_lower (float): Sets periodicity of the angular features (Default: -numpy.py).
alpha (float): L2 norm of Ridge regression (Default: 1.0).
Returns:
None
"""
# Initialization
os.makedirs(save_dir, exist_ok = True)
logger = log_maker(save_dir)
input_summary(logger, numeric_dict, angle_dict, sin_cos_dict, save_dir, np_data)
# Load black-box model
logger.info('Loading blackbox model from file >>> ' + model_function_loc)
with open(model_function_loc, 'r') as file:
func_code = file.read()
local_ns = {}
exec(func_code, globals(), local_ns)
model = local_ns["load_model"]()
logger.info("Model loaded!")
# Identify transition states for given/training dataset
state_probabilities = local_ns["run_model"](model, np_data)
points = picker_fn(state_probabilities, prob_threshold, point_max)
logger.info("Number of state transitions detected >>> " + str(len(list(points.keys()))))
logger.info("Probability threshold, maximum number of points per transition >>> " + str(prob_threshold) + ", " + str(point_max) )
if len(list(points.keys())) == 0:
logger.info("No transition detected. Check hyperparamters!")
raise ValueError("No transition detected. Check hyperparameters!")
logger.info(100*'-')
# Loop over all the transitions
importance_master = {}
for transition in points:
logger.info("Starting transition >>> " + transition)
for point in range(len(points[transition])):
index = points[transition][point]
feature_type_indices, indices_names = generate_neighborhood(save_dir, numeric_dict, angle_dict, sin_cos_dict, np_data, index, seed, num_samples, np.array([]), periodicity_upper, periodicity_lower)
state_probabilities2 = local_ns["run_model"](model, np.load(save_dir + 'DATA/make_prediction.npy'))
TERP_dat = np.load(save_dir + 'DATA/TERP_dat.npy')
selected_features = init_model(TERP_dat, state_probabilities2, cutoff, feature_type_indices, seed, alpha)
generate_neighborhood(save_dir, numeric_dict, angle_dict, sin_cos_dict, np_data, index, seed, num_samples, selected_features, periodicity_upper, periodicity_lower)
state_probabilities3 = local_ns["run_model"](model, np.load(save_dir + 'DATA_2/make_prediction.npy'))
TERP_dat = np.load(save_dir + 'DATA_2/TERP_dat.npy')
importance_0 = final_model(TERP_dat, state_probabilities3, unfaithfulness_threshold, feature_type_indices, selected_features, seed)
importance = make_result(feature_type_indices, indices_names, importance_0)
importance_master[index] = [transition, importance]
logger.info("Completed generating " + str(point + 1) + "/" + str(len(points[transition])) + " results!" + " First round features kept >>> " + str(len(selected_features)) + ", Final round features kept >>> " + str(np.nonzero(importance)[0].shape[0]))
logger.info(100*'_')
np.save(save_dir + 'MDTerp_feature_names.npy', indices_names)
with open(save_dir + 'MDTerp_results_all.pkl', 'wb') as f:
pickle.dump(importance_master, f)
logger.info("Feature names saved at >>> " + save_dir + 'MDTerp_feature_names.npy')
logger.info("All results saved at >>> " + save_dir + 'MDTerp_results_all.pkl')
shutil.rmtree(save_dir + 'DATA')
shutil.rmtree(save_dir + 'DATA_2')
logger.info("Completed!!!")
# Flush and close logger
for handler in logger.handlers[:]:
handler.close()
logger.removeHandler(handler)