Source code for mmtfPyspark.ml.datasetRegressor

#!/usr/bin/env python
'''datasetRegressor.py

Runs regression on a given dataset.
Dataset are read as Parquet file. The dataset must contain
a feature vector named "features" and a prediction column.
The column name of the prediction column must be specified
on the command lines.

'''
__author__ = "Mars (Shih-Cheng) Huang"
__maintainer__ = "Mars (Shih-Cheng) Huang"
__email__ = "marshuang80@gmail.com"
__version__ = "0.1"
__status__ = "Done"

from mmtfPyspark.ml import SparkRegressor, datasetBalancer
from pyspark.sql import SparkSession
from pyspark.ml.regression import GBTRegressor, GeneralizedLinearRegression, LinearRegression
import sys
import time

[docs]def main(argv): # Name of prediction column label = argv[1] start = time.time() spark = SparkSession.builder \ .master("local[*]") \ .appName("datasetRegressor") \ .getOrCreate() data = spark.read.parquet(argv[0]).cache() vector = data.first() print(vector) featureCount = len(vector) print("Feature count : {featureCount}") print("Dataset size (unbalanced) : {data.count()}") testFraction = 0.3 seed = 123 # Linear Regression lr = LinearRegression().setLabelCol(label) \ .setFeaturesCol("features") reg = SparkRegressor(lr, label, testFraction, seed) matrics = reg.fit(data) for k,v in matrics.items(): print(f"{k}\t{v}") # GBTRegressor gbt = GBTRegressor().setLabelCol(label) \ .setFeaturesCol("features") reg = SparkRegressor(gbt, label, testFraction, seed) matrics = reg.fit(data) for k,v in matrics.items(): print(f"{k}\t{v}") # GeneralizedLinearRegression glr = GeneralizedLinearRegression().setLabelCol(label) \ .setFeaturesCol("features") \ .setFamily("gaussian") \ .setLink("identity") \ .setMaxIter(10) \ .setRegParam(0.3) reg = SparkRegressor(glr, label, testFraction, seed) matrics = reg.fit(data) for k,v in matrics.items(): print(f"{k}\t{v}") end = time.time() print("Time: %f sec." %(end-start))
if __name__ == "__main__": if len(sys.argv) < 3: raise Exception("python datasetClassifier.py <parquet file> <prediction column name>") sys.exit() main(sys.argv[1:])