Abstract

Today we’ll talk about the popular Neural ODE paper which won best paper award in Neurips 2018. This paper develops a foundational framework for infinitely many layered deep neural networks. The framework allows us to take advantage of the extensive research on ODE Solvers.


Background : Ordinary Differential Equations

In physics, Ordinary Differential Equations(ODEs) have been often used to describe the dynamics and referred as a vector field of the underlying physical system.

A system of ODEs can be represented as,
dydt=f(t,y(t))\frac{dy}{dt} = f(t,y(t))

Neural ODEs aims to replace explicit ODEs by an ODE with learnable parameters. In past, extensive research has been done on explicit and implicit ODESolvers which aims to solve an ODE(forward pass). The simplest ODE solver Forward Euler method works by moving along the gradient from a starting point:
yn+1=yn+δf(tn,yn)y_{n+1} = y_n + \delta f(t_n,y_n)
Sophisticated higher order explicit ODE solvers like Rungakutta have low error margin.
An example of implicit ODE Solver is Backward Euler Method:
yn+1=yn+δf(tn+1,yn+1)y_{n+1} = y_n + \delta f(t_{n+1},y_{n+1})
The implicit solvers are computationally expensive but often has better approximation guarantees. Various adaptive-step size solvers have been developed which provide better error handling.

Problem Setup: Supervised learning

Traditional Machine learning solves: y=f(x)y = f(x), where, yy is label and xx is input features. Neural ODEs views the same problem as an ODE with Initial value problem,
dydt=f(x),y(0)=x\frac{dy}{dt} = f'(x), y(0) = x
where the value at initial time point is input features xx and the label yy is the value at final time point. For now, Let’s assume that xx and yy have same dimensionality for simplicity. Neural ODE aims to learn an invertible transformation between xx and yy.

The idea of viewing supervised learning as an ODE system came by observing the updates in Resnet. In Resnet, the updates in t+1t+1-th layer is as follows,
ht+1=ht+f(ht,θt)\mathbf{h}_{t+1} = \mathbf{h}_t + f(\mathbf{h}_t, \theta_t)
We can view above update as the euler discretization of a continuous dynamic system,
dh(t)dt=f(h(t),t,θ)\frac{d\mathbf{h}(t)}{dt} = f(\mathbf{h}(t), t, \theta)
We can therefore view each layer as an euler update in an ODE solver with dynamics ff.

Infinite layers: ODE Solver as foward pass

Did you ever think that a DNN can have infinite layers? If we try to implement it naively by having different parameters in each layer, then there are two prominent issues we’ll be facing,
i)Overfitting
ii)Memory issues. This is because each layer with learnable parameters in DNN needs to store its input until the backward pass. So, it’s not practically feasible in traditional setting.

A natural question arises that “why can’t we have infinite updates in ODE solver instead of ll updates?”, where ll denotes number of euler steps or layers. We just need to use state of the art adaptive ODE Solvers which is memory efficient (more discussed later).

Neural Ordinary Differential equations

Here we summarize the underlying idea of Neural ODE. Instead of trying to solve for y=F(x)y = F(x), Neural ODE solves, y=z(t1)y = z(t_1), given the initial condition z(0)=xz(0) = x. The parametrization is done as,
dz(t)dt=f(z(t),t,θ)\frac{dz(t)}{dt} = f(z(t), t, \theta)
The existing ODE Solvers are used for forward pass.

Backpropagation through Neural ODE

