{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Machine learning - Protein Chain Classification\n", "\n", "In this demo we try to classify a protein chain as either an all alpha or all beta protein based on protein sequence. We use n-grams and a Word2Vec representation of the protein sequence as a feature vector.\n", "\n", "[Word2Vec model](https://spark.apache.org/docs/latest/mllib-feature-extraction.html#word2vec)\n", "\n", "[Word2Vec example](https://spark.apache.org/docs/latest/ml-features.html#word2vec)\n", "\n", "## Imports" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "from pyspark import SparkConf, SparkContext, SQLContext\n", "from pyspark.sql import SparkSession\n", "from pyspark.sql.functions import *\n", "from pyspark.sql.types import *\n", "from mmtfPyspark.io import mmtfReader\n", "from mmtfPyspark.webfilters import Pisces\n", "from mmtfPyspark.filters import ContainsLProteinChain\n", "from mmtfPyspark.mappers import StructureToPolymerChains\n", "from mmtfPyspark.datasets import secondaryStructureExtractor\n", "from mmtfPyspark.ml import ProteinSequenceEncoder, SparkMultiClassClassifier, datasetBalancer \n", "from pyspark.sql.functions import *\n", "from pyspark.ml.classification import DecisionTreeClassifier, LogisticRegression, MultilayerPerceptronClassifier, RandomForestClassifier" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Configure Spark Context" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "conf = SparkConf() \\\n", " .setMaster(\"local[*]\") \\\n", " .setAppName(\"MachineLearningDemo\")\n", "\n", "sc = SparkContext(conf = conf)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Read MMTF File and create a non-redundant set (<=40% seq. identity) of L-protein clains" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "pdb = mmtfReader.read_sequence_file('../../resources/mmtf_reduced_sample/', sc) \\\n", " .flatMap(StructureToPolymerChains()) \\\n", " .filter(Pisces(sequenceIdentity=40,resolution=3.0))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Get secondary structure content" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "data = secondaryStructureExtractor.get_dataset(pdb)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Define addProteinFoldType function" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "def add_protein_fold_type(data, minThreshold, maxThreshold):\n", " '''\n", " Adds a column \"foldType\" with three major secondary structure class:\n", " \"alpha\", \"beta\", \"alpha+beta\", and \"other\" based upon the fraction of alpha/beta content.\n", "\n", " The simplified syntax used in this method relies on two imports:\n", " from pyspark.sql.functions import when\n", " from pyspark.sql.functions import col\n", "\n", " Attributes:\n", " data (Dataset): input dataset with alpha, beta composition\n", " minThreshold (float): below this threshold, the secondary structure is ignored\n", " maxThreshold (float): above this threshold, the secondary structure is ignored\n", " '''\n", "\n", " return data.withColumn(\"foldType\", \\\n", " when((col(\"alpha\") > maxThreshold) & (col(\"beta\") < minThreshold), \"alpha\"). \\\n", " when((col(\"beta\") > maxThreshold) & (col(\"alpha\") < minThreshold), \"beta\"). \\\n", " when((col(\"alpha\") > maxThreshold) & (col(\"beta\") > maxThreshold), \"alpha+beta\"). \\\n", " otherwise(\"other\")\\\n", " )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Classify chains by secondary structure type" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "data = add_protein_fold_type(data, minThreshold=0.05, maxThreshold=0.15)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Create a Word2Vec representation of the protein sequences\n", "\n", "**n = 2** # create 2-grams \n", "\n", "**windowSize = 25** # 25-amino residue window size for Word2Vector\n", "\n", "**vectorSize = 50** # dimension of feature vector" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
structureChainIdsequencealphabetacoildsspQ8CodedsspQ3CodefoldTypengramfeatures
01RCQ.AMRPARALIDLQALRHNYRLAREATGARALAVIKADAYGHGAVRCAE...0.3165270.2408960.442577CCCCEEEEEHHHHHHHHHHHHHHHCSEEEEECHHHHHTTCHHHHHH...CCCCEEEEEHHHHHHHHHHHHHHHCCEEEEECHHHHHCCCHHHHHH...alpha+beta[MR, RP, PA, AR, RA, AL, LI, ID, DL, LQ, QA, A...[0.22282994887529967, -0.20568346063700618, -0...
11REG.YMIEITLKKPEDFLKVKETLTRMGIANNKDKVLYQSCHILQKKGLYY...0.3083330.2916670.400000CEEEECSSGGHHHHHHHHHTTEEEEETTTTEEEECEEEEEETTEEE...CEEEECCCHHHHHHHHHHHCCEEEEECCCCEEEECEEEEEECCEEE...alpha+beta[MI, IE, EI, IT, TL, LK, KK, KP, PE, ED, DF, F...[-0.4225819193534861, -0.0816098772420371, -0....
21REQ.BSSTDQGTNPADTDDLTPTTLSLAGDFPKATEEQWEREVEKVLNRGR...0.4701130.1211630.408724XXXXXXXXXXXXXXXXXXCCCSGGGSCCCCHHHHHHHHHHHHHTTC...XXXXXXXXXXXXXXXXXXCCCCHHHCCCCCHHHHHHHHHHHHHCCC...alpha+beta[SS, ST, TD, DQ, QG, GT, TN, NP, PA, AD, DT, T...[0.013261343847444785, -0.16321914651542435, -...
31RFE.AGTKQRADIVMSEAEIADFVNSSRTGTLATIGPDGQPHLTAMWYAVI...0.3125000.3562500.331250XCCCCTTTCCCHHHHHHHHHHCCCEEEEEECTTSCEEEEEECCEEE...XCCCCCCCCCCHHHHHHHHHHCCCEEEEEECCCCCEEEEEECCEEE...alpha+beta[GT, TK, KQ, QR, RA, AD, DI, IV, VM, MS, SE, E...[-0.001911293541700203, -0.26975917786082126, ...
41RG8.BHHHHHHFNLPPGNYKKPKLLYCSNGGHFLRILPDGTVDGTRDRSDQ...0.0638300.3758870.560284XXCCSCCCCCSCCSSSCEEEEETTTTEEEEECTTSCEEEESCTTCT...XXCCCCCCCCCCCCCCCEEEEECCCCEEEEECCCCCEEEECCCCCC...other[HH, HH, HH, HH, HH, HF, FN, NL, LP, PP, PG, G...[-0.3250424425727848, 0.34787580900151155, 0.2...
\n", "
" ], "text/plain": [ " structureChainId sequence \\\n", "0 1RCQ.A MRPARALIDLQALRHNYRLAREATGARALAVIKADAYGHGAVRCAE... \n", "1 1REG.Y MIEITLKKPEDFLKVKETLTRMGIANNKDKVLYQSCHILQKKGLYY... \n", "2 1REQ.B SSTDQGTNPADTDDLTPTTLSLAGDFPKATEEQWEREVEKVLNRGR... \n", "3 1RFE.A GTKQRADIVMSEAEIADFVNSSRTGTLATIGPDGQPHLTAMWYAVI... \n", "4 1RG8.B HHHHHHFNLPPGNYKKPKLLYCSNGGHFLRILPDGTVDGTRDRSDQ... \n", "\n", " alpha beta coil \\\n", "0 0.316527 0.240896 0.442577 \n", "1 0.308333 0.291667 0.400000 \n", "2 0.470113 0.121163 0.408724 \n", "3 0.312500 0.356250 0.331250 \n", "4 0.063830 0.375887 0.560284 \n", "\n", " dsspQ8Code \\\n", "0 CCCCEEEEEHHHHHHHHHHHHHHHCSEEEEECHHHHHTTCHHHHHH... \n", "1 CEEEECSSGGHHHHHHHHHTTEEEEETTTTEEEECEEEEEETTEEE... \n", "2 XXXXXXXXXXXXXXXXXXCCCSGGGSCCCCHHHHHHHHHHHHHTTC... \n", "3 XCCCCTTTCCCHHHHHHHHHHCCCEEEEEECTTSCEEEEEECCEEE... \n", "4 XXCCSCCCCCSCCSSSCEEEEETTTTEEEEECTTSCEEEESCTTCT... \n", "\n", " dsspQ3Code foldType \\\n", "0 CCCCEEEEEHHHHHHHHHHHHHHHCCEEEEECHHHHHCCCHHHHHH... alpha+beta \n", "1 CEEEECCCHHHHHHHHHHHCCEEEEECCCCEEEECEEEEEECCEEE... alpha+beta \n", "2 XXXXXXXXXXXXXXXXXXCCCCHHHCCCCCHHHHHHHHHHHHHCCC... alpha+beta \n", "3 XCCCCCCCCCCHHHHHHHHHHCCCEEEEEECCCCCEEEEEECCEEE... alpha+beta \n", "4 XXCCCCCCCCCCCCCCCEEEEECCCCEEEEECCCCCEEEECCCCCC... other \n", "\n", " ngram \\\n", "0 [MR, RP, PA, AR, RA, AL, LI, ID, DL, LQ, QA, A... \n", "1 [MI, IE, EI, IT, TL, LK, KK, KP, PE, ED, DF, F... \n", "2 [SS, ST, TD, DQ, QG, GT, TN, NP, PA, AD, DT, T... \n", "3 [GT, TK, KQ, QR, RA, AD, DI, IV, VM, MS, SE, E... \n", "4 [HH, HH, HH, HH, HH, HF, FN, NL, LP, PP, PG, G... \n", "\n", " features \n", "0 [0.22282994887529967, -0.20568346063700618, -0... \n", "1 [-0.4225819193534861, -0.0816098772420371, -0.... \n", "2 [0.013261343847444785, -0.16321914651542435, -... \n", "3 [-0.001911293541700203, -0.26975917786082126, ... \n", "4 [-0.3250424425727848, 0.34787580900151155, 0.2... " ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "encoder = ProteinSequenceEncoder(data)\n", "data = encoder.overlapping_ngram_word2vec_encode(n=2, windowSize=25, vectorSize=50).cache()\n", "\n", "data.toPandas().head(5)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Keep only a subset of relevant fields for further processing" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "data = data.select(['structureChainId','alpha','beta','coil','foldType','features'])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Select only alpha and beta foldType to parquet file" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Total number of data: 2584\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
structureChainIdalphabetacoilfoldTypefeatures
01RI6.A0.0180180.5525530.429429beta[-0.09249431734948217, 0.09015498735141335, -0...
11RJU.V0.1666670.0000000.833333alpha[-0.2477414113602468, 0.4224771835974284, -0.2...
21RK8.C0.0000000.3939390.606061beta[-0.18497365634692342, -0.04471376525205478, -...
31RKT.B0.7950000.0000000.205000alpha[-0.23382538002824374, -0.11802027330679052, -...
41RR7.A0.7021280.0000000.297872alpha[-0.08483699508360587, 0.024782998094451614, 0...
\n", "
" ], "text/plain": [ " structureChainId alpha beta coil foldType \\\n", "0 1RI6.A 0.018018 0.552553 0.429429 beta \n", "1 1RJU.V 0.166667 0.000000 0.833333 alpha \n", "2 1RK8.C 0.000000 0.393939 0.606061 beta \n", "3 1RKT.B 0.795000 0.000000 0.205000 alpha \n", "4 1RR7.A 0.702128 0.000000 0.297872 alpha \n", "\n", " features \n", "0 [-0.09249431734948217, 0.09015498735141335, -0... \n", "1 [-0.2477414113602468, 0.4224771835974284, -0.2... \n", "2 [-0.18497365634692342, -0.04471376525205478, -... \n", "3 [-0.23382538002824374, -0.11802027330679052, -... \n", "4 [-0.08483699508360587, 0.024782998094451614, 0... " ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data = data.where((data.foldType == 'alpha') | (data.foldType == 'beta')) #| (data.foldType == 'other'))\n", "\n", "print(f\"Total number of data: {data.count()}\")\n", "data.toPandas().head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Basic dataset information and setting" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Feature count : 50\n", "Class count : 2\n", "Dataset size (unbalanced) : 2584\n", "+--------+-----+\n", "|foldType|count|\n", "+--------+-----+\n", "| beta| 660|\n", "| alpha| 1924|\n", "+--------+-----+\n", "\n", "Dataset size (balanced) : 1342\n", "+--------+-----+\n", "|foldType|count|\n", "+--------+-----+\n", "| beta| 660|\n", "| alpha| 682|\n", "+--------+-----+\n", "\n" ] } ], "source": [ "label = 'foldType'\n", "testFraction = 0.1\n", "seed = 123\n", "\n", "vector = data.first()[\"features\"]\n", "featureCount = len(vector)\n", "print(f\"Feature count : {featureCount}\")\n", " \n", "classCount = int(data.select(label).distinct().count())\n", "print(f\"Class count : {classCount}\")\n", "\n", "print(f\"Dataset size (unbalanced) : {data.count()}\")\n", " \n", "data.groupby(label).count().show(classCount)\n", "data = datasetBalancer.downsample(data, label, 1)\n", "print(f\"Dataset size (balanced) : {data.count()}\")\n", " \n", "data.groupby(label).count().show(classCount)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Decision Tree Classifier" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", " Class\tTrain\tTest\n", "alpha\t607\t75\n", "beta\t608\t52\n", "\n", "Sample predictions: DecisionTreeClassifier\n", "+----------------+-----------+-----------+----------+--------+--------------------+------------+-------------+--------------------+----------+--------------+\n", "|structureChainId| alpha| beta| coil|foldType| features|indexedLabel|rawPrediction| probability|prediction|predictedLabel|\n", "+----------------+-----------+-----------+----------+--------+--------------------+------------+-------------+--------------------+----------+--------------+\n", "| 5DI0.B|0.018867925| 0.6415094|0.33962265| beta|[-0.2703873959570...| 1.0| [13.0,377.0]|[0.03333333333333...| 1.0| beta|\n", "| 3BWU.F| 0.0234375| 0.4921875| 0.484375| beta|[0.12650794831784...| 1.0| [5.0,16.0]|[0.23809523809523...| 1.0| beta|\n", "| 3X0T.A|0.026785715| 0.51785713|0.45535713| beta|[-0.1733142428289...| 1.0| [13.0,377.0]|[0.03333333333333...| 1.0| beta|\n", "| 2QF4.A| 0.01764706| 0.5117647|0.47058824| beta|[-0.2368620687017...| 1.0| [13.0,377.0]|[0.03333333333333...| 1.0| beta|\n", "| 4OUS.A| 0.0| 0.5681818| 0.4318182| beta|[0.03226048785301...| 1.0| [91.0,13.0]| [0.875,0.125]| 0.0| alpha|\n", "| 1XAW.A| 0.8224299| 0.0|0.17757009| alpha|[-0.4501447076875...| 0.0| [320.0,3.0]|[0.99071207430340...| 0.0| alpha|\n", "| 3BRV.C| 0.78571427| 0.0|0.21428572| alpha|[0.14613234782789...| 0.0| [320.0,3.0]|[0.99071207430340...| 0.0| alpha|\n", "| 4MU6.A| 0.8540146| 0.0| 0.1459854| alpha|[-0.2809790735148...| 0.0| [320.0,3.0]|[0.99071207430340...| 0.0| alpha|\n", "| 4YK2.A| 0.7383177|0.037383176|0.22429906| alpha|[-0.3205931102219...| 0.0| [320.0,3.0]|[0.99071207430340...| 0.0| alpha|\n", "| 4I17.A| 0.8035714| 0.0|0.19642857| alpha|[-0.2930700773053...| 0.0| [0.0,2.0]| [0.0,1.0]| 1.0| beta|\n", "| 5CWG.A| 0.86153847| 0.0|0.13846155| alpha|[0.11841576919850...| 0.0| [320.0,3.0]|[0.99071207430340...| 0.0| alpha|\n", "| 4ZP0.A| 0.84438777| 0.0|0.15561225| alpha|[0.22175398589971...| 0.0| [91.0,13.0]| [0.875,0.125]| 0.0| alpha|\n", "| 1T6O.B| 0.8947368| 0.0|0.10526316| alpha|[0.67486972636298...| 0.0| [320.0,3.0]|[0.99071207430340...| 0.0| alpha|\n", "| 5H2F.T| 0.7| 0.0| 0.3| alpha|[-0.3641495476637...| 0.0| [12.0,63.0]| [0.16,0.84]| 1.0| beta|\n", "| 3S0A.A| 0.6386555|0.016806724|0.34453782| alpha|[-0.6068946828008...| 0.0| [31.0,0.0]| [1.0,0.0]| 0.0| alpha|\n", "| 5C8G.B| 0.66101694| 0.0|0.33898306| alpha|[-0.0469646227994...| 0.0| [320.0,3.0]|[0.99071207430340...| 0.0| alpha|\n", "+----------------+-----------+-----------+----------+--------+--------------------+------------+-------------+--------------------+----------+--------------+\n", "\n", "Total time taken: 3.21268630027771\n", "Method\tDecisionTreeClassifier\n", "F\t0.8039005137708921\n", "Accuracy\t0.8031496062992126\n", "Precision\t0.8055535671677405\n", "Recall\t0.8031496062992125\n", "False Positive Rase\t0.2013547345043408\n", "True Positive Rate\t0.8031496062992125\n", "\t\n", "Confusion Matrix\n", "['alpha', 'beta']\n", "DenseMatrix([[61., 14.],\n", " [11., 41.]])\n" ] } ], "source": [ "dtc = DecisionTreeClassifier()\n", "mcc = SparkMultiClassClassifier(dtc, label, testFraction, seed)\n", "matrics = mcc.fit(data)\n", "for k,v in matrics.items(): print(f\"{k}\\t{v}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Random Forest Classifier" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", " Class\tTrain\tTest\n", "alpha\t607\t75\n", "beta\t608\t52\n", "\n", "Sample predictions: RandomForestClassifier\n", "+----------------+-----------+-----------+----------+--------+--------------------+------------+--------------------+--------------------+----------+--------------+\n", "|structureChainId| alpha| beta| coil|foldType| features|indexedLabel| rawPrediction| probability|prediction|predictedLabel|\n", "+----------------+-----------+-----------+----------+--------+--------------------+------------+--------------------+--------------------+----------+--------------+\n", "| 5DI0.B|0.018867925| 0.6415094|0.33962265| beta|[-0.2703873959570...| 1.0|[0.51578998431281...|[0.02578949921564...| 1.0| beta|\n", "| 3BWU.F| 0.0234375| 0.4921875| 0.484375| beta|[0.12650794831784...| 1.0|[13.9982189640870...|[0.69991094820435...| 0.0| alpha|\n", "| 3X0T.A|0.026785715| 0.51785713|0.45535713| beta|[-0.1733142428289...| 1.0|[2.13188897197973...|[0.10659444859898...| 1.0| beta|\n", "| 2QF4.A| 0.01764706| 0.5117647|0.47058824| beta|[-0.2368620687017...| 1.0|[2.75569007720908...|[0.13778450386045...| 1.0| beta|\n", "| 4OUS.A| 0.0| 0.5681818| 0.4318182| beta|[0.03226048785301...| 1.0|[9.85487531968060...|[0.49274376598403...| 1.0| beta|\n", "| 1XAW.A| 0.8224299| 0.0|0.17757009| alpha|[-0.4501447076875...| 0.0|[18.2154681085900...|[0.91077340542950...| 0.0| alpha|\n", "| 3BRV.C| 0.78571427| 0.0|0.21428572| alpha|[0.14613234782789...| 0.0|[18.2037992417524...|[0.91018996208762...| 0.0| alpha|\n", "| 4MU6.A| 0.8540146| 0.0| 0.1459854| alpha|[-0.2809790735148...| 0.0|[19.4849307860132...|[0.97424653930066...| 0.0| alpha|\n", "| 4YK2.A| 0.7383177|0.037383176|0.22429906| alpha|[-0.3205931102219...| 0.0|[16.8241875964511...|[0.84120937982255...| 0.0| alpha|\n", "| 4I17.A| 0.8035714| 0.0|0.19642857| alpha|[-0.2930700773053...| 0.0|[13.1749067714244...|[0.65874533857122...| 0.0| alpha|\n", "| 5CWG.A| 0.86153847| 0.0|0.13846155| alpha|[0.11841576919850...| 0.0|[19.4353866362187...|[0.97176933181093...| 0.0| alpha|\n", "| 4ZP0.A| 0.84438777| 0.0|0.15561225| alpha|[0.22175398589971...| 0.0|[17.3230275572504...|[0.86615137786252...| 0.0| alpha|\n", "| 1T6O.B| 0.8947368| 0.0|0.10526316| alpha|[0.67486972636298...| 0.0|[18.6141301097890...|[0.93070650548945...| 0.0| alpha|\n", "| 5H2F.T| 0.7| 0.0| 0.3| alpha|[-0.3641495476637...| 0.0|[8.24401892548523...|[0.41220094627426...| 1.0| beta|\n", "| 3S0A.A| 0.6386555|0.016806724|0.34453782| alpha|[-0.6068946828008...| 0.0|[11.5615661468762...|[0.57807830734381...| 0.0| alpha|\n", "| 5C8G.B| 0.66101694| 0.0|0.33898306| alpha|[-0.0469646227994...| 0.0|[16.9153051488506...|[0.84576525744253...| 0.0| alpha|\n", "+----------------+-----------+-----------+----------+--------+--------------------+------------+--------------------+--------------------+----------+--------------+\n", "\n", "Total time taken: 4.367412090301514\n", "Method\tRandomForestClassifier\n", "F\t0.8512456189733639\n", "Accuracy\t0.8503937007874016\n", "Precision\t0.8547451305428928\n", "Recall\t0.8503937007874016\n", "False Positive Rase\t0.14500908540278618\n", "True Positive Rate\t0.8503937007874016\n", "\t\n", "Confusion Matrix\n", "['alpha', 'beta']\n", "DenseMatrix([[63., 12.],\n", " [ 7., 45.]])\n" ] } ], "source": [ "rfc = RandomForestClassifier()\n", "mcc = SparkMultiClassClassifier(rfc, label, testFraction, seed)\n", "matrics = mcc.fit(data)\n", "for k,v in matrics.items(): print(f\"{k}\\t{v}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Logistic Regression Classifier" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", " Class\tTrain\tTest\n", "alpha\t607\t75\n", "beta\t608\t52\n", "\n", "Sample predictions: LogisticRegression\n", "+----------------+-----------+-----------+----------+--------+--------------------+------------+--------------------+--------------------+----------+--------------+\n", "|structureChainId| alpha| beta| coil|foldType| features|indexedLabel| rawPrediction| probability|prediction|predictedLabel|\n", "+----------------+-----------+-----------+----------+--------+--------------------+------------+--------------------+--------------------+----------+--------------+\n", "| 5DI0.B|0.018867925| 0.6415094|0.33962265| beta|[-0.2703873959570...| 1.0|[-5.0623275294473...|[0.00629098022815...| 1.0| beta|\n", "| 3BWU.F| 0.0234375| 0.4921875| 0.484375| beta|[0.12650794831784...| 1.0|[-1.1865565882185...|[0.23387535279941...| 1.0| beta|\n", "| 3X0T.A|0.026785715| 0.51785713|0.45535713| beta|[-0.1733142428289...| 1.0|[-5.2476892464597...|[0.00523213887961...| 1.0| beta|\n", "| 2QF4.A| 0.01764706| 0.5117647|0.47058824| beta|[-0.2368620687017...| 1.0|[-4.5674722465786...|[0.01027745233916...| 1.0| beta|\n", "| 4OUS.A| 0.0| 0.5681818| 0.4318182| beta|[0.03226048785301...| 1.0|[1.86554161773270...|[0.86594156213690...| 0.0| alpha|\n", "| 1XAW.A| 0.8224299| 0.0|0.17757009| alpha|[-0.4501447076875...| 0.0|[2.45521137787059...|[0.92094171451502...| 0.0| alpha|\n", "| 3BRV.C| 0.78571427| 0.0|0.21428572| alpha|[0.14613234782789...| 0.0|[3.79818850368450...|[0.97807992495051...| 0.0| alpha|\n", "| 4MU6.A| 0.8540146| 0.0| 0.1459854| alpha|[-0.2809790735148...| 0.0|[5.18188108968423...|[0.99441394966369...| 0.0| alpha|\n", "| 4YK2.A| 0.7383177|0.037383176|0.22429906| alpha|[-0.3205931102219...| 0.0|[1.90968820922997...|[0.87098411573233...| 0.0| alpha|\n", "| 4I17.A| 0.8035714| 0.0|0.19642857| alpha|[-0.2930700773053...| 0.0|[3.98665134253246...|[0.98177649332300...| 0.0| alpha|\n", "| 5CWG.A| 0.86153847| 0.0|0.13846155| alpha|[0.11841576919850...| 0.0|[5.79842445508542...|[0.99697683866489...| 0.0| alpha|\n", "| 4ZP0.A| 0.84438777| 0.0|0.15561225| alpha|[0.22175398589971...| 0.0|[4.00761355513002...|[0.98214777367986...| 0.0| alpha|\n", "| 1T6O.B| 0.8947368| 0.0|0.10526316| alpha|[0.67486972636298...| 0.0|[11.3978061720542...|[0.99998878005311...| 0.0| alpha|\n", "| 5H2F.T| 0.7| 0.0| 0.3| alpha|[-0.3641495476637...| 0.0|[-1.9088191172433...|[0.12911357630884...| 1.0| beta|\n", "| 3S0A.A| 0.6386555|0.016806724|0.34453782| alpha|[-0.6068946828008...| 0.0|[2.08239502966998...|[0.88918025653597...| 0.0| alpha|\n", "| 5C8G.B| 0.66101694| 0.0|0.33898306| alpha|[-0.0469646227994...| 0.0|[2.88630893525889...|[0.94716547404562...| 0.0| alpha|\n", "+----------------+-----------+-----------+----------+--------+--------------------+------------+--------------------+--------------------+----------+--------------+\n", "\n", "Total time taken: 7.353407859802246\n", "Method\tLogisticRegression\n", "F\t0.8825623307684451\n", "Accuracy\t0.8818897637795275\n", "Precision\t0.8859846466560102\n", "Recall\t0.8818897637795275\n", "False Positive Rase\t0.11137694326670705\n", "True Positive Rate\t0.8818897637795275\n", "\t\n", "Confusion Matrix\n", "['alpha', 'beta']\n", "DenseMatrix([[65., 10.],\n", " [ 5., 47.]])\n" ] } ], "source": [ "lr = LogisticRegression()\n", "mcc = SparkMultiClassClassifier(lr, label, testFraction, seed)\n", "matrics = mcc.fit(data)\n", "for k,v in matrics.items(): print(f\"{k}\\t{v}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Simple Multilayer Perception Classifier" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", " Class\tTrain\tTest\n", "alpha\t607\t75\n", "beta\t608\t52\n", "\n", "Sample predictions: MultilayerPerceptronClassifier\n", "+----------------+-----------+-----------+----------+--------+--------------------+------------+----------+--------------+\n", "|structureChainId| alpha| beta| coil|foldType| features|indexedLabel|prediction|predictedLabel|\n", "+----------------+-----------+-----------+----------+--------+--------------------+------------+----------+--------------+\n", "| 5DI0.B|0.018867925| 0.6415094|0.33962265| beta|[-0.2703873959570...| 1.0| 1.0| beta|\n", "| 3BWU.F| 0.0234375| 0.4921875| 0.484375| beta|[0.12650794831784...| 1.0| 1.0| beta|\n", "| 3X0T.A|0.026785715| 0.51785713|0.45535713| beta|[-0.1733142428289...| 1.0| 1.0| beta|\n", "| 2QF4.A| 0.01764706| 0.5117647|0.47058824| beta|[-0.2368620687017...| 1.0| 1.0| beta|\n", "| 4OUS.A| 0.0| 0.5681818| 0.4318182| beta|[0.03226048785301...| 1.0| 0.0| alpha|\n", "| 1XAW.A| 0.8224299| 0.0|0.17757009| alpha|[-0.4501447076875...| 0.0| 0.0| alpha|\n", "| 3BRV.C| 0.78571427| 0.0|0.21428572| alpha|[0.14613234782789...| 0.0| 0.0| alpha|\n", "| 4MU6.A| 0.8540146| 0.0| 0.1459854| alpha|[-0.2809790735148...| 0.0| 0.0| alpha|\n", "| 4YK2.A| 0.7383177|0.037383176|0.22429906| alpha|[-0.3205931102219...| 0.0| 0.0| alpha|\n", "| 4I17.A| 0.8035714| 0.0|0.19642857| alpha|[-0.2930700773053...| 0.0| 0.0| alpha|\n", "| 5CWG.A| 0.86153847| 0.0|0.13846155| alpha|[0.11841576919850...| 0.0| 0.0| alpha|\n", "| 4ZP0.A| 0.84438777| 0.0|0.15561225| alpha|[0.22175398589971...| 0.0| 0.0| alpha|\n", "| 1T6O.B| 0.8947368| 0.0|0.10526316| alpha|[0.67486972636298...| 0.0| 0.0| alpha|\n", "| 5H2F.T| 0.7| 0.0| 0.3| alpha|[-0.3641495476637...| 0.0| 0.0| alpha|\n", "| 3S0A.A| 0.6386555|0.016806724|0.34453782| alpha|[-0.6068946828008...| 0.0| 0.0| alpha|\n", "| 5C8G.B| 0.66101694| 0.0|0.33898306| alpha|[-0.0469646227994...| 0.0| 0.0| alpha|\n", "+----------------+-----------+-----------+----------+--------+--------------------+------------+----------+--------------+\n", "\n", "Total time taken: 13.364720821380615\n", "Method\tMultilayerPerceptronClassifier\n", "F\t0.8654660741664717\n", "Accuracy\t0.8661417322834646\n", "Precision\t0.8657956217011336\n", "Recall\t0.8661417322834646\n", "False Positive Rase\t0.1517827579244902\n", "True Positive Rate\t0.8661417322834646\n", "\t\n", "Confusion Matrix\n", "['alpha', 'beta']\n", "DenseMatrix([[68., 7.],\n", " [10., 42.]])\n" ] } ], "source": [ "layers = [featureCount, 64, 64, classCount]\n", "mpc = MultilayerPerceptronClassifier().setLayers(layers) \\\n", " .setBlockSize(128) \\\n", " .setSeed(1234) \\\n", " .setMaxIter(100)\n", "mcc = SparkMultiClassClassifier(mpc, label, testFraction, seed)\n", "matrics = mcc.fit(data)\n", "for k,v in matrics.items(): print(f\"{k}\\t{v}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Terminate Spark" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [], "source": [ "sc.stop()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.0" } }, "nbformat": 4, "nbformat_minor": 2 }