Machine learning - Protein Chain Classification

In this demo we try to classify a protein chain as either an all alpha or all beta protein based on protein sequence. We use n-grams and a Word2Vec representation of the protein sequence as a feature vector.

Word2Vec model

Word2Vec example

Imports

In [17]:
from pyspark import SparkConf, SparkContext, SQLContext
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *
from mmtfPyspark.io import mmtfReader
from mmtfPyspark.webfilters import Pisces
from mmtfPyspark.filters import ContainsLProteinChain
from mmtfPyspark.mappers import StructureToPolymerChains
from mmtfPyspark.datasets import secondaryStructureExtractor
from mmtfPyspark.ml import ProteinSequenceEncoder, SparkMultiClassClassifier, datasetBalancer
from pyspark.sql.functions import *
from pyspark.ml.classification import DecisionTreeClassifier, LogisticRegression, MultilayerPerceptronClassifier, RandomForestClassifier

Configure Spark Context

In [18]:
conf = SparkConf() \
            .setMaster("local[*]") \
            .setAppName("MachineLearningDemo")

sc = SparkContext(conf = conf)

Read MMTF File and create a non-redundant set (<=40% seq. identity) of L-protein clains

In [19]:
pdb = mmtfReader.read_sequence_file('../../resources/mmtf_reduced_sample/', sc) \
                .flatMap(StructureToPolymerChains()) \
                .filter(Pisces(sequenceIdentity=40,resolution=3.0))

Get secondary structure content

In [20]:
data = secondaryStructureExtractor.get_dataset(pdb)

Define addProteinFoldType function

In [21]:
def add_protein_fold_type(data, minThreshold, maxThreshold):
    '''
    Adds a column "foldType" with three major secondary structure class:
    "alpha", "beta", "alpha+beta", and "other" based upon the fraction of alpha/beta content.

    The simplified syntax used in this method relies on two imports:
        from pyspark.sql.functions import when
        from pyspark.sql.functions import col

    Attributes:
        data (Dataset<Row>): input dataset with alpha, beta composition
        minThreshold (float): below this threshold, the secondary structure is ignored
        maxThreshold (float): above this threshold, the secondary structure is ignored
    '''

    return data.withColumn("foldType", \
                           when((col("alpha") > maxThreshold) & (col("beta") < minThreshold), "alpha"). \
                           when((col("beta") > maxThreshold) & (col("alpha") < minThreshold), "beta"). \
                           when((col("alpha") > maxThreshold) & (col("beta") > maxThreshold), "alpha+beta"). \
                           otherwise("other")\
                           )

Classify chains by secondary structure type

In [22]:
data = add_protein_fold_type(data, minThreshold=0.05, maxThreshold=0.15)

Create a Word2Vec representation of the protein sequences

n = 2 # create 2-grams

windowSize = 25 # 25-amino residue window size for Word2Vector

vectorSize = 50 # dimension of feature vector

In [23]:
encoder = ProteinSequenceEncoder(data)
data = encoder.overlapping_ngram_word2vec_encode(n=2, windowSize=25, vectorSize=50).cache()

