Lin
Lin

Reputation: 2565

AspectJ - Get the thread id of the parent thread that generated a lambda function using aspectJ

I have the following code:

@RequestMapping(method = RequestMethod.GET, path = "/execute")
public @ResponseBody
String execute(@RequestParam("name") String input) throws Exception {

    ExecutorService executor = Executors.newSingleThreadExecutor();
    executor.execute(() -> {
        try {
            // Do something...
        } catch (IOException e) {
            e.printStackTrace();
            throw new RuntimeException(e);
        }
    });
    return "lambda call";
}

I want using aspectJ to catch the lambda function execution and identify the id of the thread that generated it - the thread in which my "execute" function run in. I know how to catch the lambda function -

execution(void lambda$*(..)

But this too late for me to identify the thread id that created this thread (the one that called "execute"), because the lambda runs in a new thread. How can I get the "parent" thread id/the "execute" thread id?

Upvotes: 0

Views: 382

Answers (3)

kriegaex
kriegaex

Reputation: 67417

Okay, after some discussion with the OP I decided to publish what I had come up with a few months ago. I first wanted to wait for the OP's own solution, but now that I have seen it I still do not understand what exactly it does and how it does it. Maybe back in March when I was still fully immersed in the topic I would have understood better. Anyway, here is my solution.

Driver application:

package de.scrum_master.app;

import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.stream.IntStream;

public class Application {
  private static final String SEPARATOR_LINE = "\n----- main -----\n";

  public static void main(String[] args) {
    nestedSingleThreadExecutorLambda(3);
    nestedSingleThreadExecutorRunnable(3);
    nestedThreadsLambda(3);
    nestedThreadsRunnable(3);
    fixedThreadPoolExecutorLambda(7, 3);
    fixedThreadPoolExecutorRunnable(7, 3);

    negativeTestLambda(3);
    negativeTestRunnable(3);
    negativeTestStreamLambda(3);

    // Quick & dirty alternative to cleanly terminating all executors,
    // used for the sake of brevity in this example.
    System.gc();
  }

  private static void nestedSingleThreadExecutorLambda(int iterations) {
    println(SEPARATOR_LINE);
    for (int i = 0; i < iterations; i++) {
      Executors.newSingleThreadExecutor().execute(() -> {
        println("  executor level 1");
        Executors.newSingleThreadExecutor().execute(() -> {
          println("    executor level 2");
          Executors.newSingleThreadExecutor().execute(() -> {
            println("      executor level 3");
          });
          sleep(100);
        });
        sleep(100);
      });
      sleep(100);
    }
  }


  private static void nestedSingleThreadExecutorRunnable(int iterations) {
    println(SEPARATOR_LINE);
    for (int i = 0; i < iterations; i++) {
      Executors.newSingleThreadExecutor().execute(new Runnable() {
        @Override public void run() {
          println("  executor level 1");
          Executors.newSingleThreadExecutor().execute(new Runnable() {
            @Override public void run() {
              println("    executor level 2");
              Executors.newSingleThreadExecutor().execute(new Runnable() {
                @Override public void run() {
                  println("      executor level 3");
                }
              });
              sleep(100);
            }
          });
          sleep(100);
        }
      });
      sleep(100);
    }
  }

  private static void nestedThreadsLambda(int iterations) {
    println(SEPARATOR_LINE);
    for (int i = 0; i < iterations; i++) {
      new Thread(() -> {
        println("  thread level 1");
        new Thread(() -> {
          println("    thread level 2");
          new Thread(() -> {
            println("      thread level 3");
          }).start();
          sleep(100);
        }).start();
        sleep(100);
      }).start();
      sleep(100);
    }
  }

  private static void nestedThreadsRunnable(int iterations) {
    println(SEPARATOR_LINE);
    for (int i = 0; i < iterations; i++) {
      new Thread(new Runnable() {
        @Override public void run() {
          println("  thread level 1");
          new Thread(new Runnable() {
            @Override public void run() {
              println("    thread level 2");
              new Thread(new Runnable() {
                @Override public void run() {
                  println("      thread level 3");
                }
              }).start();
              sleep(100);
            }
          }).start();
          sleep(100);
        }
      }).start();
      sleep(100);
    }
  }

  private static void fixedThreadPoolExecutorLambda(int iterations, int poolSize) {
    println(SEPARATOR_LINE);
    ExecutorService executor = Executors.newFixedThreadPool(poolSize);
    for (int i = 0; i < iterations; i++) {
      executor.execute(() ->
        println("  fixed thread pool (size " + poolSize + ") executor")
      );
      sleep(100);
    }
    executor.shutdown();
    try {
      executor.awaitTermination(100, TimeUnit.MILLISECONDS);
    } catch (InterruptedException e) {
      e.printStackTrace();
    }
  }

  private static void fixedThreadPoolExecutorRunnable(int iterations, int poolSize) {
    println(SEPARATOR_LINE);
    ExecutorService executor = Executors.newFixedThreadPool(poolSize);
    for (int i = 0; i < iterations; i++) {
      executor.execute(new Runnable() {
        @Override public void run() {
          println("  fixed thread pool (size " + poolSize + ") executor");
        }
      });
      sleep(100);
    }
    executor.shutdown();
    try {
      executor.awaitTermination(100, TimeUnit.MILLISECONDS);
    } catch (InterruptedException e) {
      e.printStackTrace();
    }
  }

  private static void negativeTestLambda(int iterations) {
    println(SEPARATOR_LINE);
    for (int i = 0; i < iterations; i++)
      ((Runnable) (() -> println("  Lambda negative test"))).run();
  }

  private static void negativeTestRunnable(int iterations) {
    println(SEPARATOR_LINE);
    for (int i = 0; i < iterations; i++)
      new Runnable() {
        @Override public void run() {
          println("  Runnable negative test");
        }
      }.run();
  }

  private static void negativeTestStreamLambda(int iterations) {
    println(SEPARATOR_LINE);
    // Lambda used in 'forEach'
    IntStream.range(0, iterations)
      .forEach((int number) -> println("  Stream forEach negative test"));
  }

  private static void sleep(long millis) {
    try {
      Thread.sleep(millis);
    } catch (InterruptedException e) {
      e.printStackTrace();
    }
  }

  private static void println(String message) {
    System.out.println(message);
  }
}

Thread tracker helper class:

package de.scrum_master.aspect;

import java.util.HashMap;
import java.util.Map;

public class ThreadTracker extends InheritableThreadLocal<Thread> {

  // ----------------------------
  // Keep track of parent threads
  // ----------------------------

  private InheritableThreadLocal<Thread> parentThreadTL = new InheritableThreadLocal<>();
  private Map<Thread, Thread> parentThreads = new HashMap<>();

  public Thread getParent() {
    return getParent(Thread.currentThread());
  }

  public Thread getParent(Thread childThread) {
    if (childThread == Thread.currentThread()) {
      Thread parentThread = parentThreadTL.get();
      if (childThread != parentThread && !parentThreads.containsKey(childThread))
        parentThreads.put(childThread, parentThread);
    }
    return parentThreads.get(childThread);
  }

  // ----------------------------------
  // InheritableThreadLocal customising
  // ----------------------------------

  @Override
  protected Thread initialValue() {
    return Thread.currentThread();
  }

  @Override
  protected Thread childValue(Thread parentValue) {
    parentThreadTL.set(parentValue);
    return super.childValue(parentValue);
  }

  @Override
  public Thread get() {
    set(Thread.currentThread());
    return super.get();
  }

  // ----------------
  // toString() stuff
  // ----------------

  @Override
  public String toString() {
    return toShortString();
  }

  public String toShortString() {
    Thread parentThread = getParent();
    return "ThreadTracker[" +
      "child=" + get().getId() + ", " +
      "parent=" + (parentThread == null ? null : parentThread.getId()) +
    "]";
  }

  public String toLongString() {
    return "ThreadTracker[" +
      "child=" + get() + ", " +
      "parent=" + getParent() +
    "]";
  }

}

Aspect:

package de.scrum_master.aspect;

import de.scrum_master.app.Application;

public aspect MyAspect {
  private static ThreadTracker threadTracker = new ThreadTracker();
  
  // Make sure this aspect's ThreadTracker instance gets initialised
  // before it is used the first time. The easiest way to do that is
  // to access its value once before the main class gets statically
  // initialised (i.e. during class-loading already).
  before() : staticinitialization(Application) {
    threadTracker.get();
  }

  // Intercept lambda executions
  before() : if(isStartedThread()) && execution(private * lambda$*(..)) {
    System.out.println(threadTracker + " | " + thisJoinPoint);
  }

  // Intercept Runnable executions
  before() : if(isStartedThread()) && execution(public void Runnable.run()) {
    System.out.println(threadTracker + " | " + thisJoinPoint);
  }

  // Heuristic for matching runnables and lambdas being used to start threads
  protected static boolean isStartedThread() {
    StackTraceElement[] stackTrace = new Exception("dummy").getStackTrace();
    if (stackTrace.length < 4)
      return false;
    String targetMethodName = stackTrace[3].toString();
    return (
      targetMethodName.startsWith("java.util.concurrent.") ||
      targetMethodName.startsWith("java.lang.Thread.run(")
    );
  }
  // Intercept public Application method executions
  before() : execution(public * Application.*(..)) {
    System.out.println(threadTracker + " | " + thisJoinPoint);
  }

  // Debug ThreadTracker class
  /*
  before() : within(ThreadTracker) && execution(* *(..)) {
    System.out.println("  >> " + Thread.currentThread().getId() + " | " + thisJoinPoint);
  }
  */
}

Console log:

ThreadTracker[child=1, parent=null] | execution(void de.scrum_master.app.Application.main(String[]))

----- main -----

ThreadTracker[child=10, parent=1] | execution(void de.scrum_master.app.Application.lambda$0())
  executor level 1
ThreadTracker[child=11, parent=10] | execution(void de.scrum_master.app.Application.lambda$1())
    executor level 2
ThreadTracker[child=12, parent=11] | execution(void de.scrum_master.app.Application.lambda$2())
      executor level 3
ThreadTracker[child=13, parent=1] | execution(void de.scrum_master.app.Application.lambda$0())
  executor level 1
ThreadTracker[child=14, parent=13] | execution(void de.scrum_master.app.Application.lambda$1())
    executor level 2
ThreadTracker[child=15, parent=14] | execution(void de.scrum_master.app.Application.lambda$2())
      executor level 3
ThreadTracker[child=16, parent=1] | execution(void de.scrum_master.app.Application.lambda$0())
  executor level 1
ThreadTracker[child=17, parent=16] | execution(void de.scrum_master.app.Application.lambda$1())
    executor level 2
ThreadTracker[child=18, parent=17] | execution(void de.scrum_master.app.Application.lambda$2())
      executor level 3

----- main -----

ThreadTracker[child=19, parent=1] | execution(void de.scrum_master.app.Application.1.run())
  executor level 1
ThreadTracker[child=20, parent=19] | execution(void de.scrum_master.app.Application.1.1.run())
    executor level 2
ThreadTracker[child=21, parent=20] | execution(void de.scrum_master.app.Application.1.1.1.run())
      executor level 3
ThreadTracker[child=22, parent=1] | execution(void de.scrum_master.app.Application.1.run())
  executor level 1
ThreadTracker[child=23, parent=22] | execution(void de.scrum_master.app.Application.1.1.run())
    executor level 2
ThreadTracker[child=24, parent=23] | execution(void de.scrum_master.app.Application.1.1.1.run())
      executor level 3
ThreadTracker[child=25, parent=1] | execution(void de.scrum_master.app.Application.1.run())
  executor level 1
ThreadTracker[child=26, parent=25] | execution(void de.scrum_master.app.Application.1.1.run())
    executor level 2
ThreadTracker[child=27, parent=26] | execution(void de.scrum_master.app.Application.1.1.1.run())
      executor level 3

----- main -----

ThreadTracker[child=28, parent=1] | execution(void de.scrum_master.app.Application.lambda$3())
  thread level 1
ThreadTracker[child=29, parent=28] | execution(void de.scrum_master.app.Application.lambda$4())
    thread level 2
ThreadTracker[child=30, parent=29] | execution(void de.scrum_master.app.Application.lambda$5())
      thread level 3
ThreadTracker[child=31, parent=1] | execution(void de.scrum_master.app.Application.lambda$3())
  thread level 1
ThreadTracker[child=32, parent=31] | execution(void de.scrum_master.app.Application.lambda$4())
    thread level 2
ThreadTracker[child=33, parent=32] | execution(void de.scrum_master.app.Application.lambda$5())
      thread level 3
ThreadTracker[child=34, parent=1] | execution(void de.scrum_master.app.Application.lambda$3())
  thread level 1
ThreadTracker[child=35, parent=34] | execution(void de.scrum_master.app.Application.lambda$4())
    thread level 2
ThreadTracker[child=36, parent=35] | execution(void de.scrum_master.app.Application.lambda$5())
      thread level 3

----- main -----

ThreadTracker[child=37, parent=1] | execution(void de.scrum_master.app.Application.2.run())
  thread level 1
ThreadTracker[child=38, parent=37] | execution(void de.scrum_master.app.Application.2.1.run())
    thread level 2
ThreadTracker[child=39, parent=38] | execution(void de.scrum_master.app.Application.2.1.1.run())
      thread level 3
ThreadTracker[child=40, parent=1] | execution(void de.scrum_master.app.Application.2.run())
  thread level 1
ThreadTracker[child=41, parent=40] | execution(void de.scrum_master.app.Application.2.1.run())
    thread level 2
ThreadTracker[child=42, parent=41] | execution(void de.scrum_master.app.Application.2.1.1.run())
      thread level 3
ThreadTracker[child=43, parent=1] | execution(void de.scrum_master.app.Application.2.run())
  thread level 1
ThreadTracker[child=44, parent=43] | execution(void de.scrum_master.app.Application.2.1.run())
    thread level 2
ThreadTracker[child=45, parent=44] | execution(void de.scrum_master.app.Application.2.1.1.run())
      thread level 3

----- main -----

ThreadTracker[child=46, parent=1] | execution(void de.scrum_master.app.Application.lambda$6(int))
  fixed thread pool (size 3) executor
ThreadTracker[child=47, parent=1] | execution(void de.scrum_master.app.Application.lambda$6(int))
  fixed thread pool (size 3) executor
ThreadTracker[child=48, parent=1] | execution(void de.scrum_master.app.Application.lambda$6(int))
  fixed thread pool (size 3) executor
ThreadTracker[child=46, parent=1] | execution(void de.scrum_master.app.Application.lambda$6(int))
  fixed thread pool (size 3) executor
ThreadTracker[child=47, parent=1] | execution(void de.scrum_master.app.Application.lambda$6(int))
  fixed thread pool (size 3) executor
ThreadTracker[child=48, parent=1] | execution(void de.scrum_master.app.Application.lambda$6(int))
  fixed thread pool (size 3) executor
ThreadTracker[child=46, parent=1] | execution(void de.scrum_master.app.Application.lambda$6(int))
  fixed thread pool (size 3) executor

----- main -----

ThreadTracker[child=49, parent=1] | execution(void de.scrum_master.app.Application.3.run())
  fixed thread pool (size 3) executor
ThreadTracker[child=50, parent=1] | execution(void de.scrum_master.app.Application.3.run())
  fixed thread pool (size 3) executor
ThreadTracker[child=51, parent=1] | execution(void de.scrum_master.app.Application.3.run())
  fixed thread pool (size 3) executor
ThreadTracker[child=49, parent=1] | execution(void de.scrum_master.app.Application.3.run())
  fixed thread pool (size 3) executor
ThreadTracker[child=50, parent=1] | execution(void de.scrum_master.app.Application.3.run())
  fixed thread pool (size 3) executor
ThreadTracker[child=51, parent=1] | execution(void de.scrum_master.app.Application.3.run())
  fixed thread pool (size 3) executor
ThreadTracker[child=49, parent=1] | execution(void de.scrum_master.app.Application.3.run())
  fixed thread pool (size 3) executor

----- main -----

  Lambda negative test
  Lambda negative test
  Lambda negative test

----- main -----

  Runnable negative test
  Runnable negative test
  Runnable negative test

----- main -----

  Stream forEach negative test
  Stream forEach negative test
  Stream forEach negative test

Upvotes: 1

Lin
Lin

Reputation: 2565

In order to solve this I used InheritableThreadLocal. I have no access to the actual application so everything is done via aspectJ. I do it by first saving the thread is that is going to execute the lambda. Then I catch the lambda and if it was generated by a different thread I'll get the id in the InheritableThreadLocal context. To allow this to work for a thread that generates a thread that generates a thread etc. and not just one level we need to modify the solution to store 2 values in the InheritedTraceContext - prev and curr ids and remember to override them

  1. A class that defines the inherited trace context data

     public class InheritedTraceContext {
    
         public long threadParentId = -1;
    
         // Add public long currThreadParentId = -1; to support a thread that generates a thread that generates a thread etc.
    
    
    
     }
    
  2. A singleton class that stores the relevant inherited trace context instance

     public class TraceContextFactory {
    
    
       public static final ThreadLocal<InheritedTraceContext> inheritedTraceContext =
             new InheritableThreadLocal<InheritedTraceContext>() {
                 @Override
                 protected InheritedTraceContext initialValue() {
                     return new InheritedTraceContext();
                 }
             };
    
       public static final TraceContextFactory factory = new TraceContextFactory();
    
    
     }
    
  3. An aspect class that does the actual work - first the execute method is caught, then the lambda function itself, we can also catch the 'after' method to support operations that needs be done:

    public aspect ThreadContextAspect {

     before(): (call(public synchronized void java.lang.Thread+.start())
             || call(* java.util.concurrent.Executor+.execute(Runnable))
             || call(* java.util.concurrent.ExecutorService+.submit(Runnable, ..))
             || call(* java.util.concurrent.ExecutorService+.submit(java.util.concurrent.Callable, ..))) {
             long parentThreadId = Thread.currentThread().getId();
             InheritedTraceContext inheritedTraceContext = TraceContextFactory.getFactory().inheritedTraceContext().get()
    
             if (inheritedTraceContext.threadParentId == -1) { 
                // -1 means that this is the first parent
                // If it's not -1 this means the data is already the ROOT's data (for example if we call 2 lambda one after the other - so we don't need to modify anything)                   
                inheritedTraceContext.threadParentId = parentThreadId;
    
             }
             // Note that if the threadParentId is not -1 and threadParentId != parentThreadId then this is not the thread that started the entire process. If we need to track this info we need to use prevParentThreadId and currentParentId and not just store one id field.
     }
    
     before(): execution(void lambda$*(..)) {
           InheritedTraceContext inheritedTraceContext = TraceContextFactory.getFactory().inheritedTraceContext.get();
           long parentThreadId = inheritedTraceContext.threadParentId;
    
           /*
           if (parentThreadId != -1) { // The origin of this thread is from a different one. In case it was -1 this means that the lambda was executed by the same thread it runs on.
    
              ...
           */
     }
    

    }

Upvotes: 1

kriegaex
kriegaex

Reputation: 67417

You have several problems here:

  • AspectJ currently cannot weave into lambdas via execution() pointcut. This is manily due to the fact that the JVM instruction invokedynamic is being ignored by the AspectJ compiler/weaver. See also AspectJ tickets #471347 (created by myself) and #364886. Besides, if you use an anonymous Runnable class instead, you can easily intercept it.

  • You are not creating and starting the thread by yourself but deferring that to JDK classes and methods like ExecutorService.execute(Runnable), i.e. you also cannot weave into their execution(), only into their call() made from your own (aspect-woven) code.

  • In Java there is no general concept like "parent threads" which you could easily determine from an executing thread via a fictitious method like Thread.getParent() or similar. There is some parent stuff implemented for thread groups, but that does not help you here.

So what you are left with is an indirect way like this:

Driver application:

package de.scrum_master.app;

import java.io.IOException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

public class Application {
  String execute(String input) throws Exception {
    ExecutorService executor = Executors.newSingleThreadExecutor();

    executor.execute(() -> {
      try {
        doSomething();
      } catch (IOException e) {
        e.printStackTrace();
        throw new RuntimeException(e);
      }
    });

    return "lambda call";
  }

  private void doSomething() throws IOException {}

  public static void main(String[] args) throws Exception {
    new Application().execute("dummy");
  }
}

Aspect:

package de.scrum_master.aspect;

import java.util.concurrent.ExecutorService;

public aspect MyAspect {
  // Catch-all advice for logging purposes
  before() : !within(MyAspect) {
    System.out.println("  " + thisJoinPoint);
  }

  // Intercept calls to ExecutorService.execute(*)
  before(Runnable runnable) : call(void ExecutorService.execute(*)) && args(runnable) {
    System.out.println(Thread.currentThread() + " | " + thisJoinPoint + " -> " + runnable);
  }

  // Intercept lambda executions
  before() : execution(private void lambda$*(..)) {
    System.out.println(Thread.currentThread() + " | " + thisJoinPoint);
  }
}

Console log:

  staticinitialization(de.scrum_master.app.Application.<clinit>)
  execution(void de.scrum_master.app.Application.main(String[]))
  call(de.scrum_master.app.Application())
  preinitialization(de.scrum_master.app.Application())
  initialization(de.scrum_master.app.Application())
  execution(de.scrum_master.app.Application())
  call(String de.scrum_master.app.Application.execute(String))
  execution(String de.scrum_master.app.Application.execute(String))
  call(ExecutorService java.util.concurrent.Executors.newSingleThreadExecutor())
  call(void java.util.concurrent.ExecutorService.execute(Runnable))
Thread[main,5,main] | call(void java.util.concurrent.ExecutorService.execute(Runnable)) -> de.scrum_master.app.Application$$Lambda$1/2046562095@2dda6444
  execution(void de.scrum_master.app.Application.lambda$0())
Thread[pool-1-thread-1,5,main] | execution(void de.scrum_master.app.Application.lambda$0())
  call(void de.scrum_master.app.Application.doSomething())
  execution(void de.scrum_master.app.Application.doSomething())

Upvotes: 1

Related Questions