Monday, January 14, 2019

Different States of Java Threads

Introduction

In Java, threads can have States. The Thread.State enum defines the different states that a Java thread can have. This enum defines the following values -
  1. NEW
  2. RUNNABLE
  3. BLOCKED
  4. WAITING
  5. TIMED_WAITING
  6. TERMINATED
In the subsequent sections, I provide a brief overview of these states along with possible transitions between them.

States of a Java Thread

NEW

This is the default state a thread gets when it is first created.

RUNNABLE

As soon as a thread starts executing, it moves to the RUNNABLE state. Note that a thread that is waiting to acquire a CPU for execution is still in this state.

BLOCKED

A thread moves to the BLOCKED state as soon as it gets blocked waiting for a monitor lock. This can happen in one of the following two ways -
  1. It's waiting to acquire a lock to enter a synchronised block/method.
  2. It's waiting to reacquire the monitor lock of an object on which it invoked the Object.wait method.

WAITING

A thread moves to this state as a result of invoking one of the following methods -
  1. Object.wait without a timeout
  2. Thread.join without a timeout
  3. LockSupport.park

TIMED_WAITING

A thread moves to this state as a result of invoking one of the following methods -

TERMINATED

As soon as a thread terminates, it moves to this state.

Possible state transitions

The following diagram shows the possible transitions between different states -

Java Thread Transitions 

As soon as a thread gets scheduled for execution, it moves to the RUNNABLE state. This transition has been shown with the first arrow (marked as 1).

From the RUNNABLE state, a thread can move to any of the BLOCKED, WAITING, TIMED_WAITING, or TERMINATED state. Theoretically speaking, if a thread does not wait to acquire any lock, or does not sleep, or does not invoke any of the methods which makes it wait, it just finishes its execution and directly goes to the TERMINATED state (marked as 2d).

Of course in a practical application, the above scenario is highly unlikely. Often a thread tries to acquire a lock, in which case it moves to the BLOCKED (marked as 2a) state if it has to wait for the lock. Threads also explicitly wait for some preconditions to be true/actions from other threads, in which case they move to the WAITING (marked as 2b) or the TIMED_WAITING (marked as 2c) state, depending on whether the waits were timed or not.

Once a thread moves to the BLOCKED state, the only possible transition that is allowed next is to move to the RUNNABLE state (marked as 3d).

Similarly, the only possible transition from the WAITING state is to move to the BLOCKED state (marked as 3c).

Please note that some of the articles on the internet incorrectly adds a transition from the WAITING to the RUNNABLE state. This is just not correct. A thread can never move to the RUNNABLE state from the WAITING state directly. We can understand the reason for this with an example.

Suppose that we have a thread T which is currently in the RUNNABLE state and holds the monitor lock of three objects a, b, and c, as shown in the diagram below -

Before invoking c.wait()

At this point, T invokes c.wait(), after which it no longer holds the monitor lock of object c -

After invoking c.wait()

