In this demo we try to classify a protein chain as either an all alpha or all beta protein based on protein sequence. We use n-grams and a Word2Vec representation of the protein sequence as a feature vector.
In [17]:
from pyspark import SparkConf, SparkContext, SQLContext
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *
from mmtfPyspark.io import mmtfReader
from mmtfPyspark.webfilters import Pisces
from mmtfPyspark.filters import ContainsLProteinChain
from mmtfPyspark.mappers import StructureToPolymerChains
from mmtfPyspark.datasets import secondaryStructureExtractor
from mmtfPyspark.ml import ProteinSequenceEncoder, SparkMultiClassClassifier, datasetBalancer
from pyspark.sql.functions import *
from pyspark.ml.classification import DecisionTreeClassifier, LogisticRegression, MultilayerPerceptronClassifier, RandomForestClassifier
In [18]:
conf = SparkConf() \
.setMaster("local[*]") \
.setAppName("MachineLearningDemo")
sc = SparkContext(conf = conf)
In [19]:
pdb = mmtfReader.read_sequence_file('../../resources/mmtf_reduced_sample/', sc) \
.flatMap(StructureToPolymerChains()) \
.filter(Pisces(sequenceIdentity=40,resolution=3.0))
In [20]:
data = secondaryStructureExtractor.get_dataset(pdb)
In [21]:
def add_protein_fold_type(data, minThreshold, maxThreshold):
'''
Adds a column "foldType" with three major secondary structure class:
"alpha", "beta", "alpha+beta", and "other" based upon the fraction of alpha/beta content.
The simplified syntax used in this method relies on two imports:
from pyspark.sql.functions import when
from pyspark.sql.functions import col
Attributes:
data (Dataset<Row>): input dataset with alpha, beta composition
minThreshold (float): below this threshold, the secondary structure is ignored
maxThreshold (float): above this threshold, the secondary structure is ignored
'''
return data.withColumn("foldType", \
when((col("alpha") > maxThreshold) & (col("beta") < minThreshold), "alpha"). \
when((col("beta") > maxThreshold) & (col("alpha") < minThreshold), "beta"). \
when((col("alpha") > maxThreshold) & (col("beta") > maxThreshold), "alpha+beta"). \
otherwise("other")\
)
In [22]:
data = add_protein_fold_type(data, minThreshold=0.05, maxThreshold=0.15)
n = 2 # create 2-grams
windowSize = 25 # 25-amino residue window size for Word2Vector
vectorSize = 50 # dimension of feature vector
In [23]:
encoder = ProteinSequenceEncoder(data)
data = encoder.overlapping_ngram_word2vec_encode(n=2, windowSize=25, vectorSize=50).cache()
data.toPandas().head(5)
Out[23]:
structureChainId | sequence | alpha | beta | coil | dsspQ8Code | dsspQ3Code | foldType | ngram | features | |
---|---|---|---|---|---|---|---|---|---|---|
0 | 1RCQ.A | MRPARALIDLQALRHNYRLAREATGARALAVIKADAYGHGAVRCAE... | 0.316527 | 0.240896 | 0.442577 | CCCCEEEEEHHHHHHHHHHHHHHHCSEEEEECHHHHHTTCHHHHHH... | CCCCEEEEEHHHHHHHHHHHHHHHCCEEEEECHHHHHCCCHHHHHH... | alpha+beta | [MR, RP, PA, AR, RA, AL, LI, ID, DL, LQ, QA, A... | [0.22282994887529967, -0.20568346063700618, -0... |
1 | 1REG.Y | MIEITLKKPEDFLKVKETLTRMGIANNKDKVLYQSCHILQKKGLYY... | 0.308333 | 0.291667 | 0.400000 | CEEEECSSGGHHHHHHHHHTTEEEEETTTTEEEECEEEEEETTEEE... | CEEEECCCHHHHHHHHHHHCCEEEEECCCCEEEECEEEEEECCEEE... | alpha+beta | [MI, IE, EI, IT, TL, LK, KK, KP, PE, ED, DF, F... | [-0.4225819193534861, -0.0816098772420371, -0.... |
2 | 1REQ.B | SSTDQGTNPADTDDLTPTTLSLAGDFPKATEEQWEREVEKVLNRGR... | 0.470113 | 0.121163 | 0.408724 | XXXXXXXXXXXXXXXXXXCCCSGGGSCCCCHHHHHHHHHHHHHTTC... | XXXXXXXXXXXXXXXXXXCCCCHHHCCCCCHHHHHHHHHHHHHCCC... | alpha+beta | [SS, ST, TD, DQ, QG, GT, TN, NP, PA, AD, DT, T... | [0.013261343847444785, -0.16321914651542435, -... |
3 | 1RFE.A | GTKQRADIVMSEAEIADFVNSSRTGTLATIGPDGQPHLTAMWYAVI... | 0.312500 | 0.356250 | 0.331250 | XCCCCTTTCCCHHHHHHHHHHCCCEEEEEECTTSCEEEEEECCEEE... | XCCCCCCCCCCHHHHHHHHHHCCCEEEEEECCCCCEEEEEECCEEE... | alpha+beta | [GT, TK, KQ, QR, RA, AD, DI, IV, VM, MS, SE, E... | [-0.001911293541700203, -0.26975917786082126, ... |
4 | 1RG8.B | HHHHHHFNLPPGNYKKPKLLYCSNGGHFLRILPDGTVDGTRDRSDQ... | 0.063830 | 0.375887 | 0.560284 | XXCCSCCCCCSCCSSSCEEEEETTTTEEEEECTTSCEEEESCTTCT... | XXCCCCCCCCCCCCCCCEEEEECCCCEEEEECCCCCEEEECCCCCC... | other | [HH, HH, HH, HH, HH, HF, FN, NL, LP, PP, PG, G... | [-0.3250424425727848, 0.34787580900151155, 0.2... |
In [24]:
data = data.select(['structureChainId','alpha','beta','coil','foldType','features'])
In [25]:
data = data.where((data.foldType == 'alpha') | (data.foldType == 'beta')) #| (data.foldType == 'other'))
print(f"Total number of data: {data.count()}")
data.toPandas().head()
Total number of data: 2584
Out[25]:
structureChainId | alpha | beta | coil | foldType | features | |
---|---|---|---|---|---|---|
0 | 1RI6.A | 0.018018 | 0.552553 | 0.429429 | beta | [-0.09249431734948217, 0.09015498735141335, -0... |
1 | 1RJU.V | 0.166667 | 0.000000 | 0.833333 | alpha | [-0.2477414113602468, 0.4224771835974284, -0.2... |
2 | 1RK8.C | 0.000000 | 0.393939 | 0.606061 | beta | [-0.18497365634692342, -0.04471376525205478, -... |
3 | 1RKT.B | 0.795000 | 0.000000 | 0.205000 | alpha | [-0.23382538002824374, -0.11802027330679052, -... |
4 | 1RR7.A | 0.702128 | 0.000000 | 0.297872 | alpha | [-0.08483699508360587, 0.024782998094451614, 0... |
In [26]:
label = 'foldType'
testFraction = 0.1
seed = 123
vector = data.first()["features"]
featureCount = len(vector)
print(f"Feature count : {featureCount}")
classCount = int(data.select(label).distinct().count())
print(f"Class count : {classCount}")
print(f"Dataset size (unbalanced) : {data.count()}")
data.groupby(label).count().show(classCount)
data = datasetBalancer.downsample(data, label, 1)
print(f"Dataset size (balanced) : {data.count()}")
data.groupby(label).count().show(classCount)
Feature count : 50
Class count : 2
Dataset size (unbalanced) : 2584
+--------+-----+
|foldType|count|
+--------+-----+
| beta| 660|
| alpha| 1924|
+--------+-----+
Dataset size (balanced) : 1342
+--------+-----+
|foldType|count|
+--------+-----+
| beta| 660|
| alpha| 682|
+--------+-----+
In [27]:
dtc = DecisionTreeClassifier()
mcc = SparkMultiClassClassifier(dtc, label, testFraction, seed)
matrics = mcc.fit(data)
for k,v in matrics.items(): print(f"{k}\t{v}")
Class Train Test
alpha 607 75
beta 608 52
Sample predictions: DecisionTreeClassifier
+----------------+-----------+-----------+----------+--------+--------------------+------------+-------------+--------------------+----------+--------------+
|structureChainId| alpha| beta| coil|foldType| features|indexedLabel|rawPrediction| probability|prediction|predictedLabel|
+----------------+-----------+-----------+----------+--------+--------------------+------------+-------------+--------------------+----------+--------------+
| 5DI0.B|0.018867925| 0.6415094|0.33962265| beta|[-0.2703873959570...| 1.0| [13.0,377.0]|[0.03333333333333...| 1.0| beta|
| 3BWU.F| 0.0234375| 0.4921875| 0.484375| beta|[0.12650794831784...| 1.0| [5.0,16.0]|[0.23809523809523...| 1.0| beta|
| 3X0T.A|0.026785715| 0.51785713|0.45535713| beta|[-0.1733142428289...| 1.0| [13.0,377.0]|[0.03333333333333...| 1.0| beta|
| 2QF4.A| 0.01764706| 0.5117647|0.47058824| beta|[-0.2368620687017...| 1.0| [13.0,377.0]|[0.03333333333333...| 1.0| beta|
| 4OUS.A| 0.0| 0.5681818| 0.4318182| beta|[0.03226048785301...| 1.0| [91.0,13.0]| [0.875,0.125]| 0.0| alpha|
| 1XAW.A| 0.8224299| 0.0|0.17757009| alpha|[-0.4501447076875...| 0.0| [320.0,3.0]|[0.99071207430340...| 0.0| alpha|
| 3BRV.C| 0.78571427| 0.0|0.21428572| alpha|[0.14613234782789...| 0.0| [320.0,3.0]|[0.99071207430340...| 0.0| alpha|
| 4MU6.A| 0.8540146| 0.0| 0.1459854| alpha|[-0.2809790735148...| 0.0| [320.0,3.0]|[0.99071207430340...| 0.0| alpha|
| 4YK2.A| 0.7383177|0.037383176|0.22429906| alpha|[-0.3205931102219...| 0.0| [320.0,3.0]|[0.99071207430340...| 0.0| alpha|
| 4I17.A| 0.8035714| 0.0|0.19642857| alpha|[-0.2930700773053...| 0.0| [0.0,2.0]| [0.0,1.0]| 1.0| beta|
| 5CWG.A| 0.86153847| 0.0|0.13846155| alpha|[0.11841576919850...| 0.0| [320.0,3.0]|[0.99071207430340...| 0.0| alpha|
| 4ZP0.A| 0.84438777| 0.0|0.15561225| alpha|[0.22175398589971...| 0.0| [91.0,13.0]| [0.875,0.125]| 0.0| alpha|
| 1T6O.B| 0.8947368| 0.0|0.10526316| alpha|[0.67486972636298...| 0.0| [320.0,3.0]|[0.99071207430340...| 0.0| alpha|
| 5H2F.T| 0.7| 0.0| 0.3| alpha|[-0.3641495476637...| 0.0| [12.0,63.0]| [0.16,0.84]| 1.0| beta|
| 3S0A.A| 0.6386555|0.016806724|0.34453782| alpha|[-0.6068946828008...| 0.0| [31.0,0.0]| [1.0,0.0]| 0.0| alpha|
| 5C8G.B| 0.66101694| 0.0|0.33898306| alpha|[-0.0469646227994...| 0.0| [320.0,3.0]|[0.99071207430340...| 0.0| alpha|
+----------------+-----------+-----------+----------+--------+--------------------+------------+-------------+--------------------+----------+--------------+
Total time taken: 3.21268630027771
Method DecisionTreeClassifier
F 0.8039005137708921
Accuracy 0.8031496062992126
Precision 0.8055535671677405
Recall 0.8031496062992125
False Positive Rase 0.2013547345043408
True Positive Rate 0.8031496062992125
Confusion Matrix
['alpha', 'beta']
DenseMatrix([[61., 14.],
[11., 41.]])
In [28]:
rfc = RandomForestClassifier()
mcc = SparkMultiClassClassifier(rfc, label, testFraction, seed)
matrics = mcc.fit(data)
for k,v in matrics.items(): print(f"{k}\t{v}")
Class Train Test
alpha 607 75
beta 608 52
Sample predictions: RandomForestClassifier
+----------------+-----------+-----------+----------+--------+--------------------+------------+--------------------+--------------------+----------+--------------+
|structureChainId| alpha| beta| coil|foldType| features|indexedLabel| rawPrediction| probability|prediction|predictedLabel|
+----------------+-----------+-----------+----------+--------+--------------------+------------+--------------------+--------------------+----------+--------------+
| 5DI0.B|0.018867925| 0.6415094|0.33962265| beta|[-0.2703873959570...| 1.0|[0.51578998431281...|[0.02578949921564...| 1.0| beta|
| 3BWU.F| 0.0234375| 0.4921875| 0.484375| beta|[0.12650794831784...| 1.0|[13.9982189640870...|[0.69991094820435...| 0.0| alpha|
| 3X0T.A|0.026785715| 0.51785713|0.45535713| beta|[-0.1733142428289...| 1.0|[2.13188897197973...|[0.10659444859898...| 1.0| beta|
| 2QF4.A| 0.01764706| 0.5117647|0.47058824| beta|[-0.2368620687017...| 1.0|[2.75569007720908...|[0.13778450386045...| 1.0| beta|
| 4OUS.A| 0.0| 0.5681818| 0.4318182| beta|[0.03226048785301...| 1.0|[9.85487531968060...|[0.49274376598403...| 1.0| beta|
| 1XAW.A| 0.8224299| 0.0|0.17757009| alpha|[-0.4501447076875...| 0.0|[18.2154681085900...|[0.91077340542950...| 0.0| alpha|
| 3BRV.C| 0.78571427| 0.0|0.21428572| alpha|[0.14613234782789...| 0.0|[18.2037992417524...|[0.91018996208762...| 0.0| alpha|
| 4MU6.A| 0.8540146| 0.0| 0.1459854| alpha|[-0.2809790735148...| 0.0|[19.4849307860132...|[0.97424653930066...| 0.0| alpha|
| 4YK2.A| 0.7383177|0.037383176|0.22429906| alpha|[-0.3205931102219...| 0.0|[16.8241875964511...|[0.84120937982255...| 0.0| alpha|
| 4I17.A| 0.8035714| 0.0|0.19642857| alpha|[-0.2930700773053...| 0.0|[13.1749067714244...|[0.65874533857122...| 0.0| alpha|
| 5CWG.A| 0.86153847| 0.0|0.13846155| alpha|[0.11841576919850...| 0.0|[19.4353866362187...|[0.97176933181093...| 0.0| alpha|
| 4ZP0.A| 0.84438777| 0.0|0.15561225| alpha|[0.22175398589971...| 0.0|[17.3230275572504...|[0.86615137786252...| 0.0| alpha|
| 1T6O.B| 0.8947368| 0.0|0.10526316| alpha|[0.67486972636298...| 0.0|[18.6141301097890...|[0.93070650548945...| 0.0| alpha|
| 5H2F.T| 0.7| 0.0| 0.3| alpha|[-0.3641495476637...| 0.0|[8.24401892548523...|[0.41220094627426...| 1.0| beta|
| 3S0A.A| 0.6386555|0.016806724|0.34453782| alpha|[-0.6068946828008...| 0.0|[11.5615661468762...|[0.57807830734381...| 0.0| alpha|
| 5C8G.B| 0.66101694| 0.0|0.33898306| alpha|[-0.0469646227994...| 0.0|[16.9153051488506...|[0.84576525744253...| 0.0| alpha|
+----------------+-----------+-----------+----------+--------+--------------------+------------+--------------------+--------------------+----------+--------------+
Total time taken: 4.367412090301514
Method RandomForestClassifier
F 0.8512456189733639
Accuracy 0.8503937007874016
Precision 0.8547451305428928
Recall 0.8503937007874016
False Positive Rase 0.14500908540278618
True Positive Rate 0.8503937007874016
Confusion Matrix
['alpha', 'beta']
DenseMatrix([[63., 12.],
[ 7., 45.]])
In [29]:
lr = LogisticRegression()
mcc = SparkMultiClassClassifier(lr, label, testFraction, seed)
matrics = mcc.fit(data)
for k,v in matrics.items(): print(f"{k}\t{v}")
Class Train Test
alpha 607 75
beta 608 52
Sample predictions: LogisticRegression
+----------------+-----------+-----------+----------+--------+--------------------+------------+--------------------+--------------------+----------+--------------+
|structureChainId| alpha| beta| coil|foldType| features|indexedLabel| rawPrediction| probability|prediction|predictedLabel|
+----------------+-----------+-----------+----------+--------+--------------------+------------+--------------------+--------------------+----------+--------------+
| 5DI0.B|0.018867925| 0.6415094|0.33962265| beta|[-0.2703873959570...| 1.0|[-5.0623275294473...|[0.00629098022815...| 1.0| beta|
| 3BWU.F| 0.0234375| 0.4921875| 0.484375| beta|[0.12650794831784...| 1.0|[-1.1865565882185...|[0.23387535279941...| 1.0| beta|
| 3X0T.A|0.026785715| 0.51785713|0.45535713| beta|[-0.1733142428289...| 1.0|[-5.2476892464597...|[0.00523213887961...| 1.0| beta|
| 2QF4.A| 0.01764706| 0.5117647|0.47058824| beta|[-0.2368620687017...| 1.0|[-4.5674722465786...|[0.01027745233916...| 1.0| beta|
| 4OUS.A| 0.0| 0.5681818| 0.4318182| beta|[0.03226048785301...| 1.0|[1.86554161773270...|[0.86594156213690...| 0.0| alpha|
| 1XAW.A| 0.8224299| 0.0|0.17757009| alpha|[-0.4501447076875...| 0.0|[2.45521137787059...|[0.92094171451502...| 0.0| alpha|
| 3BRV.C| 0.78571427| 0.0|0.21428572| alpha|[0.14613234782789...| 0.0|[3.79818850368450...|[0.97807992495051...| 0.0| alpha|
| 4MU6.A| 0.8540146| 0.0| 0.1459854| alpha|[-0.2809790735148...| 0.0|[5.18188108968423...|[0.99441394966369...| 0.0| alpha|
| 4YK2.A| 0.7383177|0.037383176|0.22429906| alpha|[-0.3205931102219...| 0.0|[1.90968820922997...|[0.87098411573233...| 0.0| alpha|
| 4I17.A| 0.8035714| 0.0|0.19642857| alpha|[-0.2930700773053...| 0.0|[3.98665134253246...|[0.98177649332300...| 0.0| alpha|
| 5CWG.A| 0.86153847| 0.0|0.13846155| alpha|[0.11841576919850...| 0.0|[5.79842445508542...|[0.99697683866489...| 0.0| alpha|
| 4ZP0.A| 0.84438777| 0.0|0.15561225| alpha|[0.22175398589971...| 0.0|[4.00761355513002...|[0.98214777367986...| 0.0| alpha|
| 1T6O.B| 0.8947368| 0.0|0.10526316| alpha|[0.67486972636298...| 0.0|[11.3978061720542...|[0.99998878005311...| 0.0| alpha|
| 5H2F.T| 0.7| 0.0| 0.3| alpha|[-0.3641495476637...| 0.0|[-1.9088191172433...|[0.12911357630884...| 1.0| beta|
| 3S0A.A| 0.6386555|0.016806724|0.34453782| alpha|[-0.6068946828008...| 0.0|[2.08239502966998...|[0.88918025653597...| 0.0| alpha|
| 5C8G.B| 0.66101694| 0.0|0.33898306| alpha|[-0.0469646227994...| 0.0|[2.88630893525889...|[0.94716547404562...| 0.0| alpha|
+----------------+-----------+-----------+----------+--------+--------------------+------------+--------------------+--------------------+----------+--------------+
Total time taken: 7.353407859802246
Method LogisticRegression
F 0.8825623307684451
Accuracy 0.8818897637795275
Precision 0.8859846466560102
Recall 0.8818897637795275
False Positive Rase 0.11137694326670705
True Positive Rate 0.8818897637795275
Confusion Matrix
['alpha', 'beta']
DenseMatrix([[65., 10.],
[ 5., 47.]])
In [30]:
layers = [featureCount, 64, 64, classCount]
mpc = MultilayerPerceptronClassifier().setLayers(layers) \
.setBlockSize(128) \
.setSeed(1234) \
.setMaxIter(100)
mcc = SparkMultiClassClassifier(mpc, label, testFraction, seed)
matrics = mcc.fit(data)
for k,v in matrics.items(): print(f"{k}\t{v}")
Class Train Test
alpha 607 75
beta 608 52
Sample predictions: MultilayerPerceptronClassifier
+----------------+-----------+-----------+----------+--------+--------------------+------------+----------+--------------+
|structureChainId| alpha| beta| coil|foldType| features|indexedLabel|prediction|predictedLabel|
+----------------+-----------+-----------+----------+--------+--------------------+------------+----------+--------------+
| 5DI0.B|0.018867925| 0.6415094|0.33962265| beta|[-0.2703873959570...| 1.0| 1.0| beta|
| 3BWU.F| 0.0234375| 0.4921875| 0.484375| beta|[0.12650794831784...| 1.0| 1.0| beta|
| 3X0T.A|0.026785715| 0.51785713|0.45535713| beta|[-0.1733142428289...| 1.0| 1.0| beta|
| 2QF4.A| 0.01764706| 0.5117647|0.47058824| beta|[-0.2368620687017...| 1.0| 1.0| beta|
| 4OUS.A| 0.0| 0.5681818| 0.4318182| beta|[0.03226048785301...| 1.0| 0.0| alpha|
| 1XAW.A| 0.8224299| 0.0|0.17757009| alpha|[-0.4501447076875...| 0.0| 0.0| alpha|
| 3BRV.C| 0.78571427| 0.0|0.21428572| alpha|[0.14613234782789...| 0.0| 0.0| alpha|
| 4MU6.A| 0.8540146| 0.0| 0.1459854| alpha|[-0.2809790735148...| 0.0| 0.0| alpha|
| 4YK2.A| 0.7383177|0.037383176|0.22429906| alpha|[-0.3205931102219...| 0.0| 0.0| alpha|
| 4I17.A| 0.8035714| 0.0|0.19642857| alpha|[-0.2930700773053...| 0.0| 0.0| alpha|
| 5CWG.A| 0.86153847| 0.0|0.13846155| alpha|[0.11841576919850...| 0.0| 0.0| alpha|
| 4ZP0.A| 0.84438777| 0.0|0.15561225| alpha|[0.22175398589971...| 0.0| 0.0| alpha|
| 1T6O.B| 0.8947368| 0.0|0.10526316| alpha|[0.67486972636298...| 0.0| 0.0| alpha|
| 5H2F.T| 0.7| 0.0| 0.3| alpha|[-0.3641495476637...| 0.0| 0.0| alpha|
| 3S0A.A| 0.6386555|0.016806724|0.34453782| alpha|[-0.6068946828008...| 0.0| 0.0| alpha|
| 5C8G.B| 0.66101694| 0.0|0.33898306| alpha|[-0.0469646227994...| 0.0| 0.0| alpha|
+----------------+-----------+-----------+----------+--------+--------------------+------------+----------+--------------+
Total time taken: 13.364720821380615
Method MultilayerPerceptronClassifier
F 0.8654660741664717
Accuracy 0.8661417322834646
Precision 0.8657956217011336
Recall 0.8661417322834646
False Positive Rase 0.1517827579244902
True Positive Rate 0.8661417322834646
Confusion Matrix
['alpha', 'beta']
DenseMatrix([[68., 7.],
[10., 42.]])
In [31]:
sc.stop()