MFBERTで記述子を計算する

MFBERTとは

MFBERTは、Molecular Fingerprints through Bidirectional Encoder Representations from Transformersの略でSMILES記法から記述子を生成するモデルです。BERTという名前の通り、内部ではBERTが使用されています。こちらが論文になります。

MFBERTで記述子を生成

まずはgithubからクローンしてきます。

git clone https://github.com/GouldGroup/MFBERT.git

MFBERTを動かすのに必要なライブラリをインストールします。

cd MFBERT
pip install -r requirments.txt

MFBERTのリポジトリにはサンプルのデータが含まれています。Data/SAMPLE_500.smiが該当のファイルです。

head Data/SAMPLE_500.smi

> CCC(N)(CC)C(C)N(C)NC=N
> C1CN(CCC21CCCN(C2)CCN(C)C)C(C=3SC(C(F)(F)F)=NC3)=O
> FC1=CC(N2C(NC(C(NC3CCOCCC3)=O)=O)=CC=N2)=CC=C1
> CN1CCN(CC1=O)S(=O)(=O)C2=C(N=CC=C2)C#N
> CC(C)NN=C1N(C)SC(N)=C1C
> CN1C(C(NCC2(COCC2)O)=O)=CC(C#N)=C1
> N1(C2(C(N(CC2)CCN3C=NC=C3)=O)CCC1)C(C4=C(C(=O)C5=CC=CC=C5)C=CC=N4)=O
> C[C@@H]1C[C@@]2([C@@H]3[C@H]4[C@@]1([C@H]5C=C(C(=O)[C@@]5(CC(=C4)COC(=O)CC6=CC(=C(C(=C6)I)O)OC)O)C)O[C@@](O3)(O2)CC7=CC=CC=C7)C(=C)C
> CN(C)C12CCC(=N)NC1C2(C)O
> O=C(N[C@H](c1cncc(c1)F)C)NC1CCOC2(C1)CCCCC2

headコマンドでSAMPLE_500.smiを確認すると各行SMILES記法で表現された分子の羅列であることがわかります。

次に学習済みモデルのダウンロードを行います。学習済みモデルダウンロード用のスクリプトが用意されているので、そちらを使用します。

python Model/download_models.py

Please select which model weight(s) to download (comma separated):

    0: ALL
    1: Pre-trained checkpoint (for fine-tuning)
    2: RDKit Benchmarking platform featurizer
    3: BBBP_featurizer
    4: Clintox_featurizer
    5: HIV_featurizer
    6: tox21_featurizer
    7: Siamese BBBP featurizer/predictor
    8: Siamese Clintox featurizer/predictor
    9: Siamese HIV featurizer/predictor
    10: Lipophilicity featurizer/predictor
    11: ESOL featurizer/predictor
    12: FreeSolv featurizer/predictor

Model/download_models.pyがダウンロード用のスクリプトです。このスクリプトを実行するとダウンロードするものを指定するように指示されます。今回は1のファインチューニング用のチェックポイントをダウンロードしてみます。1を入力してEnterを押すとチェックポイントのダウンロードが始まります。

main.pyは記述子生成のサンプルコードとなっています。main.pyでOUTPUT_DIRとして指定されている出力用ディレクトリを作成すればmain.pyを実行できます。

mkdir Fingerprints
python main.py

main.pyの実行が終わったらFingerprintsディレクトリを見てください。SAMPLE_500_fingerprints.pklができていると思います。これが各分子の記述子が記載されたファイルとなります。ただ、pickleファイルはpythonオブジェクトを保存するために使用されるファイル形式なので、python以外で使用したい時には扱いにくい形式です。そこで、保存するファイル形式をcsvに変更したものを載せておきます。良ければ参考にしてみてください。

import torch
import numpy as np
import pandas as pd
from Tokenizer.MFBERT_Tokenizer import MFBERTTokenizer
from Model.model import MFBERT
from tqdm import tqdm, trange
import os
import pickle

DEVICE = 'cpu'
BATCH_SIZE = 1
DATA_DIR = 'Data/'
TOKENIZER_DIR = 'Tokenizer/'
OUTPUT_DIR = 'Fingerprints/'


def generate_dict_from_results(results):
    smiles_fingerprint_dict = {}
    for batch in results:
        smiles = batch[0]
        res = batch[1]
        for i in range(len(smiles)):
            smiles_fingerprint_dict[smiles[i]]=res[i]
    return smiles_fingerprint_dict


if __name__ == '__main__':
    # Make directory
    os.makedirs(OUTPUT_DIR, exist_ok=True)

    excepted = []
    excepted_counter = 0

    tokenizer = MFBERTTokenizer.from_pretrained(TOKENIZER_DIR+'Model/',
                                                dict_file = TOKENIZER_DIR+'Model/dict.txt')

    model = MFBERT().to(DEVICE)

    for DATA_FILE in tqdm(os.listdir(DATA_DIR)):
        if DATA_FILE.startswith('.'):
            continue

        OUTPUT_FILE = os.path.join(OUTPUT_DIR, f'{DATA_FILE.split(".")[0]}_fingerprints.csv')

        with open(f'{DATA_DIR}/{DATA_FILE}','r') as f:
            data = f.read().splitlines()


        all_res = []
        for batch in trange(0,len(data), BATCH_SIZE):

            smiles_batch = data[batch:batch+BATCH_SIZE]

            # Note the padding tokens will affect the mean embedding
            inputs = tokenizer(smiles_batch, return_tensors='pt', padding=True, truncation=True).to(DEVICE)
            try:
                res = model(inputs).detach().numpy() # numpy tensor of mean embeddings/batch
                all_res.append((smiles_batch,res))
            except:
                excepted.append(smiles_batch)
                excepted_counter+=1

                print('EXCEPTION OCCURRED TOTAL:',excepted_counter)
                
        
        dres = generate_dict_from_results(all_res)
        
        # Create pandas.DataFrame
        df = pd.DataFrame(dres.values(), index=dres.keys())
        df.to_csv(OUTPUT_FILE)