As soon as T is notified using an invocation of notify/notifyAll, it stops waiting and competes with other threads (let's say, X and Y) to acquire the monitor lock of c -

After T has been notified with notify/notifyAll

which, according to the definitions above, is the BLOCKED state. Only after acquiring the monitor lock of c, T moves to the RUNNABLE state. Similar reasoning can be applied for the Thread.join() (which internally uses Object.wait()) and LockSupport.park().

Let's get back to our original state transition diagram. As we can see, a thread can move to either the RUNNABLE (marked as 3b) or the BLOCKED (marked as 3a) state from the TIMED_WAITING state. The transition to RUNNABLE is possible in this case because a thread can enter the TIMED_WAITING state after invoking the Thread.sleep method, in which case it retains all the monitor locks it currently holds.

As a thread finishes execution after moving back and forth between the RUNNABLE, BLOCKED, WAITING or TIMED_WAITING state, it moves to the TERMINATED state once and for all.

How do we get the current state of a Thread?

We can use the Thread.getState() method to retrieve the current state of a thread. We can use this value to monitor or debug any concurrency issues that our application might face in production.

Conclusion

In this article we briefly reviewed different states a Java thread can have, and how a thread moves between these states. As always, any feedback/improvement suggestions/comments is highly appreciated!

Sunday, December 16, 2018

A brief overview of the Fork/Join Framework in Java

Introduction

The Fork/Join framework is a framework to solve a problem using a concurrent divide-and-conquer approach. They were introduced to complement the existing concurrency API. Before their introduction, the existing ExecutorService implementations were the popular choice to run asynchronous tasks, but they work best when the tasks are homogenous and independent. Running dependent tasks and combining their results using those implementations were not easy. With the introduction of the Fork/Join framework, an attempt was made to address this shortcoming. In this post, we will take a brief look at the API and solve a couple of simple problems to understand how they work.

Solving a non-blocking task

Let's jump directly into code. Let's create a task which would return the sum of all elements of a List. The following steps represent our algorithm in pseudo-code:

01. Find the middle index of the list
02. Divide the list in the middle
03. Recursively create a new task which will compute the sum of the left part
04. Recursively create a new task which will compute the sum of the right part
05. Add the result of the left sum, the middle element, and the right sum

Here is the code -

@Slf4j
public class ListSummer extends RecursiveTask<Integer> {
  private final List<Integer> listToSum;

  ListSummer(List<Integer> listToSum) {
    this.listToSum = listToSum;
  }

  @Override
  protected Integer compute() {
    if (listToSum.isEmpty()) {
      log.info("Found empty list, sum is 0");
      return 0;
    }

    int middleIndex = listToSum.size() / 2;
    log.info("List {}, middle Index: {}", listToSum, middleIndex);

    List<Integer> leftSublist = listToSum.subList(0, middleIndex);
    List<Integer> rightSublist = listToSum.subList(middleIndex + 1, listToSum.size());

    ListSummer leftSummer = new ListSummer(leftSublist);
    ListSummer rightSummer = new ListSummer(rightSublist);

    leftSummer.fork();
    rightSummer.fork();

    Integer leftSum = leftSummer.join();
    Integer rightSum = rightSummer.join();
    int total = leftSum + listToSum.get(middleIndex) + rightSum;
    log.info("Left sum is {}, right sum is {}, total is {}", leftSum, rightSum, total);

    return total;
  }
}

Firstly, we extend the RecursiveTask subtype of the ForkJoinTask. This is the type to extend from when we expect our concurrent task to return a result. When a task does not return a result but only perform an effect, we extend the RecursiveAction subtype. For most of the practical tasks that we solve, these two subtypes are sufficient.

Secondly, both RecursiveTask and RecursiveAction define an abstract compute method. This is where we put our computation.

Thirdly, inside our compute method, we check the size of the list that is passed through the constructor. If it is empty, we already know the result of the sum which is zero, and we return immediately. Otherwise, we divide our lists into two sublists and create two instances of our ListSummer type. We then call the fork() method (defined in ForkJoinTask) on these two instances -

leftSummer.fork();
rightSummer.fork();

Which cause these tasks to be scheduled for asynchronous execution, the exact mechanism which is used for this purpose will be explained later in this post.

After that, we invoke the join() method (also defined in ForkJoinTask) to wait for the result of these two parts -

Integer leftSum = leftSummer.join();
Integer rightSum = rightSummer.join();

Which are then summed with the middle element of the list to get the final result.

Plenty of log messages have been added to make the example easier to understand. However, when we process a list containing thousands of entries, it might not be a good idea to have this detailed logging, especially logging the entire list.

That's pretty much it. Let's create a test class now for a test run -

public class ListSummerTest {

  @Test
  public void shouldSumEmptyList() {
    ListSummer summer = new ListSummer(List.of());
    ForkJoinPool forkJoinPool = new ForkJoinPool();
    forkJoinPool.submit(summer);

    int result = summer.join();

    assertThat(result).isZero();
  }

  @Test
  public void shouldSumListWithOneElement() {
    ListSummer summer = new ListSummer(List.of(5));
    ForkJoinPool forkJoinPool = new ForkJoinPool();
    forkJoinPool.submit(summer);

    int result = summer.join();

    assertThat(result).isEqualTo(5);
  }

  @Test
  public void shouldSumListWithMultipleElements() {
    ListSummer summer = new ListSummer(List.of(
        1, 2, 3, 4, 5, 6, 7, 8, 9
    ));
    ForkJoinPool forkJoinPool = new ForkJoinPool();
    forkJoinPool.submit(summer);

    int result = summer.join();

    assertThat(result).isEqualTo(45);
  }
}

In the test, we create an instance of the ForkJoinPool. A ForkJoinPool is a unique ExecutorService implementation for running ForkJoinTasks. It employs a special algorithm known as the work-stealing algorithm. Contrary to the other ExecutorService implementations where there is only a single queue holding all the tasks to be executed, in a work-stealing implementation, each worker thread gets its work queue. Each thread starts executing tasks from their queue. When we detect that a ForkJoinTask can be broken down into multiple smaller subtasks, we do break them into smaller tasks, and then we invoke the fork() method on those tasks. This invocation causes the subtasks to be pushed into the executing thread's queue. During the execution, when one thread exhausts its queue/has no tasks to execute, it can "steal" tasks from other thread's queue (hence the name "work-stealing"). This stealing behaviour is what results in a better throughput than using any other ExecutorService implementations.

Earlier, when we invoked fork() on our leftSummer and rightSummer task instances, they got pushed into the work queue of the executing thread, after which they were "stolen" by other active threads in the pool (and so on) since they did not have anything else to do at that point.

Pretty cool, right?

Solving a blocking task

The problem we solved just now is non-blocking in nature. If we want to solve a problem which does some blocking operation, then to have a better throughput we will need to change our strategy.

Let's examine this with another example. Let's say we want to create a very simple web crawler. This crawler will receive a list of HTTP links, execute GET requests to fetch the response bodies, and then calculate the response length. Here is the code -

@Slf4j
public class ResponseLengthCalculator extends RecursiveTask<Map<String, Integer>> {
  private final List<String> links;

  ResponseLengthCalculator(List<String> links) {
    this.links = links;
  }

  @Override
  protected Map<String, Integer> compute() {
    if (links.isEmpty()) {
      log.info("No more links to fetch");
      return Collections.emptyMap();
    }

    int middle = links.size() / 2;
    log.info("Middle index: {}", links, middle);
    ResponseLengthCalculator leftPartition = new ResponseLengthCalculator(links.subList(0, middle));
    ResponseLengthCalculator rightPartition = new ResponseLengthCalculator(links.subList(middle + 1, links.size()));

    log.info("Forking left partition");
    leftPartition.fork();
    log.info("Left partition forked, now forking right partition");
    rightPartition.fork();
    log.info("Right partition forked");

    String middleLink = links.get(middle);
    HttpRequester httpRequester = new HttpRequester(middleLink);
    String response;
    try {
      log.info("Calling managedBlock for {}", middleLink);
      ForkJoinPool.managedBlock(httpRequester);
      response = httpRequester.response;
    } catch (InterruptedException ex) {
      log.error("Error occurred while trying to implement blocking link fetcher", ex);
      response = "";
    }

    Map<String, Integer> responseMap = new HashMap<>(links.size());

    Map<String, Integer> leftLinks = leftPartition.join();
    responseMap.putAll(leftLinks);
    responseMap.put(middleLink, response.length());
    Map<String, Integer> rightLinks = rightPartition.join();
    responseMap.putAll(rightLinks);

    log.info("Left map {}, middle length {}, right map {}", leftLinks, response.length(), rightLinks);

    return responseMap;
  }

  private static class HttpRequester implements ForkJoinPool.ManagedBlocker {
    private final String link;
    private String response;

    private HttpRequester(String link) {
      this.link = link;
    }

    @Override
    public boolean block() {
      HttpGet headRequest = new HttpGet(link);
      CloseableHttpClient client = HttpClientBuilder
          .create()
          .disableRedirectHandling()
          .build();
      
      log.info("Executing blocking request for {}", link);
      
      try (client; CloseableHttpResponse response = client.execute(headRequest)) {
        log.info("HTTP request for link {} has been executed", link);
        this.response = EntityUtils.toString(response.getEntity());
      } catch (IOException e) {
        log.error("Error while trying to fetch response from link {}: {}", link, e.getMessage());
        this.response = "";
      }
      return true;
    }

    @Override
    public boolean isReleasable() {
      return false;
    }
  }
}

We create an implementation of the ForkJoinPool.ManagedBlocker where we put the blocking HTTP call. This interface defines two methods - block() and isReleasable(). The block() method is where we put our blocking call. After we are done with our blocking operation, we return true indicating that no further blocking is necessary. We return false from the isReleasable() implementation to indicate to a fork-join worker thread that the block() method implementation is potentially blocking in nature. The isReleasable() implementation will be invoked by a fork-join worker thread first before it invokes the block() method. Finally, we submit our  HttpRequester instance to our pool by invoking ForkJoinPool.managedBlock() static method. After that our blocking task will start executing. When it blocks on the HTTP request, the ForkJoinPool.managedBlock() method will also arrange for a spare thread to be activated if necessary to ensure sufficient parallelism.

Let's take this implementation for a test drive then! Here's the code -

public class ResponseLengthCalculatorTest {

  @Test
  public void shouldReturnEmptyMapForEmptyList() {
    ResponseLengthCalculator responseLengthCalculator = new ResponseLengthCalculator(Collections.emptyList());
    ForkJoinPool pool = new ForkJoinPool();

    pool.submit(responseLengthCalculator);

    Map<String, Integer> result = responseLengthCalculator.join();
    assertThat(result).isEmpty();
  }

  @Test
  public void shouldHandle200Ok() {
    ResponseLengthCalculator responseLengthCalculator = new ResponseLengthCalculator(List.of(
        "http://httpstat.us/200"
    ));
    ForkJoinPool pool = new ForkJoinPool();

    pool.submit(responseLengthCalculator);

    Map<String, Integer> result = responseLengthCalculator.join();
    assertThat(result)
        .hasSize(1)
        .containsKeys("http://httpstat.us/200")
        .containsValue(0);
  }

  @Test
  public void shouldFetchResponseForDifferentResponseStatus() {
    ResponseLengthCalculator responseLengthCalculator = new ResponseLengthCalculator(List.of(
        "http://httpstat.us/200",
        "http://httpstat.us/302",
        "http://httpstat.us/404",
        "http://httpstat.us/502"
    ));
    ForkJoinPool pool = new ForkJoinPool();

    pool.submit(responseLengthCalculator);

    Map<String, Integer> result = responseLengthCalculator.join();
    assertThat(result)
        .hasSize(4);
  }
}

That's it for today, folks! As always, any feedback/improvement suggestions/comments are highly appreciated!

All the examples discussed here can be found on Github.

A big shout out to the awesome http://httpstat.us service, it was quite helpful for developing the simple tests.

Tuesday, May 1, 2018

Java Tips: Creating a Monitoring-friendly ExecutorService

In this article we will be extending an ExecutorService implementation with monitoring capabilities. This monitoring capability will help us to measure a number of pool parameters i.e., active threads, work queue size etc. in a live production environment. It will also enable us to measure task execution time, successful tasks count, and failed tasks count.

Monitoring Library

As for the monitoring library we will be using Metrics. For the sake of simplicity we will be using a ConsoleReporter which will report our metrics to the console. For production-grade applications, we should use an advanced reporter (i.e., Graphite reporter). If you are unfamiliar with Metrics, then I recommend you to go through the getting started guide.

Let's get started.

Extending the ThreadPoolExecutor

We will be using ThreadPoolExecutor as the base class for our new type. Let's call it MonitoredThreadPoolExecutor. This class will accept a MetricRegistry as one of its constructor parameters -

public class MonitoredThreadPoolExecutor extends ThreadPoolExecutor {
  private final MetricRegistry metricRegistry;

  public MonitoredThreadPoolExecutor(
      int corePoolSize,
      int maximumPoolSize,
      long keepAliveTime,
      TimeUnit unit,
      BlockingQueue<Runnable> workQueue,
      MetricRegistry metricRegistry
  ) {
    super(corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue);
    this.metricRegistry = metricRegistry;
  }

  public MonitoredThreadPoolExecutor(
      int corePoolSize,
      int maximumPoolSize,
      long keepAliveTime,
      TimeUnit unit,
      BlockingQueue<Runnable> workQueue,
      ThreadFactory threadFactory,
      MetricRegistry metricRegistry
  ) {
    super(corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue, threadFactory);
    this.metricRegistry = metricRegistry;
  }

  public MonitoredThreadPoolExecutor(
      int corePoolSize,
      int maximumPoolSize,
      long keepAliveTime,
      TimeUnit unit,
      BlockingQueue<Runnable> workQueue,
      RejectedExecutionHandler handler,
      MetricRegistry metricRegistry
  ) {
    super(corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue, handler);
    this.metricRegistry = metricRegistry;
  }

  public MonitoredThreadPoolExecutor(
      int corePoolSize,
      int maximumPoolSize,
      long keepAliveTime,
      TimeUnit unit,
      BlockingQueue<Runnable> workQueue,
      ThreadFactory threadFactory,
      RejectedExecutionHandler handler,
      MetricRegistry metricRegistry
  ) {
    super(corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue, threadFactory, handler);
    this.metricRegistry = metricRegistry;
  }
}

Registering Gauges to measure pool-specific paramters

A Gauge is an instantaneous measurement of a value. We will be using it to measure different pool parameters like number of active threads, task queue size etc.

Before we can register a Gauge, we need to decide how to calculate a metric name for our thread pool. Each metric, whether it's a Gauge, or a Timer, or simply a Meter, has a unique name. This name is used to identify the metric source. The convention here is to use a dotted string which is often constructed from the fully qualified name of the class being monitored.

For our thread pool, we will be using its fully qualified name as a prefix to our metrics names. Additionally we will add another constructor parameter called poolName, which will be used by the clients to specify instance-specific identifiers.

After implementing these changes the class looks like below -

public class MonitoredThreadPoolExecutor extends ThreadPoolExecutor {
  private final MetricRegistry metricRegistry;
  private final String metricsPrefix;

  public MonitoredThreadPoolExecutor(
      int corePoolSize,
      int maximumPoolSize,
      long keepAliveTime,
      TimeUnit unit,
      BlockingQueue<Runnable> workQueue,
      MetricRegistry metricRegistry,
      String poolName
  ) {
    super(corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue);
    this.metricRegistry = metricRegistry;
    this.metricsPrefix = MetricRegistry.name(getClass(), poolName);
  }

  // Rest of the constructors
}

Now we are ready to register our Gauges. For this purpose we will define a private method -

private void registerGauges() {
  metricRegistry.register(MetricRegistry.name(metricsPrefix, "corePoolSize"), (Gauge<Integer>) this::getCorePoolSize);
  metricRegistry.register(MetricRegistry.name(metricsPrefix, "activeThreads"), (Gauge<Integer>) this::getActiveCount);
  metricRegistry.register(MetricRegistry.name(metricsPrefix, "maxPoolSize"), (Gauge<Integer>) this::getMaximumPoolSize);
  metricRegistry.register(MetricRegistry.name(metricsPrefix, "queueSize"), (Gauge<Integer>) () -> getQueue().size());
}

For our example we are measuring core pool size, number of active threads, maximum pool size, and task queue size. Depending on monitoring requirements we can register more/less Gauges to measure different properties.

This private method will now be invoked from all constructors -

public MonitoredThreadPoolExecutor(
    int corePoolSize,
    int maximumPoolSize,
    long keepAliveTime,
    TimeUnit unit,
    BlockingQueue<Runnable> workQueue,
    MetricRegistry metricRegistry,
    String poolName
) {
  super(corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue);
  this.metricRegistry = metricRegistry;
  this.metricsPrefix = MetricRegistry.name(getClass(), poolName);
  registerGauges();
}

Measuring Task Execution Time

To measure the task execution time, we will override two life-cycle methods that ThreadPoolExecutor provides - beforeExecute and afterExecute.

As the name implies, beforeExecute callback is invoked prior to executing a task, by the thread that will execute the task. The default implementation of this callback does nothing.

Similarly, the afterExecute callback is invoked after each task is executed, by the thread that executed the task. The default implementation of this callback also does nothing. Even if the task throws an uncaught RuntimeException or Error, this callback will be invoked.

We will be starting a Timer in our beforeExecute override, which will then be used in our afterExecute override to get the total task execution time. To store a reference to the Timer we will introduce a new ThreadLocal field in our class.

The implementation of the callbacks are given below -

public class MonitoredThreadPoolExecutor extends ThreadPoolExecutor {
  private final MetricRegistry metricRegistry;
  private final String metricsPrefix;
  private ThreadLocal<Timer.Context> taskExecutionTimer = new ThreadLocal<>();

  // Constructors

  @Override
  protected void beforeExecute(Thread thread, Runnable task) {
    super.beforeExecute(thread, task);
    Timer timer = metricRegistry.timer(MetricRegistry.name(metricsPrefix, "task-execution"));
    taskExecutionTimer.set(timer.time());
  }

  @Override
  protected void afterExecute(Runnable task, Throwable throwable) {
    Timer.Context context = taskExecutionTimer.get();
    context.stop();
    super.afterExecute(task, throwable);
  }
}

Recording number of failed tasks due to uncaught exceptions

The second parameter to the afterExecute callback is a Throwable. If non-null, this Throwable refers to the uncaught RuntimeException or Error that caused the execution to terminate. We can use this information to partially count the total number of tasks that were terminated abruptly due to uncaught exceptions.

To get the total number of failed tasks, we must consider another case. Tasks submitted using the execute method will throw any uncaught exceptions, and it will be available as the second argument to the afterExecute callback. However, tasks submitted using the submit method are swallowed by the executor service. This is clearly explained in the JavaDoc (emphasis mine) -
Note: When actions are enclosed in tasks (such as FutureTask) either explicitly or via methods such as submit, these task objects catch and maintain computational exceptions, and so they do not cause abrupt termination, and the internal exceptions are not passed to this method. If you would like to trap both kinds of failures in this method, you can further probe for such cases, as in this sample subclass that prints either the direct cause or the underlying exception if a task has been aborted
Fortunately, the same doc also offers a solution for this, which is to examine the runnable to see if it's a Future, and then get the underlying exception.

Combining these approaches, we can modify our afterExecute method as follows -

@Override
protected void afterExecute(Runnable runnable, Throwable throwable) {
  Timer.Context context = taskExecutionTimer.get();
  context.stop();

  super.afterExecute(runnable, throwable);
  if (throwable == null && runnable instanceof Future && ((Future) runnable).isDone()) {
    try {
      ((Future) runnable).get();
    } catch (CancellationException ce) {
      throwable = ce;
    } catch (ExecutionException ee) {
      throwable = ee.getCause();
    } catch (InterruptedException ie) {
      Thread.currentThread().interrupt();
    }
  }
  if (throwable != null) {
    Counter failedTasksCounter = metricRegistry.counter(MetricRegistry.name(metricsPrefix, "failed-tasks"));
    failedTasksCounter.inc();
  }
}


Counting total number of successful tasks

The previous approach can also be used to count the total number of successful tasks: tasks that were completed without throwing any exceptions or errors -

@Override
protected void afterExecute(Runnable runnable, Throwable throwable) {
  // Rest of the method body .....

  if (throwable != null) {
    Counter failedTasksCounter = metricRegistry.counter(MetricRegistry.name(metricsPrefix, "failed-tasks"));
    failedTasksCounter.inc();
  } else {
    Counter successfulTasksCounter = metricRegistry.counter(MetricRegistry.name(metricsPrefix, "successful-tasks"));
    successfulTasksCounter.inc();
  }
}


Conclusion

In this article we have looked at a few monitoring-friendly customization to an ExecutorService implementation. Like always, any suggestions/improvements/bug fix will be highly appreciated. As for the example source code, it has been uploaded to Github.

Sunday, April 22, 2018

JPA Tips: Avoiding the N + 1 select problem

Introduction

ORM frameworks like JPA simplifies our development process by helping us to avoid lots of boilerplate code during the object <-> relational data mapping. However, they also bring some additional problems to the table, and N + 1 is one of them. In this article we will take a short look at the problem along with some ways to avoid them.

The Problem

As an example I will use a simplified version of an online book ordering application. In such application I might create an entity like below to represent a Purchase Order -

@Entity
public class PurchaseOrder {

    @Id
    private String id;
    private String customerId;

    @OneToMany(cascade = ALL, fetch = EAGER)
    @JoinColumn(name = "purchase_order_id")
    private List<PurchaseOrderItem> purchaseOrderItems = new ArrayList<>();
}


A purchase order consists of an order id, a customer id, and one or more items that are being bought. The PurchaseOrderItem entity might have the following structure -

@Entity
public class PurchaseOrderItem {

    @Id
    private String id;

    private String bookId;
}


These entities have been simplified a lot, but for the purpose of this article, this will do.

Now suppose that we need to find the orders of a customer to display them in his purchase order history. The following query will serve this purpose -

SELECT
    P
FROM
    PurchaseOrder P
WHERE
    P.customerId = :customerId

which when translated to SQL looks something like below -

select
    purchaseor0_.id as id1_1_,
    purchaseor0_.customer_id as customer2_1_ 
from
    purchase_order purchaseor0_ 
where
    purchaseor0_.customer_id = ?

This one query will return all purchase orders that a customer has. However, in order to fetch the order items, JPA will issue separate queries for each individual order.  If, for example, a customer has 5 orders, then JPA will issue 5 additional queries to fetch the order items included in those orders. This is basically known as the N + 1 problem - 1 query to fetch all N purchase orders, and N queries to fetch all order items.

This behavior creates a scalability problem for us when our data grows. Even a moderate number of orders and items can create significant performance issues.

The Solution


Avoiding Eager Fetching

This the main reason behind the issue. We should get rid of all the eager fetching from our mapping. They have almost no benefits that justify their use in a production-grade application. We should mark all relationships as Lazy instead.

One important point to note - marking a relationship mapping as Lazy does not guarantee that the underlying persistent provider will also treat it as such. The JPA specification does not guarantee the lazy fetch. It's a hint to the persistent provider at best. However, considering Hibernate, I have never seen it doing otherwise.

Only fetching the data that are actually needed

This is always recommended regardless of the decision to go for eager/lazy fetching.

I remember one N + 1 optimization that I did which improved the maximum response time of a REST endpoint from 17 minutes to 1.5 seconds. The endpoint was fetching a single entity based on some criteria, which for our current example will be something along the line of -

TypedQuery<PurchaseOrder> jpaQuery = entityManager.createQuery("SELECT P FROM PurchaseOrder P WHERE P.customerId = :customerId", PurchaseOrder.class);
jpaQuery.setParameter("customerId", "Sayem");
PurchaseOrder purchaseOrder = jpaQuery.getSingleResult();

// after some calculation
anotherRepository.findSomeStuff(purchaseOrder.getId());

The id is the only data from the result that was needed for subsequent calculations.

There were a few customers who had more than a thousand orders. Each one of the orders in turn had a few thousands additional children of a few different types. Needless to say, as a result, thousands of queries were being executed in the database whenever requests for those orders were received at this endpoint. 

To improve the performance, all I did was -

TypedQuery<String> jpaQuery = entityManager.createQuery("SELECT P.id FROM PurchaseOrder P WHERE P.customerId = :customerId", String.class);
jpaQuery.setParameter("customerId", "Sayem");
String orderId = jpaQuery.getSingleResult();

// after some calculation
anotherRepository.findSomeStuff(orderId);

Just this change resulted in a 680x improvement.

If we want to fetch more than one properties, then we can make use of the Constructor Expression that JPA provides -

TypedQuery<PurchaseOrderDTO> jpaQuery = entityManager.createQuery(
        "SELECT " +
                "NEW com.codesod.example.jpa.nplusone.dto.PurchaseOrderDTO(P.id, P.orderDate) " +
        "FROM " +
                "PurchaseOrder P " +
        "WHERE " +
                "P.customerId = :customerId",
        PurchaseOrderDTO.class);
 jpaQuery.setParameter("customerId", "Sayem");
 List<PurchaseOrderDTO> orders = jpaQuery.getResultList();

A few caveats of using the constructor expression -
  1. The target DTO must have a constructor whose parameter list match the columns being seleected
  2. The fully qualified name of the DTO class must be specified

Useing Join Fetch / Entity Graphs

We can use JOIN FETCH in our queries whenever we need to fetch an entity with all of its children at the same time. This results in a much less database traffic resulting in an improved performance.

JPA 2.1 specification introduced Entity Graphs which allows us to create static/dynamic query load plans. Thorben Janssen has written a couple of posts (here and here) detailing their usage which are worth checking out.

Some example code for this post can be found at Github.

Wednesday, November 1, 2017

Replacing exceptions with error notifications during input validation in Java

In my previous article I wrote about an input validation design which replaces hard-to-maintain-and-test if-else blocks. However, as some readers pointed out, it has a drawback - if the input data has more than one validation errors, then the user will have to submit the request multiple times to find all of them. From a usability perspective this is not a good design.

An alternative to throwing exceptions when we find a validation error is to return a Notification object containing the error(s). This will enable us to run all the validation rules on the user input, and catch all violations at the same time. Martin Fowler wrote an article detailing the approach. I highly recommend you to go ahead and give it a read, if you haven't done so already.

In this article I will refactor my previous implementation to use Error Notification object to validate user inputs.

As a first step, I will create an ErrorNotification object which encapsulates my application errors -
public class ErrorNotification {
  private List<String> errors = new ArrayList<>();

  public void addError(String message) {
    this.errors.add(message);
  }

  public boolean hasError() {
    return !this.errors.isEmpty();
  }

  public String getAllErrors() {
    return this.errors.stream()
        .collect(joining(", "));
  }
}
I will then change the OrderItemValidator interface to return an ErrorNotification object -
public interface OrderItemValidator {
  ErrorNotification validate(OrderItem orderItem);
}
and then change all the implementations to adapt to the new return type as well.

Initially, I will change all the implementations to return an empty error object, so that I can get rid of the compilation errors. For example, I will change the ItemDescriptionValidator in the following way -
class ItemDescriptionValidator implements OrderItemValidator {

  @Override
  public ErrorNotification validate(OrderItem orderItem) {
    ErrorNotification errorNotification = new ErrorNotification();
    Optional.ofNullable(orderItem)
        .map(OrderItem::getDescription)
        .map(String::trim)
        .filter(description -> !description.isEmpty())
        .orElseThrow(() -> new IllegalArgumentException("Item description should be provided"));
    return errorNotification;
  }
}
After fixing the compilation errors, I will now start replacing the exceptions with notification messages in each validator. To do this, I will first modify the related tests to reflect my intent, and then modify the validators to pass the tests.

Let's start with the ItemDescriptionValidatorTest class -
public class ItemDescriptionValidatorTest {

  @Test
  public void validate_descriptionIsNull_invalid() {
    ItemDescriptionValidator validator = new ItemDescriptionValidator();

    ErrorNotification errorNotification = validator.validate(new OrderItem());

    assertThat(errorNotification.getAllErrors()).isEqualTo("Item description should be provided");
  }

  @Test
  public void validate_descriptionIsBlank_invalid() {
    OrderItem orderItem = new OrderItem();
    orderItem.setDescription("     ");
    ItemDescriptionValidator validator = new ItemDescriptionValidator();

    ErrorNotification errorNotification = validator.validate(new OrderItem());

    assertThat(errorNotification.getAllErrors()).isEqualTo("Item description should be provided");
  }

  @Test
  public void validate_descriptionGiven_valid() {
    OrderItem orderItem = new OrderItem();
    orderItem.setDescription("dummy description");
    ItemDescriptionValidator validator = new ItemDescriptionValidator();

    ErrorNotification errorNotification = validator.validate(orderItem);

    assertThat(errorNotification.getAllErrors()).isEmpty();
  }
}
When I run these tests, only one of them passes, and two of them fail, which is expected. I will now modify the validator code to pass the tests -
class ItemDescriptionValidator implements OrderItemValidator {
  static final String MISSING_ITEM_DESCRIPTION = "Item description should be provided";

  @Override
  public ErrorNotification validate(OrderItem orderItem) {
    ErrorNotification errorNotification = new ErrorNotification();
    Optional.ofNullable(orderItem)
        .map(OrderItem::getDescription)
        .map(String::trim)
        .filter(description -> !description.isEmpty())
        .ifPresentOrElse(
            description -> {},
            () -> errorNotification.addError(MISSING_ITEM_DESCRIPTION)
        );
    return errorNotification;
  }
}
I am a bit uncomfortable with the use of the ifPresentOrElse method above. The main reason I am using it here is because Optionals don't have something like an ifNotPresent method, which would have allowed me to take an action only when the value is not present (request to my readers - if you know a better way to do this, please comment in!).

After this refactoring, all tests in the ItemValidatorTest class pass with flying color. Great!

Let's refactor the tests in the MenuValidatorTest class now -
public class MenuValidatorTest {

  @Test
  public void validate_menuIdInvalid_invalid() {
    OrderItem orderItem = new OrderItem();
    String menuId = "some menu id";
    orderItem.setMenuId(menuId);
    MenuRepository menuRepository = mock(MenuRepository.class);
    when(menuRepository.menuExists(any())).thenReturn(false);
    MenuValidator validator = new MenuValidator(menuRepository);

    ErrorNotification errorNotification = validator.validate(orderItem);

    assertThat(errorNotification.getAllErrors())
        .isEqualTo(String.format(MenuValidator.INVALID_MENU_ERROR_FORMAT, menuId));
  }

  @Test
  public void validate_menuIdNull_invalid() {
    MenuRepository menuRepository = mock(MenuRepository.class);
    when(menuRepository.menuExists(any())).thenReturn(true);
    MenuValidator validator = new MenuValidator(menuRepository);

    ErrorNotification errorNotification = validator.validate(new OrderItem());

    assertThat(errorNotification.getAllErrors())
        .isEqualTo(MenuValidator.MISSING_MENU_ERROR);
  }

  @Test
  public void validate_menuIdIsBlank_invalid() {
    OrderItem orderItem = new OrderItem();
    orderItem.setMenuId("       \t");
    MenuRepository menuRepository = mock(MenuRepository.class);
    when(menuRepository.menuExists(any())).thenReturn(true);
    MenuValidator validator = new MenuValidator(menuRepository);

    ErrorNotification errorNotification = validator.validate(orderItem);

    assertThat(errorNotification.getAllErrors())
        .isEqualTo(MenuValidator.MISSING_MENU_ERROR);
  }

  @Test
  public void validate_menuIdValid_validated() {
    OrderItem orderItem = new OrderItem();
    String menuId = "some menu id";
    orderItem.setMenuId(menuId);
    MenuRepository menuRepository = mock(MenuRepository.class);
    when(menuRepository.menuExists(menuId)).thenReturn(true);
    MenuValidator validator = new MenuValidator(menuRepository);

    ErrorNotification errorNotification = validator.validate(orderItem);

    assertThat(errorNotification.getAllErrors()).isEmpty();
  }
}

and then then MenuValidator class -
@RequiredArgsConstructor
class MenuValidator implements OrderItemValidator {
  private final MenuRepository menuRepository;

  static final String MISSING_MENU_ERROR = "A menu item must be specified.";
  static final String INVALID_MENU_ERROR_FORMAT = "Given menu [%s] does not exist.";

  @Override
  public ErrorNotification validate(OrderItem orderItem) {
    ErrorNotification errorNotification = new ErrorNotification();
    Optional.ofNullable(orderItem.getMenuId())
        .map(String::trim)
        .filter(menuId -> !menuId.isEmpty())
        .ifPresentOrElse(
            validateMenuExists(errorNotification),
            () -> errorNotification.addError(MISSING_MENU_ERROR)
        );
    return errorNotification;
  }

  private Consumer<String> validateMenuExists(ErrorNotification errorNotification) {
    return menuId -> {
      if (!menuRepository.menuExists(menuId)) {
        errorNotification.addError(String.format(INVALID_MENU_ERROR_FORMAT, menuId));
      }
    };
  }
}
and so on.

After modifying each of the individual validators, I will now modify the Composite to collect all errors for a single order item -
@RequiredArgsConstructor
class OrderItemValidatorComposite implements OrderItemValidator {
  private final List<OrderItemValidator> validators;

  @Override
  public ErrorNotification validate(OrderItem orderItem) {
    ErrorNotification errorNotification = new ErrorNotification();
    validators.stream()
        .map(validator -> validator.validate(orderItem))
        .forEach(errorNotification::addAll);
    return errorNotification;
  }
}
In order to do this, I have added a new method in the ErrorNotification class, called addAll, which basically copies all errors from another ErrorNotification object.

Finally, I will now modify the service method to collect all error messages for all the order items of an order -
@Service
@Slf4j
@RequiredArgsConstructor
class OrderService {
  private final OrderItemValidator validator;

  void createOrder(OrderDTO orderDTO) {
    ErrorNotification errorNotification = new ErrorNotification();
    orderDTO.getOrderItems()
        .stream()
        .map(validator::validate)
        .forEach(errorNotification::addAll);
    if (errorNotification.hasError()) {
      throw new IllegalArgumentException(errorNotification.getAllErrors());
    }

    log.info("Order {} saved", orderDTO);
  }
}

Making this change causes one of the tests in OrderServiceIT to fail, as it was specifically looking for an exception with cause set to NumberFormatException when the price is invalid. After our refactoring, we can safely remove this check as it is no longer relevant.

The full source code for this article has been pushed to GitHub (specific commit URL is here).

Friday, May 12, 2017

Clean Code from the trenches - Validation

Let's directly start with an example. Consider a simple web service which allows clients to place order to a shop. A very simplified version of the order controller could look something like below -

@RestController
@RequestMapping(value = "/",
    consumes = MediaType.APPLICATION_JSON_VALUE,
    produces = MediaType.APPLICATION_JSON_VALUE)
public class OrderController {
  private final OrderService orderService;

  public OrderController(OrderService orderService) {
    this.orderService = orderService;
  }

  @PostMapping
  public void doSomething(@Valid @RequestBody OrderDTO order) {
    orderService.createOrder(order);
  }
}

And the corresponding DTO class -

@Getter
@Setter
@ToString
public class OrderDTO {

  @NotNull
  private String customerId;

  @NotNull
  @Size(min = 1)
  private List<OrderItem> orderItems;

  @Getter
  @Setter
  @ToString
  public static class OrderItem {
    private String menuId;
    private String description;
    private String price;
    private Integer quantity;
  }
}


The most common approach for creating an order from this DTO is to pass it to a service, validate it as necessary, and then persist it in the database -


@Service
@Slf4j
class OrderService {
  private final MenuRepository menuRepository;

  OrderService(MenuRepository menuRepository) {
    this.menuRepository = menuRepository;
  }

  void createOrder(OrderDTO orderDTO) {
    orderDTO.getOrderItems()
        .forEach(this::validate);

    log.info("Order {} saved", orderDTO);
  }

  private void validate(OrderItem orderItem) {
    String menuId = orderItem.getMenuId();
    if (menuId == null || menuId.trim().isEmpty()) {
      throw new IllegalArgumentException("A menu item must be specified.");
    }
    if (!menuRepository.menuExists(menuId.trim())) {
      throw new IllegalArgumentException("Given menu " + menuId + " does not exist.");
    }

    String description = orderItem.getDescription();
    if (description == null || description.trim().isEmpty()) {
      throw new IllegalArgumentException("Item description should be provided");
    }

    String price = orderItem.getPrice();
    if (price == null || price.trim().isEmpty()) {
      throw new IllegalArgumentException("Price cannot be empty.");
    }
    try {
      new BigDecimal(price);
    } catch (NumberFormatException ex) {
      throw new IllegalArgumentException("Given price is not in valid format", ex);
    }
    if (orderItem.getQuantity() == null) {
      throw new IllegalArgumentException("Quantity must be given");
    }
    if (orderItem.getQuantity() <= 0) {
      throw new IllegalArgumentException("Given quantity "
          + orderItem.getQuantity()
          + " is not valid.");
    }
  }
}


The validate method is not well written. It is very hard to test. Introducing new validation rule in the future is also hard, and so is removing/modifying any of the existing ones. From my experience I have seen that most people write a few generic assertions for this type of validation check, typically in an integration test class, touching only one or two (or more, but not all) of the validation rules. As a result, refactoring in the future can only be done in Edit and Pray mode.

We can improve the code structure if we use Polymorphism to replace these conditionals. Let's create a common super type for representing a single validation rule -

public interface OrderItemValidator {
  void validate(OrderItem orderItem);
}


Next step is to create validation rule implementations which will focus on separate validation areas for the DTO. Let's start with the menu validator -

public class MenuValidator implements OrderItemValidator {
  private final MenuRepository menuRepository;

  public MenuValidator(MenuRepository menuRepository) {
    this.menuRepository = menuRepository;
  }

  @Override
  public void validate(OrderItem orderItem) {
    String menuId = Optional.ofNullable(orderItem.getMenuId())
        .map(String::trim)
        .filter(id -> !id.isEmpty())
        .orElseThrow(() -> new IllegalArgumentException("A menu item must be specified."));

    if (!menuRepository.menuExists(menuId)) {
      throw new IllegalArgumentException("Given menu [" + menuId + "] does not exist.");
    }
  }
}


Then the item description validator -

public class ItemDescriptionValidator implements OrderItemValidator {

  @Override
  public void validate(OrderItem orderItem) {
    Optional.ofNullable(orderItem)
        .map(OrderItem::getDescription)
        .map(String::trim)
        .filter(description -> !description.isEmpty())
        .orElseThrow(() -> new IllegalArgumentException("Item description should be provided"));
  }
}


Price validator -

public class PriceValidator implements OrderItemValidator {

  @Override
  public void validate(OrderItem orderItem) {
    String price = Optional.ofNullable(orderItem)
        .map(OrderItem::getPrice)
        .map(String::trim)
        .filter(itemPrice -> !itemPrice.isEmpty())
        .orElseThrow(() -> new IllegalArgumentException("Price cannot be empty."));

    try {
      new BigDecimal(price);
    } catch (NumberFormatException ex) {
      throw new IllegalArgumentException("Given price [" + price + "] is not in valid format", ex);
    }
  }
}


And finally, the quantity validator -

public class QuantityValidator implements OrderItemValidator {

  @Override
  public void validate(OrderItem orderItem) {
    Integer quantity = Optional.ofNullable(orderItem)
        .map(OrderItem::getQuantity)
        .orElseThrow(() -> new IllegalArgumentException("Quantity must be given"));
    if (quantity <= 0) {
      throw new IllegalArgumentException("Given quantity " + quantity + " is not valid.");
    }
  }
}


Each of these validator implementations can now be easily tested, independently from each other. Reasoning about each one of them also becomes easier. and so are future addition/modification/removal.

Now the wiring part. How can we integrate these validators with the order service?

One way would be to directly create a list in the OrderService constructor, and populate it with the validators. Or we could use Spring to inject a List into the OrderService -

@Service
@Slf4j
class OrderService {
  private final List<OrderItemValidator> validators;

  OrderService(List<OrderItemValidator> validators) {
    this.validators = validators;
  }

  void createOrder(OrderDTO orderDTO) {
    orderDTO.getOrderItems()
        .forEach(this::validate);

    log.info("Order {} saved", orderDTO);
  }

  private void validate(OrderItem orderItem) {
    validators.forEach(validator -> validator.validate(orderItem));
  }
}


In order for this to work we will have to declare each of the validator implementations as a Spring Bean.

We could improve our abstraction even further. The OrderService is now accepting a List of the validators. However, we can change it to be only aware of the OrderItemValidator type, and nothing else. This gives us the flexibility of injecting either a single validator or any composition of validators in the future.

So now our goal is to change the order service to treat a composition of order item validators in the same way as a single validator. There is a well-known design pattern called Composite which lets us do exactly that.

Let's create a new implementation of the validator interface, which will be the composite -

class OrderItemValidatorComposite implements OrderItemValidator {
  private final List<OrderItemValidator> validators;

  OrderItemValidatorComposite(List<OrderItemValidator> validators) {
    this.validators = validators;
  }

  @Override
  public void validate(OrderItem orderItem) {
    validators.forEach(validator -> validator.validate(orderItem));
  }
}


We then create a new Spring configuration class, which will instantiate and initialize this composite, and then expose it as a bean -

@Configuration
class ValidatorConfiguration {

  @Bean
  OrderItemValidator orderItemValidator(MenuRepository menuRepository) {
    return new OrderItemValidatorComposite(Arrays.asList(
        new MenuValidator(menuRepository),
        new ItemDescriptionValidator(),
        new PriceValidator(),
        new QuantityValidator()
    ));
  }
}


We then change the OrderService class in the following way -

@Service
@Slf4j
class OrderService {
  private final OrderItemValidator validator;

  OrderService(OrderItemValidator orderItemValidator) {
    this.validator = orderItemValidator;
  }

  void createOrder(OrderDTO orderDTO) {
    orderDTO.getOrderItems()
        .forEach(validator::validate);

    log.info("Order {} saved", orderDTO);
  }
}


And we are done!

The benefits of this approach are many. The whole validation logic has completely been abstracted away from the ordering service. Testing is easier. Future maintenance is easier. Clients only know about one validator type, and nothing else.

However, all of the above come with some problems too. Sometimes people are not comfortable with this design. They may feel like this is just too much abstraction, or that they will not be needing this much flexibility or testability for future maintenance. I'd suggest to adopt this approach based on the team culture. After all, there is no single right way of doing things in Software Development.

Note that for the sake of this article I have taken some short cuts here as well. These includes throwing a generic IllegalArgumentException when validation fails. You'd probably want a more specific/custom exception in a production-grade application to identify between different scenarios. The decimal parsing is also done naively, you might want to fix on a specific format, and then use DecimalFormat to parse it.

The full code has been uploaded to Github.

If you find any typo/other errors please feel free to comment in!

Sunday, March 5, 2017

Dealing with Java's LocalDateTime in JPA

A few days ago I ran into a problem while dealing with a LocalDateTime attribute in JPA. In this blog post I will try to create a sample problem to explain the issue, along with the solution that I used.

Consider the following entity, which models an Employee of a certain company -

@Entity
@Getter
@Setter
public class Employee {

  @Id
  @GeneratedValue
  private Long id;
  private String name;
  private String department;
  private LocalDateTime joiningDate;
}

I was using Spring Data JPA, so created the following repository -
@Repository
public interface EmployeeRepository 
    extends JpaRepository<Employee, Long> {

}

I wanted to find all employees who have joined the company at a particular date. To do that I extended my repository from JpaSpecificationExecutor -
@Repository
public interface EmployeeRepository 
    extends JpaRepository<Employee, Long>,
    JpaSpecificationExecutor<Employee> {

}

and wrote a query like below -
@SpringBootTest
@RunWith(SpringRunner.class)
@Transactional
public class EmployeeRepositoryIT {

  @Autowired
  private EmployeeRepository employeeRepository;

  @Test
  public void findingEmployees_joiningDateIsZeroHour_found() {
    DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss");
    LocalDateTime joiningDate = LocalDateTime.parse("2014-04-01 00:00:00", formatter);

    Employee employee = new Employee();
    employee.setName("Test Employee");
    employee.setDepartment("Test Department");
    employee.setJoiningDate(joiningDate);
    employeeRepository.save(employee);

    // Query to find employees
    List<Employee> employees = employeeRepository.findAll((root, query, cb) ->
        cb.and(
            cb.greaterThanOrEqualTo(root.get(Employee_.joiningDate), joiningDate),
            cb.lessThan(root.get(Employee_.joiningDate), joiningDate.plusDays(1)))
    );

    assertThat(employees).hasSize(1);
  }
}

The above test passed without any problem. However, the following test failed (which was supposed to pass) -
@Test
public void findingEmployees_joiningDateIsNotZeroHour_found() {
  DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss");
  LocalDateTime joiningDate = LocalDateTime.parse("2014-04-01 08:00:00", formatter);
  LocalDateTime zeroHour = LocalDateTime.parse("2014-04-01 00:00:00", formatter);

  Employee employee = new Employee();
  employee.setName("Test Employee");
  employee.setDepartment("Test Department");
  employee.setJoiningDate(joiningDate);
  employeeRepository.save(employee);

  List<Employee> employees = employeeRepository.findAll((root, query, cb) ->
      cb.and(
          cb.greaterThanOrEqualTo(root.get(Employee_.joiningDate), zeroHour),
          cb.lessThan(root.get(Employee_.joiningDate), zeroHour.plusDays(1))
      )
  );

  assertThat(employees).hasSize(1);
}

The only thing that is different from the previous test is that in the previous test I used the zero hour as the joining date, and here I used 8 AM.

At first it seemed weird to me. The tests seemed to pass whenever the joining date of an employee was set to a zero hour of a day, but failed whenever it was set to any other time.

In order to investigate the problem I turned on the hibernate logging to see the actual query and the values being sent to the database, and noticed something like this in the log -
2017-03-05 22:26:20.804 DEBUG 8098 --- [           main] org.hibernate.SQL:
    select
        employee0_.id as id1_0_,
        employee0_.department as departme2_0_,
        employee0_.joining_date as joining_3_0_,
        employee0_.name as name4_0_
    from
        employee employee0_
    where
        employee0_.joining_date>=?
        and employee0_.joining_dateHibernate:
    select
        employee0_.id as id1_0_,
        employee0_.department as departme2_0_,
        employee0_.joining_date as joining_3_0_,
        employee0_.name as name4_0_
    from
        employee employee0_
    where
        employee0_.joining_date>=?
        and employee0_.joining_date2017-03-05 22:26:20.806 TRACE 8098 --- [           main] o.h.type.descriptor.sql.BasicBinder      : binding parameter [1] as [VARBINARY] - [2014-04-01T00:00]
2017-03-05 22:26:20.807 TRACE 8098 --- [           main] o.h.type.descriptor.sql.BasicBinder      : binding parameter [2] as [VARBINARY] - [2014-04-02T00:00]
It was evident that JPA was NOT treating the joiningDate attribute as a date or time, but as a VARBINARY type. This is why the comparison to an actual date was failing.

In my opinion this is not a very good design. Rather than throwing something like UnsupportedAttributeException or whatever, it was silently trying to convert the value to something else, and thus failing the comparison at random (well, not exactly random). This type of bugs are hard to find in the application unless you have a strong suit of automated tests, which was fortunately my case.

Back to the problem now. The reason JPA was failing to convert LocalDateTime appropriately was very simple. The last version of the JPA specification (which is 2.1) was released before Java 8, and as a result it cannot handle the new Date and Time API.

To solve the problem, I created a custom converter implementation which converts the LocalDateTime to java.sql.Timestamp before saving it to the database, and vice versa. That solved the problem -
@Converter(autoApply = true)
public class LocalDateTimeConverter implements AttributeConverter<LocalDateTime, Timestamp> {

  @Override
  public Timestamp convertToDatabaseColumn(LocalDateTime localDateTime) {
    return Optional.ofNullable(localDateTime)
        .map(Timestamp::valueOf)
        .orElse(null);
  }

  @Override
  public LocalDateTime convertToEntityAttribute(Timestamp timestamp) {
    return Optional.ofNullable(timestamp)
        .map(Timestamp::toLocalDateTime)
        .orElse(null);
  }
}

The above converter will be automatically applied whenever I try to save a LocalDateTime attribute. I could also explicitly mark the attributes that I wanted to convert explicitly, using the javax.persistence.Convert annotation -
@Convert(converter = LocalDateTimeConverter.class)
private LocalDateTime joiningDate;

The full code is available at Github.