diff options
Diffstat (limited to 'lstm_chem/trainer.py')
-rwxr-xr-x | lstm_chem/trainer.py | 56 |
1 files changed, 56 insertions, 0 deletions
diff --git a/lstm_chem/trainer.py b/lstm_chem/trainer.py new file mode 100755 index 0000000..4e8057e --- /dev/null +++ b/lstm_chem/trainer.py @@ -0,0 +1,56 @@ +from glob import glob +import os +from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard + + +class LSTMChemTrainer(object): + def __init__(self, modeler, train_data_loader, valid_data_loader): + self.model = modeler.model + self.config = modeler.config + self.train_data_loader = train_data_loader + self.valid_data_loader = valid_data_loader + self.callbacks = [] + self.init_callbacks() + + def init_callbacks(self): + self.callbacks.append( + ModelCheckpoint( + filepath=os.path.join( + self.config.checkpoint_dir, + '%s-{epoch:02d}-{val_loss:.2f}.hdf5' % + self.config.exp_name), + monitor=self.config.checkpoint_monitor, + mode=self.config.checkpoint_mode, + save_best_only=self.config.checkpoint_save_best_only, + save_weights_only=self.config.checkpoint_save_weights_only, + verbose=self.config.checkpoint_verbose, + )) + self.callbacks.append( + TensorBoard( + log_dir=self.config.tensorboard_log_dir, + write_graph=self.config.tensorboard_write_graph, + )) + + def train(self): + history = self.model.fit_generator( + self.train_data_loader, + steps_per_epoch=self.train_data_loader.__len__(), + epochs=self.config.num_epochs, + verbose=self.config.verbose_training, + validation_data=self.valid_data_loader, + validation_steps=self.valid_data_loader.__len__(), + use_multiprocessing=True, + shuffle=True, + callbacks=self.callbacks) + + last_weight_file = glob( + os.path.join( + f'{self.config.checkpoint_dir}', + f'{self.config.exp_name}-{self.config.num_epochs:02}*.hdf5') + )[0] + + assert os.path.exists(last_weight_file) + self.config.model_weight_filename = last_weight_file + + with open(os.path.join(self.config.exp_dir, 'config.json'), 'w') as f: + f.write(self.config.toJSON(indent=2)) |