Zohar Meir
Zohar Meir

Reputation: 595

Apache Spark: how to cancel job in code and kill running tasks?

I am running a Spark application (version 1.6.0) on a Hadoop cluster with Yarn (version 2.6.0) in client mode. I have a piece of code that runs a long computation, and I want to kill it if it takes too long (and then run some other function instead).
Here is an example:

val conf = new SparkConf().setAppName("TIMEOUT_TEST")
val sc = new SparkContext(conf)
val lst = List(1,2,3)
// setting up an infite action
val future = sc.parallelize(lst).map(while (true) _).collectAsync()

try {
    Await.result(future, Duration(30, TimeUnit.SECONDS))
    println("success!")
} catch {
    case _:Throwable =>
        future.cancel()
        println("timeout")
}

// sleep for 1 hour to allow inspecting the application in yarn
Thread.sleep(60*60*1000)
sc.stop()

The timeout is set for 30 seconds, but of course the computation is infinite, and so Awaiting on the result of the future will throw an Exception, which will be caught and then the future will be canceled and the backup function will execute.
This all works perfectly well, except that the canceled job doesn't terminate completely: when looking at the web UI for the application, the job is marked as failed, but I can see there are still running tasks inside.

The same thing happens when I use SparkContext.cancelAllJobs or SparkContext.cancelJobGroup. The problem is that even though I manage to get on with my program, the running tasks of the canceled job are still hogging valuable resources (which will eventually slow me down to a near stop).

To sum things up: How do I kill a Spark job in a way that will also terminate all running tasks of that job? (as opposed to what happens now, which is stopping the job from running new tasks, but letting the currently running tasks finish)

UPDATE:
After a long time ignoring this problem, we found a messy but efficient little workaround. Instead of trying to kill the appropriate Spark Job/Stage from within the Spark application, we simply logged the stage ID of all active stages when the timeout occurred, and issued an HTTP GET request to the URL presented by the Spark Web UI used for killing said stages.

Upvotes: 5

Views: 12288

Answers (4)

Vishwajeet Pol
Vishwajeet Pol

Reputation: 65

With Unity Catalog enabled cluster in shared mode, Spark context is no longer available as the Spark connect feature is introduced which you can read more about here

For interruption to work, there are two methods going forward, you can use addTag, removeTag and inturruptTag or inurruptOperation

sharing code below for your reference which you can modify based on your requirements.

def runQueryWithTag(query: String, tag: String): Unit = {
  try {
    spark.addTag(tag)
    val df = spark.sql(query)
    println(df.count)
  } finally {
    spark.removeTag(tag)
  }
}

import scala.concurrent.{Future, ExecutionContext}
import scala.concurrent.duration._
import ExecutionContext.Implicits.global
import scala.util.{Success, Failure}

val queriesWithTags = Seq(
  ("SELECT * FROM system.information_schema.columns a , system.information_schema.columns b, system.information_schema.columns c", "tag3"),
  ("SELECT * FROM system.information_schema.columns a , system.information_schema.columns b", "tag2"),
  ("SELECT * FROM system.information_schema.columns a , system.information_schema.columns b, system.information_schema.columns c, system.information_schema.columns d, system.information_schema.columns e, system.information_schema.columns f, system.information_schema.columns g", "tag1")
)

val futures = queriesWithTags.map { case (query, tag) =>
  Future { runQueryWithTag(query, tag) }
}

Thread.sleep(30000)
println("Interrupting tag1")
spark.interruptTag("tag1")

OR

import scala.collection.mutable.ListBuffer
val list1 = ListBuffer[String]()

def runQuery(query: String): Unit = {
  val df = spark.sql(query).collectResult()
  val opid = df.operationId
  list1 += opid
  }

import scala.concurrent.{Future, ExecutionContext}
import scala.concurrent.duration._
import ExecutionContext.Implicits.global
import scala.util.{Success, Failure}

val queries = Seq(
  ("SELECT * FROM system.information_schema.columns a , system.information_schema.columns b, system.information_schema.columns c, system.information_schema.columns d, system.information_schema.columns e"),
  ("SELECT * FROM system.information_schema.columns a , system.information_schema.columns b, system.information_schema.columns c, system.information_schema.columns d, system.information_schema.columns e"),
  ("SELECT * FROM system.information_schema.columns a , system.information_schema.columns b, system.information_schema.columns c, system.information_schema.columns d, system.information_schema.columns e")
)

val futures = queries.map { case (query) =>
  Future { runQuery(query) }
}

Thread.sleep(20000)
println("Interrupting query 1 !!!!")
println(list1)
spark.interruptOperation(list1(0))

update: 21st June 2024 Adding a better one with tracking of the tags and terminating individual tags with timeout

val timeoutQueue = new TimeoutQueue()

def runQueryWithTag(query: String, tag: String): Unit = {
  try {
    spark.addTag(tag)
    val df = spark.sql(query)
    println(df.count)
    val a = spark.getTags()
    a.foreach(println)
  } finally {
    println(s"Done with $tag")
    spark.removeTag(tag)
    timeoutQueue.remove(tag)
  }
}

