aboutsummaryrefslogtreecommitdiff
path: root/lstm_chem/finetuner.py
diff options
context:
space:
mode:
Diffstat (limited to 'lstm_chem/finetuner.py')
-rwxr-xr-xlstm_chem/finetuner.py24
1 files changed, 24 insertions, 0 deletions
diff --git a/lstm_chem/finetuner.py b/lstm_chem/finetuner.py
new file mode 100755
index 0000000..904958b
--- /dev/null
+++ b/lstm_chem/finetuner.py
@@ -0,0 +1,24 @@
+from lstm_chem.utils.smiles_tokenizer import SmilesTokenizer
+from lstm_chem.generator import LSTMChemGenerator
+
+
+class LSTMChemFinetuner(LSTMChemGenerator):
+ def __init__(self, modeler, finetune_data_loader):
+ self.session = modeler.session
+ self.model = modeler.model
+ self.config = modeler.config
+ self.finetune_data_loader = finetune_data_loader
+ self.st = SmilesTokenizer()
+
+ def finetune(self):
+ self.model.compile(optimizer=self.config.optimizer,
+ loss='categorical_crossentropy')
+
+ history = self.model.fit_generator(
+ self.finetune_data_loader,
+ steps_per_epoch=self.finetune_data_loader.__len__(),
+ epochs=self.config.finetune_epochs,
+ verbose=self.config.verbose_training,
+ use_multiprocessing=True,
+ shuffle=True)
+ return history