Sunday, December 16, 2018

A brief overview of the Fork/Join Framework in Java

Introduction

The Fork/Join framework is a framework to solve a problem using a concurrent divide-and-conquer approach. They were introduced to complement the existing concurrency API. Before their introduction, the existing ExecutorService implementations were the popular choice to run asynchronous tasks, but they work best when the tasks are homogenous and independent. Running dependent tasks and combining their results using those implementations were not easy. With the introduction of the Fork/Join framework, an attempt was made to address this shortcoming. In this post, we will take a brief look at the API and solve a couple of simple problems to understand how they work.

Solving a non-blocking task

Let's jump directly into code. Let's create a task which would return the sum of all elements of a List. The following steps represent our algorithm in pseudo-code:

01. Find the middle index of the list
02. Divide the list in the middle
03. Recursively create a new task which will compute the sum of the left part
04. Recursively create a new task which will compute the sum of the right part
05. Add the result of the left sum, the middle element, and the right sum

Here is the code -

@Slf4j
public class ListSummer extends RecursiveTask<Integer> {
  private final List<Integer> listToSum;

  ListSummer(List<Integer> listToSum) {
    this.listToSum = listToSum;
  }

  @Override
  protected Integer compute() {
    if (listToSum.isEmpty()) {
      log.info("Found empty list, sum is 0");
      return 0;
    }

    int middleIndex = listToSum.size() / 2;
    log.info("List {}, middle Index: {}", listToSum, middleIndex);

    List<Integer> leftSublist = listToSum.subList(0, middleIndex);
    List<Integer> rightSublist = listToSum.subList(middleIndex + 1, listToSum.size());

    ListSummer leftSummer = new ListSummer(leftSublist);
    ListSummer rightSummer = new ListSummer(rightSublist);

    leftSummer.fork();
    rightSummer.fork();

    Integer leftSum = leftSummer.join();
    Integer rightSum = rightSummer.join();
    int total = leftSum + listToSum.get(middleIndex) + rightSum;
    log.info("Left sum is {}, right sum is {}, total is {}", leftSum, rightSum, total);

    return total;
  }
}

Firstly, we extend the RecursiveTask subtype of the ForkJoinTask. This is the type to extend from when we expect our concurrent task to return a result. When a task does not return a result but only perform an effect, we extend the RecursiveAction subtype. For most of the practical tasks that we solve, these two subtypes are sufficient.

Secondly, both RecursiveTask and RecursiveAction define an abstract compute method. This is where we put our computation.

Thirdly, inside our compute method, we check the size of the list that is passed through the constructor. If it is empty, we already know the result of the sum which is zero, and we return immediately. Otherwise, we divide our lists into two sublists and create two instances of our ListSummer type. We then call the fork() method (defined in ForkJoinTask) on these two instances -

leftSummer.fork();
rightSummer.fork();

Which cause these tasks to be scheduled for asynchronous execution, the exact mechanism which is used for this purpose will be explained later in this post.

After that, we invoke the join() method (also defined in ForkJoinTask) to wait for the result of these two parts -

Integer leftSum = leftSummer.join();
Integer rightSum = rightSummer.join();

Which are then summed with the middle element of the list to get the final result.

Plenty of log messages have been added to make the example easier to understand. However, when we process a list containing thousands of entries, it might not be a good idea to have this detailed logging, especially logging the entire list.

That's pretty much it. Let's create a test class now for a test run -

public class ListSummerTest {

  @Test
  public void shouldSumEmptyList() {
    ListSummer summer = new ListSummer(List.of());
    ForkJoinPool forkJoinPool = new ForkJoinPool();
    forkJoinPool.submit(summer);

    int result = summer.join();

    assertThat(result).isZero();
  }

  @Test
  public void shouldSumListWithOneElement() {
    ListSummer summer = new ListSummer(List.of(5));
    ForkJoinPool forkJoinPool = new ForkJoinPool();
    forkJoinPool.submit(summer);

    int result = summer.join();

    assertThat(result).isEqualTo(5);
  }

  @Test
  public void shouldSumListWithMultipleElements() {
    ListSummer summer = new ListSummer(List.of(
        1, 2, 3, 4, 5, 6, 7, 8, 9
    ));
    ForkJoinPool forkJoinPool = new ForkJoinPool();
    forkJoinPool.submit(summer);

    int result = summer.join();

    assertThat(result).isEqualTo(45);
  }
}