Ultimately we want to optimize some loss w.r.t. parameters θ,t0,t1\theta, t_0, t_1,
L(z(t1))=L(z(t0)+t0Tf(z(t),t,θ)=L(ODESolve(z(t0),t0,t1,θ)L(z(t_1)) = L(z(t_0) + \int_{t_0}^{T} f(z(t), t, \theta) = L (ODESolve(z(t_0), t_0, t_1, \theta)
If loss is mean squared error, then we can write, L(z(t1))=E((yz(t1))2L(z(t_1)) = E((y-z(t_1))^2. Without loss of generality, we’ll primary focus on computing dLdθ\frac{d L}{d \theta}. If we know the ODE solver, then we can backprop through the solver using automatic differentiation. However, there are two issues with this approach,
i) Backpropagation is dependent on ODE Solvers, this is not desirable. Ideally, we would like to treat ODE Solver as a black box.
ii) If we use “implicit” solvers which often perform inner optimization, then the backpropagation using automatic differentiation is memory intensive.

Adjoint sensitivity analysis: Reverse-mode Autodiff

This paper borrows an age old idea of adjoint based methods from ODE literature to perform backprobagation with respect to parameters θ\theta in constant memory and without having the knowledge of ODE Solver. To compute dLdθ\frac{d L}{d \theta} we need to compute dLdz(t)\frac{d L}{d \mathbf{z}(t)}. Therefore, let’s define the adjoint state a(t)=dLdz(t)\mathbf{a}(t) = \frac{d L}{d \mathbf{z}(t)}. In the case of Resnet, the adjoint state has following updates during backpropagation,

a(t)=a(t+h)+ha(t+h)df(z(t))dz(t)a(t) = a(t+h) + h a(t+h) \frac{d f(z(t))}{dz(t)}

You might have observed that updates of adjoint a(t)a(t) is an euler step in backward direction with a known dynamics. You guessed it, right. The adjoint state in continuous setting follows following dynamics in backward direction,

a(t)=a(t+1)+t+1ta(t)df(z(t),t,θ)dza(t) = a(t+1) + \int_{t+1}^t a(t)\frac{d f(\mathbf{z}(t), t, \theta)}{d \mathbf{z}}

da(t)dt=a(t)Tdf(z(t),t,θ)dz\frac{d \mathbf{a}(t) }{dt} = - \mathbf{a}(t)^T \frac{d f(\mathbf{z}(t), t, \theta)}{d \mathbf{z}}

For Resnet, the updates of loss is,

dLdθ=ha(t+h)df(z(t),t,θ)dθ\frac{d L}{d \theta} = h a(t+h)\frac{d f(\mathbf{z}(t), t, \theta)}{d\theta}

Similarly, loss in continuous setting follows a dynamics in the backward direction,

dLdθ=t1t0a(t)df(z(t),t,θ)dθ\frac{d L}{d \theta} = \int_{t_1}^{t_0}a(t) \frac{d f(\mathbf{z}(t), t, \theta)}{d\theta}


Thus, during the backward pass of z(t)z(t), we also need to do a backward pass on a(t)a(t) and dLdθ\frac{d L}{d \theta}. The vector jacobian products a(t)Tf(z(t),t,θ)z\mathbf{a}(t)^T \frac{\partial f(\mathbf{z}(t), t, \theta)}{\partial \mathbf{z}} and a(t)Tf(z(t),t,θ)θ\mathbf{a}(t)^T \frac{\partial f(\mathbf{z}(t), t, \theta)}{\partial \theta } can be computed using automatic differentiation in similar time cost as of ff. The full algorithms for all three backward dynamics is shown below,

Results

Experiment 1 : Supervised Learning
The paper performs a supervised learning on MNIST data. It uses resnet and multi-layered perceptron as baselines. The result is shown below,

RK-Net uses Runga kutta for forward pass and automatic differentiation for backward pass. We can see that ODENet has fewer parameters with similar accuracy and constant memory.

Experiment 2 : Normalizing flows
The paper proposes a continuous version of normalizing flows. Traditionally normalizing flows is enabled using change of variable formula. This paper proposes instantaneous change of variables formula using ODE dynamics.

The results are shown belo,

Time Series Latent ODE
Traditional Rnns can only utilize regular time interval signals in its vanilla form. Neural ODE allows us to sample from a continuous dynamic system,

The result for both the RNN and neural ODE for continuous time series for a spiral synthetic data is shown below. It can be seen that RNNs learn very stiff dynamics and have exploding gradients while ODEs are guaranteed to be smooth.

Conclusion

Personally, I believe that this is a phenomonal paper which has enabled us to have infinite layers, handle continuous time series and normalizing flow in constant memory by using the state of the art ODE Solvers. It’s impact on learning dynamics for physical and biophysical system would be immense in future.

References

Chen, R. T., Rubanova, Y., Bettencourt, J., & Duvenaud, D. (2018). Neural ordinary differential equations. arXiv preprint arXiv:1806.07366.