aboutsummaryrefslogtreecommitdiff
path: root/lstm_chem/generator.py
diff options
context:
space:
mode:
authorNavan Chauhan <navanchauhan@gmail.com>2020-07-31 22:19:38 +0530
committerNavan Chauhan <navanchauhan@gmail.com>2020-07-31 22:19:38 +0530
commit9a253f896fba757778370c8ad6d40daa3b4cdad0 (patch)
tree2257187cdbc4c2085fb14df8bbb2e6ae6679c3e4 /lstm_chem/generator.py
parent61ce4e7b089d68395be2221f64d89040c0b14a34 (diff)
added Curie-Generate BETA
Diffstat (limited to 'lstm_chem/generator.py')
-rwxr-xr-xlstm_chem/generator.py44
1 files changed, 44 insertions, 0 deletions
diff --git a/lstm_chem/generator.py b/lstm_chem/generator.py
new file mode 100755
index 0000000..498f864
--- /dev/null
+++ b/lstm_chem/generator.py
@@ -0,0 +1,44 @@
+from tqdm import tqdm
+import numpy as np
+from lstm_chem.utils.smiles_tokenizer import SmilesTokenizer
+
+
+class LSTMChemGenerator(object):
+ def __init__(self, modeler):
+ self.session = modeler.session
+ self.model = modeler.model
+ self.config = modeler.config
+ self.st = SmilesTokenizer()
+
+ def _generate(self, sequence):
+ while (sequence[-1] != 'E') and (len(self.st.tokenize(sequence)) <=
+ self.config.smiles_max_length):
+ x = self.st.one_hot_encode(self.st.tokenize(sequence))
+ preds = self.model.predict_on_batch(x)[0][-1]
+ next_idx = self.sample_with_temp(preds)
+ sequence += self.st.table[next_idx]
+
+ sequence = sequence[1:].rstrip('E')
+ return sequence
+
+ def sample_with_temp(self, preds):
+ streched = np.log(preds) / self.config.sampling_temp
+ streched_probs = np.exp(streched) / np.sum(np.exp(streched))
+ return np.random.choice(range(len(streched)), p=streched_probs)
+
+ def sample(self, num=1, start='G'):
+ sampled = []
+ if self.session == 'generate':
+ for _ in tqdm(range(num)):
+ sampled.append(self._generate(start))
+ return sampled
+ else:
+ from rdkit import Chem, RDLogger
+ RDLogger.DisableLog('rdApp.*')
+ while len(sampled) < num:
+ sequence = self._generate(start)
+ mol = Chem.MolFromSmiles(sequence)
+ if mol is not None:
+ canon_smiles = Chem.MolToSmiles(mol)
+ sampled.append(canon_smiles)
+ return sampled