aboutsummaryrefslogtreecommitdiff
path: root/lstm_chem/trainer.py
diff options
context:
space:
mode:
Diffstat (limited to 'lstm_chem/trainer.py')
-rw-r--r--lstm_chem/trainer.py57
1 files changed, 57 insertions, 0 deletions
diff --git a/lstm_chem/trainer.py b/lstm_chem/trainer.py
new file mode 100644
index 0000000..0ed608a
--- /dev/null
+++ b/lstm_chem/trainer.py
@@ -0,0 +1,57 @@
+from glob import glob
+import os
+from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard
+
+
+class LSTMChemTrainer(object):
+ def __init__(self, modeler, train_data_loader, valid_data_loader):
+ self.model = modeler.model
+ self.config = modeler.config
+ self.train_data_loader = train_data_loader
+ self.valid_data_loader = valid_data_loader
+ self.callbacks = []
+ self.init_callbacks()
+
+ def init_callbacks(self):
+ self.callbacks.append(
+ ModelCheckpoint(
+ filepath=os.path.join(
+ self.config.checkpoint_dir,
+ '%s-{epoch:02d}-{val_loss:.2f}.hdf5' %
+ self.config.exp_name),
+ monitor=self.config.checkpoint_monitor,
+ mode=self.config.checkpoint_mode,
+ save_best_only=self.config.checkpoint_save_best_only,
+ save_weights_only=self.config.checkpoint_save_weights_only,
+ verbose=self.config.checkpoint_verbose,
+ ))
+ self.callbacks.append(
+ TensorBoard(
+ log_dir=self.config.tensorboard_log_dir,
+ write_graph=self.config.tensorboard_write_graph,
+ ))
+
+ def train(self):
+# history = self.model.fit_generator(
+ history = self.model.fit(
+ self.train_data_loader,
+ steps_per_epoch=self.train_data_loader.__len__(),
+ epochs=self.config.num_epochs,
+ verbose=self.config.verbose_training,
+ validation_data=self.valid_data_loader,
+ validation_steps=self.valid_data_loader.__len__(),
+ use_multiprocessing=True,
+ shuffle=True,
+ callbacks=self.callbacks)
+
+ last_weight_file = glob(
+ os.path.join(
+ f'{self.config.checkpoint_dir}',
+ f'{self.config.exp_name}-{self.config.num_epochs:02}*.hdf5')
+ )[0]
+
+ assert os.path.exists(last_weight_file)
+ self.config.model_weight_filename = last_weight_file
+
+ with open(os.path.join(self.config.exp_dir, 'config.json'), 'w') as f:
+ f.write(self.config.toJSON(indent=2))