|
2 | 2 |
|
3 | 3 | importorg.apache.log4j.Level;
|
4 | 4 | importorg.apache.log4j.Logger;
|
5 |
| -importorg.apache.spark.SparkConf; |
6 |
| -importorg.apache.spark.api.java.JavaRDD; |
7 |
| -importorg.apache.spark.api.java.JavaSparkContext; |
8 |
| -importorg.apache.spark.sql.Dataset; |
9 |
| -importorg.apache.spark.sql.Encoders; |
10 |
| -importorg.apache.spark.sql.SparkSession; |
| 5 | +importorg.apache.spark.sql.*; |
11 | 6 |
|
12 | 7 | importstaticorg.apache.spark.sql.functions.avg;
|
| 8 | +importstaticorg.apache.spark.sql.functions.col; |
13 | 9 | importstaticorg.apache.spark.sql.functions.max;
|
14 | 10 |
|
15 | 11 |
|
16 | 12 | publicclassTypedDataset {
|
17 | 13 | privatestaticfinalStringAGE_MIDPOINT ="ageMidpoint";
|
18 | 14 | privatestaticfinalStringSALARY_MIDPOINT ="salaryMidPoint";
|
19 | 15 | privatestaticfinalStringSALARY_MIDPOINT_BUCKET ="salaryMidpointBucket";
|
20 |
| -privatestaticfinalfloatNULL_VALUE = -1.0f; |
21 |
| -privatestaticfinalStringCOMMA_DELIMITER =",(?=([^\"]*\"[^\"]*\")*[^\"]*$)"; |
22 | 16 |
|
23 | 17 | publicstaticvoidmain(String[]args)throwsException {
|
24 | 18 |
|
25 | 19 | Logger.getLogger("org").setLevel(Level.ERROR);
|
26 |
| -SparkConfconf =newSparkConf().setAppName("StackOverFlowSurvey").setMaster("local[1]"); |
| 20 | +SparkSessionsession =SparkSession.builder().appName("StackOverFlowSurvey").master("local[1]").getOrCreate(); |
27 | 21 |
|
28 |
| -JavaSparkContextsc =newJavaSparkContext(conf); |
| 22 | +DataFrameReaderdataFrameReader =session.read(); |
29 | 23 |
|
30 |
| -SparkSessionsession =SparkSession.builder().appName("StackOverFlowSurvey").master("local[1]").getOrCreate(); |
| 24 | +Dataset<Row>responses =dataFrameReader.option("header","true").csv("in/2016-stack-overflow-survey-responses.csv"); |
31 | 25 |
|
32 |
| -JavaRDD<String>lines =sc.textFile("in/2016-stack-overflow-survey-responses.csv"); |
| 26 | +Dataset<Row>responseWithSelectedColumns =responses.select(col("country"),col("age_midpoint").as("ageMidPoint").cast("integer"),col("occupation"),col("salary_midpoint").as("salaryMidPoint").cast("integer")); |
33 | 27 |
|
34 |
| -JavaRDD<Response>responseRDD =lines |
35 |
| - .filter(line -> !line.split(COMMA_DELIMITER, -1)[2].equals("country")) |
36 |
| - .map(line -> { |
37 |
| -String[]splits =line.split(COMMA_DELIMITER, -1); |
38 |
| -returnnewResponse(splits[2],convertStringToFloat(splits[6]),splits[9],convertStringToFloat(splits[14])); |
39 |
| - }); |
40 |
| -Dataset<Response>responseDataset =session.createDataset(responseRDD.rdd(),Encoders.bean(Response.class)); |
| 28 | +Dataset<Response>typedDataset =responseWithSelectedColumns.as(Encoders.bean(Response.class)); |
41 | 29 |
|
42 | 30 | System.out.println("=== Print out schema ===");
|
43 |
| -responseDataset.printSchema(); |
| 31 | +typedDataset.printSchema(); |
44 | 32 |
|
45 | 33 | System.out.println("=== Print 20 records of responses table ===");
|
46 |
| -responseDataset.show(20); |
| 34 | +typedDataset.show(20); |
47 | 35 |
|
48 | 36 | System.out.println("=== Print records where the response is from Afghanistan ===");
|
49 |
| -responseDataset.filter(response ->response.getCountry().equals("Afghanistan")).show(); |
| 37 | +typedDataset.filter(response ->response.getCountry().equals("Afghanistan")).show(); |
50 | 38 |
|
51 | 39 | System.out.println("=== Print the count of occupations ===");
|
52 |
| -responseDataset.groupBy(responseDataset.col("occupation")).count().show(); |
53 |
| - |
| 40 | +typedDataset.groupBy(typedDataset.col("occupation")).count().show(); |
54 | 41 |
|
55 | 42 | System.out.println("=== Print records with average mid age less than 20 ===");
|
56 |
| -responseDataset.filter(response ->response.getAgeMidPoint() !=NULL_VALUE &&response.getAgeMidPoint() <20).show(); |
| 43 | +typedDataset.filter(response ->response.getAgeMidPoint() !=null &&response.getAgeMidPoint() <20).show(); |
57 | 44 |
|
58 | 45 | System.out.println("=== Print the result with salary middle point in descending order ===");
|
59 |
| -responseDataset.orderBy(responseDataset.col(SALARY_MIDPOINT ).desc()).show(); |
| 46 | +typedDataset.orderBy(typedDataset.col(SALARY_MIDPOINT ).desc()).show(); |
60 | 47 |
|
61 | 48 | System.out.println("=== Group by country and aggregate by average salary middle point and max age middle point ===");
|
62 |
| -responseDataset |
63 |
| - .filter(response ->response.getSalaryMidPoint() !=NULL_VALUE) |
64 |
| - .groupBy("country") |
65 |
| - .agg(avg(SALARY_MIDPOINT),max(AGE_MIDPOINT)) |
66 |
| - .show(); |
| 49 | +typedDataset.filter(response ->response.getSalaryMidPoint() !=null) |
| 50 | + .groupBy("country") |
| 51 | + .agg(avg(SALARY_MIDPOINT),max(AGE_MIDPOINT)) |
| 52 | + .show(); |
67 | 53 |
|
68 | 54 | System.out.println("=== Group by salary bucket ===");
|
69 |
| - |
70 |
| -responseDataset |
71 |
| - .map(response ->Math.round(response.getSalaryMidPoint()/20000) *20000,Encoders.INT()) |
72 |
| - .withColumnRenamed("value",SALARY_MIDPOINT_BUCKET) |
73 |
| - .groupBy(SALARY_MIDPOINT_BUCKET) |
74 |
| - .count() |
75 |
| - .orderBy(SALARY_MIDPOINT_BUCKET).show(); |
| 55 | +typedDataset.filter(response ->response.getSalaryMidPoint() !=null) |
| 56 | + .map(response ->Math.round(response.getSalaryMidPoint()/20000) *20000,Encoders.INT()) |
| 57 | + .withColumnRenamed("value",SALARY_MIDPOINT_BUCKET) |
| 58 | + .groupBy(SALARY_MIDPOINT_BUCKET) |
| 59 | + .count() |
| 60 | + .orderBy(SALARY_MIDPOINT_BUCKET).show(); |
76 | 61 | }
|
77 |
| - |
78 |
| -privatestaticfloatconvertStringToFloat(Stringsplit) { |
79 |
| -returnsplit.isEmpty() ?NULL_VALUE :Float.valueOf(split); |
80 |
| - } |
81 |
| - |
82 | 62 | }
|