# Machine learning - Protein Chain Classification

In this demo we try to classify a protein chain as either an all alpha or all beta protein based on protein sequence. We use n-grams and a Word2Vec representation of the protein sequence as a feature vector.

[Word2Vec model](https://spark.apache.org/docs/latest/mllib-feature-extraction.html#word2vec)

[Word2Vec example](https://spark.apache.org/docs/latest/ml-features.html#word2vec)

## Imports

In [17]:
from pyspark import SparkConf, SparkContext, SQLContext
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *
from mmtfPyspark.io import mmtfReader
from mmtfPyspark.webfilters import Pisces
from mmtfPyspark.filters import ContainsLProteinChain
from mmtfPyspark.mappers import StructureToPolymerChains
from mmtfPyspark.datasets import secondaryStructureExtractor
from mmtfPyspark.ml import ProteinSequenceEncoder, SparkMultiClassClassifier, datasetBalancer   
from pyspark.sql.functions import *
from pyspark.ml.classification import DecisionTreeClassifier, LogisticRegression, MultilayerPerceptronClassifier, RandomForestClassifier

## Configure Spark Context

In [18]:
conf = SparkConf() \
            .setMaster("local[*]") \
            .setAppName("MachineLearningDemo")

sc = SparkContext(conf = conf)

## Read MMTF File and create a non-redundant set (<=40% seq. identity) of L-protein clains

In [19]:
pdb = mmtfReader.read_sequence_file('../../resources/mmtf_reduced_sample/', sc) \
                .flatMap(StructureToPolymerChains()) \
                .filter(Pisces(sequenceIdentity=40,resolution=3.0))

## Get secondary structure content

In [20]:
data = secondaryStructureExtractor.get_dataset(pdb)

## Define addProteinFoldType function

In [21]:
def add_protein_fold_type(data, minThreshold, maxThreshold):
    '''
    Adds a column "foldType" with three major secondary structure class:
    "alpha", "beta", "alpha+beta", and "other" based upon the fraction of alpha/beta content.

    The simplified syntax used in this method relies on two imports:
        from pyspark.sql.functions import when
        from pyspark.sql.functions import col

    Attributes:
        data (Dataset<Row>): input dataset with alpha, beta composition
        minThreshold (float): below this threshold, the secondary structure is ignored
        maxThreshold (float): above this threshold, the secondary structure is ignored
    '''

    return data.withColumn("foldType", \
                           when((col("alpha") > maxThreshold) & (col("beta") < minThreshold), "alpha"). \
                           when((col("beta") > maxThreshold) & (col("alpha") < minThreshold), "beta"). \
                           when((col("alpha") > maxThreshold) & (col("beta") > maxThreshold), "alpha+beta"). \
                           otherwise("other")\
                           )

## Classify chains by secondary structure type

In [22]:
data = add_protein_fold_type(data, minThreshold=0.05, maxThreshold=0.15)

## Create a Word2Vec representation of the protein sequences

**n = 2**     # create 2-grams 

**windowSize = 25**    # 25-amino residue window size for Word2Vector

**vectorSize = 50**    # dimension of feature vector

In [23]:
encoder = ProteinSequenceEncoder(data)
data = encoder.overlapping_ngram_word2vec_encode(n=2, windowSize=25, vectorSize=50).cache()

data.toPandas().head(5)

Unnamed: 0,structureChainId,sequence,alpha,beta,coil,dsspQ8Code,dsspQ3Code,foldType,ngram,features
0,1RCQ.A,MRPARALIDLQALRHNYRLAREATGARALAVIKADAYGHGAVRCAE...,0.316527,0.240896,0.442577,CCCCEEEEEHHHHHHHHHHHHHHHCSEEEEECHHHHHTTCHHHHHH...,CCCCEEEEEHHHHHHHHHHHHHHHCCEEEEECHHHHHCCCHHHHHH...,alpha+beta,"[MR, RP, PA, AR, RA, AL, LI, ID, DL, LQ, QA, A...","[0.22282994887529967, -0.20568346063700618, -0..."
1,1REG.Y,MIEITLKKPEDFLKVKETLTRMGIANNKDKVLYQSCHILQKKGLYY...,0.308333,0.291667,0.4,CEEEECSSGGHHHHHHHHHTTEEEEETTTTEEEECEEEEEETTEEE...,CEEEECCCHHHHHHHHHHHCCEEEEECCCCEEEECEEEEEECCEEE...,alpha+beta,"[MI, IE, EI, IT, TL, LK, KK, KP, PE, ED, DF, F...","[-0.4225819193534861, -0.0816098772420371, -0...."
2,1REQ.B,SSTDQGTNPADTDDLTPTTLSLAGDFPKATEEQWEREVEKVLNRGR...,0.470113,0.121163,0.408724,XXXXXXXXXXXXXXXXXXCCCSGGGSCCCCHHHHHHHHHHHHHTTC...,XXXXXXXXXXXXXXXXXXCCCCHHHCCCCCHHHHHHHHHHHHHCCC...,alpha+beta,"[SS, ST, TD, DQ, QG, GT, TN, NP, PA, AD, DT, T...","[0.013261343847444785, -0.16321914651542435, -..."
3,1RFE.A,GTKQRADIVMSEAEIADFVNSSRTGTLATIGPDGQPHLTAMWYAVI...,0.3125,0.35625,0.33125,XCCCCTTTCCCHHHHHHHHHHCCCEEEEEECTTSCEEEEEECCEEE...,XCCCCCCCCCCHHHHHHHHHHCCCEEEEEECCCCCEEEEEECCEEE...,alpha+beta,"[GT, TK, KQ, QR, RA, AD, DI, IV, VM, MS, SE, E...","[-0.001911293541700203, -0.26975917786082126, ..."
4,1RG8.B,HHHHHHFNLPPGNYKKPKLLYCSNGGHFLRILPDGTVDGTRDRSDQ...,0.06383,0.375887,0.560284,XXCCSCCCCCSCCSSSCEEEEETTTTEEEEECTTSCEEEESCTTCT...,XXCCCCCCCCCCCCCCCEEEEECCCCEEEEECCCCCEEEECCCCCC...,other,"[HH, HH, HH, HH, HH, HF, FN, NL, LP, PP, PG, G...","[-0.3250424425727848, 0.34787580900151155, 0.2..."


## Keep only a subset of relevant fields for further processing

In [24]:
data = data.select(['structureChainId','alpha','beta','coil','foldType','features'])

## Select only alpha and beta foldType to parquet file

In [25]:
data = data.where((data.foldType == 'alpha') | (data.foldType == 'beta')) #| (data.foldType == 'other'))

print(f"Total number of data: {data.count()}")
data.toPandas().head()

Total number of data: 2584


Unnamed: 0,structureChainId,alpha,beta,coil,foldType,features
0,1RI6.A,0.018018,0.552553,0.429429,beta,"[-0.09249431734948217, 0.09015498735141335, -0..."
1,1RJU.V,0.166667,0.0,0.833333,alpha,"[-0.2477414113602468, 0.4224771835974284, -0.2..."
2,1RK8.C,0.0,0.393939,0.606061,beta,"[-0.18497365634692342, -0.04471376525205478, -..."
3,1RKT.B,0.795,0.0,0.205,alpha,"[-0.23382538002824374, -0.11802027330679052, -..."
4,1RR7.A,0.702128,0.0,0.297872,alpha,"[-0.08483699508360587, 0.024782998094451614, 0..."


## Basic dataset information and setting

In [26]:
label = 'foldType'
testFraction = 0.1
seed = 123

vector = data.first()["features"]
featureCount = len(vector)
print(f"Feature count    : {featureCount}")
    
classCount = int(data.select(label).distinct().count())
print(f"Class count    : {classCount}")

print(f"Dataset size (unbalanced)    : {data.count()}")
    
data.groupby(label).count().show(classCount)
data = datasetBalancer.downsample(data, label, 1)
print(f"Dataset size (balanced)  : {data.count()}")
    
data.groupby(label).count().show(classCount)

Feature count    : 50
Class count    : 2
Dataset size (unbalanced)    : 2584
+--------+-----+
|foldType|count|
+--------+-----+
|    beta|  660|
|   alpha| 1924|
+--------+-----+

Dataset size (balanced)  : 1342
+--------+-----+
|foldType|count|
+--------+-----+
|    beta|  660|
|   alpha|  682|
+--------+-----+



## Decision Tree Classifier

In [27]:
dtc = DecisionTreeClassifier()
mcc = SparkMultiClassClassifier(dtc, label, testFraction, seed)
matrics = mcc.fit(data)
for k,v in matrics.items(): print(f"{k}\t{v}")


 Class	Train	Test
alpha	607	75
beta	608	52

Sample predictions: DecisionTreeClassifier
+----------------+-----------+-----------+----------+--------+--------------------+------------+-------------+--------------------+----------+--------------+
|structureChainId|      alpha|       beta|      coil|foldType|            features|indexedLabel|rawPrediction|         probability|prediction|predictedLabel|
+----------------+-----------+-----------+----------+--------+--------------------+------------+-------------+--------------------+----------+--------------+
|          5DI0.B|0.018867925|  0.6415094|0.33962265|    beta|[-0.2703873959570...|         1.0| [13.0,377.0]|[0.03333333333333...|       1.0|          beta|
|          3BWU.F|  0.0234375|  0.4921875|  0.484375|    beta|[0.12650794831784...|         1.0|   [5.0,16.0]|[0.23809523809523...|       1.0|          beta|
|          3X0T.A|0.026785715| 0.51785713|0.45535713|    beta|[-0.1733142428289...|         1.0| [13.0,377.0]|[0.033333333

## Random Forest Classifier

In [28]:
rfc = RandomForestClassifier()
mcc = SparkMultiClassClassifier(rfc, label, testFraction, seed)
matrics = mcc.fit(data)
for k,v in matrics.items(): print(f"{k}\t{v}")


 Class	Train	Test
alpha	607	75
beta	608	52

Sample predictions: RandomForestClassifier
+----------------+-----------+-----------+----------+--------+--------------------+------------+--------------------+--------------------+----------+--------------+
|structureChainId|      alpha|       beta|      coil|foldType|            features|indexedLabel|       rawPrediction|         probability|prediction|predictedLabel|
+----------------+-----------+-----------+----------+--------+--------------------+------------+--------------------+--------------------+----------+--------------+
|          5DI0.B|0.018867925|  0.6415094|0.33962265|    beta|[-0.2703873959570...|         1.0|[0.51578998431281...|[0.02578949921564...|       1.0|          beta|
|          3BWU.F|  0.0234375|  0.4921875|  0.484375|    beta|[0.12650794831784...|         1.0|[13.9982189640870...|[0.69991094820435...|       0.0|         alpha|
|          3X0T.A|0.026785715| 0.51785713|0.45535713|    beta|[-0.1733142428289...|    

## Logistic Regression Classifier

In [29]:
lr = LogisticRegression()
mcc = SparkMultiClassClassifier(lr, label, testFraction, seed)
matrics = mcc.fit(data)
for k,v in matrics.items(): print(f"{k}\t{v}")


 Class	Train	Test
alpha	607	75
beta	608	52

Sample predictions: LogisticRegression
+----------------+-----------+-----------+----------+--------+--------------------+------------+--------------------+--------------------+----------+--------------+
|structureChainId|      alpha|       beta|      coil|foldType|            features|indexedLabel|       rawPrediction|         probability|prediction|predictedLabel|
+----------------+-----------+-----------+----------+--------+--------------------+------------+--------------------+--------------------+----------+--------------+
|          5DI0.B|0.018867925|  0.6415094|0.33962265|    beta|[-0.2703873959570...|         1.0|[-5.0623275294473...|[0.00629098022815...|       1.0|          beta|
|          3BWU.F|  0.0234375|  0.4921875|  0.484375|    beta|[0.12650794831784...|         1.0|[-1.1865565882185...|[0.23387535279941...|       1.0|          beta|
|          3X0T.A|0.026785715| 0.51785713|0.45535713|    beta|[-0.1733142428289...|        

## Simple Multilayer Perception Classifier

In [30]:
layers = [featureCount, 64, 64, classCount]
mpc = MultilayerPerceptronClassifier().setLayers(layers) \
                                          .setBlockSize(128) \
                                          .setSeed(1234) \
                                          .setMaxIter(100)
mcc = SparkMultiClassClassifier(mpc, label, testFraction, seed)
matrics = mcc.fit(data)
for k,v in matrics.items(): print(f"{k}\t{v}")


 Class	Train	Test
alpha	607	75
beta	608	52

Sample predictions: MultilayerPerceptronClassifier
+----------------+-----------+-----------+----------+--------+--------------------+------------+----------+--------------+
|structureChainId|      alpha|       beta|      coil|foldType|            features|indexedLabel|prediction|predictedLabel|
+----------------+-----------+-----------+----------+--------+--------------------+------------+----------+--------------+
|          5DI0.B|0.018867925|  0.6415094|0.33962265|    beta|[-0.2703873959570...|         1.0|       1.0|          beta|
|          3BWU.F|  0.0234375|  0.4921875|  0.484375|    beta|[0.12650794831784...|         1.0|       1.0|          beta|
|          3X0T.A|0.026785715| 0.51785713|0.45535713|    beta|[-0.1733142428289...|         1.0|       1.0|          beta|
|          2QF4.A| 0.01764706|  0.5117647|0.47058824|    beta|[-0.2368620687017...|         1.0|       1.0|          beta|
|          4OUS.A|        0.0|  0.5681818| 

## Terminate Spark

In [31]:
sc.stop()