A Quick Start with DJL (Running Deep Learning Models in Pure Java)


中文 | English

1. Introduction

In most cases, server-side applications are developed in Java, while deep learning models are predominantly written in Python. This often leads to scenarios where Java interacts with Python APIs to perform tasks like inference. However, this approach is not only inefficient and inelegant but also problematic if you aim to deploy inference on Android devices, where Java is the primary programming language.

This article introduces a powerful tool: Deep Java Library (DJL), an open-source Java library for deep learning. With DJL, you can perform model inference and even train models directly in Java. Although many tutorials and articles cover DJL, they often overlook a crucial aspect: deep learning isn't just about inference. Preprocessing and postprocessing typically involve complex tensor operations, but most resources fail to address how to handle these tasks.

To better meet practical needs, this article focuses solely on inference with DJL, without diving into model training. Specifically, the content is structured as follows:

  • Overview of DJL’s core features
  • Loading PyTorch models with DJL
  • Tensor operations in DJL
  • A practical example: using a PyTorch model in DJL for image classification

2. Core Features of DJL

2.1 What is DJL?

DJL (Deep Java Library) is an open-source deep learning framework designed for Java (and Android). It supports building and training deep learning models, performing tensor operations, and leveraging pre-trained models from popular frameworks like MXNet, PyTorch, and TensorFlow. DJL works with Java 1.8 or higher and provides GPU support.

2.2 Core API of DJL

Before diving into practical examples, let's first go over the core APIs of DJL. This will help you better understand the functionality of the code in the examples later.

2.2.1 Criteria

The Criteria class is used to define the configuration of a model, such as its file path, input/output types, and other properties.

Here is an example of initializing a model in DJL:

```java
Criteria<Input, Output> criteria = Criteria.builder()
        .setTypes(Input.class, Output.class) // defines input and output data type
        .optTranslator(new InputOutputTranslator())
        .optModelPath(Paths.get("/var/models/my_resnet50")) // search models in specified path
        .optModelName("model/resnet50") // specify model file prefix
        .build();

ZooModel<Image, Classifications> model = criteria.loadModel();
```

In the above code, the Criteria class defines the configuration of the model, with the following key components:

  • Criteria<I, O>: Specifies the input (I) and output (O) types of the model. These can either be custom classes or classes provided by DJL.
  • setTypes(I.class, O.class): This is a required method call. Since the input and output class objects cannot be directly inferred from the generic types I and O, you need to explicitly set them.
  • optModelTranslator: The input and output of the model are represented as tensors. This method specifies how to convert between your I and O classes and tensor types. The details of the Translator will be explained later.
  • optModelName: Sets the name or prefix of the model file.

Once the model configuration is defined, you can use the loadModel method to create a Model Zoo instance.

The Model Zoo is a core component in DJL. It provides functionality for managing models, such as creating models, creating predictors, saving models, and more.

2.2.2 Translator

In the previous section, we discussed that the input and output classes of a model can be customized. However, PyTorch models only accept Tensor types—they cannot directly handle your custom-defined classes. This is where the Translator interface comes in: it allows you to define how your custom input and output classes are converted to and from Tensor types.

Here’s an example implementation of a Translator:

```java
private Translator<Input, Output> translator = new Translator<Input, Output>() {

    @Override
    public NDList processInput(TranslatorContext ctx, Input input) throws Exception {
        return null;
    }

    @Override
    public Output processOutput(TranslatorContext ctx, NDList ndList) throws Exception {
        return null;
    }
};
```

