Deep Learning#
Because of the static multiple dispatch paradigm layed out in Multiple Dispatch, we need to first include the primitive operations for the device(s) we are inteding on using such that the algorithms (and datastructures) we later include for deep learning can use them.
[1]:
#include <rl_tools/operations/cpu.h>
[2]:
#include <rl_tools/nn/layers/dense/operations_cpu.h>
We set up the environment as described in Containers:
[3]:
namespace rlt = rl_tools;
using DEVICE = rlt::devices::DefaultCPU;
using T = float;
using TI = typename DEVICE::index_t;
DEVICE device;
TI seed = 1;
auto rng = rlt::random::default_engine(DEVICE::SPEC::RANDOM(), seed);
As justified by our analysis of the reinforcement learnign for continuous control landscape (in the paper) in the beginning RLtools only supports fully connected neural networks. But we are planning on adding more architectures (especially recurrent neural networks) in the future.
We can instantiate a simple layer by first defining its hyperparameters (which are compile-time constexpr
and types):
[4]:
constexpr TI BATCH_SIZE = 1;
constexpr TI INPUT_DIM = 5;
constexpr TI OUTPUT_DIM = 5;
constexpr auto ACTIVATION_FUNCTION = rlt::nn::activation_functions::RELU;
These hyperparameters and other options are combined into a specification type such that it is easier to pass it around and such that we don’t need to write out all hyperparameters and options as template parameters when a function takes the datastructure as an argument:
[5]:
using LAYER_CONFIG = rlt::nn::layers::dense::Configuration<T, TI, OUTPUT_DIM, ACTIVATION_FUNCTION>;
The data structure of a layer does not only depend on its (previously defined) structure but also on the required capabilities. E.g. if we want to do backward passes, the layer needs to store intermediate activations during the forward pass. Furthermore, the buffers of these intermediate values also depend on the batch size. We decouple the capabilities from the structure specification such that we can easily change the capability of a layer or model (e.g. for checkpointing, where we only want to save the parts required for inference or changing the batch size).
[6]:
using CAPABILITY = rlt::nn::capability::Forward<>;
You might have noticed that the LAYER_CONFIG
does not specify an input dimensionality. This is because the input shapes automatically cascade through models in RLtools, where the input shape of a subsequent layer is determined by the output shape of the preceding layer. This will make more sense once we introduce the Sequential
model that combines multiple layers later in this document.
[7]:
using INPUT_SHAPE = rlt::tensor::Shape<TI, BATCH_SIZE, INPUT_DIM>;
Using this specification we can declare an actual layer:
[8]:
rlt::nn::layers::dense::Layer<LAYER_CONFIG, CAPABILITY, INPUT_SHAPE> layer;
A fully connected neural network consists of layers each implementing:
where \(x\) is the input (external or from the previous layer), \(W\) and \(b\) are the weight matrix and biases respectively and \(f\) is an element-wise non-linear function. Hence the data structure of a layer should contain at least \(W\) and \(b\). Because these parameters are containers they need to be allocated:
[9]:
rlt::malloc(device, layer);
Now that the memory is allocated we need to initialize it (because it may contain arbitrary values). We use the standard Kaiming initialization scheme:
[10]:
rlt::init_weights(device, layer, rng);
We can print \(W\) and \(b\):
[11]:
rlt::print(device, layer.weights.parameters)
-0.211912 0.010027 0.373245 -0.388598 -0.055528
-0.127251 -0.126478 0.330389 0.238816 -0.412481
-0.385369 -0.351579 -0.394084 -0.141052 -0.433443
-0.327643 0.299608 -0.113104 -0.288047 0.322775
-0.399042 -0.282702 -0.171875 0.296949 -0.087313
[12]:
rlt::print(device, layer.biases.parameters)
-0.447200 -0.363058 0.144835 0.238661 0.172029
Now that the layer is initialized we can run inference using a random input. We first declare and allocate input and output matrices and then randomly initialize the input:
[13]:
rlt::Matrix<rlt::matrix::Specification<T, TI, BATCH_SIZE, INPUT_DIM>> input;
rlt::Matrix<rlt::matrix::Specification<T, TI, BATCH_SIZE, OUTPUT_DIM>> output;
rlt::malloc(device, input);
rlt::malloc(device, output);
rlt::randn(device, input, rng);
rlt::print(device, input);
0.100807 -0.911862 2.108090 0.094763 0.537630
Now we can evaluate output of the layer:
[14]:
decltype(layer)::Buffer<BATCH_SIZE> buffer;
rlt::evaluate(device, layer, input, output, buffer, rng);
rlt::print(device, output);
0.242450 0.236803 0.000000 0.000000 0.008457
Now we are revisiting the capabilities mentioned earlier. For inference storing \(W\) and \(b\) is sufficient but for training we at least need to also store the gradient of the loss \(L\) wrt. \(W\) and \(b\): \(\frac{\mathrm{d}L}{\mathrm{d}W}\) and \(\frac{\mathrm{d}L}{\mathrm{d}b}\). Because depending on the optimizer type we might need to store more information per parameter (like the first and second-order moment in the case of
Adam), we abstract the storage for the weights and biases using a parameter type (defined under the rl_tools::nn::parameters
namespace) that can e.b. be Plain
, Gradient
, Adam
or any other type extended by the user. For this illustration we are using Gradient
:
[15]:
using PARAMETER_TYPE = rlt::nn::parameters::Gradient;
using CAPABILITY_2 = rlt::nn::capability::Gradient<PARAMETER_TYPE, BATCH_SIZE>;
using LAYER_2_CONFIG = rlt::nn::layers::dense::Configuration<T, TI, OUTPUT_DIM, ACTIVATION_FUNCTION>;
rlt::nn::layers::dense::Layer<LAYER_2_CONFIG, CAPABILITY_2, INPUT_SHAPE> layer_2;
rlt::malloc(device, layer_2);
rlt::copy(device, device, layer, layer_2);
rlt::zero_gradient(device, layer_2);
Note that by using the rl_tools::nn::capability::Gradient
capability, the rl_tools::nn::layers::dense::Layer
datastructure contains the necessary buffers (e.g. for itermediate activations) to support the backpropagation algorithm. Additionally, similar to PyTorch we are setting the gradient to zero because it is accumulated with subsequent backward passes.
Now we can backpropagate the derivative of the loss wrt. the output
to calculate the derivative of the loss wrt. the input
. Hence the derivative of the loss wrt. the output
: d_output
is actually an input to the rl_tools::backward
operator. The operator also accumulates the derivative of the loss wrt. the weights and biases in the layer. We first allocate containers for d_input
and d_output
and randomly set d_output
(a hypothetical gradient of the input of some
upstream layers)
[16]:
rlt::Matrix<rlt::matrix::Specification<T, TI, BATCH_SIZE, OUTPUT_DIM>> d_output;
rlt::Matrix<rlt::matrix::Specification<T, TI, BATCH_SIZE, INPUT_DIM>> d_input;
rlt::malloc(device, d_input);
rlt::malloc(device, d_output);
rlt::randn(device, d_output, rng);
Now we execute the backpropagation and display the gradient of the loss wrt. the input:
[17]:
rlt::forward(device, layer_2, input, buffer, rng);
std::cout << "Output (should be identical to layer_1): " << std::endl;
rlt::print(device, layer_2.output);
rlt::backward_full(device, layer_2, input, d_output, d_input, buffer);
std::cout << "Derivative with respect to the input: " << std::endl;
rlt::print(device, d_input);
Output (should be identical to layer_1):
0.242450 0.236803 0.000000 0.000000 0.008457
Derivative with respect to the input:
-0.214641 -0.065061 -0.232326 -0.153192 0.257247
This also accumulates the gradient in the weights and biases:
[18]:
rlt::print(device, layer_2.weights.gradient);
0.036480 -0.329980 0.762866 0.034292 0.194555
-0.080598 0.729055 -1.685465 -0.075765 -0.429848
0.000000 0.000000 0.000000 0.000000 0.000000
0.000000 0.000000 0.000000 0.000000 0.000000
0.060552 -0.547734 1.266280 0.056922 0.322942
[19]:
rlt::print(device, layer_2.biases.gradient);
0.361875 -0.799522 0.000000 0.000000 0.600676
[20]:
rlt::free(device, layer);
rlt::free(device, layer_2);
rlt::free(device, input);
rlt::free(device, output);
rlt::free(device, d_input);
rlt::free(device, d_output);
Multilayer Perceptron (MLP)#
Until now we showed the behavior of a single, fully-connected layer. RLtools contains an Multilayer Perceptron (MLP) that conveniently integrates an arbitrary number of layers into a single data structure with algorithms to perform forward passes and backpropagation across the whole model. The MLP is locate under the namespace rl_tools::nn_models
and we include it as well as the operations of the Adam optimizer:
[21]:
#include <rl_tools/nn/optimizers/adam/instance/operations_generic.h>
#include <rl_tools/nn_models/mlp/operations_generic.h>
#include <rl_tools/nn/optimizers/adam/operations_generic.h>
Note that the operations of the (Adam) optimizer are split into instance/operations_generic.h
and operations_generic.h
. The former contains operations that use and modify the values that are associated with a particular set of parameters (e.g. weights
or biases
of a particular layer). An example is the rl_tools::update
operation that applies the optimizer and in the case of Adam, updates the first and second order moment based on the gradient and then applies the update rule
to the parameters. Since these parameters can reside in an arbitrary structure (like an MLP or Sequential nn_model
) the rl_tools::update
function is called by an operation that knows about this structure (e.g. the rl_tools::update
of the rl_tools::nn_models::mlp
in turn calls the update operations of its layers). These instance-associated operations carry out the bulk of the gradient descent step but are necessarily myopic because they don’t know about higher-level structure.
Because optimizers like Adam not only have instance-associated state (like the first and second order moments of the gradient) but also global state like the step counter we also include the global .../adam/operations_generic.h
.
The order of the include is dictated by the underlying usage, where we call rl_tools::update
on the optimizer, providing the model. The optimizer then invokes the update of the model, which invokes the update of its submodels/layers which in turn call the update on then parameter instance. For each step in this chain, the next operation should already be included, hence we arrive at the order used in this example.
Next we define the hyperparameters:
[22]:
constexpr TI INPUT_DIM_MLP = 5;
constexpr TI OUTPUT_DIM_MLP = 1;
constexpr TI NUM_LAYERS = 3;
constexpr TI HIDDEN_DIM = 10;
constexpr auto ACTIVATION_FUNCTION_MLP = rlt::nn::activation_functions::RELU;
constexpr auto OUTPUT_ACTIVATION_FUNCTION_MLP = rlt::nn::activation_functions::IDENTITY;
Note that the MLP supports architectures with an arbitrary depth but each layer has to have the same dimensionality. This is because the layers are stored in an array and hence all need to have the same type. If we would allow for different hidden dimensions, we would have to give up on having arbitrary depths.
We aggregate the hyperparameters into a specification again (first just for the structure, later for the full network, incorporating the structure):
[23]:
using MODEL_CONFIG = rlt::nn_models::mlp::Configuration<T, TI, OUTPUT_DIM_MLP, NUM_LAYERS, HIDDEN_DIM, ACTIVATION_FUNCTION_MLP, OUTPUT_ACTIVATION_FUNCTION_MLP>;
We use the default Adam parameters (taken from TensorFlow) and set up the optimizer type using these parameters. Moreover, we create a full network specification for a network that can be trained with Adam which takes the structure specification as an input. Finally we define the full network type:
[24]:
using PARAMETER_TYPE = rlt::nn::parameters::Adam;
using CAPABILITY = rlt::nn::capability::Gradient<PARAMETER_TYPE>;
using OPTIMIZER_SPEC = rlt::nn::optimizers::adam::Specification<T, TI>;
using OPTIMIZER = rlt::nn::optimizers::Adam<OPTIMIZER_SPEC>;
using MODEL_TYPE = rlt::nn_models::mlp::NeuralNetwork<MODEL_CONFIG, CAPABILITY, INPUT_SHAPE>;
Using these type definitions we can now declare the optimizer and the model. All the optimizer state is contained in the PARAMETER_TYPE
of the model (and an additional age
integer in the model in the case of Adam). In comparison to PyTorch which stores the optimizer state in the optimizer, we prefer to store the first and second-order moment next to the parameters like it is the case for the gradient anyways (in PyTorch as well). Hence the optimizer is stateless in this case (does not
need to be for user-defined optimizers) and we only need to allocate the model.
The backpropagation algorithm needs to store the intermediate gradients. To save memory we do not add a d_input
or d_output
to each layer but rather use a double buffer with the maximum size of the hidden representation needed.
[25]:
OPTIMIZER optimizer;
MODEL_TYPE model;
typename MODEL_TYPE::Buffer<> buffer;
We allocate the model and set initialize its weights randomly like in the case for the single layer. We are again zeroing the gradient of all parameters of all layers as well as resetting the optimizer state of all parameters of all layers (e.g. in the case of Adam the first and second order moments are set to zero). Finally we also allocate the buffers
[26]:
rlt::malloc(device, model);
rlt::init_weights(device, model, rng); // recursively initializes all layers using kaiming initialization
rlt::zero_gradient(device, model); // recursively zeros all gradients in the layers
rlt::reset_optimizer_state(device, optimizer, model);
rlt::malloc(device, buffer);
In this example we showcase an MLP with a five dimensional input and a one dimensional output (remember the OUTPUT_ACTIVATION_FUNCTION_MLP
is IDENTITY
so it can also output negative values). For these new shapes we declare and allocate the input and output containers:
[27]:
rlt::Matrix<rlt::matrix::Specification<T, TI, BATCH_SIZE, INPUT_DIM_MLP>> input_mlp, d_input_mlp;
rlt::Matrix<rlt::matrix::Specification<T, TI, BATCH_SIZE, OUTPUT_DIM_MLP>> d_output_mlp;
rlt::malloc(device, input_mlp);
rlt::malloc(device, d_input_mlp);
rlt::malloc(device, d_output_mlp);
Now, like in the case of the single layer, we can run a forward pass using the input. Because the model is a Adam model (which is a subclass of rlt::nn_models::mlp::NeuralNetworkBackwardGradient
), it stores the intermediate (and final) outputs.
[28]:
rlt::randn(device, input_mlp, rng);
rlt::forward(device, model, input_mlp, buffer, rng);
T output_value = get(model.output_layer.output, 0, 0);
output_value
[28]:
0.506566f
Now imagine we want the output of the model (for this input) to be \(1\). We calculate the error and feed it back through the model using backpropagation. d_output_mlp
should be the derivative of the loss function, hence it gives the direction of the output that would increase the loss. Our error is the opposite, if we would move the output into the direction of the error we would come closer to our target value and hence decrease the loss. Because of this, we feed back -error
. This
procedure also corresponds to using a squared loss because error
is (up to a constant) the derivative of the squared loss.
[29]:
T target_output_value = 1;
T error = target_output_value - output_value;
rlt::set(d_output_mlp, 0, 0, -error);
rlt::backward(device, model, input_mlp, d_output_mlp, buffer);
The backward pass populates the gradient in all parameters of the model. Using this gradient we can apply the rlt::step
operator which updates the first and second order moments of the gradient of all parameters and afterwards applies the Adam update rule to update the parameters:
[30]:
rlt::step(device, optimizer, model);
Now the next forward pass should be closer to the target value:
[31]:
rlt::forward(device, model, input_mlp, buffer, rng);
get(model.output_layer.output, 0, 0)
[31]:
0.518496f
Next we will train the network to actually perform a function (not only trying to output a constant value as before). With the following training loop we train it to behave like the rlt::max
operator which outputs the max of the five inputs. We run the forward and backward pass for \(32\) iterations while accumulating the gradient which effectively leads to a batch size of \(32\)
[32]:
for(TI i=0; i < 10000; i++){
rlt::zero_gradient(device, model);
T mse = 0;
for(TI batch_i=0; batch_i < 32; batch_i++){
rlt::randn(device, input_mlp, rng);
rlt::forward(device, model, input_mlp, buffer, rng);
T output_value = get(model.output_layer.output, 0, 0);
T target_output_value = rlt::max(device, input_mlp);
T error = target_output_value - output_value;
rlt::set(d_output_mlp, 0, 0, -error);
rlt::backward(device, model, input_mlp, d_output_mlp, buffer);
mse += error * error;
}
rlt::step(device, optimizer, model);
if(i % 1000 == 0)
std::cout << "Squared error: " << mse/32 << std::endl;
}
Squared error: 0.643161
Squared error: 0.055282
Squared error: 0.025145
Squared error: 0.016128
Squared error: 0.016535
Squared error: 0.017536
Squared error: 0.011195
Squared error: 0.009795
Squared error: 0.008141
Squared error: 0.013169
Now we can test the model using some arbitrary input (which should be in the distribution of input values) and the model should output a value close to the maximum of the five input values:
[33]:
set(input_mlp, 0, 0, +0.0);
set(input_mlp, 0, 1, -0.1);
set(input_mlp, 0, 2, +0.5);
set(input_mlp, 0, 3, -0.4);
set(input_mlp, 0, 4, +0.1);
rlt::forward(device, model, input_mlp, buffer, rng);
rlt::get(model.output_layer.output, 0, 0)
[33]:
0.465697f
We can also automatically test it with \(10\) random inputs:
[34]:
for(TI i=0; i < 10; i++){
rlt::randn(device, input_mlp, rng);
rlt::forward(device, model, input_mlp, buffer, rng);
std::cout << "max: " << rlt::max(device, input_mlp) << " output: " << rlt::get(model.output_layer.output, 0, 0) << std::endl;
}
max: 0.539628 output: 0.555159
max: 1.348390 output: 1.313255
max: 1.660528 output: 1.620614
max: 1.779285 output: 1.739960
max: 1.311534 output: 1.279970
max: 0.965693 output: 0.929909
max: 2.799156 output: 2.870854
max: 1.195009 output: 1.313729
max: 0.797983 output: 0.711116
max: 0.419951 output: 0.451161
If the values are not close the model might need some more training iterations.
Sequential#
The great advantage of the previously introduced MLP module is that the number of layers is a parameter and hence the architecture can be scaled in width and depth through just two parameters without defining additional types. In many cases more flexibility is required, though, which is why we introduced the Sequential model.
The Sequential model follows the torch.nn.Sequential
and tensorflow.keras.Sequential
philosophy. Initially, the Sequential model was created to introduce automatic differentiation to RLtools, since the MLP just has a hard-coded backward pass. With the sequential model the user can just specify a sequence of layers and the forward and backward passes are inferred by the compiler at compile-time, automatically. With the addition of the rlt::Tensor
interface for arbitrary-dimensional
containers, the role of Sequential model was amplified as we wanted to move away from hard-coded assumptions about dimensions (like batch dimension or sequence dimension) and move to layers adapting depending on the input shape. This “adaptation” of course happens all at compile time to maintain the main philosophy of RLtools that the sizes of all datastructures and loops is known at compile time.
With the move to rlt::Tensor
the sequential interface adopted semantics that are more similar to tensorflow.keras.Sequential
where the intermediate shapes are inferred based on the input shape as well. Practically this means, that e.g. dense layers or MLPs (yes, MLPs can be “layers” inside a Sequential model) broadcast over all leading dimensions except the last one. While e.g. recurrent layers like the GRU are consuming (and outputting) the last three dimensions
(SEQUENCE_LENGTH x BATCH_SIZE x FEATURE_DIM
).
In the following we build a three-layer, “funnel”-type MLP that we could not build using the MLP
model because it requires all hidden layers to be of the same dimensionality for the reasons described before. The following specifies the [32, 16, 4]
MLP with input dim 5:
[35]:
#include <rl_tools/nn_models/sequential/operations_generic.h>
using namespace rlt::nn_models::sequential;
using LAYER_1_CONFIG = rlt::nn::layers::dense::Configuration<T, TI, 32, rlt::nn::activation_functions::ActivationFunction::RELU>;
using LAYER_1 = rlt::nn::layers::dense::BindConfiguration<LAYER_1_CONFIG>;
using LAYER_2_CONFIG = rlt::nn::layers::dense::Configuration<T, TI, 16, rlt::nn::activation_functions::ActivationFunction::RELU>;
using LAYER_2 = rlt::nn::layers::dense::BindConfiguration<LAYER_2_CONFIG>;
using LAYER_3_CONFIG = rlt::nn::layers::dense::Configuration<T, TI, 4, rlt::nn::activation_functions::ActivationFunction::IDENTITY>;
using LAYER_3 = rlt::nn::layers::dense::BindConfiguration<LAYER_3_CONFIG>;
using MODULE_CHAIN = Module<LAYER_1, Module<LAYER_2, Module<LAYER_3>>>;
using CAPABILITY = rlt::nn::capability::Forward<>;
constexpr TI SEQUENCE_LENGTH = 1;
using INPUT_SHAPE = rlt::tensor::Shape<TI, SEQUENCE_LENGTH, BATCH_SIZE, INPUT_DIM_MLP>;
using SEQUENTIAL = Build<CAPABILITY, MODULE_CHAIN, INPUT_SHAPE>;
For each layer, we specify a configuration and create a wrapper that binds this configuration. The semantics of this wrapper is that it represents a layer with the given configuration but for any capability or input shape. Technically the wrapper contains a template that is used by the sequential model to actually instantiate the layer with the appropriate capability and input shape.
Capability: The capability specifies the model’s overarching properties and capabilities. The main capability is if the module (and its constituting layers) only supports forward, backward wrt. to the input or backward including the gradient wrt. the parameters. Additionally the capability is also used to specify if the model should be statically or dynamically allocated. The reason we do not specify these attributes in the configuration is that we might want to use a particular model in different ways: e.g. having the critics being capable of backward-gradient but the target critic only being capable of forward inference in TD3/SAC. Furthermore, when checkpointing we probably only want to save the parameters and not the gradient and optimizer state. We will showcase this later
Input Shape: Similar to the capability, we might want to use the model for different input shapes. It is not compelling to change the FEATURE_DIM
of the input shape, but it is often desirable to change the BATCH_SIZE
e.g. between RL training and data-collection/exploration.
The API is the same as for other models and layers:
[36]:
SEQUENTIAL sequential;
SEQUENTIAL::Buffer<> sequential_buffer;
rlt::malloc(device, sequential);
rlt::malloc(device, sequential_buffer);
[37]:
rlt::init_weights(device, sequential, rng);
[38]:
rlt::Tensor<rlt::tensor::Specification<T, TI, typename SEQUENTIAL::INPUT_SHAPE, false>> input; // note the "false" in the specification making it static/stack-allocated, hence no "malloc" is required
rlt::Tensor<rlt::tensor::Specification<T, TI, typename SEQUENTIAL::OUTPUT_SHAPE, false>> output;
rlt::randn(device, input, rng);
[39]:
rlt::evaluate(device, sequential, input, output, sequential_buffer, rng);
rlt::print(device, output);
dim[0] = 0:
-2.716676e-01 6.897645e-02 -2.784356e-02 -1.252805e-01
Now if we want to train the model we should have defined the CAPABILITY
to support backward-gradient. But due to the decoupling from the configuration we can do this in hindsight:
[40]:
using PARAMETER_TYPE = rlt::nn::parameters::Adam;
using NEW_CAPABILITY = rlt::nn::capability::Gradient<PARAMETER_TYPE, false>; // note the "false" specifies that it is statically allocated, so no "malloc" required
using NEW_SEQUENTIAL_CAPABILITY = SEQUENTIAL::CHANGE_CAPABILITY<NEW_CAPABILITY>;
Lets also say we want to increase the BATCH_SIZE
for training:
[41]:
constexpr TI NEW_BATCH_SIZE = 32;
using NEW_SEQUENTIAL = NEW_SEQUENTIAL_CAPABILITY::CHANGE_BATCH_SIZE<TI, NEW_BATCH_SIZE>;
NEW_SEQUENTIAL new_sequential;
NEW_SEQUENTIAL::Buffer<false> new_sequential_buffer; // "false" again specifies a static buffer, no "malloc" required
rlt::copy(device, device, sequential, new_sequential);
Models and layers support smaller batch sizes than specified in the evaluate
call (as long as the buffer matches the batch size), hence we can check if it still gives the same result:
[42]:
rlt::evaluate(device, new_sequential, input, output, new_sequential_buffer, rng);
rlt::print(device, output);
dim[0] = 0:
-2.716676e-01 6.897645e-02 -2.784356e-02 -1.252805e-01
Now for the backward pass we need to record intermediate values, which are thrown away for efficiency in the evaluate
call. Hence we call the forward
function
[43]:
rlt::Tensor<rlt::tensor::Specification<T, TI, typename NEW_SEQUENTIAL::INPUT_SHAPE, false>> new_input;
rlt::randn(device, new_input, rng);
rlt::forward(device, new_sequential, new_input, new_sequential_buffer, rng);
After the forward pass we can simulate some loss function by generating a random \(\frac{\mathrm{d}L}{\mathrm{d}\;\text{output}}\) and calculating the gradient wrt. the parameters:
[44]:
rlt::Tensor<rlt::tensor::Specification<T, TI, typename NEW_SEQUENTIAL::OUTPUT_SHAPE, false>> new_d_output;
rlt::randn(device, new_d_output, rng);
rlt::zero_gradient(device, new_sequential);
rlt::backward(device, new_sequential, new_input, new_d_output, new_sequential_buffer);
[45]:
rlt::print(device, new_sequential.content.weights.gradient);
0.320461 0.172568 -0.038686 -0.269831 0.269185
0.067568 -0.187339 0.089244 -0.029956 -0.493866
-0.077152 0.195330 -0.070027 0.051061 0.254358
0.322985 -0.273944 -0.381387 -0.221107 0.520769
-0.298484 0.026869 0.017577 0.050301 0.419451
-0.378259 -0.467855 0.218961 0.359731 -0.004112
0.103813 -0.196705 -0.069182 -0.002469 0.115669
-0.324696 0.523877 0.420453 0.019141 -0.220434
0.146015 -0.047098 -0.078760 -0.201766 -0.056156
-0.031791 0.392544 -0.073967 0.002385 0.406896
-0.576604 0.626401 -0.016995 0.302599 -0.052693
0.325223 -0.084683 0.175826 0.310664 -0.385406
-0.042280 -0.267090 -0.178485 0.044999 -0.004504
-0.284561 -0.243193 0.094030 0.169179 0.079656
0.173386 -0.351642 0.204653 0.047558 0.153697
0.154294 -0.035325 -0.305142 -0.308024 -0.828210
-0.600413 0.165149 0.052111 0.247424 0.352850
0.438994 0.303696 -0.196991 -0.032276 0.020843
-0.413299 -0.486740 0.398008 -0.079801 -0.224559
0.009283 -0.161274 -0.037893 0.253224 -0.235326
-0.021573 0.113396 0.017783 -0.083122 0.064917
-0.168050 -0.075272 0.149018 0.147866 -0.182045
0.703156 0.525455 0.042927 0.439643 0.074443
0.562241 -0.243317 -0.412354 0.271648 0.734457
-0.186803 0.137267 -0.138127 0.228056 0.099720
-0.150418 -0.587200 0.224454 0.163250 -0.756426
0.179686 0.405630 0.109221 -0.147826 0.124997
0.350725 -0.368138 -0.383135 -0.435791 0.528348
-0.046798 0.538072 -0.145205 -0.234868 0.320617
0.425575 -0.551638 0.328734 -0.025153 -0.846081
0.363814 -0.414661 0.238195 0.115213 -0.487551
-0.371508 0.177195 0.037000 0.367933 0.195225