Source code for mmtfPyspark.interactions.polymerInteractionFingerprint

#!/user/bin/env python
'''polymerInteractionFIngerprint.py

Finds interactions between polymer chains and maps them onto polymer sequences.

'''
__author__ = "Mars (Shih-Cheng) Huang"
__maintainer__ = "Mars (Shih-Cheng) Huang"
__email__ = "marshuang80@gmail.com"
__version__ = "0.2.0"
__status__ = "done"

from mmtfPyspark.interactions import InteractionFilter, AtomInteraction, InteractionCenter
from mmtfPyspark.utils import ColumnarStructure
from mmtfPyspark.utils import DistanceBox
from pyspark.sql import Row
import numpy as np


[docs]class PolymerInteractionFingerprint(object): def __init__(self, interactionFilter): self.filter = interactionFilter def __call__(self, t): structureId = t[0] structure = t[1] return self.get_interactions(structureId, structure)
[docs] def get_interactions(self, structureId, structure): rows = [] cutoffDistanceSquared = self.filter.get_distance_cutoff() ** 2 arrays = ColumnarStructure(structure, True) chainNames = arrays.get_chain_names() groupNames = arrays.get_group_names() groupNumbers = arrays.get_group_numbers() atomNames = arrays.get_atom_names() entityIndices = arrays.get_entity_indices() elements = arrays.get_elements() polymer = arrays.is_polymer() sequenceMapIndices = arrays.get_sequence_positions() x = arrays.get_x_coords() y = arrays.get_y_coords() z = arrays.get_z_coords() # create a distance box for quick lookup interactions of polymer atoms # of the specified elements boxes = {} for i in range(arrays.get_num_atoms()): if polymer[i] \ and (self.filter.is_target_group(groupNames[i]) or self.filter.is_query_group(groupNames[i])) \ and (self.filter.is_target_atom_name(atomNames[i]) or self.filter.is_query_atom_name(atomNames[i])) \ and (self.filter.is_target_element(elements[i]) or self.filter_is_query_element_name(elements[i])) \ and not self.filter.is_prohibited_target_group(groupNames[i]): if chainNames[i] not in boxes: box = DistanceBox(self.filter.get_distance_cutoff()) boxes[chainNames[i]] = box newPoint = np.array([x[i],y[i],z[i]]) boxes[chainNames[i]].add_point(newPoint,i) chainBoxes = [(k,v) for k,v in boxes.items()] # loop over all pairwise polymer chain interactions for i in range(len(chainBoxes) - 1): chainI = chainBoxes[i][0] boxI = chainBoxes[i][1] for j in range(i+1, len(chainBoxes)): chainJ = chainBoxes[j][0] boxJ = chainBoxes[j][1] intersectionI = boxI.getIntersection(boxJ) intersectionJ = boxJ.getIntersection(boxI) # maps to store sequence indices mapped to group numbers indicesI = {} indicesJ = {} entityIndexI = -1 entityIndexJ = -1 # loop over pairs of atom interactions and check if # they satisfy the interaction filter criteria for n in intersectionI: for m in intersectionJ: dx = x[n] - x[m] dy = y[n] - y[m] dz = z[n] - z[m] dSq = dx * dx + dy * dy + dz * dz if dSq <= cutoffDistanceSquared: if self.filter.is_target_group(groupNames[n]) \ and self.filter.is_target_atom_name(atomNames[n]) \ and self.filter.is_target_element(elements[n]) \ and self.filter.is_query_group(groupNames[m]) \ and self.filter.is_query_atom_name(atomNames[m]) \ and self.filter.is_query_element(elements[m]): entityIndexI = entityIndices[n] indicesI[sequenceMapIndices[n]] = groupNumbers[n] if self.filter.is_target_group(groupNames[m]) \ and self.filter.is_target_atom_name(atomNames[m]) \ and self.filter.is_target_element(elements[m]) \ and self.filter.is_query_group(groupNames[n]) \ and self.filter.is_query_atom_name(atomNames[n]) \ and self.filter.is_query_element(elements[n]): entityIndexJ = entityIndices[m] indicesJ[sequenceMapIndices[m]] = groupNumbers[m] if len(indicesI) >= self.filter.get_min_interactions(): sequenceIndiciesI = sorted([int(i) for i in indicesI.keys()]) groupNumbersI = sorted(list(indicesI.values())) rows.append(Row(structureId + '.' + chainI, chainJ, chainI, \ groupNumbersI, sequenceIndiciesI, \ structure.entity_list[entityIndexI]['sequence'])) if len(indicesJ) >= self.filter.get_min_interactions(): sequenceIndiciesJ = sorted([int(i) for i in indicesJ.keys()]) groupNumbersJ = sorted(list(indicesJ.values())) rows.append(Row(structureId + '.' + chainJ, chainI, chainJ, \ groupNumbersJ, sequenceIndiciesJ, \ structure.entity_list[entityIndexJ]['sequence'])) return rows