from tqdm import tqdm import numpy as np from lstm_chem.utils.smiles_tokenizer2 import SmilesTokenizer class LSTMChemGenerator(object): def __init__(self, modeler): self.session = modeler.session self.model = modeler.model self.config = modeler.config self.st = SmilesTokenizer() def _generate(self, sequence): while (sequence[-1] != 'E') and (len(self.st.tokenize(sequence)) <= self.config.smiles_max_length): x = self.st.one_hot_encode(self.st.tokenize(sequence)) preds = self.model.predict_on_batch(x)[0][-1] next_idx = self.sample_with_temp(preds) sequence += self.st.table[next_idx] sequence = sequence[1:].rstrip('E') return sequence def sample_with_temp(self, preds): streched = np.log(preds) / self.config.sampling_temp streched_probs = np.exp(streched) / np.sum(np.exp(streched)) return np.random.choice(range(len(streched)), p=streched_probs) def sample(self, num=1, start='G'): sampled = [] if self.session == 'generate': for _ in tqdm(range(num)): sampled.append(self._generate(start)) return sampled else: from rdkit import Chem, RDLogger RDLogger.DisableLog('rdApp.*') while len(sampled) < num: sequence = self._generate(start) mol = Chem.MolFromSmiles(sequence) if mol is not None: canon_smiles = Chem.MolToSmiles(mol) sampled.append(canon_smiles) return sampled