Source code for

#!/user/bin/env python

Creates a balanced dataset for classification problems by either
downsampling the majority classes or upsampling the  minority classes.
It randomly samples each class and returns a dataset with approximately
the same number of samples in each class

__author__ = "Mars (Shih-Cheng) Huang"
__maintainer__ = "Mars (Shih-Cheng) Huang"
__email__ = ""
__version__ = "0.2.0"
__status__ = "Done"

from pyspark.sql import DataFrame
from pyspark.sql import Row
from functools import reduce
import math

[docs]def downsample(data, columnName, seed=7): '''Returns a balanced dataset for the given column name by downsampling the majority classes. The classification column must be of type String Parameters ---------- data : Dataframe columnName : str column to be balanced by seed : int random number seed ''' counts = data.groupby(columnName).count().collect() count = [int(x[1]) for x in counts] names = [y[0] for y in counts] minCount = min(count) samples = [data.filter(columnName + "='%s'" % n) .sample(False, minCount / float(c), seed) for n, c in zip(names, count)] return reduce(lambda x, y: x.union(y), samples)
[docs]def upsample(data, columnName, seed=7): '''Returns a balanced dataset for the given column name by upsampling the majority classes. The classification column must be of type String Parameters ---------- data : Dataframe) columnName : str column to be balanced by seed : int random number seed ''' counts = data.groupby(columnName).count().collect() count = [int(x[1]) for x in counts] names = [y[0] for y in counts] maxCount = max(count) samples = [data.filter(columnName + "='%s'" % n) .sample(False, maxCount / float(c), seed) if abs(1 - maxCount / float(c)) > 1.0 else data.filter(columnName + "='%s'" % n) for n, c in zip(names, count) ] return reduce(lambda x, y: x.union(y), samples)