In the test, we create an instance of the ForkJoinPool. A ForkJoinPool is a unique ExecutorService implementation for running ForkJoinTasks. It employs a special algorithm known as the work-stealing algorithm. Contrary to the other ExecutorService implementations where there is only a single queue holding all the tasks to be executed, in a work-stealing implementation, each worker thread gets its work queue. Each thread starts executing tasks from their queue. When we detect that a ForkJoinTask can be broken down into multiple smaller subtasks, we do break them into smaller tasks, and then we invoke the fork() method on those tasks. This invocation causes the subtasks to be pushed into the executing thread's queue. During the execution, when one thread exhausts its queue/has no tasks to execute, it can "steal" tasks from other thread's queue (hence the name "work-stealing"). This stealing behaviour is what results in a better throughput than using any other ExecutorService implementations.

Earlier, when we invoked fork() on our leftSummer and rightSummer task instances, they got pushed into the work queue of the executing thread, after which they were "stolen" by other active threads in the pool (and so on) since they did not have anything else to do at that point.

Pretty cool, right?

Solving a blocking task

The problem we solved just now is non-blocking in nature. If we want to solve a problem which does some blocking operation, then to have a better throughput we will need to change our strategy.

Let's examine this with another example. Let's say we want to create a very simple web crawler. This crawler will receive a list of HTTP links, execute GET requests to fetch the response bodies, and then calculate the response length. Here is the code -

@Slf4j
public class ResponseLengthCalculator extends RecursiveTask<Map<String, Integer>> {
  private final List<String> links;

  ResponseLengthCalculator(List<String> links) {
    this.links = links;
  }

  @Override
  protected Map<String, Integer> compute() {
    if (links.isEmpty()) {
      log.info("No more links to fetch");
      return Collections.emptyMap();
    }

    int middle = links.size() / 2;
    log.info("Middle index: {}", links, middle);
    ResponseLengthCalculator leftPartition = new ResponseLengthCalculator(links.subList(0, middle));
    ResponseLengthCalculator rightPartition = new ResponseLengthCalculator(links.subList(middle + 1, links.size()));

    log.info("Forking left partition");
    leftPartition.fork();
    log.info("Left partition forked, now forking right partition");
    rightPartition.fork();
    log.info("Right partition forked");

    String middleLink = links.get(middle);
    HttpRequester httpRequester = new HttpRequester(middleLink);
    String response;
    try {
      log.info("Calling managedBlock for {}", middleLink);
      ForkJoinPool.managedBlock(httpRequester);
      response = httpRequester.response;
    } catch (InterruptedException ex) {
      log.error("Error occurred while trying to implement blocking link fetcher", ex);
      response = "";
    }

    Map<String, Integer> responseMap = new HashMap<>(links.size());

    Map<String, Integer> leftLinks = leftPartition.join();
    responseMap.putAll(leftLinks);
    responseMap.put(middleLink, response.length());
    Map<String, Integer> rightLinks = rightPartition.join();
    responseMap.putAll(rightLinks);

    log.info("Left map {}, middle length {}, right map {}", leftLinks, response.length(), rightLinks);

    return responseMap;
  }

  private static class HttpRequester implements ForkJoinPool.ManagedBlocker {
    private final String link;
    private String response;

    private HttpRequester(String link) {
      this.link = link;
    }

    @Override
    public boolean block() {
      HttpGet headRequest = new HttpGet(link);
      CloseableHttpClient client = HttpClientBuilder
          .create()
          .disableRedirectHandling()
          .build();
      
      log.info("Executing blocking request for {}", link);
      
      try (client; CloseableHttpResponse response = client.execute(headRequest)) {
        log.info("HTTP request for link {} has been executed", link);
        this.response = EntityUtils.toString(response.getEntity());
      } catch (IOException e) {
        log.error("Error while trying to fetch response from link {}: {}", link, e.getMessage());
        this.response = "";
      }
      return true;
    }

    @Override
    public boolean isReleasable() {
      return false;
    }
  }
}

