diff options
Diffstat (limited to 'lstm_chem/model.py')
-rw-r--r-- | lstm_chem/model.py | 73 |
1 files changed, 73 insertions, 0 deletions
diff --git a/lstm_chem/model.py b/lstm_chem/model.py new file mode 100644 index 0000000..368a834 --- /dev/null +++ b/lstm_chem/model.py @@ -0,0 +1,73 @@ +import os +import time +from tensorflow.keras import Sequential +from tensorflow.keras.models import model_from_json +from tensorflow.keras.layers import LSTM, Dense +from tensorflow.keras.initializers import RandomNormal +from lstm_chem.utils.smiles_tokenizer2 import SmilesTokenizer + + +class LSTMChem(object): + def __init__(self, config, session='train'): + assert session in ['train', 'generate', 'finetune'], \ + 'one of {train, generate, finetune}' + + self.config = config + self.session = session + self.model = None + + if self.session == 'train': + self.build_model() + else: + self.model = self.load(self.config.model_arch_filename, + self.config.model_weight_filename) + + def build_model(self): + st = SmilesTokenizer() + n_table = len(st.table) + weight_init = RandomNormal(mean=0.0, + stddev=0.05, + seed=self.config.seed) + + self.model = Sequential() + self.model.add( + LSTM(units=self.config.units, + input_shape=(None, n_table), + return_sequences=True, + kernel_initializer=weight_init, + dropout=0.3)) + self.model.add( + LSTM(units=self.config.units, + input_shape=(None, n_table), + return_sequences=True, + kernel_initializer=weight_init, + dropout=0.5)) + self.model.add( + Dense(units=n_table, + activation='softmax', + kernel_initializer=weight_init)) + + arch = self.model.to_json(indent=2) + self.config.model_arch_filename = os.path.join(self.config.exp_dir, + 'model_arch.json') + with open(self.config.model_arch_filename, 'w') as f: + f.write(arch) + + self.model.compile(optimizer=self.config.optimizer, + loss='categorical_crossentropy') + + def save(self, checkpoint_path): + assert self.model, 'You have to build the model first.' + + print('Saving model ...') + self.model.save_weights(checkpoint_path) + print('model saved.') + + def load(self, model_arch_file, checkpoint_file): + print(f'Loading model architecture from {model_arch_file} ...') + with open(model_arch_file) as f: + model = model_from_json(f.read()) + print(f'Loading model checkpoint from {checkpoint_file} ...') + model.load_weights(checkpoint_file) + print('Loaded the Model.') + return model |