Source code for mmtfPyspark.interactions.structureToAtomInteractions

#!/user/bin/env python
'''structureToAtomInteraction.py

Finds interactions that match the criteria specified by the InteractionFilter

'''
__author__ = "Mars (Shih-Cheng) Huang"
__maintainer__ = "Mars (Shih-Cheng) Huang"
__email__ = "marshuang80@gmail.com"
__version__ = "0.2.0"
__status__ = "done"

from mmtfPyspark.interactions import InteractionFilter, AtomInteraction, InteractionCenter
from mmtfPyspark.utils import ColumnarStructureX
from mmtfPyspark.utils import DistanceBox
import numpy as np


[docs]class StructureToAtomInteractions(object): '''Class that finds structure to atom intteractions. Attributes ---------- bfilter : obj Specifies the conditions for calculating interactions pairwise : bool If True, results as one row per pair interactions. If False, the interactions of one atom with all other atoms are returned as a single row. ''' def __init__(self, bfilter, pairwise=False): self.filter = bfilter.value self.pairwise = pairwise def __call__(self, t): interactions = [] structureId = t[0] # convert structure to an array-based format for efficient processing arrays = ColumnarStructureX(t[1], True) # create a list of query atoms for which interactions should be calculated queryAtomIndices = self._get_query_atom_indices(arrays) if len(queryAtomIndices) == 0: return interactions # Add atom (indices) on grid for rapid indexing of atom neighbors on a # grid based on a cutoff distance box = self._get_distance_box(arrays) for queryAtomIndex in queryAtomIndices: # find interactions of query atom specified by atom index interaction = self._get_interactions(arrays, queryAtomIndex, box) interaction.set_structure_id(structureId) # only add interations that are within the given limits of interations if interaction.get_num_interactions() >= self.filter.get_min_interactions() \ and interaction.get_num_interactions() <= self.filter.get_max_interactions(): # return interactions as either pairs or all interaction of # one atom as a row if self.pairwise: interactions += interaction.get_pair_interactions_as_rows() else: multiInteract = interaction.get_multiple_interactions_as_row( self.filter.get_max_interactions()) interactions += multiInteract return interactions def _get_interactions(self, arrays, queryAtomIndex, box): '''Get the interacting neighbors of an atom in a structure Parameters ---------- arrays : columnarStructure structure in columnarStructure format queryAtomIndex : int the index of the querying atom box : distanceBox the distance box of the query atom Returns ------- AtomInteraction an AtomInteraction class with interacting neighbors ''' interaction = AtomInteraction() # get the x,y,z coordinates of the structure x = arrays.get_x_coords() y = arrays.get_y_coords() z = arrays.get_z_coords() # get the query atom coordinates qx = x[queryAtomIndex] qy = y[queryAtomIndex] qz = z[queryAtomIndex] # get required information of the columnarStructure atomToGroupIndices = arrays.get_atom_to_group_indices() occupancies = arrays.get_occupancies() normalizedbFactors = arrays.get_normalized_b_factors() groupNames = arrays.get_group_names() # record query atom info queryCenter = InteractionCenter(arrays, queryAtomIndex) interaction.set_center(queryCenter) # Retrieve atom indices of atoms that lay within grid cubes that are # within cutoff distance of the query atom cutoffDistanceSq = self.filter.get_distance_cutoff() ** 2 # Retrieve atom indices of atoms that lay within grid cubes # that are within cutoff distance of the query atom neighborIndices = box.get_neighbors(np.array([qx, qy, qz])) # TEST: flattern neighborIndices if type(neighborIndices[0]) == list: neighborIndices = [ n for neighbors in neighborIndices for n in neighbors] # determine and record interactions with neighbor atoms for neighborIndex in neighborIndices: # exclude self interactions with a group if atomToGroupIndices[neighborIndex] == atomToGroupIndices[queryAtomIndex]: continue # check if interaction is within distance cutoff dx = qx - x[neighborIndex] dy = qy - y[neighborIndex] dz = qz - z[neighborIndex] distSq = dx * dx + dy * dy + dz * dz if distSq <= cutoffDistanceSq: # Exclude interactions with undesired groups and # atoms with partial occupancy (< 1.0) if self.filter.is_prohibited_target_group(groupNames[neighborIndex]) \ or self.filter.get_normalized_b_factor_cutoff() < normalizedbFactors[neighborIndex] \ or occupancies[neighborIndex] < float(1.0): # return an empty atom interaction return AtomInteraction() # add interacting atom info neighbor = InteractionCenter(arrays, neighborIndex) interaction.add_neighbor(neighbor) # terminate early if the number of interactions exceeds limit if interaction.get_num_interactions() > self.filter.get_max_interactions(): return interaction return interaction def _get_distance_box(self, arrays): '''Add atom indices on grid for rapid indexing of atom neighbors on a grid based on a cutoff distance Parameters ---------- arrays : columnarStructure structure in columnarStructure format ''' # Get required data x = arrays.get_x_coords() y = arrays.get_y_coords() z = arrays.get_z_coords() elements = arrays.get_elements() groupNames = arrays.get_group_names() box = DistanceBox(self.filter.get_distance_cutoff()) for i in range(arrays.get_num_atoms()): if self.filter.is_target_group(groupNames[i]) \ and self.filter.is_target_element(elements[i]): newPoint = np.array([x[i], y[i], z[i]]) box.add_point(newPoint, i) return box def _get_query_atom_indices(self, arrays): '''Returns a list of indices to query atoms in the structure Parameters ---------- arrays : columnarStructure structure in columnarStructure format ''' # Get required data groupNames = arrays.get_group_names() elements = arrays.get_elements() groupStartIndices = arrays.get_group_to_atom_indices() occupancies = arrays.get_occupancies() normalizedbFactors = arrays.get_normalized_b_factors() # Find atoms that match the query criteria and exlcued atoms with # partial occupancy indices = [] for i in range(arrays.get_num_groups()): start = groupStartIndices[i] end = groupStartIndices[i + 1] if self.filter.is_query_group(groupNames[start]): indices += [j for j in range(start, end) if self.filter.is_query_element(elements[j]) and normalizedbFactors[j] < self.filter.get_normalized_b_factor_cutoff() and occupancies[j] >= 1.0] return indices