blob: 4f80e9fde3f03e96a47552fae7e756430c5c095f (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
|
from tqdm import tqdm
import numpy as np
from lstm_chem.utils.smiles_tokenizer2 import SmilesTokenizer
class LSTMChemGenerator(object):
def __init__(self, modeler):
self.session = modeler.session
self.model = modeler.model
self.config = modeler.config
self.st = SmilesTokenizer()
def _generate(self, sequence):
while (sequence[-1] != 'E') and (len(self.st.tokenize(sequence)) <=
self.config.smiles_max_length):
x = self.st.one_hot_encode(self.st.tokenize(sequence))
preds = self.model.predict_on_batch(x)[0][-1]
next_idx = self.sample_with_temp(preds)
sequence += self.st.table[next_idx]
sequence = sequence[1:].rstrip('E')
return sequence
def sample_with_temp(self, preds):
streched = np.log(preds) / self.config.sampling_temp
streched_probs = np.exp(streched) / np.sum(np.exp(streched))
return np.random.choice(range(len(streched)), p=streched_probs)
def sample(self, num=1, start='G'):
sampled = []
if self.session == 'generate':
for _ in tqdm(range(num)):
sampled.append(self._generate(start))
return sampled
else:
from rdkit import Chem, RDLogger
RDLogger.DisableLog('rdApp.*')
while len(sampled) < num:
sequence = self._generate(start)
mol = Chem.MolFromSmiles(sequence)
if mol is not None:
canon_smiles = Chem.MolToSmiles(mol)
sampled.append(canon_smiles)
return sampled
|