import scala.concurrent.{Future, ExecutionContext}
import scala.concurrent.duration._
import ExecutionContext.Implicits.global
import scala.util.{Success, Failure}
import java.util.concurrent.{Executors, ConcurrentLinkedQueue, PriorityBlockingQueue}

val queriesWithTags = Seq(
  ("SELECT * FROM system.information_schema.columns a , system.information_schema.columns b, system.information_schema.columns c", "tag3"),
  ("SELECT * FROM system.information_schema.columns a , system.information_schema.columns b", "tag2"),
  ("SELECT * FROM system.information_schema.columns a , system.information_schema.columns b, system.information_schema.columns c, system.information_schema.columns d, system.information_schema.columns e, system.information_schema.columns f, system.information_schema.columns g", "tag1")
  // ("SELECT * FROM system.information_schema.columns a", "tag1")
)

// Order them based on the expiration time.
case class TimedTag(tag: String, expirationTime: Long) extends Comparable[TimedTag] {
  override def compareTo(other: TimedTag): Int = expirationTime.compareTo(other.expirationTime)
}

class TimeoutQueue {
  private val queue = new PriorityBlockingQueue[TimedTag]()

  private def popExpired(): Option[String] = {
    val currentTime = System.currentTimeMillis()
    val headOption = Option(queue.peek())
    headOption match {
      case Some(head) if head.expirationTime <= currentTime =>
        queue.poll() // remove the head
        Some(head.tag)
      case _ => None
    }
  }

  def add(tag: String, timeout: FiniteDuration): Unit = {
    val expirationTime = System.currentTimeMillis() + timeout.toMillis
    queue.put(TimedTag(tag, expirationTime))
  }

  def remove(tag: String): Unit = {
    queue.removeIf(_.tag == tag)
  }

  def loop(checkInterval: FiniteDuration): Unit = {
      while (queue.size() > 0) { // while we have tags, wait for them
        popExpired() match {
          case Some(tag) => { println(s"Interrupting $tag"); spark.interruptTag(tag) }
          case None => // No item ready to process
        }
        // println(s"sleeping for $checkInterval")
        Thread.sleep(checkInterval.toMillis)
      }
  }
}

val futures = queriesWithTags.map { case (query, tag) =>
  Future { runQueryWithTag(query, tag) }
}

timeoutQueue.add("tag1", 30000.milliseconds)
timeoutQueue.add("tag2", 55000.milliseconds)
timeoutQueue.add("tag3", 42000.milliseconds)
timeoutQueue.loop(500.milliseconds)

Upvotes: 1

Fabrizio Faber
Fabrizio Faber

Reputation: 51

I don't know if this answers your question. My need was to kill the jobs hanging for longer duration (my jobs extract data from Oracle tables, but for some unknown reason, seldom the connection hangs forever).

After some study, I came to this solution:

import scala.concurrent.Future
import scala.concurrent.ExecutionContext.Implicits.global
import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
import org.apache.spark.JobExecutionStatus

val MAX_JOB_SECONDS = 100
val statusTracker = sc.statusTracker;
val sparkListener = new SparkListener()  
{ 
    
    override def onJobStart(jobStart : SparkListenerJobStart)     
    {
        val jobId = jobStart.jobId
        val f = Future 
        {
            var c = MAX_JOB_SECONDS;
            var mustCancel = false;
            var running = true;
            while(!mustCancel && running)
            {
                Thread.sleep(1000);
                c = c - 1;
                mustCancel = c <= 0;
                val jobInfo = statusTracker.getJobInfo(jobId);
                if(jobInfo!=null)
                {
                    val v = jobInfo.get.status()
                    running = v == JobExecutionStatus.RUNNING
                }
                else
                    running = false;
            }
            if(mustCancel)
            {
              sc.cancelJob(jobId)
            }
        }
    }
}
sc.addSparkListener(sparkListener)
try
{
    val df = spark.sql("SELECT * FROM VERY_BIG_TABLE") //just an example of long-running-job
    println(df.count)
}
catch
{
    case exc: org.apache.spark.SparkException =>
    {
        if(exc.getMessage.contains("cancelled"))
            throw new Exception("Job forcibly cancelled")
        else
            throw exc
    }
    case ex : Throwable => 
    {
        println(s"Another exception: $ex")
    }
}
finally
{
    sc.removeSparkListener(sparkListener)
}

Upvotes: 5

Boris
Boris

Reputation: 491

For the sake of future visitors, Spark introduced the Spark task reaper since 2.0.3, which does address this scenario (more or less) and is a built-in solution. Note that is can kill an Executor eventually, if the task is not responsive.

Moreover, some built-in Spark sources of data have been refactored to be more responsive to spark:

For the 1.6.0 version, Zohar's solution is a "messy but efficient" one.

Upvotes: 3

echo
echo

Reputation: 29

According to setJobGroup:

"If interruptOnCancel is set to true for the job group, then job cancellation will result in Thread.interrupt() being called on the job's executor threads."

So the anno function in your map must be interruptible like this:

val future = sc.parallelize(lst).map(while (!Thread.interrupted) _).collectAsync()

Upvotes: 0

Related Questions