Deep Learning#

Binder

Because of the static multiple dispatch paradigm layed out in Multiple Dispatch, we need to first include the primitive operations for the device(s) we are inteding on using such that the algorithms (and datastructures) we later include for deep learning can use them.

[1]:
#include <rl_tools/operations/cpu.h>
[2]:
#include <rl_tools/nn/layers/dense/operations_cpu.h>

We set up the environment as described in Containers:

[3]:
namespace rlt = rl_tools;
using DEVICE = rlt::devices::DefaultCPU;
using T = float;
using TI = typename DEVICE::index_t;
DEVICE device;
TI seed = 1;
auto rng = rlt::random::default_engine(DEVICE::SPEC::RANDOM(), seed);

As justified by our analysis of the reinforcement learnign for continuous control landscape (in the paper) in the beginning RLtools only supports fully connected neural networks. But we are planning on adding more architectures (especially recurrent neural networks) in the future.

We can instantiate a simple layer by first defining its hyperparameters (which are compile-time constexpr and types):

[4]:
constexpr TI INPUT_DIM = 5;
constexpr TI OUTPUT_DIM = 5;
constexpr auto ACTIVATION_FUNCTION = rlt::nn::activation_functions::RELU;

These hyperparameters and other options are combined into a specification type such that it is easier to pass it around and such that we don’t need to write out all hyperparameters and options as template parameters when a function takes the datastructure as an argument:

[5]:
using LAYER_SPEC = rlt::nn::layers::dense::Specification<T, TI, INPUT_DIM, OUTPUT_DIM, ACTIVATION_FUNCTION>;

The data structure of a layer does not only depend on its (previously defined) structure but also on the required capabilities. E.g. if we want to do backward passes, the layer needs to store intermediate activations during the forward pass. Furthermore, the buffers of these intermediate values also depend on the batch size. We decouple the capabilities from the structure specification such that we can easily change the capability of a layer or model (e.g. for checkpointing, where we only want to save the parts required for inference or changing the batch size).

[6]:
using CAPABILITY = rlt::nn::layer_capability::Forward;

Using this specification we can declare an actual layer:

[7]:
rlt::nn::layers::dense::Layer<CAPABILITY, LAYER_SPEC> layer;

A fully connected neural network consists of layers each implementing:

\[y = f(Wx + b)\]

where \(x\) is the input (external or from the previous layer), \(W\) and \(b\) are the weight matrix and biases respectively and \(f\) is an element-wise non-linear function. Hence the data structure of a layer should contain at least \(W\) and \(b\). Because these parameters are containers they need to be allocated:

[8]:
rlt::malloc(device, layer);

Now that the memory is allocated we need to initialize it (because it may contain arbitrary values). We use the standard Kaiming initialization scheme:

[9]:
rlt::init_weights(device, layer, rng);

We can print \(W\) and \(b\):

[10]:
rlt::print(device, layer.weights.parameters)
   -0.211912     0.010027     0.373245    -0.388598    -0.055528
   -0.127251    -0.126478     0.330389     0.238816    -0.412481
   -0.385369    -0.351579    -0.394084    -0.141052    -0.433443
   -0.327643     0.299608    -0.113104    -0.288047     0.322775
   -0.399042    -0.282702    -0.171875     0.296949    -0.087313
[11]:
rlt::print(device, layer.biases.parameters)
   -0.447200    -0.363058     0.144835     0.238661     0.172029

Now that the layer is initialized we can run inference using a random input. We first declare and allocate input and output matrices and then randomly initialize the input:

[12]:
constexpr TI BATCH_SIZE = 1;
rlt::MatrixDynamic<rlt::matrix::Specification<T, TI, BATCH_SIZE, INPUT_DIM>> input;
rlt::MatrixDynamic<rlt::matrix::Specification<T, TI, BATCH_SIZE, OUTPUT_DIM>> output;
rlt::malloc(device, input);
rlt::malloc(device, output);
rlt::randn(device, input, rng);
rlt::print(device, input);
    0.100807    -0.911862     2.108090     0.094763     0.537630

Now we can evaluate output of the layer:

[13]:
decltype(layer)::Buffer<BATCH_SIZE> buffer;
rlt::evaluate(device, layer, input, output, buffer, rng);
rlt::print(device, output);
    0.242450     0.236803     0.000000     0.000000     0.008457

Now we are revisiting the capabilities mentioned earlier. For inference storing \(W\) and \(b\) is sufficient but for training we at least need to also store the gradient of the loss \(L\) wrt. \(W\) and \(b\): \(\frac{\mathrm{d}L}{\mathrm{d}W}\) and \(\frac{\mathrm{d}L}{\mathrm{d}b}\). Because depending on the optimizer type we might need to store more information per parameter (like the first and second-order moment in the case of Adam), we abstract the storage for the weights and biases using a parameter type (defined under the rl_tools::nn::parameters namespace) that can e.b. be Plain, Gradient, Adam or any other type extended by the user. For this illustration we are using Gradient:

[14]:
using PARAMETER_TYPE = rlt::nn::parameters::Gradient;
using CAPABILITY_2 = rlt::nn::layer_capability::Gradient<PARAMETER_TYPE, BATCH_SIZE>;
using LAYER_2_SPEC = rlt::nn::layers::dense::Specification<T, TI, INPUT_DIM, OUTPUT_DIM, ACTIVATION_FUNCTION>;
rlt::nn::layers::dense::Layer<CAPABILITY_2, LAYER_2_SPEC> layer_2;
rlt::malloc(device, layer_2);
rlt::copy(device, device, layer, layer_2);
rlt::zero_gradient(device, layer_2);

Note that by using the rl_tools::nn::layer_capability::Gradient capability, the rl_tools::nn::layers::dense::Layer datastructure contains the necessary buffers (e.g. for itermediate activations) to support the backpropagation algorithm. Additionally, similar to PyTorch we are setting the gradient to zero because it is accumulated with subsequent backward passes.

Now we can backpropagate the derivative of the loss wrt. the output to calculate the derivative of the loss wrt. the input. Hence the derivative of the loss wrt. the output: d_output is actually an input to the rl_tools::backward operator. The operator also accumulates the derivative of the loss wrt. the weights and biases in the layer. We first allocate containers for d_input and d_output and randomly set d_output (a hypothetical gradient of the input of some upstream layers)

[15]:
rlt::MatrixDynamic<rlt::matrix::Specification<T, TI, BATCH_SIZE, OUTPUT_DIM>> d_output;
rlt::MatrixDynamic<rlt::matrix::Specification<T, TI, BATCH_SIZE, INPUT_DIM>> d_input;
rlt::malloc(device, d_input);
rlt::malloc(device, d_output);
rlt::randn(device, d_output, rng);

Now we execute the backpropagation and display the gradient of the loss wrt. the input:

[16]:
rlt::forward(device, layer_2, input, buffer, rng);
std::cout << "Output (should be identical to layer_1): " << std::endl;
rlt::print(device, layer_2.output);
rlt::backward_full(device, layer_2, input, d_output, d_input, buffer);
std::cout << "Derivative with respect to the input: " << std::endl;
rlt::print(device, d_input);
Output (should be identical to layer_1):
    0.242450     0.236803     0.000000     0.000000     0.008457
Derivative with respect to the input:
   -0.214641    -0.065061    -0.232326    -0.153192     0.257247

This also accumulates the gradient in the weights and biases:

[17]:
rlt::print(device, layer_2.weights.gradient);
    0.036480    -0.329980     0.762866     0.034292     0.194555
   -0.080598     0.729055    -1.685465    -0.075765    -0.429848
    0.000000     0.000000     0.000000     0.000000     0.000000
    0.000000     0.000000     0.000000     0.000000     0.000000
    0.060552    -0.547734     1.266280     0.056922     0.322942
[18]:
rlt::print(device, layer_2.biases.gradient);
    0.361875    -0.799522     0.000000     0.000000     0.600676
[19]:
rlt::free(device, layer);
rlt::free(device, layer_2);
rlt::free(device, input);
rlt::free(device, output);
rlt::free(device, d_input);
rlt::free(device, d_output);

Multilayer Perceptron (MLP)#

Until now we showed the behavior of a single, fully-connected layer. RLtools contains an Multilayer Perceptron (MLP) that conveniently integrates an arbitrary number of layers into a single data structure with algorithms to perform forward passes and backpropagation across the whole model. The MLP is locate under the namespace rl_tools::nn_models and we include it as well as the operations of the Adam optimizer:

[20]:
#include <rl_tools/nn/optimizers/adam/instance/operations_generic.h>
#include <rl_tools/nn_models/mlp/operations_generic.h>
#include <rl_tools/nn/optimizers/adam/operations_generic.h>

Note that the operations of the (Adam) optimizer are split into instance/operations_generic.h and operations_generic.h. The former contains operations that use and modify the values that are associated with a particular set of parameters (e.g. weights or biases of a particular layer). An example is the rl_tools::update operation that applies the optimizer and in the case of Adam, updates the first and second order moment based on the gradient and then applies the update rule to the parameters. Since these parameters can reside in an arbitrary structure (like an MLP or Sequential nn_model) the rl_tools::update function is called by an operation that knows about this structure (e.g. the rl_tools::update of the rl_tools::nn_models::mlp in turn calls the update operations of its layers). These instance-associated operations carry out the bulk of the gradient descent step but are necessarily myopic because they don’t know about higher-level structure. Because optimizers like Adam not only have instance-associated state (like the first and second order moments of the gradient) but also global state like the step counter we also include the global .../adam/operations_generic.h.

The order of the include is dictated by the underlying usage, where we call rl_tools::update on the optimizer, providing the model. The optimizer then invokes the update of the model, which invokes the update of its submodels/layers which in turn call the update on then parameter instance. For each step in this chain, the next operation should already be included, hence we arrive at the order used in this example.

Next we define the hyperparameters:

[21]:
constexpr TI INPUT_DIM_MLP = 5;
constexpr TI OUTPUT_DIM_MLP = 1;
constexpr TI NUM_LAYERS = 3;
constexpr TI HIDDEN_DIM = 10;
constexpr auto ACTIVATION_FUNCTION_MLP = rlt::nn::activation_functions::RELU;
constexpr auto OUTPUT_ACTIVATION_FUNCTION_MLP = rlt::nn::activation_functions::IDENTITY;

Note that the MLP supports architectures with an arbitrary depth but each layer has to have the same dimensionality. This is because the layers are stored in an array and hence all need to have the same type. If we would allow for different hidden dimensions, we would have to give up on having arbitrary depths.

We aggregate the hyperparameters into a specification again (first just for the structure, later for the full network, incorporating the structure):

[22]:
using MODEL_SPEC = rlt::nn_models::mlp::Specification<T, DEVICE::index_t, INPUT_DIM_MLP, OUTPUT_DIM_MLP, NUM_LAYERS, HIDDEN_DIM, ACTIVATION_FUNCTION_MLP, OUTPUT_ACTIVATION_FUNCTION_MLP>;

We use the default Adam parameters (taken from TensorFlow) and set up the optimizer type using these parameters. Moreover, we create a full network specification for a network that can be trained with Adam which takes the structure specification as an input. Finally we define the full network type:

[23]:
using PARAMETER_TYPE = rlt::nn::parameters::Adam;
using CAPABILITY = rlt::nn::layer_capability::Gradient<PARAMETER_TYPE, BATCH_SIZE>;
using OPTIMIZER_SPEC = rlt::nn::optimizers::adam::Specification<T, TI>;
using OPTIMIZER = rlt::nn::optimizers::Adam<OPTIMIZER_SPEC>;
using MODEL_TYPE = rlt::nn_models::mlp::NeuralNetwork<CAPABILITY, MODEL_SPEC>;

Using these type definitions we can now declare the optimizer and the model. All the optimizer state is contained in the PARAMETER_TYPE of the model (and an additional age integer in the model in the case of Adam). In comparison to PyTorch which stores the optimizer state in the optimizer, we prefer to store the first and second-order moment next to the parameters like it is the case for the gradient anyways (in PyTorch as well). Hence the optimizer is stateless in this case (does not need to be for user-defined optimizers) and we only need to allocate the model.

The backpropagation algorithm needs to store the intermediate gradients. To save memory we do not add a d_input or d_output to each layer but rather use a double buffer with the maximum size of the hidden representation needed.

[24]:
OPTIMIZER optimizer;
MODEL_TYPE model;
typename MODEL_TYPE::Buffer<BATCH_SIZE> buffer;

We allocate the model and set initialize its weights randomly like in the case for the single layer. We are again zeroing the gradient of all parameters of all layers as well as resetting the optimizer state of all parameters of all layers (e.g. in the case of Adam the first and second order moments are set to zero). Finally we also allocate the buffers

[25]:
rlt::malloc(device, model);
rlt::init_weights(device, model, rng); // recursively initializes all layers using kaiming initialization
rlt::zero_gradient(device, model); // recursively zeros all gradients in the layers
rlt::reset_optimizer_state(device, optimizer, model);
rlt::malloc(device, buffer);

In this example we showcase an MLP with a five dimensional input and a one dimensional output (remember the OUTPUT_ACTIVATION_FUNCTION_MLP is IDENTITY so it can also output negative values). For these new shapes we declare and allocate the input and output containers:

