Skip to content
Snippets Groups Projects
Commit ed2b1834 authored by Martin Kurecka's avatar Martin Kurecka
Browse files

feat(config): add configuration file

parent a03d528f
Branches master
No related tags found
No related merge requests found
{
"learning_rate": 0.002,
"ae_loss_fn": "Huber",
"encoder_top": [
["selu", 32],
["selu", 16],
["selu", 8],
["linear", -1]
],
"decoder_top": [
["selu", 8],
["selu", 16],
["selu", 32],
["linear", -1]
],
"discriminator_top": [
["", 32],
["", 16],
["", 8],
["", 1]
]
}
\ No newline at end of file
%% Cell type:markdown id:fe8b803a tags:
%% Cell type:markdown id:65938515 tags:
# Analysis and Sampling of Molecular Simulations by adversarial Autoencoders
---
1. [Packages import](#1.-Packages-import)
2. [Internal coordinates computation](#2.-Internal-coordinates-computation)
3. [Execution & visualization](#3.-Execution-&-visualization)
%% Cell type:markdown id:5d87c4dc tags:
%% Cell type:markdown id:c37df5b6 tags:
## 1. Packages import
%% Cell type:code id:d155e436 tags:
%% Cell type:code id:ae4eb330 tags:
``` python
# Import packages
from scipy.stats import gaussian_kde
import matplotlib.pyplot as plt
from src import asmsa
import src.asmsa.legacy_mol as asmsa_leg
from src.gan import GAN
from src.visualizer import GAN_visualizer
import mdtraj as md
import numpy as np
import nglview as nv
import urllib.request
import torch
import tensorflow as tf
import tf2onnx
import onnx2torch
```
%% Cell type:code id:cb1d77b4 tags:
%% Cell type:code id:3f4aa1b9 tags:
``` python
# download trpcage files
urllib.request.urlretrieve("https://drive.google.com/uc?export=download&id=19RZmGgz9goXEAbreUd0OBCKWr5UAVN5d","index_correct.ndx")
urllib.request.urlretrieve("https://drive.google.com/uc?export=download&id=1G-4HRnc1R-LqAArFh0DTY-e5ULd8Goly","topol_correct.top")
urllib.request.urlretrieve("https://drive.google.com/uc?export=download&id=1ddOJPQxXCw3jY3Yds6bm-EwGps8ClVGQ","trpcage_correct.pdb")
urllib.request.urlretrieve("https://drive.google.com/uc?export=download&id=1FRM3-bCdbesShVcyRfk0cj0ILBL1ycbb","trpcage_red.xtc")
```
%% Cell type:code id:6dc5c118 tags:
%% Cell type:code id:03cd62eb tags:
``` python
# Define input files
# input conformation
#conf = "alaninedipeptide_H.pdb"
conf = "trpcage_correct.pdb"
# input trajectory
# atom numbering must be consistent with {conf}
#traj = "alaninedipeptide_reduced.xtc"
traj = "trpcage_red.xtc"
# input topology
# expected to be produced with
# gmx pdb2gmx -f {conf} -p {topol} -n {index}
# Gromacs changes atom numbering, the index file must be generated and used as well
#topol = "topol.top"
topol = "topol_correct.top"
index = 'index_correct.ndx'
```
%% Cell type:markdown id:c1bb39b2 tags:
%% Cell type:markdown id:a209ea91 tags:
## 2. Internal coordinates computation
%% Cell type:code id:646a7576 tags:
%% Cell type:code id:70105cec tags:
``` python
tr = md.load(traj,top=conf)
idx=tr[0].top.select("name CA")
#idx=tr[0].top.select("element != H")
tr.superpose(tr[0],atom_indices=idx)
geom = np.moveaxis(tr.xyz ,0,-1)
```
%% Cell type:code id:adeb574e tags:
%% Cell type:code id:d0f0bc1d tags:
``` python
v = nv.show_mdtraj(tr)
v.clear()
v.add_representation("licorice")
v
```
%% Cell type:code id:7a4dab22 tags:
%% Cell type:code id:5c0c7d9e tags:
``` python
geom.shape
```
%% Cell type:code id:63c1fc3c tags:
%% Cell type:code id:30d9e4bb tags:
``` python
# Define sparse and dense feture extensions of IC
density = 2 # integer in [1, n_atoms-1]
dense_dists = asmsa.mol_model.NBDistancesDense(geom.shape[0])
sparse_dists = asmsa.mol_model.NBDistancesSparse(geom.shape[0], density=density)
mol = asmsa.Molecule(pdb=conf,top=topol,ndx=index,fms=[sparse_dists])
# mol = asmsa.Molecule(pdb=conf,top=topol,ndx=index,fms=[sparse_dists])
mol = asmsa.Molecule(pdb=conf,fms=[dense_dists],n_atoms=geom.shape[0])
```
%% Cell type:code id:0420220c tags:
%% Cell type:code id:5077906b tags:
``` python
X_train = mol.intcoord(geom).T
X_train.shape
```
%% Cell type:markdown id:993d4255 tags:
%% Cell type:markdown id:b246e354 tags:
## 3. Set visualizer
%% Cell type:code id:f8839e64 tags:
%% Cell type:code id:93b00c45 tags:
``` python
visualizer = GAN_visualizer(
nbins=50,
visualize_freq=10,
visualize_freq=100,
# analysis_files=['visualization/rama_ala_reduced0.txt',
# 'visualization/rama_ala_reduced1.txt',
# 'visualization/angever0.txt',
# 'visualization/angever1.txt',
# 'visualization/angever2.txt'],
figsize=(15,5),
cmap='hsv'
)
```
%% Cell type:code id:bf29826a tags:
%% Cell type:code id:1d123085 tags:
``` python
# download mushroom image as prior
urllib.request.urlretrieve("https://drive.google.com/uc?export=download&id=1I2WP92MMWS5s5vin_4cvmruuV-1W77Hl", "mushroom_bw.png")
```
%% Cell type:code id:359b952f tags:
%% Cell type:code id:52bf72e4 tags:
``` python
output_file = 'lows.txt'
gan = GAN(X_train,batch_size=256,prior='mushroom_bw.png')
test = gan.train(epochs=1, out_file=output_file, visualizer=visualizer)
gan = GAN(X_train,batch_size=256,prior='uniform')
test = gan.train(epochs=500, out_file=output_file, visualizer=visualizer)
```
%% Cell type:code id:e9a63710 tags:
%% Cell type:code id:bdec4ccc tags:
``` python
visualizer = GAN_visualizer(
lows='lows.txt',
nbins=50,
figsize=(15,5),
)
```
%% Cell type:code id:00c38fe3 tags:
%% Cell type:code id:1a713322 tags:
``` python
visualizer.make_visualization()
```
%% Cell type:code id:2d9a5668 tags:
%% Cell type:code id:b745bdae tags:
``` python
# Rgyr color coded in low dim (rough view)
lows = np.loadtxt(output_file)
rg = md.compute_rg(tr)
cmap = plt.get_cmap('rainbow')
plt.figure(figsize=(12,12))
plt.scatter(lows[:,0],lows[:,1],marker='.',c=rg,cmap=cmap)
plt.colorbar(cmap=cmap)
plt.show()
```
%% Cell type:markdown id:f1ec38de tags:
%% Cell type:markdown id:59787ae7 tags:
### Export model to `libtorch`
%% Cell type:code id:fbe3a98b tags:
%% Cell type:code id:84d24126 tags:
``` python
tf.keras.Model.save(gan.encoder, 'model_save')
# Convert to TF -> ONNX
!python -m tf2onnx.convert --saved-model model_save --output model.onnx
# Convert to ONNX -> Torch
onnx_model_path = 'model.onnx'
torch_encoder = onnx2torch.convert(onnx_model_path)
mol_model = mol.get_model()
def complete_model(x):
return torch_encoder(mol_model(x).reshape(-1))
# Save Torch model using TorchScript trace
example_input = torch.randn([geom.shape[0], geom.shape[1], 1])
traced_script_module = torch.jit.trace(complete_model, example_input)
traced_script_module.save("model.pt")
```
%% Cell type:markdown id:3c21c59e tags:
%% Cell type:markdown id:ef2d34f6 tags:
### Verify pytorch traced model
%% Cell type:code id:e38b572e tags:
%% Cell type:code id:ec1d2d36 tags:
``` python
example_geom = np.random.rand(geom.shape[0], geom.shape[1], 1)
X = mol.intcoord(example_geom).T
tf_low = np.array(gan.encoder(X))
torch_geom = torch.tensor(example_geom.reshape(-1), dtype=torch.float32, requires_grad=True)
torch_low = traced_script_module(torch_geom)
for out in torch_low:
grad = torch.autograd.grad(out, torch_geom, retain_graph=True)
print(grad[0].shape)
assert(np.max(np.abs(tf_low - torch_low.detach().numpy())) < 1e-05)
tf_low - torch_low.detach().numpy()
```
%% Cell type:code id:5d431348 tags:
%% Cell type:code id:9d83bf44 tags:
``` python
```
......
......@@ -157,10 +157,10 @@ def _match_type(atom,pattern):
class Molecule:
def __init__(self,pdb = None,top = None, ndx = None,ff = os.path.dirname(os.path.abspath(__file__)) + '/ffbonded.itp',fms=[]):
def __init__(self,pdb = None,top = None, ndx = None,ff = os.path.dirname(os.path.abspath(__file__)) + '/ffbonded.itp',fms=[], n_atoms=None):
if not top and not fms:
raise ValueError("At least one of `top` or `fms` must be provided")
if not top and not (fms and n_atoms):
raise ValueError("At least one of `top` or `fms+n_atoms` must be provided")
if top:
if ndx:
......@@ -180,15 +180,21 @@ class Molecule:
self.fms = fms
self.model = MoleculeModel(
len(self.atypes),
bonds=self.bonds,
angles=self.angles,
angles_th0=self.angles_th0,
dihed4=self.dihed4,
dihed9=self.dihed9,
feature_maps=self.fms
)
if top:
self.model = MoleculeModel(
len(self.atypes),
bonds=self.bonds,
angles=self.angles,
angles_th0=self.angles_th0,
dihed4=self.dihed4,
dihed9=self.dihed9,
feature_maps=self.fms
)
else:
self.model = MoleculeModel(
n_atoms,
feature_maps=self.fms
)
def _match_bonds(self,btypes):
self.bonds_b0 = np.empty(self.bonds.shape[0],dtype=np.float32)
......@@ -299,15 +305,6 @@ class Molecule:
# geoms[atom][xyz][conf]
def intcoord(self,geoms):
if not hasattr(self,'atypes'):
return np.concatenate([fm.ic(geoms) for fm in self.fms],axis=0)
if geoms.shape[0] != len(self.atypes):
raise ValueError(f"Number of atoms ({geoms.shape[0]}) does not match topology ({len(self.atypes)})")
if geoms.shape[1] != 3:
raise ValueError(f"3D coordinates expected, {geoms.shape[1]} given")
return np.array(self.model(torch.tensor(geoms)))
......
......@@ -42,6 +42,8 @@ class AnglesModel(torch.nn.Module):
dot = torch.sum(v1 * v2, axis=1) / (n1 * n2)
aa = torch.arccos(dot * 0.999999) # numerical stability of arccos
# Why such a weird map? The input should be normalized anyway.
# If this map does not correspond to normalization, then it probably worsen the network performace anyway.
return (aa - .75 * self.angles_th0[:,None]) * self.angles_2rth0[:,None] # map 0.75 a0 -- 1.25 a0 to 0 -- 1
......@@ -68,7 +70,14 @@ class DihedralModel(torch.nn.Module):
sp1 = torch.sum(vp1 * vp2, axis=1)
sp2 = torch.sum(vp3 * vp2, axis=1)
return (torch.cat([-sp2, sp1], axis = 0) + 1) * 0.5
""" original:
# output for i-th dihedral angle
aa = np.arctan2(sp1,sp2) - np.pi * .5
return np.sin(aa), np.cos(aa)
"""
#NOTE: Why adding two variables that determine each other? It the angle better?
return torch.nn.functional.normalize(torch.stack([-sp2, sp1]), p=2, dim=0).reshape(2*len(self.atoms), geoms.shape[2])
"""
......
......@@ -6,7 +6,7 @@ from keras.layers import BatchNormalization
from keras.layers.advanced_activations import LeakyReLU
from keras.callbacks import CSVLogger, EarlyStopping
from keras.models import Sequential, Model
from keras.losses import BinaryCrossentropy, MeanSquaredError
from keras.losses import BinaryCrossentropy, MeanSquaredError, Huber
from keras import backend as kb
from PIL import Image
......@@ -16,6 +16,7 @@ import numpy as np
import datetime
import os
import json
class AAEModel(Model):
......@@ -39,7 +40,7 @@ class AAEModel(Model):
def compile(self,
opt = Adam(0.0002,0.5), # FIXME: justify
opt = Adam(0.0002, 0.5), # FIXME: justify
ae_loss_fn = MeanSquaredError(),
# XXX: logits as in https://keras.io/guides/customizing_what_happens_in_fit/,
# hope it works as the discriminator output is never used directly
......@@ -150,18 +151,45 @@ class GAN():
self.latent_dim = 2
self.prior = prior
self.encoder = self._build_encoder()
self.decoder = self._build_decoder()
self.discriminator = self._build_discriminator()
self.batch_size = batch_size
with open("GAN_config.json", "r") as f:
self.config = json.load(f)
if "encoder_top" in self.config:
self.encoder = self._build_encoder(params=self.config['encoder_top'])
else:
self.encoder = self._build_encoder()
if "decoder_top" in self.config:
self.decoder = self._build_decoder(params=self.config['decoder_top'])
else:
self.decoder = self._build_decoder()
if "discriminator_top" in self.config:
self.discriminator = self._build_discriminator(params=self.config['discriminator_top'])
else:
self.discriminator = self._build_discriminator()
self.batch_size = batch_size
self._compile(verbose)
def _compile(self, verbose=False):
self.aae = AAEModel(self.encoder,self.decoder,self.discriminator,self.latent_dim,self.prior, self.batch_size)
self.aae.compile()
kwargs = {}
if 'learning_rate' in self.config:
kwargs['opt'] = Adam(self.config['learning_rate'], 0.5)
if 'ae_loss_fn' in self.config:
if self.config['ae_loss_fn'] == 'MeanSquaredError':
ae_loss_fn = MeanSquaredError()
elif self.config['ae_loss_fn'] == 'Huber':
ae_loss_fn = Huber()
else:
raise ValueError(f"Unknown loss function name '{self.config['ae_loss_fn']}'")
self.aae.compile(**kwargs)
if verbose:
print(self.encoder.summary(expand_nested=True))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment