A Quick Start with DJL (Running Deep Learning Models in Pure Java)
Table of Content
1. Introduction
In most cases, server-side applications are developed in Java, while deep learning models are predominantly written in Python. This often leads to scenarios where Java interacts with Python APIs to perform tasks like inference. However, this approach is not only inefficient and inelegant but also problematic if you aim to deploy inference on Android devices, where Java is the primary programming language.
This article introduces a powerful tool: Deep Java Library (DJL), an open-source Java library for deep learning. With DJL, you can perform model inference and even train models directly in Java. Although many tutorials and articles cover DJL, they often overlook a crucial aspect: deep learning isn't just about inference. Preprocessing and postprocessing typically involve complex tensor operations, but most resources fail to address how to handle these tasks.
To better meet practical needs, this article focuses solely on inference with DJL, without diving into model training. Specifically, the content is structured as follows:
- Overview of DJL’s core features
- Loading PyTorch models with DJL
- Tensor operations in DJL
- A practical example: using a PyTorch model in DJL for image classification
2. Core Features of DJL
2.1 What is DJL?
DJL (Deep Java Library) is an open-source deep learning framework designed for Java (and Android). It supports building and training deep learning models, performing tensor operations, and leveraging pre-trained models from popular frameworks like MXNet, PyTorch, and TensorFlow. DJL works with Java 1.8 or higher and provides GPU support.
2.2 Core API of DJL
Before diving into practical examples, let's first go over the core APIs of DJL. This will help you better understand the functionality of the code in the examples later.
2.2.1 Criteria
The Criteria class is used to define the configuration of a model, such as its file path, input/output types, and other properties.
Here is an example of initializing a model in DJL:
```java Criteria<Input, Output> criteria = Criteria.builder() .setTypes(Input.class, Output.class) // defines input and output data type .optTranslator(new InputOutputTranslator()) .optModelPath(Paths.get("/var/models/my_resnet50")) // search models in specified path .optModelName("model/resnet50") // specify model file prefix .build(); ZooModel<Image, Classifications> model = criteria.loadModel(); ```
In the above code, the Criteria
class defines the configuration of the model, with the following key components:
Criteria<I, O>
: Specifies the input (I
) and output (O
) types of the model. These can either be custom classes or classes provided by DJL.setTypes(I.class, O.class)
: This is a required method call. Since the input and outputclass
objects cannot be directly inferred from the generic typesI
andO
, you need to explicitly set them.optModelTranslator
: The input and output of the model are represented as tensors. This method specifies how to convert between yourI
andO
classes and tensor types. The details of theTranslator
will be explained later.optModelName
: Sets the name or prefix of the model file.
Once the model configuration is defined, you can use the loadModel
method to create a Model Zoo instance.
The Model Zoo is a core component in DJL. It provides functionality for managing models, such as creating models, creating predictors, saving models, and more.
2.2.2 Translator
In the previous section, we discussed that the input and output classes of a model can be customized. However, PyTorch models only accept Tensor
types—they cannot directly handle your custom-defined classes. This is where the Translator
interface comes in: it allows you to define how your custom input and output classes are converted to and from Tensor
types.
Here’s an example implementation of a Translator
:
```java private Translator<Input, Output> translator = new Translator<Input, Output>() { @Override public NDList processInput(TranslatorContext ctx, Input input) throws Exception { return null; } @Override public Output processOutput(TranslatorContext ctx, NDList ndList) throws Exception { return null; } }; ```
The Translator
interface contains two key methods:
processInput
: Converts an input class object into tensors. Here,Input
represents the custom input class, whileNDList
is a collection of tensors (since a model's forward function might take multiple tensor arguments). In DJL, a tensor is represented by theNDArray
class (similar tondarray
in NumPy), which will be explained in detail later.processOutput
: Converts the model's output tensors back into your custom-defined class. Since a model can output multiple tensors, this method also handles anNDList
.
Both methods take an important parameter, TranslatorContext
, which stores the context for the translator. You can use it to access certain objects (like Model
or Predictor
) or to store and retrieve custom data using the setAttachment
and getAttachment
methods.
In the official examples, the
Translator
is primarily used for image processing. However, its usage is not limited to images. TheInput
andOutput
types can be any Java classes.
2.2.3 NDArray
In Python, we have numpy
, and in Java, we have NDArray
provided by DJL (Deep Java Library). With this class, we can perform nearly all tensor operations available in numpy. This section introduces some commonly used tensor operations.
Before diving into examples, let’s first look at a few key classes related to NDArray
:
NDArray
: Similar tonumpy.ndarray
, you can retrieve its shape using thegetShape()
method.NDManager
: The manager class for creating and managingNDArrays
. Typically, you initialize a global instance ofNDManager
to manage allNDArrays
.NDIndex
: Used for slicing tensors.Shape
: When creating anNDArray
, you need to specify its shape using this class. When querying the shape of anNDArray
, the result is also an instance of this class.
Now, let’s explore some common tensor operations with examples (only a few examples are listed here. If you have questions about specific operations, feel free to ask in the comments, and I’ll add more examples).
Creating an NDArray (Tensor)
Here’s how to create a tensor with the shape (1, 2, 3, 4)
:
```java NDManager ndManager = NDManager.newBaseManager(); NDArray ndArray = ndManager.create(new Shape(1, 2, 3, 4)); ```
You should create a single global instance of
NDManager
.
Creating a tensor with specified values:
```java ndManager.create(new int[]{1, 2, 3, 4}); ```
Changing Data Types:
Convert the tensor to float
type:
```java ndManager.create(new int[]{1, 2, 3, 4}).toType(DataType.FLOAT32, false); ```
Convert to a float
array:
```java ndManager.create(new int[]{1, 2, 3, 4}).toType(DataType.FLOAT32, false) .toFloatArray(); ```
Important: Before calling
toArray()
, ensure theNDArray
is of the correct type, and the number of bytes matches the target type. For example, in Java, float uses 32 bits (4 bytes) for storage, so theNDArray
must be of typeFloat32
, notFloat64
. Otherwise, an error will occur.
Arithmetic Operations:
You can perform addition, subtraction, multiplication, and division as follows:
```java ndArray.add(1); ndArray.sub(1); ndArray.mul(1); ndArray.div(1); ```
Alternatively, you can use NDArrays.add
, which is similar to np.add()
in Python:
```java NDArrays.add(ndArray, ndArray); ```
Slicing:
Here’s how to slice an NDArray
:
```java NDArray ndArray = ndManager.arange(24).reshape(3, 8); ndArray = ndArray.get(new NDIndex("1:, :")); ```
This is equivalent to Python slicing [1:, :]
.
Limitation: It seems DJL does not support slicing with specific indices directly (e.g.,
nums[[1, 2, 3], [2, 3, 4]]
in Python). I haven’t found a way to do this in DJL yet. Instead, I had to implement it manually using loops. If anyone knows how to achieve this, please share in the comments!
Assigning Values:
You can assign values to slices of an NDArray
:
```java NDArray ndArray = ndManager.arange(24).reshape(3, 8); ndArray.set(new NDIndex("1:, :"), 1); ```
This is equivalent to Python’s ndArray[1:, :] = 1
.
Flipping Tensors:
In Python, you can reverse arrays using slicing (e.g., [..., ::-1]
). While this syntax isn’t directly available in Java, you can achieve the same effect using the flip
method:
```java NDArray ndArray = ndManager.arange(24).reshape(3, 8); ndArray = ndArray.flip(-1); ```
2.2.3 Predictor
After creating the model, you need to instantiate a Predictor
and use it to perform predictions. Here's how you can do it:
```java predictor = zooModel.newPredictor(); Output output = predictor.predict(input); ```
With this, we've covered the commonly used DJL APIs. Next, let’s dive into a practical example.
3. Practical Example: Using DJL with a PyTorch Model for Image Classification
In this example, we'll use a PyTorch pre-trained ResNet-18 model to complete an image classification task.
(1) First, Add the Necessary Dependencies
```xml <dependency> <groupId>ai.djl.pytorch</groupId> <artifactId>pytorch-engine</artifactId> <version>0.17.0</version> <scope>runtime</scope> </dependency> <dependency> <groupId>ai.djl.pytorch</groupId> <artifactId>pytorch-native-cpu</artifactId> <classifier>win-x86_64</classifier> <scope>runtime</scope> <version>1.11.0</version> </dependency> <dependency> <groupId>ai.djl.pytorch</groupId> <artifactId>pytorch-jni</artifactId> <version>1.11.0-0.17.0</version> <scope>runtime</scope> </dependency> <dependency> <groupId>ai.djl</groupId> <artifactId>api</artifactId> <version>0.17.0</version> </dependency> <dependency> <groupId>ai.djl</groupId> <artifactId>basicdataset</artifactId> <version>0.17.0</version> </dependency> <dependency> <groupId>ai.djl.opencv</groupId> <artifactId>opencv</artifactId> <version>0.17.0</version> </dependency> ```
(2) Export the ResNet-18 Model from PyTorch
```python import torch import torchvision # An instance of your model. model = torchvision.models.resnet18(pretrained=True) # Switch the model to eval model model.eval() # An example input you would normally provide to your model's forward() method. example = torch.rand(1, 3, 224, 224) # Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing. traced_script_module = torch.jit.trace(model, example) # Save the TorchScript model traced_script_module.save("traced_resnet_model.pt") ```
(3) Copy the Exported Model to Your Project’s model Directory
(4) Create a Translator: We'll define the input as a String
(representing the image path) and the output as another String (representing the predicted class). Before passing the image to the ResNet-18 model, we need to perform a series of preprocessing steps. Here's the implementation with Python:
```python ... preprocess = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) ... ```
Here is the implementation with Java:
```java Translator<String/*filename*/, String/*class*/> translator = new Translator<String, String>() { @Override public NDList processInput(TranslatorContext ctx, String input) throws Exception { // Load the image from the specified file path Image image = ImageFactory.getInstance().fromFile(Paths.get(input)); NDArray ndArray = image.toNDArray(ctx.getNDManager()); // Before passing the image to ResNet, we need to preprocess it. // While the official examples use 'transforms', we use NDArray operations here // for better alignment with the explanations above. Resize resize = new Resize(256, 256); ndArray = resize.transform(ndArray); // Resize the image to 256x256 pixels // Python equivalent: transforms.CenterCrop(224) // Since NDArray doesn't have a direct CenterCrop method, we use slicing instead. ndArray = ndArray.get(new NDIndex("16:240, 16:240, :")); // ToTensor: Converts the shape from (224, 224, 3) to (3, 224, 224) // and scales pixel values from 0-255 to 0-1 ndArray = new ToTensor().transform(ndArray); // Python equivalent: transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) // Normalize the image using the specified mean and standard deviation Normalize normalize = new Normalize(new float[]{0.485f, 0.456f, 0.406f}, new float[]{0.229f, 0.224f, 0.225f}); ndArray = normalize.transform(ndArray); // ResNet expects a single tensor as input return new NDList(ndArray); } @Override public String processOutput(TranslatorContext ctx, NDList list) throws Exception { // ResNet returns a single tensor, so we extract the first one int index = list.get(0).argMax().toType(DataType.INT32, false).getInt(); // ResNet can classify 1,000 categories; here, we return only the index return index + ""; } }; ```
- Define a
Criteria
, instantiate the model, and create aPredictor
```java Criteria<String, String> criteria = Criteria.builder() .setTypes(String.class, String.class) .optModelPath(Paths.get("model/traced_resnet_model.pt")) .optOption("mapLocation", "true") .optTranslator(translator) .build(); ZooModel model = criteria.loadModel(); Predictor predictor = model.newPredictor(); ```
- Place an image in your project’s test directory for testing.
- Use the
Predictor
to perform the prediction.
```java System.out.println(predictor.predict("test/test.jpg")); ```
Output:
``` 258 ```
ResNet-18 can classify 1,000 different objects. For simplicity, we’re outputting only the class index. To find the class name corresponding to the index, refer to the official class list.
the number 258
corresponds to the class Samoyed (a dog breed), which confirms the prediction is correct.
For more DJL examples and use cases, refer to the official DJL demos.
参考资料
Deep Java Library Official Documentation:https://docs.djl.ai/
Dive Into Deep Learning: https://d2l.djl.ai/chapter_preliminaries/ndarray.html
djl-demo: https://github.com/deepjavalibrary/djl-demo