data.toPandas().head(5)
Out[23]:
structureChainId sequence alpha beta coil dsspQ8Code dsspQ3Code foldType ngram features
0 1RCQ.A MRPARALIDLQALRHNYRLAREATGARALAVIKADAYGHGAVRCAE... 0.316527 0.240896 0.442577 CCCCEEEEEHHHHHHHHHHHHHHHCSEEEEECHHHHHTTCHHHHHH... CCCCEEEEEHHHHHHHHHHHHHHHCCEEEEECHHHHHCCCHHHHHH... alpha+beta [MR, RP, PA, AR, RA, AL, LI, ID, DL, LQ, QA, A... [0.22282994887529967, -0.20568346063700618, -0...
1 1REG.Y MIEITLKKPEDFLKVKETLTRMGIANNKDKVLYQSCHILQKKGLYY... 0.308333 0.291667 0.400000 CEEEECSSGGHHHHHHHHHTTEEEEETTTTEEEECEEEEEETTEEE... CEEEECCCHHHHHHHHHHHCCEEEEECCCCEEEECEEEEEECCEEE... alpha+beta [MI, IE, EI, IT, TL, LK, KK, KP, PE, ED, DF, F... [-0.4225819193534861, -0.0816098772420371, -0....
2 1REQ.B SSTDQGTNPADTDDLTPTTLSLAGDFPKATEEQWEREVEKVLNRGR... 0.470113 0.121163 0.408724 XXXXXXXXXXXXXXXXXXCCCSGGGSCCCCHHHHHHHHHHHHHTTC... XXXXXXXXXXXXXXXXXXCCCCHHHCCCCCHHHHHHHHHHHHHCCC... alpha+beta [SS, ST, TD, DQ, QG, GT, TN, NP, PA, AD, DT, T... [0.013261343847444785, -0.16321914651542435, -...
3 1RFE.A GTKQRADIVMSEAEIADFVNSSRTGTLATIGPDGQPHLTAMWYAVI... 0.312500 0.356250 0.331250 XCCCCTTTCCCHHHHHHHHHHCCCEEEEEECTTSCEEEEEECCEEE... XCCCCCCCCCCHHHHHHHHHHCCCEEEEEECCCCCEEEEEECCEEE... alpha+beta [GT, TK, KQ, QR, RA, AD, DI, IV, VM, MS, SE, E... [-0.001911293541700203, -0.26975917786082126, ...
4 1RG8.B HHHHHHFNLPPGNYKKPKLLYCSNGGHFLRILPDGTVDGTRDRSDQ... 0.063830 0.375887 0.560284 XXCCSCCCCCSCCSSSCEEEEETTTTEEEEECTTSCEEEESCTTCT... XXCCCCCCCCCCCCCCCEEEEECCCCEEEEECCCCCEEEECCCCCC... other [HH, HH, HH, HH, HH, HF, FN, NL, LP, PP, PG, G... [-0.3250424425727848, 0.34787580900151155, 0.2...

Keep only a subset of relevant fields for further processing

In [24]:
data = data.select(['structureChainId','alpha','beta','coil','foldType','features'])

Select only alpha and beta foldType to parquet file

In [25]:
data = data.where((data.foldType == 'alpha') | (data.foldType == 'beta')) #| (data.foldType == 'other'))

print(f"Total number of data: {data.count()}")
data.toPandas().head()
Total number of data: 2584
Out[25]:
structureChainId alpha beta coil foldType features
0 1RI6.A 0.018018 0.552553 0.429429 beta [-0.09249431734948217, 0.09015498735141335, -0...
1 1RJU.V 0.166667 0.000000 0.833333 alpha [-0.2477414113602468, 0.4224771835974284, -0.2...
2 1RK8.C 0.000000 0.393939 0.606061 beta [-0.18497365634692342, -0.04471376525205478, -...
3 1RKT.B 0.795000 0.000000 0.205000 alpha [-0.23382538002824374, -0.11802027330679052, -...
4 1RR7.A 0.702128 0.000000 0.297872 alpha [-0.08483699508360587, 0.024782998094451614, 0...

Basic dataset information and setting

In [26]:
label = 'foldType'
testFraction = 0.1
seed = 123

vector = data.first()["features"]
featureCount = len(vector)
print(f"Feature count    : {featureCount}")

classCount = int(data.select(label).distinct().count())
print(f"Class count    : {classCount}")

print(f"Dataset size (unbalanced)    : {data.count()}")

data.groupby(label).count().show(classCount)
data = datasetBalancer.downsample(data, label, 1)
print(f"Dataset size (balanced)  : {data.count()}")

data.groupby(label).count().show(classCount)
Feature count    : 50
Class count    : 2
Dataset size (unbalanced)    : 2584
+--------+-----+
|foldType|count|
+--------+-----+
|    beta|  660|
|   alpha| 1924|
+--------+-----+

Dataset size (balanced)  : 1342
+--------+-----+
|foldType|count|
+--------+-----+
|    beta|  660|
|   alpha|  682|
+--------+-----+

Decision Tree Classifier

In [27]:
dtc = DecisionTreeClassifier()
mcc = SparkMultiClassClassifier(dtc, label, testFraction, seed)
matrics = mcc.fit(data)
for k,v in matrics.items(): print(f"{k}\t{v}")

 Class  Train   Test
alpha   607     75
beta    608     52

Sample predictions: DecisionTreeClassifier
+----------------+-----------+-----------+----------+--------+--------------------+------------+-------------+--------------------+----------+--------------+
|structureChainId|      alpha|       beta|      coil|foldType|            features|indexedLabel|rawPrediction|         probability|prediction|predictedLabel|
+----------------+-----------+-----------+----------+--------+--------------------+------------+-------------+--------------------+----------+--------------+
|          5DI0.B|0.018867925|  0.6415094|0.33962265|    beta|[-0.2703873959570...|         1.0| [13.0,377.0]|[0.03333333333333...|       1.0|          beta|
|          3BWU.F|  0.0234375|  0.4921875|  0.484375|    beta|[0.12650794831784...|         1.0|   [5.0,16.0]|[0.23809523809523...|       1.0|          beta|
|          3X0T.A|0.026785715| 0.51785713|0.45535713|    beta|[-0.1733142428289...|         1.0| [13.0,377.0]|[0.03333333333333...|       1.0|          beta|
|          2QF4.A| 0.01764706|  0.5117647|0.47058824|    beta|[-0.2368620687017...|         1.0| [13.0,377.0]|[0.03333333333333...|       1.0|          beta|
|          4OUS.A|        0.0|  0.5681818| 0.4318182|    beta|[0.03226048785301...|         1.0|  [91.0,13.0]|       [0.875,0.125]|       0.0|         alpha|
|          1XAW.A|  0.8224299|        0.0|0.17757009|   alpha|[-0.4501447076875...|         0.0|  [320.0,3.0]|[0.99071207430340...|       0.0|         alpha|
|          3BRV.C| 0.78571427|        0.0|0.21428572|   alpha|[0.14613234782789...|         0.0|  [320.0,3.0]|[0.99071207430340...|       0.0|         alpha|
|          4MU6.A|  0.8540146|        0.0| 0.1459854|   alpha|[-0.2809790735148...|         0.0|  [320.0,3.0]|[0.99071207430340...|       0.0|         alpha|
|          4YK2.A|  0.7383177|0.037383176|0.22429906|   alpha|[-0.3205931102219...|         0.0|  [320.0,3.0]|[0.99071207430340...|       0.0|         alpha|
|          4I17.A|  0.8035714|        0.0|0.19642857|   alpha|[-0.2930700773053...|         0.0|    [0.0,2.0]|           [0.0,1.0]|       1.0|          beta|
|          5CWG.A| 0.86153847|        0.0|0.13846155|   alpha|[0.11841576919850...|         0.0|  [320.0,3.0]|[0.99071207430340...|       0.0|         alpha|
|          4ZP0.A| 0.84438777|        0.0|0.15561225|   alpha|[0.22175398589971...|         0.0|  [91.0,13.0]|       [0.875,0.125]|       0.0|         alpha|
|          1T6O.B|  0.8947368|        0.0|0.10526316|   alpha|[0.67486972636298...|         0.0|  [320.0,3.0]|[0.99071207430340...|       0.0|         alpha|
|          5H2F.T|        0.7|        0.0|       0.3|   alpha|[-0.3641495476637...|         0.0|  [12.0,63.0]|         [0.16,0.84]|       1.0|          beta|
|          3S0A.A|  0.6386555|0.016806724|0.34453782|   alpha|[-0.6068946828008...|         0.0|   [31.0,0.0]|           [1.0,0.0]|       0.0|         alpha|
|          5C8G.B| 0.66101694|        0.0|0.33898306|   alpha|[-0.0469646227994...|         0.0|  [320.0,3.0]|[0.99071207430340...|       0.0|         alpha|
+----------------+-----------+-----------+----------+--------+--------------------+------------+-------------+--------------------+----------+--------------+

Total time taken: 3.21268630027771
Method  DecisionTreeClassifier
F       0.8039005137708921
Accuracy        0.8031496062992126
Precision       0.8055535671677405
Recall  0.8031496062992125
False Positive Rase     0.2013547345043408
True Positive Rate      0.8031496062992125

Confusion Matrix
['alpha', 'beta']
DenseMatrix([[61., 14.],
             [11., 41.]])

Random Forest Classifier

In [28]:
rfc = RandomForestClassifier()
mcc = SparkMultiClassClassifier(rfc, label, testFraction, seed)
matrics = mcc.fit(data)
for k,v in matrics.items(): print(f"{k}\t{v}")

 Class  Train   Test
alpha   607     75
beta    608     52

Sample predictions: RandomForestClassifier
+----------------+-----------+-----------+----------+--------+--------------------+------------+--------------------+--------------------+----------+--------------+
|structureChainId|      alpha|       beta|      coil|foldType|            features|indexedLabel|       rawPrediction|         probability|prediction|predictedLabel|
+----------------+-----------+-----------+----------+--------+--------------------+------------+--------------------+--------------------+----------+--------------+
|          5DI0.B|0.018867925|  0.6415094|0.33962265|    beta|[-0.2703873959570...|         1.0|[0.51578998431281...|[0.02578949921564...|       1.0|          beta|
|          3BWU.F|  0.0234375|  0.4921875|  0.484375|    beta|[0.12650794831784...|         1.0|[13.9982189640870...|[0.69991094820435...|       0.0|         alpha|
|          3X0T.A|0.026785715| 0.51785713|0.45535713|    beta|[-0.1733142428289...|         1.0|[2.13188897197973...|[0.10659444859898...|       1.0|          beta|
|          2QF4.A| 0.01764706|  0.5117647|0.47058824|    beta|[-0.2368620687017...|         1.0|[2.75569007720908...|[0.13778450386045...|       1.0|          beta|
|          4OUS.A|        0.0|  0.5681818| 0.4318182|    beta|[0.03226048785301...|         1.0|[9.85487531968060...|[0.49274376598403...|       1.0|          beta|
|          1XAW.A|  0.8224299|        0.0|0.17757009|   alpha|[-0.4501447076875...|         0.0|[18.2154681085900...|[0.91077340542950...|       0.0|         alpha|
|          3BRV.C| 0.78571427|        0.0|0.21428572|   alpha|[0.14613234782789...|         0.0|[18.2037992417524...|[0.91018996208762...|       0.0|         alpha|
|          4MU6.A|  0.8540146|        0.0| 0.1459854|   alpha|[-0.2809790735148...|         0.0|[19.4849307860132...|[0.97424653930066...|       0.0|         alpha|
|          4YK2.A|  0.7383177|0.037383176|0.22429906|   alpha|[-0.3205931102219...|         0.0|[16.8241875964511...|[0.84120937982255...|       0.0|         alpha|
|          4I17.A|  0.8035714|        0.0|0.19642857|   alpha|[-0.2930700773053...|         0.0|[13.1749067714244...|[0.65874533857122...|       0.0|         alpha|
|          5CWG.A| 0.86153847|        0.0|0.13846155|   alpha|[0.11841576919850...|         0.0|[19.4353866362187...|[0.97176933181093...|       0.0|         alpha|
|          4ZP0.A| 0.84438777|        0.0|0.15561225|   alpha|[0.22175398589971...|         0.0|[17.3230275572504...|[0.86615137786252...|       0.0|         alpha|
|          1T6O.B|  0.8947368|        0.0|0.10526316|   alpha|[0.67486972636298...|         0.0|[18.6141301097890...|[0.93070650548945...|       0.0|         alpha|
|          5H2F.T|        0.7|        0.0|       0.3|   alpha|[-0.3641495476637...|         0.0|[8.24401892548523...|[0.41220094627426...|       1.0|          beta|
|          3S0A.A|  0.6386555|0.016806724|0.34453782|   alpha|[-0.6068946828008...|         0.0|[11.5615661468762...|[0.57807830734381...|       0.0|         alpha|
|          5C8G.B| 0.66101694|        0.0|0.33898306|   alpha|[-0.0469646227994...|         0.0|[16.9153051488506...|[0.84576525744253...|       0.0|         alpha|
+----------------+-----------+-----------+----------+--------+--------------------+------------+--------------------+--------------------+----------+--------------+

Total time taken: 4.367412090301514
Method  RandomForestClassifier
F       0.8512456189733639
Accuracy        0.8503937007874016
Precision       0.8547451305428928
Recall  0.8503937007874016
False Positive Rase     0.14500908540278618
True Positive Rate      0.8503937007874016

Confusion Matrix
['alpha', 'beta']
DenseMatrix([[63., 12.],
             [ 7., 45.]])

Logistic Regression Classifier

In [29]:
lr = LogisticRegression()
mcc = SparkMultiClassClassifier(lr, label, testFraction, seed)
matrics = mcc.fit(data)
for k,v in matrics.items(): print(f"{k}\t{v}")

 Class  Train   Test
alpha   607     75
beta    608     52

Sample predictions: LogisticRegression
+----------------+-----------+-----------+----------+--------+--------------------+------------+--------------------+--------------------+----------+--------------+
|structureChainId|      alpha|       beta|      coil|foldType|            features|indexedLabel|       rawPrediction|         probability|prediction|predictedLabel|
+----------------+-----------+-----------+----------+--------+--------------------+------------+--------------------+--------------------+----------+--------------+
|          5DI0.B|0.018867925|  0.6415094|0.33962265|    beta|[-0.2703873959570...|         1.0|[-5.0623275294473...|[0.00629098022815...|       1.0|          beta|
|          3BWU.F|  0.0234375|  0.4921875|  0.484375|    beta|[0.12650794831784...|         1.0|[-1.1865565882185...|[0.23387535279941...|       1.0|          beta|
|          3X0T.A|0.026785715| 0.51785713|0.45535713|    beta|[-0.1733142428289...|         1.0|[-5.2476892464597...|[0.00523213887961...|       1.0|          beta|
|          2QF4.A| 0.01764706|  0.5117647|0.47058824|    beta|[-0.2368620687017...|         1.0|[-4.5674722465786...|[0.01027745233916...|       1.0|          beta|
|          4OUS.A|        0.0|  0.5681818| 0.4318182|    beta|[0.03226048785301...|         1.0|[1.86554161773270...|[0.86594156213690...|       0.0|         alpha|
|          1XAW.A|  0.8224299|        0.0|0.17757009|   alpha|[-0.4501447076875...|         0.0|[2.45521137787059...|[0.92094171451502...|       0.0|         alpha|
|          3BRV.C| 0.78571427|        0.0|0.21428572|   alpha|[0.14613234782789...|         0.0|[3.79818850368450...|[0.97807992495051...|       0.0|         alpha|
|          4MU6.A|  0.8540146|        0.0| 0.1459854|   alpha|[-0.2809790735148...|         0.0|[5.18188108968423...|[0.99441394966369...|       0.0|         alpha|
|          4YK2.A|  0.7383177|0.037383176|0.22429906|   alpha|[-0.3205931102219...|         0.0|[1.90968820922997...|[0.87098411573233...|       0.0|         alpha|
|          4I17.A|  0.8035714|        0.0|0.19642857|   alpha|[-0.2930700773053...|         0.0|[3.98665134253246...|[0.98177649332300...|       0.0|         alpha|
|          5CWG.A| 0.86153847|        0.0|0.13846155|   alpha|[0.11841576919850...|         0.0|[5.79842445508542...|[0.99697683866489...|       0.0|         alpha|
|          4ZP0.A| 0.84438777|        0.0|0.15561225|   alpha|[0.22175398589971...|         0.0|[4.00761355513002...|[0.98214777367986...|       0.0|         alpha|
|          1T6O.B|  0.8947368|        0.0|0.10526316|   alpha|[0.67486972636298...|         0.0|[11.3978061720542...|[0.99998878005311...|       0.0|         alpha|
|          5H2F.T|        0.7|        0.0|       0.3|   alpha|[-0.3641495476637...|         0.0|[-1.9088191172433...|[0.12911357630884...|       1.0|          beta|
|          3S0A.A|  0.6386555|0.016806724|0.34453782|   alpha|[-0.6068946828008...|         0.0|[2.08239502966998...|[0.88918025653597...|       0.0|         alpha|
|          5C8G.B| 0.66101694|        0.0|0.33898306|   alpha|[-0.0469646227994...|         0.0|[2.88630893525889...|[0.94716547404562...|       0.0|         alpha|
+----------------+-----------+-----------+----------+--------+--------------------+------------+--------------------+--------------------+----------+--------------+

Total time taken: 7.353407859802246
Method  LogisticRegression
F       0.8825623307684451
Accuracy        0.8818897637795275
Precision       0.8859846466560102
Recall  0.8818897637795275
False Positive Rase     0.11137694326670705
True Positive Rate      0.8818897637795275

Confusion Matrix
['alpha', 'beta']
DenseMatrix([[65., 10.],
             [ 5., 47.]])

Simple Multilayer Perception Classifier

In [30]:
layers = [featureCount, 64, 64, classCount]
mpc = MultilayerPerceptronClassifier().setLayers(layers) \
                                          .setBlockSize(128) \
                                          .setSeed(1234) \
                                          .setMaxIter(100)
mcc = SparkMultiClassClassifier(mpc, label, testFraction, seed)
matrics = mcc.fit(data)
for k,v in matrics.items(): print(f"{k}\t{v}")

 Class  Train   Test
alpha   607     75
beta    608     52

Sample predictions: MultilayerPerceptronClassifier
+----------------+-----------+-----------+----------+--------+--------------------+------------+----------+--------------+
|structureChainId|      alpha|       beta|      coil|foldType|            features|indexedLabel|prediction|predictedLabel|
+----------------+-----------+-----------+----------+--------+--------------------+------------+----------+--------------+
|          5DI0.B|0.018867925|  0.6415094|0.33962265|    beta|[-0.2703873959570...|         1.0|       1.0|          beta|
|          3BWU.F|  0.0234375|  0.4921875|  0.484375|    beta|[0.12650794831784...|         1.0|       1.0|          beta|
|          3X0T.A|0.026785715| 0.51785713|0.45535713|    beta|[-0.1733142428289...|         1.0|       1.0|          beta|
|          2QF4.A| 0.01764706|  0.5117647|0.47058824|    beta|[-0.2368620687017...|         1.0|       1.0|          beta|
|          4OUS.A|        0.0|  0.5681818| 0.4318182|    beta|[0.03226048785301...|         1.0|       0.0|         alpha|
|          1XAW.A|  0.8224299|        0.0|0.17757009|   alpha|[-0.4501447076875...|         0.0|       0.0|         alpha|
|          3BRV.C| 0.78571427|        0.0|0.21428572|   alpha|[0.14613234782789...|         0.0|       0.0|         alpha|
|          4MU6.A|  0.8540146|        0.0| 0.1459854|   alpha|[-0.2809790735148...|         0.0|       0.0|         alpha|
|          4YK2.A|  0.7383177|0.037383176|0.22429906|   alpha|[-0.3205931102219...|         0.0|       0.0|         alpha|
|          4I17.A|  0.8035714|        0.0|0.19642857|   alpha|[-0.2930700773053...|         0.0|       0.0|         alpha|
|          5CWG.A| 0.86153847|        0.0|0.13846155|   alpha|[0.11841576919850...|         0.0|       0.0|         alpha|
|          4ZP0.A| 0.84438777|        0.0|0.15561225|   alpha|[0.22175398589971...|         0.0|       0.0|         alpha|
|          1T6O.B|  0.8947368|        0.0|0.10526316|   alpha|[0.67486972636298...|         0.0|       0.0|         alpha|
|          5H2F.T|        0.7|        0.0|       0.3|   alpha|[-0.3641495476637...|         0.0|       0.0|         alpha|
|          3S0A.A|  0.6386555|0.016806724|0.34453782|   alpha|[-0.6068946828008...|         0.0|       0.0|         alpha|
|          5C8G.B| 0.66101694|        0.0|0.33898306|   alpha|[-0.0469646227994...|         0.0|       0.0|         alpha|
+----------------+-----------+-----------+----------+--------+--------------------+------------+----------+--------------+

Total time taken: 13.364720821380615
Method  MultilayerPerceptronClassifier
F       0.8654660741664717
Accuracy        0.8661417322834646
Precision       0.8657956217011336
Recall  0.8661417322834646
False Positive Rase     0.1517827579244902
True Positive Rate      0.8661417322834646

Confusion Matrix
['alpha', 'beta']
DenseMatrix([[68.,  7.],
             [10., 42.]])

Terminate Spark

In [31]:
sc.stop()