Source code for mmtfPyspark.webfilters.customReportQuery
'''customReportQuery.py
This filter runs an SQL query on specified PDB metadata and annotation fields retrived using
RCSB PDB RESTful web services. The fields are then queried and the resulting PDB IDs are
used to filter the data. The input to the filter consists of an SQL WHERE clause, and list
data columns availible from RCSB PDB web services.
References
----------
- List of supported field names: `reportFiled <http://www.rcsb.org/pdb/results/reportField.do>`_
- Examples of SQL WHERE clauses: `SQL where <https://www.w3schools.com/sql/sql_where.asp>`_
Examples
--------
Find PDB entries with Enzyme classification number 2.7.11.1
and source organism Homo sapiens:
>>> pdb = read_full_sequence_files(sc)
>>> whereClause = "WHERE ecNo='2.7.11.1' AND source='Homo sapiens'"
>>> pdb = pdb.filter(RcsbWebserviceFilter(whereClause, "ecNo","source"))
'''
__author__ = "Mars (Shih-Cheng) Huang"
__maintainer__ = "Mars (Shih-Cheng) Huang"
__email__ = "marshuang80@gmail.com"
__status__ = "Done"
from mmtfPyspark.datasets import customReportService
from pyspark.sql import SparkSession
[docs]class CustomReportQuery(object):
'''Filters using an SQL query on the specified fields
Attributes
----------
whereClause : str
WHERE Clause of SQL statement
fields : str, list
one or more field names to be used in query
'''
def __init__(self, whereClause, fields):
# Check if fields are in a list or string
if type(fields) == str:
if ',' in fields:
fields = fields.split(',')
else:
fields = [fields]
# Get requested data columns
dataset = customReportService.get_dataset(fields)
# Check if the results contain chain level data
self.chainLevel = "structureChainId" in dataset.columns
# Create a temporary view of the dataset
dataset.createOrReplaceTempView("table")
# Create SparkSession
spark = SparkSession.builder.getOrCreate()
# Run SQL query
if (self.chainLevel):
# For chain level data
sql = "SELECT structureChainID, structureId, chainId FROM table " \
+ whereClause
results = spark.sql(sql)
# Add both PDB entry and chain level data, so chain-based data can be filtered
self.pdbIds = results.distinct().rdd.map(lambda x: x[0]).collect()
self.pdbIds += results.distinct().rdd.map(lambda x: x[1]).collect()
else:
# For PDB entry level data
sql = "SELECT structureId FROM table " + whereCaluse
results = spark.sql(sql)
self.pdbIds = results.distinct().rdd.map(lambda x: x[0]).collect()
self.pdbIds = list(set(self.pdbIds))
def __call__(self, t):
match = t[0] in self.pdbIds
# If results are PDB IDs, but the keys contains chain names,
# Then truncate the chain name before matching (e.g., 4HHB.A -> 4HHB)
if (not self.chainLevel) and (not match) and (len(t[0]) > 4):
print(t[0])
return t[0][:4] in self.pdbIds
return match