We create an implementation of the ForkJoinPool.ManagedBlocker where we put the blocking HTTP call. This interface defines two methods - block() and isReleasable(). The block() method is where we put our blocking call. After we are done with our blocking operation, we return true indicating that no further blocking is necessary. We return false from the isReleasable() implementation to indicate to a fork-join worker thread that the block() method implementation is potentially blocking in nature. The isReleasable() implementation will be invoked by a fork-join worker thread first before it invokes the block() method. Finally, we submit our  HttpRequester instance to our pool by invoking ForkJoinPool.managedBlock() static method. After that our blocking task will start executing. When it blocks on the HTTP request, the ForkJoinPool.managedBlock() method will also arrange for a spare thread to be activated if necessary to ensure sufficient parallelism.

Let's take this implementation for a test drive then! Here's the code -

public class ResponseLengthCalculatorTest {

  @Test
  public void shouldReturnEmptyMapForEmptyList() {
    ResponseLengthCalculator responseLengthCalculator = new ResponseLengthCalculator(Collections.emptyList());
    ForkJoinPool pool = new ForkJoinPool();

    pool.submit(responseLengthCalculator);

    Map<String, Integer> result = responseLengthCalculator.join();
    assertThat(result).isEmpty();
  }

  @Test
  public void shouldHandle200Ok() {
    ResponseLengthCalculator responseLengthCalculator = new ResponseLengthCalculator(List.of(
        "http://httpstat.us/200"
    ));
    ForkJoinPool pool = new ForkJoinPool();

    pool.submit(responseLengthCalculator);

    Map<String, Integer> result = responseLengthCalculator.join();
    assertThat(result)
        .hasSize(1)
        .containsKeys("http://httpstat.us/200")
        .containsValue(0);
  }

  @Test
  public void shouldFetchResponseForDifferentResponseStatus() {
    ResponseLengthCalculator responseLengthCalculator = new ResponseLengthCalculator(List.of(
        "http://httpstat.us/200",
        "http://httpstat.us/302",
        "http://httpstat.us/404",
        "http://httpstat.us/502"
    ));
    ForkJoinPool pool = new ForkJoinPool();

    pool.submit(responseLengthCalculator);

    Map<String, Integer> result = responseLengthCalculator.join();
    assertThat(result)
        .hasSize(4);
  }
}

That's it for today, folks! As always, any feedback/improvement suggestions/comments are highly appreciated!

All the examples discussed here can be found on Github.

A big shout out to the awesome http://httpstat.us service, it was quite helpful for developing the simple tests.

Tuesday, May 1, 2018

Java Tips: Creating a Monitoring-friendly ExecutorService

In this article we will be extending an ExecutorService implementation with monitoring capabilities. This monitoring capability will help us to measure a number of pool parameters i.e., active threads, work queue size etc. in a live production environment. It will also enable us to measure task execution time, successful tasks count, and failed tasks count.

Monitoring Library

As for the monitoring library we will be using Metrics. For the sake of simplicity we will be using a ConsoleReporter which will report our metrics to the console. For production-grade applications, we should use an advanced reporter (i.e., Graphite reporter). If you are unfamiliar with Metrics, then I recommend you to go through the getting started guide.

Let's get started.

Extending the ThreadPoolExecutor

We will be using ThreadPoolExecutor as the base class for our new type. Let's call it MonitoredThreadPoolExecutor. This class will accept a MetricRegistry as one of its constructor parameters -

public class MonitoredThreadPoolExecutor extends ThreadPoolExecutor {
  private final MetricRegistry metricRegistry;

  public MonitoredThreadPoolExecutor(
      int corePoolSize,
      int maximumPoolSize,
      long keepAliveTime,
      TimeUnit unit,
      BlockingQueue<Runnable> workQueue,
      MetricRegistry metricRegistry
  ) {
    super(corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue);
    this.metricRegistry = metricRegistry;
  }

  public MonitoredThreadPoolExecutor(
      int corePoolSize,
      int maximumPoolSize,
      long keepAliveTime,
      TimeUnit unit,
      BlockingQueue<Runnable> workQueue,
      ThreadFactory threadFactory,
      MetricRegistry metricRegistry
  ) {
    super(corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue, threadFactory);
    this.metricRegistry = metricRegistry;
  }

