import os import time from tensorflow.keras import Sequential from tensorflow.keras.models import model_from_json from tensorflow.keras.layers import LSTM, Dense from tensorflow.keras.initializers import RandomNormal from lstm_chem.utils.smiles_tokenizer2 import SmilesTokenizer class LSTMChem(object): def __init__(self, config, session='train'): assert session in ['train', 'generate', 'finetune'], \ 'one of {train, generate, finetune}' self.config = config self.session = session self.model = None if self.session == 'train': self.build_model() else: self.model = self.load(self.config.model_arch_filename, self.config.model_weight_filename) def build_model(self): st = SmilesTokenizer() n_table = len(st.table) weight_init = RandomNormal(mean=0.0, stddev=0.05, seed=self.config.seed) self.model = Sequential() self.model.add( LSTM(units=self.config.units, input_shape=(None, n_table), return_sequences=True, kernel_initializer=weight_init, dropout=0.3)) self.model.add( LSTM(units=self.config.units, input_shape=(None, n_table), return_sequences=True, kernel_initializer=weight_init, dropout=0.5)) self.model.add( Dense(units=n_table, activation='softmax', kernel_initializer=weight_init)) arch = self.model.to_json(indent=2) self.config.model_arch_filename = os.path.join(self.config.exp_dir, 'model_arch.json') with open(self.config.model_arch_filename, 'w') as f: f.write(arch) self.model.compile(optimizer=self.config.optimizer, loss='categorical_crossentropy') def save(self, checkpoint_path): assert self.model, 'You have to build the model first.' print('Saving model ...') self.model.save_weights(checkpoint_path) print('model saved.') def load(self, model_arch_file, checkpoint_file): print(f'Loading model architecture from {model_arch_file} ...') with open(model_arch_file) as f: model = model_from_json(f.read()) print(f'Loading model checkpoint from {checkpoint_file} ...') model.load_weights(checkpoint_file) print('Loaded the Model.') return model