wandermonk
wandermonk

Reputation: 7356

Testing a utility function by writing a unit test in apache spark scala

I have a utility function written in scala to read parquet files from s3 bucket. Could someone help me in writing unit test cases for this

Below is the function which needs to be tested.

  def readParquetFile(spark: SparkSession,
                      locationPath: String): DataFrame = {
    spark.read
      .parquet(locationPath)
  }

So far i have created a SparkSession for which the master is local

import org.apache.spark.sql.SparkSession


trait SparkSessionTestWrapper {

  lazy val spark: SparkSession = {
    SparkSession.builder().master("local").appName("Test App").getOrCreate()
  }

}

I am stuck with testing the function. Here is the code where I am stuck. The question is should i create a real parquet file and load to see if the dataframe is getting created or is there a mocking framework to test this.

import com.github.mrpowers.spark.fast.tests.DataFrameComparer
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
import org.scalatest.FunSpec

class ReadAndWriteSpec extends FunSpec with DataFrameComparer with SparkSessionTestWrapper {

  import spark.implicits._

  it("reads a parquet file and creates a dataframe") {

  }

}

Edit:

Basing on the inputs from the comments i came up with the below but i am still not able to understand how this can be leveraged.

I am using https://github.com/findify/s3mock

class ReadAndWriteSpec extends FunSpec with DataFrameComparer with SparkSessionTestWrapper {

  import spark.implicits._

  it("reads a parquet file and creates a dataframe") {

    val api = S3Mock(port = 8001, dir = "/tmp/s3")
    api.start

    val endpoint = new EndpointConfiguration("http://localhost:8001", "us-west-2")
    val client = AmazonS3ClientBuilder
      .standard
      .withPathStyleAccessEnabled(true)
      .withEndpointConfiguration(endpoint)
      .withCredentials(new AWSStaticCredentialsProvider(new AnonymousAWSCredentials()))
      .build

    /** Use it as usual. */
    client.createBucket("foo")
    client.putObject("foo", "bar", "baz")
    val url = client.getUrl("foo","bar")

    println(url.getFile())

    val df = ReadAndWrite.readParquetFile(spark,url.getPath())
    df.printSchema()

  }

}

Upvotes: 4

Views: 4154

Answers (1)

wandermonk
wandermonk

Reputation: 7356

I figured out and kept it simple. I could complete some basic test cases.

Here is my solution. I hope this will help someone.

import org.apache.spark.sql
import org.apache.spark.sql.{SaveMode, SparkSession}
import org.scalatest.{BeforeAndAfterEach, FunSuite}
import loaders.ReadAndWrite

class ReadAndWriteTestSpec extends FunSuite with BeforeAndAfterEach{

  private val master = "local"

  private val appName = "ReadAndWrite-Test"

  var spark : SparkSession = _

  override def beforeEach(): Unit = {
    spark = new sql.SparkSession.Builder().appName(appName).master(master).getOrCreate()
  }

  test("creating data frame from parquet file") {
    val sparkSession = spark
    import sparkSession.implicits._
    val peopleDF = spark.read.json("src/test/resources/people.json")
    peopleDF.write.mode(SaveMode.Overwrite).parquet("src/test/resources/people.parquet")

    val df = ReadAndWrite.readParquetFile(sparkSession,"src/test/resources/people.parquet")
    df.printSchema()

  }


  test("creating data frame from text file") {
    val sparkSession = spark
    import sparkSession.implicits._
    val peopleDF = ReadAndWrite.readTextfileToDataSet(sparkSession,"src/test/resources/people.txt").map(_.split(","))
      .map(attributes => Person(attributes(0), attributes(1).trim.toInt))
      .toDF()
    peopleDF.printSchema()
  }

  test("counts should match with number of records in a text file") {
    val sparkSession = spark
    import sparkSession.implicits._
    val peopleDF = ReadAndWrite.readTextfileToDataSet(sparkSession,"src/test/resources/people.txt").map(_.split(","))
      .map(attributes => Person(attributes(0), attributes(1).trim.toInt))
      .toDF()
    peopleDF.printSchema()

    assert(peopleDF.count() == 3)
  }

  test("data should match with sample records in a text file") {
    val sparkSession = spark
    import sparkSession.implicits._
    val peopleDF = ReadAndWrite.readTextfileToDataSet(sparkSession,"src/test/resources/people.txt").map(_.split(","))
      .map(attributes => Person(attributes(0), attributes(1).trim.toInt))
      .toDF()
    peopleDF.printSchema()

    assert(peopleDF.take(1)(0)(0).equals("Michael"))
  }

  test("Write a data frame as csv file") {
    val sparkSession = spark
    import sparkSession.implicits._
    val peopleDF = ReadAndWrite.readTextfileToDataSet(sparkSession,"src/test/resources/people.txt").map(_.split(","))
      .map(attributes => Person(attributes(0), attributes(1).trim.toInt))
      .toDF()

    //header argument should be boolean to the user to avoid confusions
    ReadAndWrite.writeDataframeAsCSV(peopleDF,"src/test/resources/out.csv",java.time.Instant.now().toString,",","true")
  }

  override def afterEach(): Unit = {
    spark.stop()
  }

}

case class Person(name: String, age: Int)

Upvotes: 3

Related Questions