Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
77 commits
Select commit Hold shift + click to select a range
7d11158
BUG & MAINT: fix `lr_decay` and set default model units
xiaochendu Sep 20, 2024
a466f39
BUG: chgnet latest version fix import
xiaochendu Sep 20, 2024
aa4d509
Merge pull request #26 from learningmatter-mit/vssr_mc
ajhoffman1229 Sep 24, 2024
427427e
Update stress calculation for MACE
sauradeep93 Nov 21, 2024
d1ee75b
Update ase_calcs.py
sauradeep93 Nov 22, 2024
814e9f7
Updated and tested ase_calcs.py
sauradeep93 Nov 22, 2024
54dded8
Further updated and tested
sauradeep93 Nov 22, 2024
071d692
Generalizing for stress calculation
sauradeep93 Nov 22, 2024
1f4b38d
removed extra print statements
sauradeep93 Nov 22, 2024
32fe995
reverted unintentional changes
sauradeep93 Nov 22, 2024
f2c5054
Update constants.py
sauradeep93 Nov 23, 2024
32840f8
Update ase_calcs.py
sauradeep93 Nov 23, 2024
6aab90d
Merge pull request #27 from sauradeep93/master
ajhoffman1229 Dec 6, 2024
a9cb5ad
autopep
steinmig Dec 16, 2024
607ef69
getting existent tests to work (with cpu)
steinmig Dec 16, 2024
d66dff0
move tests to designated folder
steinmig Dec 16, 2024
ea48cb3
some regression tests based on tutorials
steinmig Dec 16, 2024
4bafcbf
simple CI
steinmig Dec 16, 2024
9bdf505
avoid execution of broken code by pytest?
steinmig Dec 17, 2024
e4a3238
Merge branch 'unittests' into formatting
steinmig Dec 17, 2024
5e03767
autoflake and ruff formatting
steinmig Dec 17, 2024
e3fc9df
autopep
steinmig Dec 17, 2024
88e0450
ruff format
steinmig Dec 17, 2024
c0201a4
more extreme autoflake
steinmig Dec 17, 2024
dab2882
ruff format
steinmig Dec 17, 2024
3585265
ignore test files
steinmig Dec 17, 2024
322f4b4
ruff safe fixes
steinmig Dec 17, 2024
c9b5540
manual fixes of analysis, data and io
steinmig Dec 17, 2024
6f9ccb9
update lint config
steinmig Dec 17, 2024
05452f9
remove print
steinmig Dec 17, 2024
e567529
disable excited states test for now
steinmig Dec 17, 2024
38bd679
merge ignore improvements
steinmig Dec 17, 2024
a90b3b5
some minor fixes
steinmig Dec 20, 2024
bbc7278
ruff formatting
steinmig Dec 20, 2024
2ecb6cf
ordering
steinmig Dec 20, 2024
4bfeace
more linting
steinmig Dec 23, 2024
1bcdc1d
generally allow no type annotation on properties and remove noqa
steinmig Dec 23, 2024
81a6aea
more linting
steinmig Dec 23, 2024
1552579
allow characters in comments / docs
steinmig Dec 23, 2024
f541ab6
formatting
steinmig Dec 23, 2024
fb3350f
more linting
steinmig Dec 23, 2024
f7e5532
typo
steinmig Jan 29, 2025
4c98f4b
reduce comments in tests
steinmig Jan 29, 2025
2a08208
limit to torch 2.5
steinmig Jan 29, 2025
41c6c65
correct range syntax
steinmig Jan 29, 2025
2f9d73b
merge
steinmig Jan 29, 2025
81f31c4
overlooked merge
steinmig Jan 29, 2025
84bcc1f
sort
steinmig Jan 29, 2025
abe4f5d
remove self assignment
steinmig Jan 29, 2025
f309fb6
Merge pull request #28 from steinmig/unittests
ajhoffman1229 Jan 29, 2025
ab2940b
flake8 fixes
steinmig Jan 29, 2025
0dfde00
DOC & STY: update CV and NVE files
ajh1229 Feb 3, 2025
167daad
ENH: add tqdm for range for loop
ajh1229 Feb 3, 2025
4505a1f
Enforce numpy 1
steinmig Feb 5, 2025
8842f47
Update pyproject.toml
steinmig Feb 5, 2025
dbbcb0a
Update nff/io/cprop.py
steinmig Feb 6, 2025
9a6b2e3
fix e741
steinmig Feb 6, 2025
fc5eb5b
Update nff/nn/models/schnet.py
steinmig Feb 6, 2025
d67990e
Update nff/reactive_tools/nms.py
steinmig Feb 6, 2025
bca4458
bring back star import
steinmig Feb 6, 2025
d9f9dfd
overwrite from master
steinmig Feb 6, 2025
6661710
Merge remote-tracking branch 'origin/master' into formatting
steinmig Feb 6, 2025
94f9788
pass ruff checks
steinmig Feb 6, 2025
29b76e0
Merge pull request #30 from steinmig/formatting
ajhoffman1229 Feb 6, 2025
18f8fed
Merge pull request #31 from steinmig/dependencies-hotfix
ajhoffman1229 Feb 6, 2025
d0f3cd9
ENH: add catch for ASE version to run dynamics
ajh1229 Apr 19, 2025
679e150
TST: add test for ASE version dynamics run
ajh1229 Apr 19, 2025
511bf37
TST: update model path
ajh1229 Apr 19, 2025
7dc6788
TST: update ethanol model path with pathlib
ajh1229 Apr 19, 2025
4725cc2
TST: skip Langevin test for now
ajh1229 Apr 19, 2025
4e907a6
DOC: remove htvs references
ajhoffman1229 Apr 21, 2025
eb7d0b0
Merge pull request #33 from learningmatter-mit/ase_dyn_update
ajhoffman1229 Apr 21, 2025
df42616
Merge branch 'master' into vssr_pourbaix
xiaochendu Apr 23, 2025
5fe46a2
Apply suggestions from code review
xiaochendu Apr 23, 2025
31fbcfc
MAINT & STY: update based on PR #34 comments and style fixes
xiaochendu Apr 23, 2025
95937c9
MAINT: update shuffle default in `convert_chgnet_structure_data_to_nf…
xiaochendu May 1, 2025
08d142f
Merge pull request #34 from learningmatter-mit/vssr_pourbaix
HojeChun May 1, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions .github/workflows/python-app.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
name: Test NeuralForceField package

on: [push]

jobs:
build:

runs-on: ubuntu-latest
strategy:
matrix:
# python-version: ["pypy3.10", "3.8", "3.9", "3.10", "3.11", "3.12", "3.13"]
python-version: ["3.10"]

steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Display Python version
run: python -c "import sys; print(sys.version)"
- name: Install basics
run: python -m pip install --upgrade pip setuptools wheel
- name: Install package
run: python -m pip install .
# - name: Install linters
# run: python -m pip install flake8 mypy pylint
# - name: Install documentation requirements
# run: python -m pip install -r docs/requirements.txt
# - name: Test with flake8
# run: flake8 polymethod
# - name: Test with mypy
# run: mypy polymethod
# - name: Test with pylint
# run: pylint polymethod
- name: Test with pytest
run: |
pip install pytest pytest-cov
pytest nff/tests --doctest-modules --junitxml=junit/test-results-${{ matrix.python-version }}.xml --cov=nff --cov-report=xml --cov-report=html
- name: Upload pytest test results
uses: actions/upload-artifact@v4
with:
name: pytest-results-${{ matrix.python-version }}
path: junit/test-results-${{ matrix.python-version }}.xml
if: ${{ always() }}
# - name: Test documentation
# run: sphinx-build docs/source docs/build
12 changes: 12 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -66,5 +66,17 @@ dist/
sandbox_excited/
build/

# Editor files
# vim
*.swp
*.swo

# pycharm
.idea/

# coverage and tests
junit
.coverage

# required exceptions
!tutorials/models/ammonia/Ammonia.xyz
77 changes: 37 additions & 40 deletions nff/analysis/attribution.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
from typing import Dict, List, Optional, Union

import numpy as np
import torch
from ase.io import Trajectory, write
from ase import Atoms
import numpy as np
from ase.io import Trajectory, write
from tqdm import tqdm

from nff.io.ase_calcs import EnsembleNFF
from nff.io.ase import AtomsBatch
from nff.utils.scatter import compute_grad
from nff.io.ase_calcs import EnsembleNFF
from nff.utils.cuda import batch_to
from typing import Union

from tqdm import tqdm
from nff.utils.scatter import compute_grad


def get_molecules(atom: AtomsBatch, bond_length: dict = None, mode: str = "bond", **kwargs) -> list[np.array]:
def get_molecules(
atom: AtomsBatch, bond_length: Optional[Dict[str, float]] = None, mode: str = "bond", **kwargs
) -> List[np.array]:
"""
find molecules in periodic or non-periodic system. bond mode finds molecules within bond length.
Must pass bond_length dict: e.g bond_length=dict()
Expand All @@ -29,7 +31,8 @@ def get_molecules(atom: AtomsBatch, bond_length: dict = None, mode: str = "bond"
give extra cutoff = 6 e.g input

output:
list of array of atom indices in molecules. e.g: if there is a H2O molecule, you will get a list with the atom indices
list of array of atom indices in molecules. e.g: if there is a H2O molecule,
you will get a list with the atom indices

"""
types = list(set(atom.numbers))
Expand All @@ -50,15 +53,18 @@ def get_molecules(atom: AtomsBatch, bond_length: dict = None, mode: str = "bond"
oxy_neighbors = []
if mode == "bond":
for t in types:
if bond_length.get("%s-%s" % (ty, t)) != None:
if bond_length.get(f"{ty}-{t}") is not None:
oxy_neighbors.extend(
list(
np.where(atom.numbers == t)[0][
np.where(dis_sq[i, np.where(atom.numbers == t)[0]] <= bond_length["%s-%s" % (ty, t)])[0]
np.where(dis_sq[i, np.where(atom.numbers == t)[0]] <= bond_length[f"{ty}-{t}"])[0]
]
)
)
elif mode == "cutoff":
if "cutoff" not in kwargs:
raise ValueError("Specifying mode 'cutoff' requires passing a cutoff value as a keyword argument")
cutoff = kwargs["cutoff"]
oxy_neighbors.extend(list(np.where(dis_sq[i] <= cutoff)[0])) # cutoff input extra argument
oxy_neighbors = np.array(oxy_neighbors)
if len(oxy_neighbors) == 0:
Expand All @@ -69,10 +75,10 @@ def get_molecules(atom: AtomsBatch, bond_length: dict = None, mode: str = "bond"
elif (clusters[oxy_neighbors] == 0).all() and clusters[i] == 0:
clusters[oxy_neighbors] = mm + 1
clusters[i] = mm + 1
elif (clusters[oxy_neighbors] == 0).all() == False and clusters[i] == 0:
elif not (clusters[oxy_neighbors] == 0).all() and clusters[i] == 0:
clusters[i] = min(clusters[oxy_neighbors][clusters[oxy_neighbors] != 0])
clusters[oxy_neighbors] = min(clusters[oxy_neighbors][clusters[oxy_neighbors] != 0])
elif (clusters[oxy_neighbors] == 0).all() == False and clusters[i] != 0:
elif not (clusters[oxy_neighbors] == 0).all() and clusters[i] != 0:
tmp = clusters[oxy_neighbors][clusters[oxy_neighbors] != 0][
clusters[oxy_neighbors][clusters[oxy_neighbors] != 0]
!= min(clusters[oxy_neighbors][clusters[oxy_neighbors] != 0])
Expand All @@ -91,17 +97,17 @@ def get_molecules(atom: AtomsBatch, bond_length: dict = None, mode: str = "bond"
return molecules


def reconstruct_atoms(atomsobject: AtomsBatch, mol_idx: list[np.array], centre: int = None):
def reconstruct_atoms(atomsobject: AtomsBatch, mol_idx: List[np.array], centre: Optional[int] = None):
"""
Function to shift atoms when we create non-periodic system from periodic.
inputs:
atomsobject: Atomsbatch object from NFF
mol_idx: list of array of atom indices in molecules or atoms you want to keep together when changing to non-periodic
system
centre: by default the atoms in a molecule or set of close atoms are shifted so as to get them close to the centre which
is by default the first atom index in the array. For reconstructing molecules this is fine. However, for attribution,
we may have to shift a whole molecule to come closer to the atoms with high attribution. In that case, we manually assign
the atom index.
centre: by default the atoms in a molecule or set of close atoms are shifted so as to get them close
to the centre which is by default the first atom index in the array. For reconstructing molecules this is fine.
However, for attribution, we may have to shift a whole molecule to come closer to the atoms with high attribution.
In that case, we manually assign the atom index.
"""

sys_xyz = torch.Tensor(atomsobject.get_positions(wrap=True))
Expand All @@ -111,38 +117,34 @@ def reconstruct_atoms(atomsobject: AtomsBatch, mol_idx: list[np.array], centre:
mol_xyz = sys_xyz[idx]
if any(atomsobject.pbc):
center = mol_xyz.shape[0] // 2
if centre != None:
if centre is not None:
center = centre # changes the central atom to atom in focus
intra_dmat = (mol_xyz[None, :, ...] - mol_xyz[:, None, ...])[center]
if np.count_nonzero(atomsobject.cell.T - np.diag(np.diagonal(atomsobject.cell.T))) != 0:
M, N = intra_dmat.shape[0], intra_dmat.shape[1]
M, _ = intra_dmat.shape[0], intra_dmat.shape[1]
f = torch.linalg.solve(torch.Tensor(atomsobject.cell.T), (intra_dmat.view(-1, 3).T)).T
g = f - torch.floor(f + 0.5)
intra_dmat = torch.matmul(g, torch.Tensor(atomsobject.cell))
intra_dmat = intra_dmat.view(M, 3)
offsets = -torch.floor(f + 0.5).view(M, 3)
traj_unwrap = mol_xyz + torch.matmul(offsets, torch.Tensor(atomsobject.cell))
else:
sub = (intra_dmat > 0.5 * box_len).to(torch.float) * box_len
add = (intra_dmat <= -0.5 * box_len).to(torch.float) * box_len
(intra_dmat > 0.5 * box_len).to(torch.float) * box_len
(intra_dmat <= -0.5 * box_len).to(torch.float) * box_len
shift = torch.round(torch.divide(intra_dmat, box_len))
offsets = -shift
traj_unwrap = mol_xyz + offsets * box_len
else:
traj_unwrap = mol_xyz
# traj_unwrap=mol_xyz+add-sub
sys_xyz[idx] = traj_unwrap

new_pos = sys_xyz.numpy()

return new_pos


# -


class Attribution:
def __init__(self, ensemble: EnsembleNFF, save_file: str = None):
def __init__(self, ensemble: EnsembleNFF, save_file: Optional[str] = None):
self.ensemble = ensemble
self.save_file = save_file

Expand Down Expand Up @@ -197,17 +199,15 @@ def calc_attribution_file(
step: int = 1,
progress_bar: bool = True,
to_chemiscope: bool = False,
bond_length: dict = None,
bond_length: Optional[dict] = None,
) -> list:
attributions = []
atoms_list = []
energies = []
energy_stds = []
grads = []
grad_stds = []
with tqdm(
range(skip, len(traj), step), disable=True if progress_bar == False else False
) as pbar: # , postfix={"fbest":"?",}) as pbar:
with tqdm(range(skip, len(traj), step), disable=not progress_bar) as pbar: # , postfix={"fbest":"?",}) as pbar:
# for i in range(skip,len(traj),step):
for i in pbar:
# create atoms batch object
Expand Down Expand Up @@ -269,8 +269,7 @@ def calc_attribution_file(
},
}
return atoms_list, properties
else:
return attributions
return attributions

def activelearning(
self,
Expand All @@ -281,12 +280,10 @@ def activelearning(
skip: int = 0,
step: int = 1,
progress_bar: bool = True,
bond_length: dict = None,
bond_length: Optional[dict] = None,
):
atom_list = []
with tqdm(
range(skip, len(traj), step), disable=True if progress_bar == False else False
) as pbar: # , postfix={"fbest":"?",}) as pbar:
with tqdm(range(skip, len(traj), step), disable=not progress_bar) as pbar: # , postfix={"fbest":"?",}) as pbar:
# for i in range(skip,len(traj),step):
for i in pbar:
# create atoms batch object
Expand Down Expand Up @@ -337,15 +334,15 @@ def activelearning(
neighs = np.append(neighs, a)
for n in neighs:
atomstocare = np.append(atomstocare, molecules[np.where(balanced_mols == n)[0][0]])
atomstocare = np.array((list(set(atomstocare))))
atomstocare = np.array(list(set(atomstocare)))
atomstocare = np.int64(atomstocare)
atoms1 = atoms[atomstocare]
index = np.where(atoms1.positions == atoms.positions[a])[0][0]
xyz = reconstruct_atoms(atoms1, [np.arange(0, len(atoms1))], centre=index)
atoms1.positions = xyz
is_repeated = False
for Atoms in atom_list:
if atoms1.__eq__(Atoms):
for at in atom_list:
if atoms1 == at:
is_repeated = True
break
if not is_repeated:
Expand Down
Loading