aboutsummaryrefslogtreecommitdiff
path: root/lstm_chem/data_loader.py
diff options
context:
space:
mode:
Diffstat (limited to 'lstm_chem/data_loader.py')
-rwxr-xr-xlstm_chem/data_loader.py122
1 files changed, 122 insertions, 0 deletions
diff --git a/lstm_chem/data_loader.py b/lstm_chem/data_loader.py
new file mode 100755
index 0000000..86ddbba
--- /dev/null
+++ b/lstm_chem/data_loader.py
@@ -0,0 +1,122 @@
+import json
+import os
+import numpy as np
+from tqdm import tqdm
+from tensorflow.keras.utils import Sequence
+from lstm_chem.utils.smiles_tokenizer import SmilesTokenizer
+
+
+class DataLoader(Sequence):
+ def __init__(self, config, data_type='train'):
+ self.config = config
+ self.data_type = data_type
+ assert self.data_type in ['train', 'valid', 'finetune']
+
+ self.max_len = 0
+
+ if self.data_type == 'train':
+ self.smiles = self._load(self.config.data_filename)
+ elif self.data_type == 'finetune':
+ self.smiles = self._load(self.config.finetune_data_filename)
+ else:
+ pass
+
+ self.st = SmilesTokenizer()
+ self.one_hot_dict = self.st.one_hot_dict
+
+ self.tokenized_smiles = self._tokenize(self.smiles)
+
+ if self.data_type in ['train', 'valid']:
+ self.idx = np.arange(len(self.tokenized_smiles))
+ self.valid_size = int(
+ np.ceil(
+ len(self.tokenized_smiles) * self.config.validation_split))
+ np.random.seed(self.config.seed)
+ np.random.shuffle(self.idx)
+
+ def _set_data(self):
+ if self.data_type == 'train':
+ ret = [
+ self.tokenized_smiles[self.idx[i]]
+ for i in self.idx[self.valid_size:]
+ ]
+ elif self.data_type == 'valid':
+ ret = [
+ self.tokenized_smiles[self.idx[i]]
+ for i in self.idx[:self.valid_size]
+ ]
+ else:
+ ret = self.tokenized_smiles
+ return ret
+
+ def _load(self, data_filename):
+ length = self.config.data_length
+ print('loading SMILES...')
+ with open(data_filename) as f:
+ smiles = [s.rstrip() for s in f]
+ if length != 0:
+ smiles = smiles[:length]
+ print('done.')
+ return smiles
+
+ def _tokenize(self, smiles):
+ assert isinstance(smiles, list)
+ print('tokenizing SMILES...')
+ tokenized_smiles = [self.st.tokenize(smi) for smi in tqdm(smiles)]
+
+ if self.data_type == 'train':
+ for tokenized_smi in tokenized_smiles:
+ length = len(tokenized_smi)
+ if self.max_len < length:
+ self.max_len = length
+ self.config.train_smi_max_len = self.max_len
+ print('done.')
+ return tokenized_smiles
+
+ def __len__(self):
+ target_tokenized_smiles = self._set_data()
+ if self.data_type in ['train', 'valid']:
+ ret = int(
+ np.ceil(
+ len(target_tokenized_smiles) /
+ float(self.config.batch_size)))
+ else:
+ ret = int(
+ np.ceil(
+ len(target_tokenized_smiles) /
+ float(self.config.finetune_batch_size)))
+ return ret
+
+ def __getitem__(self, idx):
+ target_tokenized_smiles = self._set_data()
+ if self.data_type in ['train', 'valid']:
+ data = target_tokenized_smiles[idx *
+ self.config.batch_size:(idx + 1) *
+ self.config.batch_size]
+ else:
+ data = target_tokenized_smiles[idx *
+ self.config.finetune_batch_size:
+ (idx + 1) *
+ self.config.finetune_batch_size]
+ data = self._padding(data)
+
+ self.X, self.y = [], []
+ for tp_smi in data:
+ X = [self.one_hot_dict[symbol] for symbol in tp_smi[:-1]]
+ self.X.append(X)
+ y = [self.one_hot_dict[symbol] for symbol in tp_smi[1:]]
+ self.y.append(y)
+
+ self.X = np.array(self.X, dtype=np.float32)
+ self.y = np.array(self.y, dtype=np.float32)
+
+ return self.X, self.y
+
+ def _pad(self, tokenized_smi):
+ return ['G'] + tokenized_smi + ['E'] + [
+ 'A' for _ in range(self.max_len - len(tokenized_smi))
+ ]
+
+ def _padding(self, data):
+ padded_smiles = [self._pad(t_smi) for t_smi in data]
+ return padded_smiles