Source code for mmtfPyspark.ml.pythonRDDToDataset

#!/user/bin/env python
'''pythonRDDToDataset.py:

This class converts a PythonRDD<Row> to a Dataset<Row>. This method only
supports simple data types and all data need to be not null.

'''
__author__ = "Mars (Shih-Cheng) Huang"
__maintainer__ = "Mars (Shih-Cheng) Huang"
__email__ = "marshuang80@gmail.com"
__version__ = "0.2.0"
__status__ = "Done"

from pyspark.sql.types import *
from pyspark.sql import SparkSession

[docs]def get_dataset(data, colNames): '''Converts a PythonRDD<Row> to a Dataset<Row>. This method only supports simple data types and all data need to be not null. Parameters ---------- data : PythonRDD PythonRDD of row objects colNames : list names of the columns in a row ''' row = data.first() length = len(row) if length != len(colNames): raise Exception("colNames length does not match row length") sf = [] for i in range(len(colNames)): o = row[i] if type(o) == str: sf.append(StructField(colNames[i], StringType(), False)) elif type(o) == int: sf.append(StructField(colNames[i], IntegerType(), False)) elif type(o) == float: sf.append(StructField(colNames[i], FloatType(), False)) elif type(o) == long: sf.append(StructField(colNames[i], LongType(), False)) else: print("Data type not implemented yet") schema = StructType(sf) spark = SparkSession.builder.getOrCreate() return spark.createDataFrame(data, schema)