  public MonitoredThreadPoolExecutor(
      int corePoolSize,
      int maximumPoolSize,
      long keepAliveTime,
      TimeUnit unit,
      BlockingQueue<Runnable> workQueue,
      RejectedExecutionHandler handler,
      MetricRegistry metricRegistry
  ) {
    super(corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue, handler);
    this.metricRegistry = metricRegistry;
  }

  public MonitoredThreadPoolExecutor(
      int corePoolSize,
      int maximumPoolSize,
      long keepAliveTime,
      TimeUnit unit,
      BlockingQueue<Runnable> workQueue,
      ThreadFactory threadFactory,
      RejectedExecutionHandler handler,
      MetricRegistry metricRegistry
  ) {
    super(corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue, threadFactory, handler);
    this.metricRegistry = metricRegistry;
  }
}

Registering Gauges to measure pool-specific paramters

A Gauge is an instantaneous measurement of a value. We will be using it to measure different pool parameters like number of active threads, task queue size etc.

Before we can register a Gauge, we need to decide how to calculate a metric name for our thread pool. Each metric, whether it's a Gauge, or a Timer, or simply a Meter, has a unique name. This name is used to identify the metric source. The convention here is to use a dotted string which is often constructed from the fully qualified name of the class being monitored.

For our thread pool, we will be using its fully qualified name as a prefix to our metrics names. Additionally we will add another constructor parameter called poolName, which will be used by the clients to specify instance-specific identifiers.

After implementing these changes the class looks like below -

public class MonitoredThreadPoolExecutor extends ThreadPoolExecutor {
  private final MetricRegistry metricRegistry;
  private final String metricsPrefix;

  public MonitoredThreadPoolExecutor(
      int corePoolSize,
      int maximumPoolSize,
      long keepAliveTime,
      TimeUnit unit,
      BlockingQueue<Runnable> workQueue,
      MetricRegistry metricRegistry,
      String poolName
  ) {
    super(corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue);
    this.metricRegistry = metricRegistry;
    this.metricsPrefix = MetricRegistry.name(getClass(), poolName);
  }

  // Rest of the constructors
}

Now we are ready to register our Gauges. For this purpose we will define a private method -

private void registerGauges() {
  metricRegistry.register(MetricRegistry.name(metricsPrefix, "corePoolSize"), (Gauge<Integer>) this::getCorePoolSize);
  metricRegistry.register(MetricRegistry.name(metricsPrefix, "activeThreads"), (Gauge<Integer>) this::getActiveCount);
  metricRegistry.register(MetricRegistry.name(metricsPrefix, "maxPoolSize"), (Gauge<Integer>) this::getMaximumPoolSize);
  metricRegistry.register(MetricRegistry.name(metricsPrefix, "queueSize"), (Gauge<Integer>) () -> getQueue().size());
}

For our example we are measuring core pool size, number of active threads, maximum pool size, and task queue size. Depending on monitoring requirements we can register more/less Gauges to measure different properties.

This private method will now be invoked from all constructors -

public MonitoredThreadPoolExecutor(
    int corePoolSize,
    int maximumPoolSize,
    long keepAliveTime,
    TimeUnit unit,
    BlockingQueue<Runnable> workQueue,
    MetricRegistry metricRegistry,
    String poolName
) {
  super(corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue);
  this.metricRegistry = metricRegistry;
  this.metricsPrefix = MetricRegistry.name(getClass(), poolName);
  registerGauges();
}

Measuring Task Execution Time

To measure the task execution time, we will override two life-cycle methods that ThreadPoolExecutor provides - beforeExecute and afterExecute.

As the name implies, beforeExecute callback is invoked prior to executing a task, by the thread that will execute the task. The default implementation of this callback does nothing.

Similarly, the afterExecute callback is invoked after each task is executed, by the thread that executed the task. The default implementation of this callback also does nothing. Even if the task throws an uncaught RuntimeException or Error, this callback will be invoked.

We will be starting a Timer in our beforeExecute override, which will then be used in our afterExecute override to get the total task execution time. To store a reference to the Timer we will introduce a new ThreadLocal field in our class.

The implementation of the callbacks are given below -

