diff options
Diffstat (limited to 'lstm_chem/data_loader.py')
-rwxr-xr-x | lstm_chem/data_loader.py | 122 |
1 files changed, 122 insertions, 0 deletions
diff --git a/lstm_chem/data_loader.py b/lstm_chem/data_loader.py new file mode 100755 index 0000000..86ddbba --- /dev/null +++ b/lstm_chem/data_loader.py @@ -0,0 +1,122 @@ +import json +import os +import numpy as np +from tqdm import tqdm +from tensorflow.keras.utils import Sequence +from lstm_chem.utils.smiles_tokenizer import SmilesTokenizer + + +class DataLoader(Sequence): + def __init__(self, config, data_type='train'): + self.config = config + self.data_type = data_type + assert self.data_type in ['train', 'valid', 'finetune'] + + self.max_len = 0 + + if self.data_type == 'train': + self.smiles = self._load(self.config.data_filename) + elif self.data_type == 'finetune': + self.smiles = self._load(self.config.finetune_data_filename) + else: + pass + + self.st = SmilesTokenizer() + self.one_hot_dict = self.st.one_hot_dict + + self.tokenized_smiles = self._tokenize(self.smiles) + + if self.data_type in ['train', 'valid']: + self.idx = np.arange(len(self.tokenized_smiles)) + self.valid_size = int( + np.ceil( + len(self.tokenized_smiles) * self.config.validation_split)) + np.random.seed(self.config.seed) + np.random.shuffle(self.idx) + + def _set_data(self): + if self.data_type == 'train': + ret = [ + self.tokenized_smiles[self.idx[i]] + for i in self.idx[self.valid_size:] + ] + elif self.data_type == 'valid': + ret = [ + self.tokenized_smiles[self.idx[i]] + for i in self.idx[:self.valid_size] + ] + else: + ret = self.tokenized_smiles + return ret + + def _load(self, data_filename): + length = self.config.data_length + print('loading SMILES...') + with open(data_filename) as f: + smiles = [s.rstrip() for s in f] + if length != 0: + smiles = smiles[:length] + print('done.') + return smiles + + def _tokenize(self, smiles): + assert isinstance(smiles, list) + print('tokenizing SMILES...') + tokenized_smiles = [self.st.tokenize(smi) for smi in tqdm(smiles)] + + if self.data_type == 'train': + for tokenized_smi in tokenized_smiles: + length = len(tokenized_smi) + if self.max_len < length: + self.max_len = length + self.config.train_smi_max_len = self.max_len + print('done.') + return tokenized_smiles + + def __len__(self): + target_tokenized_smiles = self._set_data() + if self.data_type in ['train', 'valid']: + ret = int( + np.ceil( + len(target_tokenized_smiles) / + float(self.config.batch_size))) + else: + ret = int( + np.ceil( + len(target_tokenized_smiles) / + float(self.config.finetune_batch_size))) + return ret + + def __getitem__(self, idx): + target_tokenized_smiles = self._set_data() + if self.data_type in ['train', 'valid']: + data = target_tokenized_smiles[idx * + self.config.batch_size:(idx + 1) * + self.config.batch_size] + else: + data = target_tokenized_smiles[idx * + self.config.finetune_batch_size: + (idx + 1) * + self.config.finetune_batch_size] + data = self._padding(data) + + self.X, self.y = [], [] + for tp_smi in data: + X = [self.one_hot_dict[symbol] for symbol in tp_smi[:-1]] + self.X.append(X) + y = [self.one_hot_dict[symbol] for symbol in tp_smi[1:]] + self.y.append(y) + + self.X = np.array(self.X, dtype=np.float32) + self.y = np.array(self.y, dtype=np.float32) + + return self.X, self.y + + def _pad(self, tokenized_smi): + return ['G'] + tokenized_smi + ['E'] + [ + 'A' for _ in range(self.max_len - len(tokenized_smi)) + ] + + def _padding(self, data): + padded_smiles = [self._pad(t_smi) for t_smi in data] + return padded_smiles |