generated from Quantum-Accelerators/template
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
172 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,4 +35,3 @@ def sella_wrapper( | |
) | ||
if traj_file: | ||
traj.close() | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
mace: | ||
model_path: 'MACE_model.model' | ||
model_path: 'tests/MACE_model_cpu.model' | ||
newtonnet: | ||
model_path: 'tests/best_model_state.tar' | ||
config_path: 'tests/config0.yml' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
general: | ||
me: /global/home/users/ericyuan/20230310_Transition1x/config0.yml # path to this file | ||
device: [cuda:0, cuda:1] # cpu / cuda:0 / [cuda:0, cuda:1, cuda:2, cuda:3] / list of cuda | ||
driver: /global/home/users/ericyuan/NewtonNet/cli/newtonnet_train # path to the run script | ||
output: [/global/home/users/ericyuan/20230310_Transition1x/output0, 1] # path and iterator for the output directory | ||
|
||
data: | ||
train_path: /global/scratch/users/ericyuan/Transtion1x/train_data.npz # path to the training data | ||
val_path: /global/scratch/users/ericyuan/Transtion1x/val_data.npz # path to the validation data | ||
test_path: /global/scratch/users/ericyuan/Transtion1x/test_data.npz # path to the test data | ||
train_size: -1 # -1 for all | ||
test_size: 1 | ||
val_size: 1000 | ||
cutoff: 5.0 # cutoff radius | ||
random_states: 90 # random seed for data splitting | ||
|
||
model: | ||
pre_trained: /global/home/users/ericyuan/20230310_Transition1x/0.1-iv+ln-cont/training_1/models/best_model_state.tar # path to the previously trained model for warm-up start | ||
activation: swish # activation function: swish, ssp, relu, ... | ||
requires_dr: True # if derivative of the output is required | ||
w_energy: 1.0 # the weight of energy loss in the loss function | ||
w_force: 20.0 # EDITED from 100 # the weight of force loss in the loss function | ||
wf_decay: 0.0 # rate of exponential decay of force wight by training epoch | ||
w_f_mag: 0.0 # the weight of force magnitude loss in the loss function | ||
lambda_l1: 0.0 # the coefficient of L1 regularization | ||
w_f_dir: 1.0 # the weight of force direction loss in the loss function | ||
resolution: 20 # number of basis functions that describe interatomic distances | ||
n_features: 128 # number of features | ||
max_z: 10 # maximum atomic number in the chemical systems | ||
n_interactions: 3 # number of interaction blocks of newtonnet | ||
cutoff_network: poly # the cutoff function: poly (polynomial), cosine | ||
normalize_atomic: True # EDITED from false # if True the atomic energy needs to be inverse normalized, otherwise total energy will be scaled back | ||
shared_interactions: False # if True parameters of interaction blocks will be shared. | ||
normalize_filter: False # | ||
return_latent: False # EDITED from true # if True, the latent space will be returned for the future investigation | ||
double_update_latent: True # EDITED from false | ||
layer_norm: True # EDITED from false normalize hidden layer with a 1D layer_norm function | ||
|
||
training: | ||
epochs: 1000 # number of times the entire training data will be shown to the model | ||
tr_batch_size: 100 # number of training points (snapshots) in a batch of data that is feed to the model | ||
val_batch_size: 100 # number of validation points (snapshots) in a batch of data that is feed to the model | ||
tr_rotations: 0 # number of times the training data needs to be randomly rotated (redundant for NewtonNet model) | ||
val_rotations: 0 # number of times the validation data needs to be randomly rotated (redundant for NewtonNet model) | ||
tr_frz_rot: False # if True, fixed rotations matrix will be used at each epoch | ||
val_frz_rot: False # | ||
tr_keep_original: True # if True, the original orientation of data will be preserved as part of training set (beside other rotations) | ||
val_keep_original: True # | ||
shuffle: True # shuffle training data before each epoch | ||
drop_last: True # if True, drop the left over data points that are less than a full batch size | ||
lr: 1.0e-3 # learning rate | ||
lr_scheduler: [plateau, 15, 30, 0.7, 1.0e-6] # the learning rate decay based on the plateau algorithm: n_epoch_averaging, patience, decay_rate, stop_lr | ||
# lr_scheduler: [decay, 0.05] # the learning rate decay based on exponential decay: the rate of decay | ||
weight_decay: 0 # the l2 norm | ||
dropout: 0.0 # dropout between 0 and 1 | ||
|
||
hooks: | ||
vismolvector3d: False # if the latent force vectors need to be visualized (only works when the return_latent is on) | ||
|
||
checkpoint: | ||
log: 1 # log the results every this many epochs | ||
val: 1 # evaluate the performance on the validation set every this many epochs | ||
test: 1 # evaluate the performance on the test set every this many epochs | ||
model: 10 # save the model every this many epochs | ||
verbose: False # verbosity of the logging | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
import pytest | ||
from calcs import load_config, calc_mace, calc | ||
from mace.calculators import MACECalculator | ||
from newtonnet.utils.ase_interface import MLAseCalculator | ||
|
||
|
||
def test_load_config(): | ||
config = load_config('tests/config.yml') | ||
|
||
assert isinstance(config, dict) | ||
|
||
assert 'mace' in config or 'newtonnet' in config, "Config should have either 'mace' or 'newtonnet' key" | ||
|
||
if 'mace' in config: | ||
assert 'model_path' in config['mace'], "'mace' should have a 'model_path' key" | ||
|
||
if 'newtonnet' in config: | ||
assert 'model_path' in config['newtonnet'], "'newtonnet' should have a 'model_path' key" | ||
assert 'config_path' in config['newtonnet'], "'newtonnet' should have a 'config_path' key" | ||
|
||
|
||
def test_calc_mace(): | ||
ml_calculator = calc_mace() | ||
assert isinstance(ml_calculator, MACECalculator) | ||
|
||
|
||
def test_calc(): | ||
ml_calculator = calc() | ||
assert isinstance(ml_calculator, MLAseCalculator) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import os | ||
import pytest | ||
import shutil | ||
from ase import Atoms | ||
from ase.io import write | ||
from setup_images import setup_images | ||
|
||
|
||
@pytest.fixture | ||
def setup_test_environment(tmp_path): | ||
# Create temporary directory | ||
logdir = tmp_path / "log" | ||
os.makedirs(logdir, exist_ok=True) | ||
|
||
# Create a mock XYZ file with reactant and product structures | ||
xyz_r_p = tmp_path / "r_p.xyz" | ||
|
||
reactant = Atoms( | ||
symbols='CCHHCHH', | ||
positions=[ | ||
[1.4835950817281542, -1.0145410211301968, -0.13209027203235943], | ||
[0.8409564131524673, 0.018549610257914483, -0.07338809662321308], | ||
[-0.6399757891931867, 0.01763740851518944, 0.0581573443268891], | ||
[-1.0005576455546672, 1.0430257532387608, 0.22197240310602892], | ||
[1.402180736662139, 0.944112416574632, -0.12179540364365492], | ||
[-1.1216961389434357, -0.3883639833876232, -0.8769102842015071], | ||
[-0.9645026578514683, -0.6204201840686793, 0.9240543090678239] | ||
] | ||
) | ||
|
||
product = Atoms( | ||
symbols='CCHHCHH', | ||
positions=[ | ||
[1.348003553501624, 0.4819311116778978, 0.2752537177143993], | ||
[0.2386618286631742, -0.3433222966734429, 0.37705518940917926], | ||
[-0.9741307940518336, 0.07686022294949588, 0.08710778043683955], | ||
[-1.8314843503320921, -0.5547344604780035, 0.1639037492534953], | ||
[0.3801391040059668, -1.3793340533058087, 0.71035902765307], | ||
[1.9296265384257907, 0.622088341468767, 1.0901733942191298], | ||
[-1.090815880212625, 1.0965111343610956, -0.23791518420660265] | ||
] | ||
) | ||
|
||
write(xyz_r_p, [reactant, product]) | ||
|
||
return logdir, xyz_r_p | ||
|
||
|
||
def test_setup_images(setup_test_environment): | ||
logdir, xyz_r_p = setup_test_environment | ||
|
||
# Call the setup_images function | ||
images = setup_images(logdir=str(logdir), xyz_r_p=str(xyz_r_p), n_intermediate=2) | ||
|
||
# Check that images were returned | ||
assert len(images) > 0, "No images were generated" | ||
|
||
# Verify output files were created | ||
assert os.path.isfile(logdir / 'reactant_opt.traj'), "Reactant optimization file not found" | ||
assert os.path.isfile(logdir / 'product_opt.traj'), "Product optimization file not found" | ||
assert os.path.isfile(logdir / 'r_p.xyz'), "Reactant-Product file not found" | ||
assert os.path.isfile(logdir / 'output.xyz'), "Intermediate images file not found" | ||
assert os.path.isfile(logdir / 'geodesic_path.xyz'), "Geodesic path file not found" | ||
|
||
# Check energies and forces | ||
for image in images: | ||
assert 'energy' in image.info, "Energy not found in image info" | ||
assert 'forces' in image.arrays, "Forces not found in image arrays" | ||
|
||
|
||
if __name__ == "__main__": | ||
pytest.main() |