public class MonitoredThreadPoolExecutor extends ThreadPoolExecutor {
  private final MetricRegistry metricRegistry;
  private final String metricsPrefix;
  private ThreadLocal<Timer.Context> taskExecutionTimer = new ThreadLocal<>();

  // Constructors

  @Override
  protected void beforeExecute(Thread thread, Runnable task) {
    super.beforeExecute(thread, task);
    Timer timer = metricRegistry.timer(MetricRegistry.name(metricsPrefix, "task-execution"));
    taskExecutionTimer.set(timer.time());
  }

  @Override
  protected void afterExecute(Runnable task, Throwable throwable) {
    Timer.Context context = taskExecutionTimer.get();
    context.stop();
    super.afterExecute(task, throwable);
  }
}

Recording number of failed tasks due to uncaught exceptions

The second parameter to the afterExecute callback is a Throwable. If non-null, this Throwable refers to the uncaught RuntimeException or Error that caused the execution to terminate. We can use this information to partially count the total number of tasks that were terminated abruptly due to uncaught exceptions.

To get the total number of failed tasks, we must consider another case. Tasks submitted using the execute method will throw any uncaught exceptions, and it will be available as the second argument to the afterExecute callback. However, tasks submitted using the submit method are swallowed by the executor service. This is clearly explained in the JavaDoc (emphasis mine) -
Note: When actions are enclosed in tasks (such as FutureTask) either explicitly or via methods such as submit, these task objects catch and maintain computational exceptions, and so they do not cause abrupt termination, and the internal exceptions are not passed to this method. If you would like to trap both kinds of failures in this method, you can further probe for such cases, as in this sample subclass that prints either the direct cause or the underlying exception if a task has been aborted
Fortunately, the same doc also offers a solution for this, which is to examine the runnable to see if it's a Future, and then get the underlying exception.

Combining these approaches, we can modify our afterExecute method as follows -

@Override
protected void afterExecute(Runnable runnable, Throwable throwable) {
  Timer.Context context = taskExecutionTimer.get();
  context.stop();

  super.afterExecute(runnable, throwable);
  if (throwable == null && runnable instanceof Future && ((Future) runnable).isDone()) {
    try {
      ((Future) runnable).get();
    } catch (CancellationException ce) {
      throwable = ce;
    } catch (ExecutionException ee) {
      throwable = ee.getCause();
    } catch (InterruptedException ie) {
      Thread.currentThread().interrupt();
    }
  }
  if (throwable != null) {
    Counter failedTasksCounter = metricRegistry.counter(MetricRegistry.name(metricsPrefix, "failed-tasks"));
    failedTasksCounter.inc();
  }
}


Counting total number of successful tasks

The previous approach can also be used to count the total number of successful tasks: tasks that were completed without throwing any exceptions or errors -

@Override
protected void afterExecute(Runnable runnable, Throwable throwable) {
  // Rest of the method body .....

  if (throwable != null) {
    Counter failedTasksCounter = metricRegistry.counter(MetricRegistry.name(metricsPrefix, "failed-tasks"));
    failedTasksCounter.inc();
  } else {
    Counter successfulTasksCounter = metricRegistry.counter(MetricRegistry.name(metricsPrefix, "successful-tasks"));
    successfulTasksCounter.inc();
  }
}


Conclusion

In this article we have looked at a few monitoring-friendly customization to an ExecutorService implementation. Like always, any suggestions/improvements/bug fix will be highly appreciated. As for the example source code, it has been uploaded to Github.

Sunday, April 22, 2018

JPA Tips: Avoiding the N + 1 select problem

Introduction

ORM frameworks like JPA simplifies our development process by helping us to avoid lots of boilerplate code during the object <-> relational data mapping. However, they also bring some additional problems to the table, and N + 1 is one of them. In this article we will take a short look at the problem along with some ways to avoid them.

The Problem

As an example I will use a simplified version of an online book ordering application. In such application I might create an entity like below to represent a Purchase Order -

@Entity
public class PurchaseOrder {

    @Id
    private String id;
    private String customerId;

    @OneToMany(cascade = ALL, fetch = EAGER)
    @JoinColumn(name = "purchase_order_id")
    private List<PurchaseOrderItem> purchaseOrderItems = new ArrayList<>();
}


