blob: 24f26ce3bcebef8b6a787258980a749ff9d48862 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
|
from lstm_chem.utils.smiles_tokenizer2 import SmilesTokenizer
from lstm_chem.generator import LSTMChemGenerator
class LSTMChemFinetuner(LSTMChemGenerator):
def __init__(self, modeler, finetune_data_loader):
self.session = modeler.session
self.model = modeler.model
self.config = modeler.config
self.finetune_data_loader = finetune_data_loader
self.st = SmilesTokenizer()
def finetune(self):
self.model.compile(optimizer=self.config.optimizer,
loss='categorical_crossentropy')
# history = self.model.fit_generator(
history = self.model.fit(
self.finetune_data_loader,
steps_per_epoch=self.finetune_data_loader.__len__(),
epochs=self.config.finetune_epochs,
verbose=self.config.verbose_training,
use_multiprocessing=True,
shuffle=True)
return history
|