diff options
Diffstat (limited to 'lstm_chem/generator.py')
-rwxr-xr-x | lstm_chem/generator.py | 44 |
1 files changed, 44 insertions, 0 deletions
diff --git a/lstm_chem/generator.py b/lstm_chem/generator.py new file mode 100755 index 0000000..498f864 --- /dev/null +++ b/lstm_chem/generator.py @@ -0,0 +1,44 @@ +from tqdm import tqdm +import numpy as np +from lstm_chem.utils.smiles_tokenizer import SmilesTokenizer + + +class LSTMChemGenerator(object): + def __init__(self, modeler): + self.session = modeler.session + self.model = modeler.model + self.config = modeler.config + self.st = SmilesTokenizer() + + def _generate(self, sequence): + while (sequence[-1] != 'E') and (len(self.st.tokenize(sequence)) <= + self.config.smiles_max_length): + x = self.st.one_hot_encode(self.st.tokenize(sequence)) + preds = self.model.predict_on_batch(x)[0][-1] + next_idx = self.sample_with_temp(preds) + sequence += self.st.table[next_idx] + + sequence = sequence[1:].rstrip('E') + return sequence + + def sample_with_temp(self, preds): + streched = np.log(preds) / self.config.sampling_temp + streched_probs = np.exp(streched) / np.sum(np.exp(streched)) + return np.random.choice(range(len(streched)), p=streched_probs) + + def sample(self, num=1, start='G'): + sampled = [] + if self.session == 'generate': + for _ in tqdm(range(num)): + sampled.append(self._generate(start)) + return sampled + else: + from rdkit import Chem, RDLogger + RDLogger.DisableLog('rdApp.*') + while len(sampled) < num: + sequence = self._generate(start) + mol = Chem.MolFromSmiles(sequence) + if mol is not None: + canon_smiles = Chem.MolToSmiles(mol) + sampled.append(canon_smiles) + return sampled |