aboutsummaryrefslogtreecommitdiff
path: root/lstm_chem/model.py
diff options
context:
space:
mode:
Diffstat (limited to 'lstm_chem/model.py')
-rwxr-xr-xlstm_chem/model.py73
1 files changed, 73 insertions, 0 deletions
diff --git a/lstm_chem/model.py b/lstm_chem/model.py
new file mode 100755
index 0000000..079589a
--- /dev/null
+++ b/lstm_chem/model.py
@@ -0,0 +1,73 @@
+import os
+import time
+from tensorflow.keras import Sequential
+from tensorflow.keras.models import model_from_json
+from tensorflow.keras.layers import LSTM, Dense
+from tensorflow.keras.initializers import RandomNormal
+from lstm_chem.utils.smiles_tokenizer import SmilesTokenizer
+
+
+class LSTMChem(object):
+ def __init__(self, config, session='train'):
+ assert session in ['train', 'generate', 'finetune'], \
+ 'one of {train, generate, finetune}'
+
+ self.config = config
+ self.session = session
+ self.model = None
+
+ if self.session == 'train':
+ self.build_model()
+ else:
+ self.model = self.load(self.config.model_arch_filename,
+ self.config.model_weight_filename)
+
+ def build_model(self):
+ st = SmilesTokenizer()
+ n_table = len(st.table)
+ weight_init = RandomNormal(mean=0.0,
+ stddev=0.05,
+ seed=self.config.seed)
+
+ self.model = Sequential()
+ self.model.add(
+ LSTM(units=self.config.units,
+ input_shape=(None, n_table),
+ return_sequences=True,
+ kernel_initializer=weight_init,
+ dropout=0.3))
+ self.model.add(
+ LSTM(units=self.config.units,
+ input_shape=(None, n_table),
+ return_sequences=True,
+ kernel_initializer=weight_init,
+ dropout=0.5))
+ self.model.add(
+ Dense(units=n_table,
+ activation='softmax',
+ kernel_initializer=weight_init))
+
+ arch = self.model.to_json(indent=2)
+ self.config.model_arch_filename = os.path.join(self.config.exp_dir,
+ 'model_arch.json')
+ with open(self.config.model_arch_filename, 'w') as f:
+ f.write(arch)
+
+ self.model.compile(optimizer=self.config.optimizer,
+ loss='categorical_crossentropy')
+
+ def save(self, checkpoint_path):
+ assert self.model, 'You have to build the model first.'
+
+ print('Saving model ...')
+ self.model.save_weights(checkpoint_path)
+ print('model saved.')
+
+ def load(self, model_arch_file, checkpoint_file):
+ print(f'Loading model architecture from {model_arch_file} ...')
+ with open(model_arch_file) as f:
+ model = model_from_json(f.read())
+ print(f'Loading model checkpoint from {checkpoint_file} ...')
+ model.load_weights(checkpoint_file)
+ print('Loaded the Model.')
+ return model