aboutsummaryrefslogtreecommitdiff
path: root/lstm_chem/utils/smiles_tokenizer2.py
diff options
context:
space:
mode:
Diffstat (limited to 'lstm_chem/utils/smiles_tokenizer2.py')
-rw-r--r--lstm_chem/utils/smiles_tokenizer2.py56
1 files changed, 56 insertions, 0 deletions
diff --git a/lstm_chem/utils/smiles_tokenizer2.py b/lstm_chem/utils/smiles_tokenizer2.py
new file mode 100644
index 0000000..29575ba
--- /dev/null
+++ b/lstm_chem/utils/smiles_tokenizer2.py
@@ -0,0 +1,56 @@
+import numpy as np
+
+
+class SmilesTokenizer(object):
+ def __init__(self):
+ atoms = [
+ 'Al', 'As', 'B', 'Br', 'C', 'Cl', 'F', 'H', 'I', 'K', 'Li', 'N',
+ 'Na', 'O', 'P', 'S', 'Se', 'Si', 'Te'
+ ]
+ special = [
+ '(', ')', '[', ']', '=', '#', '%', '0', '1', '2', '3', '4', '5',
+ '6', '7', '8', '9', '+', '-', 'se', 'te', 'c', 'n', 'o', 's'
+ ]
+ padding = ['G', 'A', 'E']
+
+ self.table = sorted(atoms, key=len, reverse=True) + special + padding
+ table_len = len(self.table)
+
+ self.table_2_chars = list(filter(lambda x: len(x) == 2, self.table))
+ self.table_1_chars = list(filter(lambda x: len(x) == 1, self.table))
+
+ self.one_hot_dict = {}
+ for i, symbol in enumerate(self.table):
+ vec = np.zeros(table_len, dtype=np.float32)
+ vec[i] = 1
+ self.one_hot_dict[symbol] = vec
+
+ def tokenize(self, smiles):
+ smiles = smiles + ' '
+ N = len(smiles)
+ token = []
+ i = 0
+ while (i < N):
+ c1 = smiles[i]
+ c2 = smiles[i:i + 2]
+
+ if c2 in self.table_2_chars:
+ token.append(c2)
+ i += 2
+ continue
+
+ if c1 in self.table_1_chars:
+ token.append(c1)
+ i += 1
+ continue
+
+ i += 1
+
+ return token
+
+ def one_hot_encode(self, tokenized_smiles):
+ result = np.array(
+ [self.one_hot_dict[symbol] for symbol in tokenized_smiles],
+ dtype=np.float32)
+ result = result.reshape(1, result.shape[0], result.shape[1])
+ return result