diff --git a/apache-spark/data/customerData.csv b/apache-spark/data/customerData.csv new file mode 100644 index 0000000000..8661bcb352 --- /dev/null +++ b/apache-spark/data/customerData.csv @@ -0,0 +1,9 @@ +id,zoneId,FIRST_NAME,MIDDLE_NAME,LAST_NAME,CITY,gender,transaction_date,transaction_amount +1,EU11,Ana1,A,MN, London,Female,1/2/2018,5200 +2,EU12,Ana2,A,MN, London,Female,29/3/2018,1000 +3,EU13,Jack1,Josh,MN, London,Male,19/6/2018,9600 +4,EU14,Jack2,Josh,MN, London,Male,9/9/2018,1000 +5,EU15,Nick1,H,Dee,London,Male,9/6/2021,3000 +6,EU16,Nick2,H,Dee,London,Male,27/9/2021,500 +7,EU17,Nick3,H,Dee,London,Male,8/12/2021,500 +8,EU18,Sara1,H,Dee,London,Female,7/4/2021,2000 \ No newline at end of file diff --git a/apache-spark/data/customerData.json b/apache-spark/data/customerData.json new file mode 100644 index 0000000000..d713d75678 --- /dev/null +++ b/apache-spark/data/customerData.json @@ -0,0 +1,114 @@ +[ + { + "zoneId": "LONDON", + "customerId": 1, + "gender": "Female", + "name": "jane austin", + "contact": { + "street_Address": "XYZ Road", + "address2": "house 47", + "zipcode": "MK1110", + "county": "DCounty", + "phone_number": "(919) 403-0025", + "customer_city": "CityA" + }, + "transaction_date": "2021-02-05", + "transaction_amount": 15000 + }, + { + "zoneId": "LONDON", + "customerId": 2, + "gender": "Male", + "name": "Jack Trevor", + "contact": { + "street_Address": " kingfisher road", + "address2": "SUITE 110", + "zipcode": "HM1190", + "county": "CCounty", + "phone_number": "(919) 403-0025", + "customer_city": "CityB" + }, + "transaction_date": "2021-02-05", + "transaction_amount": 12000 + }, + { + "zoneId": "LONDON", + "customerId": 11, + "gender": "Female", + "name": "jane1 austin", + "contact": { + "street_Address": "B Road", + "address2": "house 47", + "zipcode": "MK1110", + "county": "BCounty", + "phone_number": "(919) 403-0025", + "customer_city": "CityA" + }, + "transaction_date": "2021-02-05", + "transaction_amount": 1000 + }, + { + "zoneId": "LONDON", + "customerId": 21, + "gender": "Male", + "name": "Jack1 Trevor", + "contact": { + "street_Address": " A road", + "address2": "SUITE 777", + "zipcode": "AZ890", + "county": "ACounty", + "phone_number": "(919) 403-0025", + "customer_city": "CityB" + }, + "transaction_date": "2021-02-05", + "transaction_amount": 1000 + }, + { + "zoneId": "Wales", + "customerId": 3, + "gender": "Male", + "name": "John Jack", + "contact": { + "street_Address": "sunny croft", + "address2": "SUITE 1", + "zipcode": "SN1030", + "county": "bucks", + "phone_number": "(919) 403-0025", + "customer_city": "Cardiff" + }, + "transaction_date": "2018-02-05", + "transaction_amount": 5000 + }, + { + "zoneId": "LONDON", + "customerId": 4, + "gender": "Male", + "name": "Jack Trevor", + "contact": { + "street_Address": " kingfisher road", + "address2": "SUITE 110", + "zipcode": "HM1190", + "county": "Hampshire", + "phone_number": "(919) 403-0025", + "customer_city": "CityB" + }, + "transaction_date": "2021-02-05", + "transaction_amount": 500 + }, + { + "zoneId": "Wales", + "customerId": 5, + "gender": "Male", + "name": "John Jack", + "contact": { + "street_Address": "sunny croft", + "address2": "SUITE 1", + "zipcode": "SN1030", + "county": "ECounty", + "phone_number": "(919) 403-0025", + "customer_city": "Cardiff" + }, + "transaction_date": "2018-01-25", + "transaction_amount": 500 + } +] diff --git a/apache-spark/data/minCustomerData.json b/apache-spark/data/minCustomerData.json new file mode 100644 index 0000000000..a47a739d51 --- /dev/null +++ b/apache-spark/data/minCustomerData.json @@ -0,0 +1,20 @@ +[ + { + "id": "1", + "gender": "Female", + "name": "Jo", + "transaction_amount": 200 + }, + { + "id":"2", + "gender": "Male", + "name": "Mike", + "transaction_amount": 500 + }, + { + "id": "3", + "gender": "Male", + "name": "Dave", + "transaction_amount": 5000 + } +] \ No newline at end of file diff --git a/apache-spark/docker/docker-compose.yaml b/apache-spark/docker/docker-compose.yaml new file mode 100644 index 0000000000..960e543229 --- /dev/null +++ b/apache-spark/docker/docker-compose.yaml @@ -0,0 +1,33 @@ +version: "3" + +services: + postgres: + image: postgres:12.3-alpine + restart: always + environment: + POSTGRES_PASSWORD: postgres + POSTGRES_USER: postgres + expose: + - 5432 + ports: + - 5432:5432 + command: -p 5432 + volumes: + - ./init.sql:/docker-entrypoint-initdb.d/init.sql +# - postgres:/var/lib/postgresql/data + + pgadmin: + image: dpage/pgadmin4:4.23 + environment: + PGADMIN_DEFAULT_EMAIL: admin@pgadmin.com + PGADMIN_DEFAULT_PASSWORD: password + PGADMIN_LISTEN_PORT: 80 + ports: + - 15432:80 + volumes: + - pgadmin:/var/lib/pgadmin + depends_on: + - postgres + +volumes: + pgadmin: \ No newline at end of file diff --git a/apache-spark/docker/init.sql b/apache-spark/docker/init.sql new file mode 100755 index 0000000000..371135faa3 --- /dev/null +++ b/apache-spark/docker/init.sql @@ -0,0 +1 @@ +CREATE DATABASE customerdb; \ No newline at end of file diff --git a/apache-spark/pom.xml b/apache-spark/pom.xml index 05c5088662..b86e99433a 100644 --- a/apache-spark/pom.xml +++ b/apache-spark/pom.xml @@ -20,37 +20,31 @@ org.apache.spark spark-core_2.11 ${org.apache.spark.spark-core.version} - provided org.apache.spark spark-sql_2.11 ${org.apache.spark.spark-sql.version} - provided org.apache.spark spark-graphx_2.11 ${org.apache.spark.spark-graphx.version} - provided graphframes graphframes ${graphframes.version} - provided org.apache.spark spark-streaming_2.11 ${org.apache.spark.spark-streaming.version} - provided org.apache.spark spark-mllib_2.11 ${org.apache.spark.spark-mllib.version} - provided org.apache.spark @@ -67,6 +61,11 @@ spark-cassandra-connector-java_2.11 ${com.datastax.spark.spark-cassandra-connector-java.version} + + org.postgresql + postgresql + ${postgres.version} + @@ -108,6 +107,7 @@ 2.4.8 2.5.2 1.6.0-M1 + 42.3.3 \ No newline at end of file diff --git a/apache-spark/src/main/java/com/baeldung/dataframes/Customer.java b/apache-spark/src/main/java/com/baeldung/dataframes/Customer.java new file mode 100644 index 0000000000..97fa160872 --- /dev/null +++ b/apache-spark/src/main/java/com/baeldung/dataframes/Customer.java @@ -0,0 +1,52 @@ +package com.baeldung.dataframes; + +public class Customer { + String id; + String name; + String gender; + int transaction_amount; + + public Customer() { + + } + + public Customer(String id, String name, String gender, int transaction_amount) { + this.id = id; + this.name = name; + this.gender = gender; + this.transaction_amount = transaction_amount; + } + + public String getId() { + return id; + } + + public void setId(String id) { + this.id = id; + } + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + public String getGender() { + return gender; + } + + public void setGender(String gender) { + this.gender = gender; + } + + public int getTransaction_amount() { + return transaction_amount; + } + + public void setTransaction_amount(int transaction_amount) { + this.transaction_amount = transaction_amount; + } + +} diff --git a/apache-spark/src/main/java/com/baeldung/dataframes/CustomerDataAggregationPipeline.java b/apache-spark/src/main/java/com/baeldung/dataframes/CustomerDataAggregationPipeline.java new file mode 100644 index 0000000000..869645624a --- /dev/null +++ b/apache-spark/src/main/java/com/baeldung/dataframes/CustomerDataAggregationPipeline.java @@ -0,0 +1,129 @@ +package com.baeldung.dataframes; + +import static org.apache.spark.sql.functions.col; +import static org.apache.spark.sql.functions.column; +import static org.apache.spark.sql.functions.concat; +import static org.apache.spark.sql.functions.lit; + +import java.util.Properties; + +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SaveMode; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.functions; + +public class CustomerDataAggregationPipeline { + private static final SparkSession SPARK_SESSION = SparkDriver.getSparkSession(); + + private final Properties dbProperties; + + public CustomerDataAggregationPipeline(Properties properties) { + dbProperties = properties; + } + + public static void main(String[] args) { + // replace with actual DB properties + Properties dbProps = new Properties(); + dbProps.setProperty("connectionURL", "jdbc:postgresql://localhost:5432/customerdb"); + dbProps.setProperty("driver", "org.postgresql.Driver"); + dbProps.setProperty("user", "postgres"); + dbProps.setProperty("password", "postgres"); + + new CustomerDataAggregationPipeline(dbProps).run(); + } + + public void run() { + Dataset ebayDFRaw = ingestCustomerDataFromEbay(); + Dataset ebayDf = normalizeCustomerDataFromEbay(ebayDFRaw); + + Dataset amazonDFRaw = ingestCustomerDataFromAmazon(); + Dataset amazonDf = normalizeCustomerDataFromAmazon(amazonDFRaw); + + Dataset combineDataframes = combineDataframes(ebayDf, amazonDf); + + Dataset rowDataset = aggregateYearlySalesByGender(combineDataframes); + + exportData(rowDataset); + } + + private static Dataset ingestCustomerDataFromAmazon() { + return SPARK_SESSION.read() + .format("csv") + .option("header", "true") + .schema(SchemaFactory.customerSchema()) + .option("dateFormat", "m/d/YYYY") + .load("data/customerData.csv"); + } + + private static Dataset ingestCustomerDataFromEbay() { + return SPARK_SESSION.read() + .format("org.apache.spark.sql.execution.datasources.json.JsonFileFormat") + .option("multiline", true) + .load("data/customerData.json"); + } + + private static Dataset combineDataframes(Dataset df1, Dataset df2) { + return df1.unionByName(df2); + } + + private static Dataset normalizeCustomerDataFromEbay(Dataset rawDataset) { + Dataset transformedDF = rawDataset.withColumn("id", concat(rawDataset.col("zoneId"), lit("-"), rawDataset.col("customerId"))) + .drop(column("customerId")) + .withColumn("source", lit("ebay")) + .withColumn("city", rawDataset.col("contact.customer_city")) + .drop(column("contact")) + .drop(column("zoneId")) + .withColumn("year", functions.year(col("transaction_date"))) + .drop("transaction_date") + .withColumn("firstName", functions.split(column("name"), " ") + .getItem(0)) + .withColumn("lastName", functions.split(column("name"), " ") + .getItem(1)) + .drop(column("name")); + + print(transformedDF); + return transformedDF; + } + + private static Dataset normalizeCustomerDataFromAmazon(Dataset rawDataset) { + + Dataset transformedDF = rawDataset.withColumn("id", concat(rawDataset.col("zoneId"), lit("-"), rawDataset.col("id"))) + .withColumn("source", lit("amazon")) + .withColumnRenamed("CITY", "city") + .withColumnRenamed("PHONE_NO", "contactNo") + .withColumnRenamed("POSTCODE", "postCode") + .withColumnRenamed("FIRST_NAME", "firstName") + .drop(column("MIDDLE_NAME")) + .drop(column("zoneId")) + .withColumnRenamed("LAST_NAME", "lastName") + .withColumn("year", functions.year(col("transaction_date"))) + .drop("transaction_date"); + + print(transformedDF); + return transformedDF; + } + + private static Dataset aggregateYearlySalesByGender(Dataset dataset) { + + Dataset aggDF = dataset.groupBy(column("year"), column("source"), column("gender")) + .sum("transaction_amount") + .withColumnRenamed("sum(transaction_amount)", "annual_spending") + .orderBy(col("year").asc(), col("annual_spending").desc()); + + print(aggDF); + return aggDF; + } + + private static void print(Dataset aggDs) { + aggDs.show(); + aggDs.printSchema(); + } + + private void exportData(Dataset dataset) { + String connectionURL = dbProperties.getProperty("connectionURL"); + dataset.write() + .mode(SaveMode.Overwrite) + .jdbc(connectionURL, "customer", dbProperties); + } +} diff --git a/apache-spark/src/main/java/com/baeldung/dataframes/CustomerToDataFrameConverterApp.java b/apache-spark/src/main/java/com/baeldung/dataframes/CustomerToDataFrameConverterApp.java new file mode 100644 index 0000000000..53799c1079 --- /dev/null +++ b/apache-spark/src/main/java/com/baeldung/dataframes/CustomerToDataFrameConverterApp.java @@ -0,0 +1,46 @@ +package com.baeldung.dataframes; + +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; + +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; + +public class CustomerToDataFrameConverterApp { + + private static final List CUSTOMERS = Arrays.asList( + aCustomerWith("01", "jo", "Female", 2000), + aCustomerWith("02", "jack", "Male", 1200) + ); + + public static void main(String[] args) { + Dataset dataFrame = convertAfterMappingRows(CUSTOMERS); + print(dataFrame); + Dataset customerDF = convertToDataFrameWithNoChange(); + print(customerDF); + } + + public static Dataset convertToDataFrameWithNoChange() { + return SparkDriver.getSparkSession().createDataFrame(CUSTOMERS, Customer.class); + } + + public static Dataset convertAfterMappingRows(List customer) { + List rows = customer.stream() + .map(c -> new CustomerToRowMapper().call(c)) + .collect(Collectors.toList()); + + return SparkDriver.getSparkSession() + .createDataFrame(rows, SchemaFactory.minimumCustomerDataSchema()); + } + + private static Customer aCustomerWith(String id, String name, String gender, int amount) { + return new Customer(id, name, gender, amount); + } + + private static void print(Dataset dataFrame) { + dataFrame.printSchema(); + dataFrame.show(); + } + +} diff --git a/apache-spark/src/main/java/com/baeldung/dataframes/CustomerToRowMapper.java b/apache-spark/src/main/java/com/baeldung/dataframes/CustomerToRowMapper.java new file mode 100644 index 0000000000..e54bceb3ad --- /dev/null +++ b/apache-spark/src/main/java/com/baeldung/dataframes/CustomerToRowMapper.java @@ -0,0 +1,18 @@ +package com.baeldung.dataframes; + +import org.apache.commons.lang3.StringUtils; +import org.apache.spark.api.java.function.MapFunction; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; + +public class CustomerToRowMapper implements MapFunction { + + @Override + public Row call(Customer customer) { + Row row = RowFactory.create( + customer.getId(), customer.getName().toUpperCase(), + StringUtils.substring(customer.getGender(),0, 1), + customer.getTransaction_amount()); + return row; + } +} \ No newline at end of file diff --git a/apache-spark/src/main/java/com/baeldung/dataframes/DataFrameToCustomerConverterApp.java b/apache-spark/src/main/java/com/baeldung/dataframes/DataFrameToCustomerConverterApp.java new file mode 100644 index 0000000000..31ad8de12b --- /dev/null +++ b/apache-spark/src/main/java/com/baeldung/dataframes/DataFrameToCustomerConverterApp.java @@ -0,0 +1,21 @@ +package com.baeldung.dataframes; + +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; + +public class DataFrameToCustomerConverterApp { + + public static void main(String[] args) { + Dataset df = SparkDriver.getSparkSession() + .read() + .format("org.apache.spark.sql.execution.datasources.json.JsonFileFormat") + .option("multiline", true) + .load("data/minCustomerData.json"); + df.show(); + df.printSchema(); + Dataset customerDS = df.map(new RowToCustomerMapper(), Encoders.bean(Customer. class)); + customerDS.show(); + customerDS.printSchema(); + } +} diff --git a/apache-spark/src/main/java/com/baeldung/dataframes/DataSetToDataFrameConverterApp.java b/apache-spark/src/main/java/com/baeldung/dataframes/DataSetToDataFrameConverterApp.java new file mode 100644 index 0000000000..23db18dddf --- /dev/null +++ b/apache-spark/src/main/java/com/baeldung/dataframes/DataSetToDataFrameConverterApp.java @@ -0,0 +1,57 @@ +package com.baeldung.dataframes; + +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; + +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; + +public class DataSetToDataFrameConverterApp { + + private static final SparkSession SPARK_SESSION = SparkDriver.getSparkSession(); + + public static void main(String[] args) { + + Dataset customerDataset = convertToDataSetFromPOJO(); + Dataset customerDataFrame = customerDataset.toDF(); + print(customerDataFrame); + + List names = getNames(); + Dataset namesDataset = convertToDataSetFromStrings(names); + Dataset namesDataFrame = namesDataset.toDF(); + print(namesDataFrame); + } + + private static Dataset convertToDataSetFromStrings(List names) { + return SPARK_SESSION.createDataset(names, Encoders.STRING()); + } + + private static Dataset convertToDataSetFromPOJO() { + return SPARK_SESSION.createDataset(CUSTOMERS, Encoders.bean(Customer.class)); + } + + private static final List CUSTOMERS = Arrays.asList( + aCustomerWith("01", "jo", "Female", 2000), + aCustomerWith("02", "jack", "Female", 1200), + aCustomerWith("03", "ash", "male", 2000), + aCustomerWith("04", "emma", "Female", 2000) + ); + + private static List getNames() { + return CUSTOMERS.stream() + .map(Customer::getName) + .collect(Collectors.toList()); + } + + private static void print(Dataset df) { + df.show(); + df.printSchema(); + } + + private static Customer aCustomerWith(String id, String name, String gender, int amount) { + return new Customer(id, name, gender, amount); + } +} diff --git a/apache-spark/src/main/java/com/baeldung/dataframes/RowToCustomerMapper.java b/apache-spark/src/main/java/com/baeldung/dataframes/RowToCustomerMapper.java new file mode 100644 index 0000000000..02fde539c8 --- /dev/null +++ b/apache-spark/src/main/java/com/baeldung/dataframes/RowToCustomerMapper.java @@ -0,0 +1,19 @@ +package com.baeldung.dataframes; + +import org.apache.spark.api.java.function.MapFunction; +import org.apache.spark.sql.Row; + +class RowToCustomerMapper implements MapFunction { + + @Override + public Customer call(Row row) { + + Customer customer = new Customer(); + customer.setId(row.getAs("id")); + customer.setName(row.getAs("name")); + customer.setGender(row.getAs("gender")); + customer.setTransaction_amount(Math.toIntExact(row.getAs("transaction_amount"))); + + return customer; + } +} diff --git a/apache-spark/src/main/java/com/baeldung/dataframes/SchemaFactory.java b/apache-spark/src/main/java/com/baeldung/dataframes/SchemaFactory.java new file mode 100644 index 0000000000..6c298e4829 --- /dev/null +++ b/apache-spark/src/main/java/com/baeldung/dataframes/SchemaFactory.java @@ -0,0 +1,31 @@ +package com.baeldung.dataframes; + +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +public class SchemaFactory { + + public static StructType customerSchema() { + return DataTypes.createStructType( + new StructField[] { DataTypes.createStructField("id", DataTypes.IntegerType, false), + DataTypes.createStructField("zoneId", DataTypes.StringType, false), + DataTypes.createStructField("FIRST_NAME", DataTypes.StringType, false), + DataTypes.createStructField("MIDDLE_NAME", DataTypes.StringType, false), + DataTypes.createStructField("LAST_NAME", DataTypes.StringType, false), + DataTypes.createStructField("CITY", DataTypes.StringType, false), + DataTypes.createStructField("gender", DataTypes.StringType, false), + DataTypes.createStructField("transaction_date", DataTypes.DateType, false), + DataTypes.createStructField("transaction_amount", DataTypes.IntegerType, false) + }); + } + + public static StructType minimumCustomerDataSchema() { + return DataTypes.createStructType(new StructField[] { + DataTypes.createStructField("id", DataTypes.StringType, true), + DataTypes.createStructField("name", DataTypes.StringType, true), + DataTypes.createStructField("gender", DataTypes.StringType, true), + DataTypes.createStructField("transaction_amount", DataTypes.IntegerType, true) + }); + } +} diff --git a/apache-spark/src/main/java/com/baeldung/dataframes/SparkDriver.java b/apache-spark/src/main/java/com/baeldung/dataframes/SparkDriver.java new file mode 100644 index 0000000000..adc25170a7 --- /dev/null +++ b/apache-spark/src/main/java/com/baeldung/dataframes/SparkDriver.java @@ -0,0 +1,16 @@ +package com.baeldung.dataframes; + +import java.io.Serializable; + +import org.apache.spark.sql.SparkSession; + +public class SparkDriver implements Serializable { + + public static SparkSession getSparkSession() { + return SparkSession.builder() + .appName("Customer Aggregation pipeline") + .master("local") + .getOrCreate(); + + } +} diff --git a/apache-spark/src/test/java/com/baeldung/dataframes/CustomerDataAggregationPipelineLiveTest.java b/apache-spark/src/test/java/com/baeldung/dataframes/CustomerDataAggregationPipelineLiveTest.java new file mode 100644 index 0000000000..52a7b1451f --- /dev/null +++ b/apache-spark/src/test/java/com/baeldung/dataframes/CustomerDataAggregationPipelineLiveTest.java @@ -0,0 +1,52 @@ +package com.baeldung.dataframes; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.Properties; + +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +class CustomerDataAggregationPipelineLiveTest { + + private static Connection conn; + + @BeforeAll + static void beforeAll() throws SQLException { + DriverManager.registerDriver(new org.postgresql.Driver()); + String dbURL1 = "jdbc:postgresql://localhost:5432/customerdb"; + conn = DriverManager.getConnection(dbURL1, "postgres", "postgres"); + + String sql = "drop table if exists customer"; + + PreparedStatement statement = conn.prepareStatement(sql); + statement.executeUpdate(); + } + + @Test + void givenCSVAndJSON_whenRun_thenStoresAggregatedDataFrameInDB() throws Exception { + Properties dbProps = new Properties(); + dbProps.setProperty("connectionURL", "jdbc:postgresql://localhost:5432/customerdb"); + dbProps.setProperty("driver", "org.postgresql.Driver"); + dbProps.setProperty("user", "postgres"); + dbProps.setProperty("password", "postgres"); + + CustomerDataAggregationPipeline pipeline = new CustomerDataAggregationPipeline(dbProps); + pipeline.run(); + + String allCustomersSql = "Select count(*) from customer"; + + Statement statement = conn.createStatement(); + ResultSet resultSet = statement.executeQuery(allCustomersSql); + resultSet.next(); + int count = resultSet.getInt(1); + assertEquals(7, count); + } + +} diff --git a/apache-spark/src/test/java/com/baeldung/dataframes/CustomerToDataFrameConverterAppUnitTest.java b/apache-spark/src/test/java/com/baeldung/dataframes/CustomerToDataFrameConverterAppUnitTest.java new file mode 100644 index 0000000000..06c8f66bcd --- /dev/null +++ b/apache-spark/src/test/java/com/baeldung/dataframes/CustomerToDataFrameConverterAppUnitTest.java @@ -0,0 +1,62 @@ +package com.baeldung.dataframes; + +import static com.baeldung.dataframes.CustomerToDataFrameConverterApp.convertAfterMappingRows; +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.util.Arrays; +import java.util.List; + +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.junit.jupiter.api.Test; + +class CustomerToDataFrameConverterAppUnitTest { + + @Test + void givenCustomers_whenConvertAfterMappingRows_thenConvertsToDataSet() { + + List customers = Arrays.asList( + new Customer("01", "jo", "Female", 2000), + new Customer("02", "jack", "Male", 1200) + ); + + Dataset customerDF = convertAfterMappingRows(customers); + List rows = customerDF.collectAsList(); + Row row1 = rows.get(0); + Row row2 = rows.get(1); + + assertEquals("01", row1.get(0)); + assertEquals( "JO", row1.get(1)); + assertEquals( "F", row1.get(2)); + assertEquals( 2000, row1.get(3)); + + assertEquals("02", row2.get(0)); + assertEquals( "JACK", row2.get(1)); + assertEquals( "M", row2.get(2)); + assertEquals( 1200, row2.get(3)); + } + + @Test + void givenCustomers_whenConvertWithNoChange_thenConvertsToDataSet() { + + List customers = Arrays.asList( + new Customer("01", "jo", "Female", 2000), + new Customer("02", "jack", "Male", 1200) + ); + + Dataset customerDF = CustomerToDataFrameConverterApp.convertToDataFrameWithNoChange(); + List rows = customerDF.collectAsList(); + Row row1 = rows.get(0); + Row row2 = rows.get(1); + + assertEquals("01", row1.getAs("id")); + assertEquals( "jo", row1.getAs("name")); + assertEquals( "Female", row1.getAs("gender")); + assertEquals( 2000, (int)row1.getAs("transaction_amount")); + + assertEquals("02", row2.getAs("id")); + assertEquals( "jack", row2.getAs("name")); + assertEquals( "Male", row2.getAs("gender")); + assertEquals( 1200, (int)row2.getAs("transaction_amount")); + } +}