The Translator interface contains two key methods:

  • processInput: Converts an input class object into tensors. Here, Input represents the custom input class, while NDList is a collection of tensors (since a model's forward function might take multiple tensor arguments). In DJL, a tensor is represented by the NDArray class (similar to ndarray in NumPy), which will be explained in detail later.
  • processOutput: Converts the model's output tensors back into your custom-defined class. Since a model can output multiple tensors, this method also handles an NDList.

Both methods take an important parameter, TranslatorContext, which stores the context for the translator. You can use it to access certain objects (like Model or Predictor) or to store and retrieve custom data using the setAttachment and getAttachment methods.

In the official examples, the Translator is primarily used for image processing. However, its usage is not limited to images. The Input and Output types can be any Java classes.

2.2.3 NDArray

In Python, we have numpy, and in Java, we have NDArray provided by DJL (Deep Java Library). With this class, we can perform nearly all tensor operations available in numpy. This section introduces some commonly used tensor operations.

Before diving into examples, let’s first look at a few key classes related to NDArray:

  • NDArray: Similar to numpy.ndarray, you can retrieve its shape using the getShape() method.
  • NDManager: The manager class for creating and managing NDArrays. Typically, you initialize a global instance of NDManager to manage all NDArrays.
  • NDIndex: Used for slicing tensors.
  • Shape: When creating an NDArray, you need to specify its shape using this class. When querying the shape of an NDArray, the result is also an instance of this class.

Now, let’s explore some common tensor operations with examples (only a few examples are listed here. If you have questions about specific operations, feel free to ask in the comments, and I’ll add more examples).


Creating an NDArray (Tensor)

Here’s how to create a tensor with the shape (1, 2, 3, 4):

```java
NDManager ndManager = NDManager.newBaseManager();
NDArray ndArray = ndManager.create(new Shape(1, 2, 3, 4));
```

You should create a single global instance of NDManager.

Creating a tensor with specified values:

```java
ndManager.create(new int[]{1, 2, 3, 4});
```

Changing Data Types

Convert the tensor to float type:

```java
ndManager.create(new int[]{1, 2, 3, 4}).toType(DataType.FLOAT32, false);
```

Convert to a float array:

```java
ndManager.create(new int[]{1, 2, 3, 4}).toType(DataType.FLOAT32, false)
                                    .toFloatArray();
```

Important: Before calling toArray(), ensure the NDArray is of the correct type, and the number of bytes matches the target type. For example, in Java, float uses 32 bits (4 bytes) for storage, so the NDArray must be of type Float32, not Float64. Otherwise, an error will occur.


Arithmetic Operations

You can perform addition, subtraction, multiplication, and division as follows:

```java
ndArray.add(1);
ndArray.sub(1);
ndArray.mul(1);
ndArray.div(1);
```

Alternatively, you can use NDArrays.add, which is similar to np.add() in Python:

```java
NDArrays.add(ndArray, ndArray);
```

Slicing:

Here’s how to slice an NDArray:

```java
NDArray ndArray = ndManager.arange(24).reshape(3, 8);
ndArray = ndArray.get(new NDIndex("1:, :"));
```

This is equivalent to Python slicing [1:, :].

Limitation: It seems DJL does not support slicing with specific indices directly (e.g., nums[[1, 2, 3], [2, 3, 4]] in Python). I haven’t found a way to do this in DJL yet. Instead, I had to implement it manually using loops. If anyone knows how to achieve this, please share in the comments!


Assigning Values:

You can assign values to slices of an NDArray:

```java
NDArray ndArray = ndManager.arange(24).reshape(3, 8);
ndArray.set(new NDIndex("1:, :"), 1);
```

This is equivalent to Python’s ndArray[1:, :] = 1.


Flipping Tensors:

In Python, you can reverse arrays using slicing (e.g., [..., ::-1]). While this syntax isn’t directly available in Java, you can achieve the same effect using the flip method:

```java
NDArray ndArray = ndManager.arange(24).reshape(3, 8);
ndArray = ndArray.flip(-1);
```

2.2.3 Predictor

After creating the model, you need to instantiate a Predictor and use it to perform predictions. Here's how you can do it:

```java
predictor = zooModel.newPredictor();
Output output = predictor.predict(input);
```

With this, we've covered the commonly used DJL APIs. Next, let’s dive into a practical example.

3. Practical Example: Using DJL with a PyTorch Model for Image Classification

In this example, we'll use a PyTorch pre-trained ResNet-18 model to complete an image classification task.

(1) First, Add the Necessary Dependencies

```xml
<dependency>
    <groupId>ai.djl.pytorch</groupId>
    <artifactId>pytorch-engine</artifactId>
    <version>0.17.0</version>
    <scope>runtime</scope>
</dependency>

<dependency>
    <groupId>ai.djl.pytorch</groupId>
    <artifactId>pytorch-native-cpu</artifactId>
    <classifier>win-x86_64</classifier>
    <scope>runtime</scope>
    <version>1.11.0</version>
</dependency>
<dependency>
    <groupId>ai.djl.pytorch</groupId>
    <artifactId>pytorch-jni</artifactId>
    <version>1.11.0-0.17.0</version>
    <scope>runtime</scope>
</dependency>

<dependency>
    <groupId>ai.djl</groupId>
    <artifactId>api</artifactId>
    <version>0.17.0</version>
</dependency>

<dependency>
    <groupId>ai.djl</groupId>
    <artifactId>basicdataset</artifactId>
    <version>0.17.0</version>
</dependency>

<dependency>
    <groupId>ai.djl.opencv</groupId>
    <artifactId>opencv</artifactId>
    <version>0.17.0</version>
</dependency>
```

(2) Export the ResNet-18 Model from PyTorch

```python
import torch
import torchvision

# An instance of your model.
model = torchvision.models.resnet18(pretrained=True)

# Switch the model to eval model
model.eval()

# An example input you would normally provide to your model's forward() method.
example = torch.rand(1, 3, 224, 224)

# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(model, example)

# Save the TorchScript model
traced_script_module.save("traced_resnet_model.pt")
```

(3) Copy the Exported Model to Your Project’s model Directory
在这里插入图片描述

(4) Create a Translator: We'll define the input as a String (representing the image path) and the output as another String (representing the predicted class). Before passing the image to the ResNet-18 model, we need to perform a series of preprocessing steps. Here's the implementation with Python:

```python
...
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
...
```

Here is the implementation with Java:

```java
Translator<String/*filename*/, String/*class*/> translator = new Translator<String, String>() {

    @Override
    public NDList processInput(TranslatorContext ctx, String input) throws Exception {
        // Load the image from the specified file path
        Image image = ImageFactory.getInstance().fromFile(Paths.get(input));
        NDArray ndArray = image.toNDArray(ctx.getNDManager());

        // Before passing the image to ResNet, we need to preprocess it.
        // While the official examples use 'transforms', we use NDArray operations here
        // for better alignment with the explanations above.
        Resize resize = new Resize(256, 256);
        ndArray = resize.transform(ndArray); // Resize the image to 256x256 pixels

        // Python equivalent: transforms.CenterCrop(224)
        // Since NDArray doesn't have a direct CenterCrop method, we use slicing instead.
        ndArray = ndArray.get(new NDIndex("16:240, 16:240, :"));

        // ToTensor: Converts the shape from (224, 224, 3) to (3, 224, 224)
        // and scales pixel values from 0-255 to 0-1
        ndArray = new ToTensor().transform(ndArray);

        // Python equivalent: transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        // Normalize the image using the specified mean and standard deviation
        Normalize normalize = new Normalize(new float[]{0.485f, 0.456f, 0.406f}, new float[]{0.229f, 0.224f, 0.225f});
        ndArray = normalize.transform(ndArray);

        // ResNet expects a single tensor as input
        return new NDList(ndArray);
    }

    @Override
    public String processOutput(TranslatorContext ctx, NDList list) throws Exception {
        // ResNet returns a single tensor, so we extract the first one
        int index = list.get(0).argMax().toType(DataType.INT32, false).getInt();
        // ResNet can classify 1,000 categories; here, we return only the index
        return index + "";
    }
};
```
  1. Define a Criteria, instantiate the model, and create a Predictor
```java
Criteria<String, String> criteria = Criteria.builder()
        .setTypes(String.class, String.class)
        .optModelPath(Paths.get("model/traced_resnet_model.pt"))
        .optOption("mapLocation", "true")
        .optTranslator(translator)
        .build();

ZooModel model = criteria.loadModel();
Predictor predictor = model.newPredictor();
```
  1. Place an image in your project’s test directory for testing.


在这里插入图片描述
在这里插入图片描述

  1. Use the Predictor to perform the prediction.
```java
System.out.println(predictor.predict("test/test.jpg"));
```

Output:

```
258
```

ResNet-18 can classify 1,000 different objects. For simplicity, we’re outputting only the class index. To find the class name corresponding to the index, refer to the official class list.

the number 258 corresponds to the class Samoyed (a dog breed), which confirms the prediction is correct.

For more DJL examples and use cases, refer to the official DJL demos.


参考资料

Deep Java Library Official Documentation:https://docs.djl.ai/

Dive Into Deep Learning: https://d2l.djl.ai/chapter_preliminaries/ndarray.html

djl-demo: https://github.com/deepjavalibrary/djl-demo

Next Post Previous Post
No Comment
Add Comment
comment url