diff options
author | Navan Chauhan <navanchauhan@gmail.com> | 2020-07-31 21:23:03 +0530 |
---|---|---|
committer | Navan Chauhan <navanchauhan@gmail.com> | 2020-07-31 21:23:03 +0530 |
commit | 61ce4e7b089d68395be2221f64d89040c0b14a34 (patch) | |
tree | 6bd60c1fd4f8bcd1c503914c61272ed382fff1da /app/lstm_chem/finetuner.py | |
parent | 376f04d1df2692a8874de686372d18b9ab07950a (diff) |
added AI model
Diffstat (limited to 'app/lstm_chem/finetuner.py')
-rwxr-xr-x | app/lstm_chem/finetuner.py | 24 |
1 files changed, 24 insertions, 0 deletions
diff --git a/app/lstm_chem/finetuner.py b/app/lstm_chem/finetuner.py new file mode 100755 index 0000000..904958b --- /dev/null +++ b/app/lstm_chem/finetuner.py @@ -0,0 +1,24 @@ +from lstm_chem.utils.smiles_tokenizer import SmilesTokenizer +from lstm_chem.generator import LSTMChemGenerator + + +class LSTMChemFinetuner(LSTMChemGenerator): + def __init__(self, modeler, finetune_data_loader): + self.session = modeler.session + self.model = modeler.model + self.config = modeler.config + self.finetune_data_loader = finetune_data_loader + self.st = SmilesTokenizer() + + def finetune(self): + self.model.compile(optimizer=self.config.optimizer, + loss='categorical_crossentropy') + + history = self.model.fit_generator( + self.finetune_data_loader, + steps_per_epoch=self.finetune_data_loader.__len__(), + epochs=self.config.finetune_epochs, + verbose=self.config.verbose_training, + use_multiprocessing=True, + shuffle=True) + return history |