aboutsummaryrefslogtreecommitdiff
path: root/app/lstm_chem/trainer.py
diff options
context:
space:
mode:
authorNavan Chauhan <navanchauhan@gmail.com>2020-07-31 21:23:03 +0530
committerNavan Chauhan <navanchauhan@gmail.com>2020-07-31 21:23:03 +0530
commit61ce4e7b089d68395be2221f64d89040c0b14a34 (patch)
tree6bd60c1fd4f8bcd1c503914c61272ed382fff1da /app/lstm_chem/trainer.py
parent376f04d1df2692a8874de686372d18b9ab07950a (diff)
added AI model
Diffstat (limited to 'app/lstm_chem/trainer.py')
-rwxr-xr-xapp/lstm_chem/trainer.py56
1 files changed, 56 insertions, 0 deletions
diff --git a/app/lstm_chem/trainer.py b/app/lstm_chem/trainer.py
new file mode 100755
index 0000000..4e8057e
--- /dev/null
+++ b/app/lstm_chem/trainer.py
@@ -0,0 +1,56 @@
+from glob import glob
+import os
+from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard
+
+
+class LSTMChemTrainer(object):
+ def __init__(self, modeler, train_data_loader, valid_data_loader):
+ self.model = modeler.model
+ self.config = modeler.config
+ self.train_data_loader = train_data_loader
+ self.valid_data_loader = valid_data_loader
+ self.callbacks = []
+ self.init_callbacks()
+
+ def init_callbacks(self):
+ self.callbacks.append(
+ ModelCheckpoint(
+ filepath=os.path.join(
+ self.config.checkpoint_dir,
+ '%s-{epoch:02d}-{val_loss:.2f}.hdf5' %
+ self.config.exp_name),
+ monitor=self.config.checkpoint_monitor,
+ mode=self.config.checkpoint_mode,
+ save_best_only=self.config.checkpoint_save_best_only,
+ save_weights_only=self.config.checkpoint_save_weights_only,
+ verbose=self.config.checkpoint_verbose,
+ ))
+ self.callbacks.append(
+ TensorBoard(
+ log_dir=self.config.tensorboard_log_dir,
+ write_graph=self.config.tensorboard_write_graph,
+ ))
+
+ def train(self):
+ history = self.model.fit_generator(
+ self.train_data_loader,
+ steps_per_epoch=self.train_data_loader.__len__(),
+ epochs=self.config.num_epochs,
+ verbose=self.config.verbose_training,
+ validation_data=self.valid_data_loader,
+ validation_steps=self.valid_data_loader.__len__(),
+ use_multiprocessing=True,
+ shuffle=True,
+ callbacks=self.callbacks)
+
+ last_weight_file = glob(
+ os.path.join(
+ f'{self.config.checkpoint_dir}',
+ f'{self.config.exp_name}-{self.config.num_epochs:02}*.hdf5')
+ )[0]
+
+ assert os.path.exists(last_weight_file)
+ self.config.model_weight_filename = last_weight_file
+
+ with open(os.path.join(self.config.exp_dir, 'config.json'), 'w') as f:
+ f.write(self.config.toJSON(indent=2))