A purchase order consists of an order id, a customer id, and one or more items that are being bought. The PurchaseOrderItem entity might have the following structure -

@Entity
public class PurchaseOrderItem {

    @Id
    private String id;

    private String bookId;
}


These entities have been simplified a lot, but for the purpose of this article, this will do.

Now suppose that we need to find the orders of a customer to display them in his purchase order history. The following query will serve this purpose -

SELECT
    P
FROM
    PurchaseOrder P
WHERE
    P.customerId = :customerId

which when translated to SQL looks something like below -

select
    purchaseor0_.id as id1_1_,
    purchaseor0_.customer_id as customer2_1_ 
from
    purchase_order purchaseor0_ 
where
    purchaseor0_.customer_id = ?

This one query will return all purchase orders that a customer has. However, in order to fetch the order items, JPA will issue separate queries for each individual order.  If, for example, a customer has 5 orders, then JPA will issue 5 additional queries to fetch the order items included in those orders. This is basically known as the N + 1 problem - 1 query to fetch all N purchase orders, and N queries to fetch all order items.

This behavior creates a scalability problem for us when our data grows. Even a moderate number of orders and items can create significant performance issues.

The Solution


Avoiding Eager Fetching

This the main reason behind the issue. We should get rid of all the eager fetching from our mapping. They have almost no benefits that justify their use in a production-grade application. We should mark all relationships as Lazy instead.

One important point to note - marking a relationship mapping as Lazy does not guarantee that the underlying persistent provider will also treat it as such. The JPA specification does not guarantee the lazy fetch. It's a hint to the persistent provider at best. However, considering Hibernate, I have never seen it doing otherwise.

Only fetching the data that are actually needed

This is always recommended regardless of the decision to go for eager/lazy fetching.

I remember one N + 1 optimization that I did which improved the maximum response time of a REST endpoint from 17 minutes to 1.5 seconds. The endpoint was fetching a single entity based on some criteria, which for our current example will be something along the line of -

TypedQuery<PurchaseOrder> jpaQuery = entityManager.createQuery("SELECT P FROM PurchaseOrder P WHERE P.customerId = :customerId", PurchaseOrder.class);
jpaQuery.setParameter("customerId", "Sayem");
PurchaseOrder purchaseOrder = jpaQuery.getSingleResult();

// after some calculation
anotherRepository.findSomeStuff(purchaseOrder.getId());

The id is the only data from the result that was needed for subsequent calculations.

There were a few customers who had more than a thousand orders. Each one of the orders in turn had a few thousands additional children of a few different types. Needless to say, as a result, thousands of queries were being executed in the database whenever requests for those orders were received at this endpoint. 

To improve the performance, all I did was -

TypedQuery<String> jpaQuery = entityManager.createQuery("SELECT P.id FROM PurchaseOrder P WHERE P.customerId = :customerId", String.class);
jpaQuery.setParameter("customerId", "Sayem");
String orderId = jpaQuery.getSingleResult();

// after some calculation
anotherRepository.findSomeStuff(orderId);

Just this change resulted in a 680x improvement.

If we want to fetch more than one properties, then we can make use of the Constructor Expression that JPA provides -

TypedQuery<PurchaseOrderDTO> jpaQuery = entityManager.createQuery(
        "SELECT " +
                "NEW com.codesod.example.jpa.nplusone.dto.PurchaseOrderDTO(P.id, P.orderDate) " +
        "FROM " +
                "PurchaseOrder P " +
        "WHERE " +
                "P.customerId = :customerId",
        PurchaseOrderDTO.class);
 jpaQuery.setParameter("customerId", "Sayem");
 List<PurchaseOrderDTO> orders = jpaQuery.getResultList();

A few caveats of using the constructor expression -
  1. The target DTO must have a constructor whose parameter list match the columns being seleected
  2. The fully qualified name of the DTO class must be specified

Useing Join Fetch / Entity Graphs

We can use JOIN FETCH in our queries whenever we need to fetch an entity with all of its children at the same time. This results in a much less database traffic resulting in an improved performance.

JPA 2.1 specification introduced Entity Graphs which allows us to create static/dynamic query load plans. Thorben Janssen has written a couple of posts (here and here) detailing their usage which are worth checking out.

Some example code for this post can be found at Github.