aboutsummaryrefslogtreecommitdiff
path: root/lstm_chem/finetuner.py
diff options
context:
space:
mode:
authorNavan Chauhan <navanchauhan@gmail.com>2020-07-31 22:19:38 +0530
committerNavan Chauhan <navanchauhan@gmail.com>2020-07-31 22:19:38 +0530
commit9a253f896fba757778370c8ad6d40daa3b4cdad0 (patch)
tree2257187cdbc4c2085fb14df8bbb2e6ae6679c3e4 /lstm_chem/finetuner.py
parent61ce4e7b089d68395be2221f64d89040c0b14a34 (diff)
added Curie-Generate BETA
Diffstat (limited to 'lstm_chem/finetuner.py')
-rwxr-xr-xlstm_chem/finetuner.py24
1 files changed, 24 insertions, 0 deletions
diff --git a/lstm_chem/finetuner.py b/lstm_chem/finetuner.py
new file mode 100755
index 0000000..904958b
--- /dev/null
+++ b/lstm_chem/finetuner.py
@@ -0,0 +1,24 @@
+from lstm_chem.utils.smiles_tokenizer import SmilesTokenizer
+from lstm_chem.generator import LSTMChemGenerator
+
+
+class LSTMChemFinetuner(LSTMChemGenerator):
+ def __init__(self, modeler, finetune_data_loader):
+ self.session = modeler.session
+ self.model = modeler.model
+ self.config = modeler.config
+ self.finetune_data_loader = finetune_data_loader
+ self.st = SmilesTokenizer()
+
+ def finetune(self):
+ self.model.compile(optimizer=self.config.optimizer,
+ loss='categorical_crossentropy')
+
+ history = self.model.fit_generator(
+ self.finetune_data_loader,
+ steps_per_epoch=self.finetune_data_loader.__len__(),
+ epochs=self.config.finetune_epochs,
+ verbose=self.config.verbose_training,
+ use_multiprocessing=True,
+ shuffle=True)
+ return history