[BAEL-3634] Code for Spark DataFrame article (#12039)
* [BAEL-3634] Code for Spark DataFrame article * [BAEL-3634] Improve example data sample and sort aggregations * [BAEL-3634] change column name for clarity * [BAEL-3634] Update method name with U.S english standard Co-authored-by: uzma khan <uzma.khan@nominet.uk>
This commit is contained in:
@@ -0,0 +1,52 @@
|
||||
package com.baeldung.dataframes;
|
||||
|
||||
public class Customer {
|
||||
String id;
|
||||
String name;
|
||||
String gender;
|
||||
int transaction_amount;
|
||||
|
||||
public Customer() {
|
||||
|
||||
}
|
||||
|
||||
public Customer(String id, String name, String gender, int transaction_amount) {
|
||||
this.id = id;
|
||||
this.name = name;
|
||||
this.gender = gender;
|
||||
this.transaction_amount = transaction_amount;
|
||||
}
|
||||
|
||||
public String getId() {
|
||||
return id;
|
||||
}
|
||||
|
||||
public void setId(String id) {
|
||||
this.id = id;
|
||||
}
|
||||
|
||||
public String getName() {
|
||||
return name;
|
||||
}
|
||||
|
||||
public void setName(String name) {
|
||||
this.name = name;
|
||||
}
|
||||
|
||||
public String getGender() {
|
||||
return gender;
|
||||
}
|
||||
|
||||
public void setGender(String gender) {
|
||||
this.gender = gender;
|
||||
}
|
||||
|
||||
public int getTransaction_amount() {
|
||||
return transaction_amount;
|
||||
}
|
||||
|
||||
public void setTransaction_amount(int transaction_amount) {
|
||||
this.transaction_amount = transaction_amount;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,129 @@
|
||||
package com.baeldung.dataframes;
|
||||
|
||||
import static org.apache.spark.sql.functions.col;
|
||||
import static org.apache.spark.sql.functions.column;
|
||||
import static org.apache.spark.sql.functions.concat;
|
||||
import static org.apache.spark.sql.functions.lit;
|
||||
|
||||
import java.util.Properties;
|
||||
|
||||
import org.apache.spark.sql.Dataset;
|
||||
import org.apache.spark.sql.Row;
|
||||
import org.apache.spark.sql.SaveMode;
|
||||
import org.apache.spark.sql.SparkSession;
|
||||
import org.apache.spark.sql.functions;
|
||||
|
||||
public class CustomerDataAggregationPipeline {
|
||||
private static final SparkSession SPARK_SESSION = SparkDriver.getSparkSession();
|
||||
|
||||
private final Properties dbProperties;
|
||||
|
||||
public CustomerDataAggregationPipeline(Properties properties) {
|
||||
dbProperties = properties;
|
||||
}
|
||||
|
||||
public static void main(String[] args) {
|
||||
// replace with actual DB properties
|
||||
Properties dbProps = new Properties();
|
||||
dbProps.setProperty("connectionURL", "jdbc:postgresql://localhost:5432/customerdb");
|
||||
dbProps.setProperty("driver", "org.postgresql.Driver");
|
||||
dbProps.setProperty("user", "postgres");
|
||||
dbProps.setProperty("password", "postgres");
|
||||
|
||||
new CustomerDataAggregationPipeline(dbProps).run();
|
||||
}
|
||||
|
||||
public void run() {
|
||||
Dataset<Row> ebayDFRaw = ingestCustomerDataFromEbay();
|
||||
Dataset<Row> ebayDf = normalizeCustomerDataFromEbay(ebayDFRaw);
|
||||
|
||||
Dataset<Row> amazonDFRaw = ingestCustomerDataFromAmazon();
|
||||
Dataset<Row> amazonDf = normalizeCustomerDataFromAmazon(amazonDFRaw);
|
||||
|
||||
Dataset<Row> combineDataframes = combineDataframes(ebayDf, amazonDf);
|
||||
|
||||
Dataset<Row> rowDataset = aggregateYearlySalesByGender(combineDataframes);
|
||||
|
||||
exportData(rowDataset);
|
||||
}
|
||||
|
||||
private static Dataset<Row> ingestCustomerDataFromAmazon() {
|
||||
return SPARK_SESSION.read()
|
||||
.format("csv")
|
||||
.option("header", "true")
|
||||
.schema(SchemaFactory.customerSchema())
|
||||
.option("dateFormat", "m/d/YYYY")
|
||||
.load("data/customerData.csv");
|
||||
}
|
||||
|
||||
private static Dataset<Row> ingestCustomerDataFromEbay() {
|
||||
return SPARK_SESSION.read()
|
||||
.format("org.apache.spark.sql.execution.datasources.json.JsonFileFormat")
|
||||
.option("multiline", true)
|
||||
.load("data/customerData.json");
|
||||
}
|
||||
|
||||
private static Dataset<Row> combineDataframes(Dataset<Row> df1, Dataset<Row> df2) {
|
||||
return df1.unionByName(df2);
|
||||
}
|
||||
|
||||
private static Dataset<Row> normalizeCustomerDataFromEbay(Dataset<Row> rawDataset) {
|
||||
Dataset<Row> transformedDF = rawDataset.withColumn("id", concat(rawDataset.col("zoneId"), lit("-"), rawDataset.col("customerId")))
|
||||
.drop(column("customerId"))
|
||||
.withColumn("source", lit("ebay"))
|
||||
.withColumn("city", rawDataset.col("contact.customer_city"))
|
||||
.drop(column("contact"))
|
||||
.drop(column("zoneId"))
|
||||
.withColumn("year", functions.year(col("transaction_date")))
|
||||
.drop("transaction_date")
|
||||
.withColumn("firstName", functions.split(column("name"), " ")
|
||||
.getItem(0))
|
||||
.withColumn("lastName", functions.split(column("name"), " ")
|
||||
.getItem(1))
|
||||
.drop(column("name"));
|
||||
|
||||
print(transformedDF);
|
||||
return transformedDF;
|
||||
}
|
||||
|
||||
private static Dataset<Row> normalizeCustomerDataFromAmazon(Dataset<Row> rawDataset) {
|
||||
|
||||
Dataset<Row> transformedDF = rawDataset.withColumn("id", concat(rawDataset.col("zoneId"), lit("-"), rawDataset.col("id")))
|
||||
.withColumn("source", lit("amazon"))
|
||||
.withColumnRenamed("CITY", "city")
|
||||
.withColumnRenamed("PHONE_NO", "contactNo")
|
||||
.withColumnRenamed("POSTCODE", "postCode")
|
||||
.withColumnRenamed("FIRST_NAME", "firstName")
|
||||
.drop(column("MIDDLE_NAME"))
|
||||
.drop(column("zoneId"))
|
||||
.withColumnRenamed("LAST_NAME", "lastName")
|
||||
.withColumn("year", functions.year(col("transaction_date")))
|
||||
.drop("transaction_date");
|
||||
|
||||
print(transformedDF);
|
||||
return transformedDF;
|
||||
}
|
||||
|
||||
private static Dataset<Row> aggregateYearlySalesByGender(Dataset<Row> dataset) {
|
||||
|
||||
Dataset<Row> aggDF = dataset.groupBy(column("year"), column("source"), column("gender"))
|
||||
.sum("transaction_amount")
|
||||
.withColumnRenamed("sum(transaction_amount)", "annual_spending")
|
||||
.orderBy(col("year").asc(), col("annual_spending").desc());
|
||||
|
||||
print(aggDF);
|
||||
return aggDF;
|
||||
}
|
||||
|
||||
private static void print(Dataset<Row> aggDs) {
|
||||
aggDs.show();
|
||||
aggDs.printSchema();
|
||||
}
|
||||
|
||||
private void exportData(Dataset<Row> dataset) {
|
||||
String connectionURL = dbProperties.getProperty("connectionURL");
|
||||
dataset.write()
|
||||
.mode(SaveMode.Overwrite)
|
||||
.jdbc(connectionURL, "customer", dbProperties);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,46 @@
|
||||
package com.baeldung.dataframes;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import org.apache.spark.sql.Dataset;
|
||||
import org.apache.spark.sql.Row;
|
||||
|
||||
public class CustomerToDataFrameConverterApp {
|
||||
|
||||
private static final List<Customer> CUSTOMERS = Arrays.asList(
|
||||
aCustomerWith("01", "jo", "Female", 2000),
|
||||
aCustomerWith("02", "jack", "Male", 1200)
|
||||
);
|
||||
|
||||
public static void main(String[] args) {
|
||||
Dataset<Row> dataFrame = convertAfterMappingRows(CUSTOMERS);
|
||||
print(dataFrame);
|
||||
Dataset<Row> customerDF = convertToDataFrameWithNoChange();
|
||||
print(customerDF);
|
||||
}
|
||||
|
||||
public static Dataset<Row> convertToDataFrameWithNoChange() {
|
||||
return SparkDriver.getSparkSession().createDataFrame(CUSTOMERS, Customer.class);
|
||||
}
|
||||
|
||||
public static Dataset<Row> convertAfterMappingRows(List<Customer> customer) {
|
||||
List<Row> rows = customer.stream()
|
||||
.map(c -> new CustomerToRowMapper().call(c))
|
||||
.collect(Collectors.toList());
|
||||
|
||||
return SparkDriver.getSparkSession()
|
||||
.createDataFrame(rows, SchemaFactory.minimumCustomerDataSchema());
|
||||
}
|
||||
|
||||
private static Customer aCustomerWith(String id, String name, String gender, int amount) {
|
||||
return new Customer(id, name, gender, amount);
|
||||
}
|
||||
|
||||
private static void print(Dataset<Row> dataFrame) {
|
||||
dataFrame.printSchema();
|
||||
dataFrame.show();
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,18 @@
|
||||
package com.baeldung.dataframes;
|
||||
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.spark.api.java.function.MapFunction;
|
||||
import org.apache.spark.sql.Row;
|
||||
import org.apache.spark.sql.RowFactory;
|
||||
|
||||
public class CustomerToRowMapper implements MapFunction<Customer, Row> {
|
||||
|
||||
@Override
|
||||
public Row call(Customer customer) {
|
||||
Row row = RowFactory.create(
|
||||
customer.getId(), customer.getName().toUpperCase(),
|
||||
StringUtils.substring(customer.getGender(),0, 1),
|
||||
customer.getTransaction_amount());
|
||||
return row;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
package com.baeldung.dataframes;
|
||||
|
||||
import org.apache.spark.sql.Dataset;
|
||||
import org.apache.spark.sql.Encoders;
|
||||
import org.apache.spark.sql.Row;
|
||||
|
||||
public class DataFrameToCustomerConverterApp {
|
||||
|
||||
public static void main(String[] args) {
|
||||
Dataset<Row> df = SparkDriver.getSparkSession()
|
||||
.read()
|
||||
.format("org.apache.spark.sql.execution.datasources.json.JsonFileFormat")
|
||||
.option("multiline", true)
|
||||
.load("data/minCustomerData.json");
|
||||
df.show();
|
||||
df.printSchema();
|
||||
Dataset<Customer> customerDS = df.map(new RowToCustomerMapper(), Encoders.bean(Customer. class));
|
||||
customerDS.show();
|
||||
customerDS.printSchema();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,57 @@
|
||||
package com.baeldung.dataframes;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import org.apache.spark.sql.Dataset;
|
||||
import org.apache.spark.sql.Encoders;
|
||||
import org.apache.spark.sql.Row;
|
||||
import org.apache.spark.sql.SparkSession;
|
||||
|
||||
public class DataSetToDataFrameConverterApp {
|
||||
|
||||
private static final SparkSession SPARK_SESSION = SparkDriver.getSparkSession();
|
||||
|
||||
public static void main(String[] args) {
|
||||
|
||||
Dataset<Customer> customerDataset = convertToDataSetFromPOJO();
|
||||
Dataset<Row> customerDataFrame = customerDataset.toDF();
|
||||
print(customerDataFrame);
|
||||
|
||||
List<String> names = getNames();
|
||||
Dataset<String> namesDataset = convertToDataSetFromStrings(names);
|
||||
Dataset<Row> namesDataFrame = namesDataset.toDF();
|
||||
print(namesDataFrame);
|
||||
}
|
||||
|
||||
private static Dataset<String> convertToDataSetFromStrings(List<String> names) {
|
||||
return SPARK_SESSION.createDataset(names, Encoders.STRING());
|
||||
}
|
||||
|
||||
private static Dataset<Customer> convertToDataSetFromPOJO() {
|
||||
return SPARK_SESSION.createDataset(CUSTOMERS, Encoders.bean(Customer.class));
|
||||
}
|
||||
|
||||
private static final List<Customer> CUSTOMERS = Arrays.asList(
|
||||
aCustomerWith("01", "jo", "Female", 2000),
|
||||
aCustomerWith("02", "jack", "Female", 1200),
|
||||
aCustomerWith("03", "ash", "male", 2000),
|
||||
aCustomerWith("04", "emma", "Female", 2000)
|
||||
);
|
||||
|
||||
private static List<String> getNames() {
|
||||
return CUSTOMERS.stream()
|
||||
.map(Customer::getName)
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
private static void print(Dataset<Row> df) {
|
||||
df.show();
|
||||
df.printSchema();
|
||||
}
|
||||
|
||||
private static Customer aCustomerWith(String id, String name, String gender, int amount) {
|
||||
return new Customer(id, name, gender, amount);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,19 @@
|
||||
package com.baeldung.dataframes;
|
||||
|
||||
import org.apache.spark.api.java.function.MapFunction;
|
||||
import org.apache.spark.sql.Row;
|
||||
|
||||
class RowToCustomerMapper implements MapFunction<Row, Customer> {
|
||||
|
||||
@Override
|
||||
public Customer call(Row row) {
|
||||
|
||||
Customer customer = new Customer();
|
||||
customer.setId(row.getAs("id"));
|
||||
customer.setName(row.getAs("name"));
|
||||
customer.setGender(row.getAs("gender"));
|
||||
customer.setTransaction_amount(Math.toIntExact(row.getAs("transaction_amount")));
|
||||
|
||||
return customer;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,31 @@
|
||||
package com.baeldung.dataframes;
|
||||
|
||||
import org.apache.spark.sql.types.DataTypes;
|
||||
import org.apache.spark.sql.types.StructField;
|
||||
import org.apache.spark.sql.types.StructType;
|
||||
|
||||
public class SchemaFactory {
|
||||
|
||||
public static StructType customerSchema() {
|
||||
return DataTypes.createStructType(
|
||||
new StructField[] { DataTypes.createStructField("id", DataTypes.IntegerType, false),
|
||||
DataTypes.createStructField("zoneId", DataTypes.StringType, false),
|
||||
DataTypes.createStructField("FIRST_NAME", DataTypes.StringType, false),
|
||||
DataTypes.createStructField("MIDDLE_NAME", DataTypes.StringType, false),
|
||||
DataTypes.createStructField("LAST_NAME", DataTypes.StringType, false),
|
||||
DataTypes.createStructField("CITY", DataTypes.StringType, false),
|
||||
DataTypes.createStructField("gender", DataTypes.StringType, false),
|
||||
DataTypes.createStructField("transaction_date", DataTypes.DateType, false),
|
||||
DataTypes.createStructField("transaction_amount", DataTypes.IntegerType, false)
|
||||
});
|
||||
}
|
||||
|
||||
public static StructType minimumCustomerDataSchema() {
|
||||
return DataTypes.createStructType(new StructField[] {
|
||||
DataTypes.createStructField("id", DataTypes.StringType, true),
|
||||
DataTypes.createStructField("name", DataTypes.StringType, true),
|
||||
DataTypes.createStructField("gender", DataTypes.StringType, true),
|
||||
DataTypes.createStructField("transaction_amount", DataTypes.IntegerType, true)
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
package com.baeldung.dataframes;
|
||||
|
||||
import java.io.Serializable;
|
||||
|
||||
import org.apache.spark.sql.SparkSession;
|
||||
|
||||
public class SparkDriver implements Serializable {
|
||||
|
||||
public static SparkSession getSparkSession() {
|
||||
return SparkSession.builder()
|
||||
.appName("Customer Aggregation pipeline")
|
||||
.master("local")
|
||||
.getOrCreate();
|
||||
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,52 @@
|
||||
package com.baeldung.dataframes;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
import java.sql.Connection;
|
||||
import java.sql.DriverManager;
|
||||
import java.sql.PreparedStatement;
|
||||
import java.sql.ResultSet;
|
||||
import java.sql.SQLException;
|
||||
import java.sql.Statement;
|
||||
import java.util.Properties;
|
||||
|
||||
import org.junit.jupiter.api.BeforeAll;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
class CustomerDataAggregationPipelineLiveTest {
|
||||
|
||||
private static Connection conn;
|
||||
|
||||
@BeforeAll
|
||||
static void beforeAll() throws SQLException {
|
||||
DriverManager.registerDriver(new org.postgresql.Driver());
|
||||
String dbURL1 = "jdbc:postgresql://localhost:5432/customerdb";
|
||||
conn = DriverManager.getConnection(dbURL1, "postgres", "postgres");
|
||||
|
||||
String sql = "drop table if exists customer";
|
||||
|
||||
PreparedStatement statement = conn.prepareStatement(sql);
|
||||
statement.executeUpdate();
|
||||
}
|
||||
|
||||
@Test
|
||||
void givenCSVAndJSON_whenRun_thenStoresAggregatedDataFrameInDB() throws Exception {
|
||||
Properties dbProps = new Properties();
|
||||
dbProps.setProperty("connectionURL", "jdbc:postgresql://localhost:5432/customerdb");
|
||||
dbProps.setProperty("driver", "org.postgresql.Driver");
|
||||
dbProps.setProperty("user", "postgres");
|
||||
dbProps.setProperty("password", "postgres");
|
||||
|
||||
CustomerDataAggregationPipeline pipeline = new CustomerDataAggregationPipeline(dbProps);
|
||||
pipeline.run();
|
||||
|
||||
String allCustomersSql = "Select count(*) from customer";
|
||||
|
||||
Statement statement = conn.createStatement();
|
||||
ResultSet resultSet = statement.executeQuery(allCustomersSql);
|
||||
resultSet.next();
|
||||
int count = resultSet.getInt(1);
|
||||
assertEquals(7, count);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,62 @@
|
||||
package com.baeldung.dataframes;
|
||||
|
||||
import static com.baeldung.dataframes.CustomerToDataFrameConverterApp.convertAfterMappingRows;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
import org.apache.spark.sql.Dataset;
|
||||
import org.apache.spark.sql.Row;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
class CustomerToDataFrameConverterAppUnitTest {
|
||||
|
||||
@Test
|
||||
void givenCustomers_whenConvertAfterMappingRows_thenConvertsToDataSet() {
|
||||
|
||||
List<Customer> customers = Arrays.asList(
|
||||
new Customer("01", "jo", "Female", 2000),
|
||||
new Customer("02", "jack", "Male", 1200)
|
||||
);
|
||||
|
||||
Dataset<Row> customerDF = convertAfterMappingRows(customers);
|
||||
List<Row> rows = customerDF.collectAsList();
|
||||
Row row1 = rows.get(0);
|
||||
Row row2 = rows.get(1);
|
||||
|
||||
assertEquals("01", row1.get(0));
|
||||
assertEquals( "JO", row1.get(1));
|
||||
assertEquals( "F", row1.get(2));
|
||||
assertEquals( 2000, row1.get(3));
|
||||
|
||||
assertEquals("02", row2.get(0));
|
||||
assertEquals( "JACK", row2.get(1));
|
||||
assertEquals( "M", row2.get(2));
|
||||
assertEquals( 1200, row2.get(3));
|
||||
}
|
||||
|
||||
@Test
|
||||
void givenCustomers_whenConvertWithNoChange_thenConvertsToDataSet() {
|
||||
|
||||
List<Customer> customers = Arrays.asList(
|
||||
new Customer("01", "jo", "Female", 2000),
|
||||
new Customer("02", "jack", "Male", 1200)
|
||||
);
|
||||
|
||||
Dataset<Row> customerDF = CustomerToDataFrameConverterApp.convertToDataFrameWithNoChange();
|
||||
List<Row> rows = customerDF.collectAsList();
|
||||
Row row1 = rows.get(0);
|
||||
Row row2 = rows.get(1);
|
||||
|
||||
assertEquals("01", row1.getAs("id"));
|
||||
assertEquals( "jo", row1.getAs("name"));
|
||||
assertEquals( "Female", row1.getAs("gender"));
|
||||
assertEquals( 2000, (int)row1.getAs("transaction_amount"));
|
||||
|
||||
assertEquals("02", row2.getAs("id"));
|
||||
assertEquals( "jack", row2.getAs("name"));
|
||||
assertEquals( "Male", row2.getAs("gender"));
|
||||
assertEquals( 1200, (int)row2.getAs("transaction_amount"));
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user