[26]:
rlt::MatrixDynamic<rlt::matrix::Specification<T, TI, BATCH_SIZE, INPUT_DIM_MLP>> input_mlp, d_input_mlp;
rlt::MatrixDynamic<rlt::matrix::Specification<T, TI, BATCH_SIZE, OUTPUT_DIM_MLP>> d_output_mlp;
rlt::malloc(device, input_mlp);
rlt::malloc(device, d_input_mlp);
rlt::malloc(device, d_output_mlp);

Now, like in the case of the single layer, we can run a forward pass using the input. Because the model is a Adam model (which is a subclass of rlt::nn_models::mlp::NeuralNetworkBackwardGradient), it stores the intermediate (and final) outputs.

[27]:
rlt::randn(device, input_mlp, rng);
rlt::forward(device, model, input_mlp, buffer, rng);
T output_value = get(model.output_layer.output, 0, 0);
output_value
[27]:
0.506566f

Now imagine we want the output of the model (for this input) to be \(1\). We calculate the error and feed it back through the model using backpropagation. d_output_mlp should be the derivative of the loss function, hence it gives the direction of the output that would increase the loss. Our error is the opposite, if we would move the output into the direction of the error we would come closer to our target value and hence decrease the loss. Because of this, we feed back -error. This procedure also corresponds to using a squared loss because error is (up to a constant) the derivative of the squared loss.

[28]:
T target_output_value = 1;
T error = target_output_value - output_value;
rlt::set(d_output_mlp, 0, 0, -error);
rlt::backward(device, model, input_mlp, d_output_mlp, buffer);

The backward pass populates the gradient in all parameters of the model. Using this gradient we can apply the rlt::step operator which updates the first and second order moments of the gradient of all parameters and afterwards applies the Adam update rule to update the parameters:

[29]:
rlt::step(device, optimizer, model);

Now the next forward pass should be closer to the target value:

[30]:
rlt::forward(device, model, input_mlp, buffer, rng);
get(model.output_layer.output, 0, 0)
[30]:
0.518496f

Next we will train the network to actually perform a function (not only trying to output a constant value as before). With the following training loop we train it to behave like the rlt::max operator which outputs the max of the five inputs. We run the forward and backward pass for \(32\) iterations while accumulating the gradient which effectively leads to a batch size of \(32\)

[31]:
for(TI i=0; i < 10000; i++){
    rlt::zero_gradient(device, model);
    T mse = 0;
    for(TI batch_i=0; batch_i < 32; batch_i++){
        rlt::randn(device, input_mlp, rng);
        rlt::forward(device, model, input_mlp, buffer, rng);
        T output_value = get(model.output_layer.output, 0, 0);
        T target_output_value = rlt::max(device, input_mlp);
        T error = target_output_value - output_value;
        rlt::set(d_output_mlp, 0, 0, -error);
        rlt::backward(device, model, input_mlp, d_output_mlp, buffer);
        mse += error * error;
    }
    rlt::step(device, optimizer, model);
    if(i % 1000 == 0)
    std::cout << "Squared error: " << mse/32 << std::endl;
}
Squared error: 0.643161
Squared error: 0.055282
Squared error: 0.025145
Squared error: 0.016128
Squared error: 0.016535
Squared error: 0.017536
Squared error: 0.011195
Squared error: 0.009795
Squared error: 0.008141
Squared error: 0.013169

Now we can test the model using some arbitrary input (which should be in the distribution of input values) and the model should output a value close to the maximum of the five input values:

[32]:
set(input_mlp, 0, 0, +0.0);
set(input_mlp, 0, 1, -0.1);
set(input_mlp, 0, 2, +0.5);
set(input_mlp, 0, 3, -0.4);
set(input_mlp, 0, 4, +0.1);

rlt::forward(device, model, input_mlp, buffer, rng);
rlt::get(model.output_layer.output, 0, 0)
[32]:
0.465697f

We can also automatically test it with \(10\) random inputs:

[33]:
for(TI i=0; i < 10; i++){
    rlt::randn(device, input_mlp, rng);
    rlt::forward(device, model, input_mlp, buffer, rng);
    std::cout << "max: " << rlt::max(device, input_mlp) << " output: " << rlt::get(model.output_layer.output, 0, 0) << std::endl;
}
max: 0.539628 output: 0.555159
max: 1.348390 output: 1.313255
max: 1.660528 output: 1.620614
max: 1.779285 output: 1.739960
max: 1.311534 output: 1.279970
max: 0.965693 output: 0.929909
max: 2.799156 output: 2.870854
max: 1.195009 output: 1.313729
max: 0.797983 output: 0.711116
max: 0.419951 output: 0.451161

If the values are not close the model might need some more training iterations.

Sequential#

#todo