2 Automatic differentiation
This section first give some general background on reverse-mode automatic differentiation, and then describes the interface provided by this library. The implementation is largely based on Pearlmutter and Siskind (2008).
2.1 Background
2.1.1 Derivatives
This section contains a brief recap of differentiation—
Suppose A and B are finite-dimensional vector spaces, and that we have a function
f : A \to B.
The derivative of f (if it exists) is a map
df : A \to A \to B
associating a linear map from A to B with each element of A. We say that df(x) is the derivative of f at the point x.
You may be used to thinking of a derivative as a number, perhaps written \frac{df}{dx} or f^\prime(x), or as a matrix J_{ij} = \frac{\partial f_i}{\partial x_j}. This presentation is equivalent, but will have several advantages for our purposes.
For example if f(x) = 5x, the derivative (as defined above) of f is the linear map
df(x) = \Delta x \mapsto 5 \Delta x,
and we understand notation such as \frac{df}{dx} = 5 to indicate the coefficients of this map.
Sometimes it is convenient to choose bases for A and B, and represent this linear map as a matrix (which is then known as the Jacobian), but sometimes another representation is preferable.
We will keep the representation of linear maps as functions. Eventually, just as we think of a (mathematical) function f being implemented (or approximated by) a Racket function, its derivative will be a Racket function too.
It is often necessary, eventually, to turn this function into a numerical representation. Evaluating the linear map gives a directional derivative (in terms of Jacobians, this would be a Jacobian-vector product). Evaluating it for each element of a basis of A allows us to reconstruct the whole Jacobian. Notice that to reconstruct the Jacobian of f at x we would need to make \dim{A} evaluations of df(x), regardless of the dimensionality of B.
Example
f : \mathbf{R}^3 \to \mathbf{R}^2 \\ f(x,y,z) = (z + 1, xy)
then
\begin{split} d&f(x,y,z) =\\ &(\Delta x, \Delta y, \Delta z) \mapsto (\Delta z, y \Delta x + x \Delta y) \end{split}
The directional derivative, in the direction (1,0,0), is
df(x,y,z)(1,0,0) = (0,y)
Evaluating the map at the standard basis vectors (1,0,0), (0,1,0) and (0,0,1) gives the Jacobian matrix:
\begin{pmatrix} 0 & 0 & 1 \\ y & x & 0 \end{pmatrix}
The adjoint map is
\begin{split} D&f(x,y,z) =\\ &(\Delta u, \Delta v) \mapsto (y\Delta v, x\Delta v, \Delta u) \end{split}
where \Delta u and \Delta v are sometimes known as sensitivity variables. The function is said to map output or result sensitivities to argument sensitivities.
Notice that it takes just two evaluations of the adjoint map, at (1,0) and (0,1), to obtain the same Jacobian as above, at a cost of two multiplications per evaluation in each case.
A case that is often useful in practice is when A is very high dimensional, and B is \mathbf{R}. Loss functions in optimization problems have this form, for example.
Handling this case more efficiently is the motivation for reverse-mode AD, which is based on the following idea.
If we further insist that A and B are both equipped with an inner product, we can obtain the adjoint of a linear map L : A \to B, which is another linear map L^* : B \to A. This allows us to define
Df : A \to B \to A\\ Df(x) = df(x)^*.
If df(x) can be represented by the Jacobian matrix J, then the matrix representation of Df(x) is the transpose of the Jacobian, J^T.
Particularly when referring to its implementation in code, we call Df(x) the backpropagator of f at x.
Returning to the case we considered above, of f : A \to \mathbf{R}, it would be possible to reconstruct the Jacobian from a single evaluation of the linear map Df(x) : \mathbf{R} \to A.
\nabla f : A \to A\\ \nabla f(x) = Df(x)(1)
assuming the usual inner product on R.
2.1.2 Composition and the chain rule
Our goal is to be able to differentiate as many (Racket) functions as possible. In some cases, we will be content with explicitly providing a function Df that computes the derivative of another function f, and associating them somehow. It would be unsatisfactory if we had to do this for every f, though, so we seek a way of determining the derivative of a function, from its definition in terms of other functions. The ability to do this is the main selling point of of automatic differentiation. The primary way (and in some sense, the only way) that this is achieved is via the chain rule.
The chain rule allows derivatives of compositions of functions to be related to compositions of their derivatives. The chain rule can be expressed in terms of d or D:
d(g \circ f)(x) = dg(f(x)) \circ df(x)
D(g \circ f)(x) = Df(x) \circ Dg(f(x)).
Notice the ‘reverse’ order of composition in the right hand side of the equation immediately above.
We will focus on D for the rest of the section, but similar considerations would apply to d. Notice too that for both rules, we need to know f(x) to express the derivative of the composition (not merely Df). There is often some shared work involved in computing Df(x) and f(x), but this is not apparent from the usual chain rule, and an interface based on this would not let us take advantage of it.
Instead, define
D^+f(x) = (f(x), Df(x))
and now
D^+(g\circ f)(x) = \big(g(f(x)), Df(x) \circ Dg(f(x))\big).
Notice that D^+(g\circ f) can now be expressed in terms of D^+g and D^+f.
2.1.3 The reverse transform
The mapping
f \mapsto D^+f
as implemented in code is the central operation of reverse-mode AD.
Why reverse transform?
Notice the composition rule above: Roughly, whereas data flows ‘forwards’ through the composition g \circ f, the derivatives of f and g are composed in the opposite order, and so data flows ‘in reverse’ through them.
TODO a picture would help here! Since the output of the reverse transform combines the function value and its derivative, data must in fact flow both ways. The idea is to store each function evaluation on the way ‘forward’, to be consumed by the appropriate derivative computation on the way back again.
This description is far from complete. Handling variable assignment (and repeated use of a variable) as well as mutable state have been omitted, as have many technical details needed for a practical implementation.
2.2 Reverse transform API
The previous section defined the reverse transformation as a mapping f \mapsto D^+ f. This section describes how it applies to Racket code. The macros D+ and lift/D+ perform transformations similar to this one; D and grad are provided as a convenience. Of these, lift/D+ is fundamental.
When differentiating an expression, each procedure that is encountered
is replaced with one that computes both the primal—
In the example below, the reverse transformation of * is obtained with lift/D+. The primal and backpropagator are returned in a proc-result struct.
> (define result ((lift/D+ *) 4.0 2.5)) > (primal result) 10.0
> (backprop result) #<procedure:...ator/primitives.rkt:87:2>
Procedures whose definitions occur within the expression being differentiated can be transformed automatically by the library. Any procedure that is used but not defined within the expression must also be replaced with its reverse transform. Such procedures are known as primitives, and include, for example, arithmetic operations. They must have backpropagators that are known in advance.
In this library, the backpropagator of a function takes two arguments: the result sensitivity, which should conform to the value returned by the function, and the box sensitivities, which will be explained below. The result of evaluating a backpropagator is a list containing
A list of the sensitivities of the closed-over variables in the function (in an unspecified order)
The argument sensitivity for each argument passed to the function.
The box sensitivity argument to a backpropagator is the way sensitivities of mutable data structures are handled. This is a hash table (satisfying (and/c hash? hash-eq? (not immutable?))) mapping a mutable data structure to its corresponding sensitivity value. The value in the hash table with a given mutable data structure as its key can be updated by the backpropagator of a function that refers to an element of the data structure.
> ((backprop result) 1.0 (make-hasheq)) '(() 2.5 4.0)
Notice the empty hash table passed as the second argument, and the first element of the resulting list, with a list of closed-over variable sensitivities (in this case there are none, so the list is empty).
Alternatively, use D+ to avoid the empty hash table argument, and to drop the closure sensitivities:
> (define result ((D+ *) 4.0 2.5)) > (primal result) 10.0
> ((backprop result) 1.0) '(2.5 4.0)
2.2.1 Specifying reverse transformations
When specifying a reverse transform, it should have the form described above (as returned by lift/D+). Here is the reverse transformation of two-argument multiplication:
(λ (x y) (proc-result (* x y) (λ (Aw Abox) (list '() (scale Aw y) (scale Aw x)))))
and of exp:
(λ (x) (let ([exp-x (exp x)]) (proc-result exp-x (λ (Aw Abox) (list '() (scale Aw exp-x))))))
Backpropagator definitions should allow for the fact that the result sensitivity may be passed a value of gen-zero (hence the use of scale). See Linear generic interface.
The reverse transform of a binding must be provided when registering a binding as a new primitive with register-primitive!, or by the require/primal+backprop mechanism. It can subsequently be used in functions defined with define/D+.
2.2.2 Interface
syntax
(grad expr)
The result of evaluating expr must be a procedure.
This form evaluates to a function of the same arity, that when called
returns the gradient (represented as described above)—
The first form is equivalent to the second with 1.0 passed as the value of result-sensitivity.
syntax
(∇ expr)
syntax
(grad1 expr)
syntax
(D expr)
> (define/D+ (f x y) (vector->immutable-vector (vector (* x x) y))) > (((D f) 2.0 3.0) #(1.0 0.0)) '(4.0 0.0)
; An error: sensitivity does not conform with the result: > (((D f) 2.0 3.0) '(1.0 0.0)) raise-argument-error: contract violation
expected: exact-nonnegative-integer?
given: '(1.0 0.0)
TODO Fix unhelpful error message
syntax
(D+ expr)
The backpropagator in the result is partially applied in its second argument to a new empty hash table for holding the sensitivities of mutable values
The result of evaluating the backpropagator contains only the argument sensitivities (and not the sensitivies of any closed-over variables).
syntax
(lift/D+ expr)
The backpropagator is the two argument form described above. The first argument is the result sensitivity, and the second must be a mutable hash table (satisfying (and/c hash? hash-eq? (not/c immutable?))).
The resulting function is of the correct form to pass to derivatives of higher-order functions.
Using it directly:
> (define D+f (lift/D+ (lambda (x) (set! x (* x x)) x))) > (define result (D+f 2.0)) > (primal result) 4.0
> ((backprop result) 1.0 (make-hasheq)) '(() 4.0)
Similarly to define define, bind id to the result of
evaluating expr in the first case, or to a procedure in the
second case—
In addition, the reverse transform of expr or body is determined, and id registered as a primitive. Recursive definitions are allowed.
2.2.3 Reverse-transformed procedure results
struct
(struct proc-result (primal backprop) #:transparent) primal : any/c backprop : procedure?
procedure
r : proc-result?
procedure
(backprop r) → procedure?
r : proc-result?
proc-result-primal and proc-result-backprop are also provided under the shorter aliases primal and backprop.
2.2.4 Handling functions unknown as primitives
During reverse transformation, an identifier may be encountered that is not registered as a primitive. In this case, it is transformed to the result of calling the procedure stored as the value of the parameter current-unknown-transform. In general, the job of this procedure is to raise an error, but a few other cases where the result is known may also be handled.
The default value of current-unknown-transform is error-non-zero-sensitivity-transform, that raises an error only if an attempt is made to call the unknown backpropagator with a non-zero sensitivity argument. This permits code paths that do not contribute to the derivative (e.g. error handling, tests in conditionals) without having to register (perhaps meaningless) derivative information for every function that is called.
error-unknown-transform and error-unknown-proc-transform can be very useful for debugging.
parameter
(current-unknown-transform proc) → void? proc : (-> any/c any/c procedure?)
= error-non-zero-sensitivity-transform
> (define/D+ (car-or-void v) (when (pair? v) (car v))) > ((grad car-or-void) '(1.0 2.0 3.0)) '((1.0 0.0 0.0))
> ((grad car-or-void) 123.0) '(0.0)
> (parameterize ([current-unknown-transform error-unknown-transform]) ((grad car-or-void) '(1.0 2.0 3.0))) lift/D+: Backpropagator unknown
op: 'pair?
procedure
(error-unknown-transform op op-name) → any
op : any/c op-name : any/c
The resulting procedure will be used as the reverse transform of op. It will unconditionally raise an error when called.
procedure
(error-unknown-proc-transform op op-name) → procedure?
op : any/c op-name : any/c
The resulting procedure will be used as the reverse transform of op. It will raise an error when op is a procedure, otherwise op is returned.
procedure
(error-non-zero-sensitivity-transform op op-name) → procedure? op : any/c op-name : any/c
The resulting procedure will be used as the reverse transform of op.
When op is non-procedure value, op is returned.
When op is a procedure, attempt to construct a reverse transform for it, whose primal is the result of evaluating the procedure, and whose backpropagator raises an error when called, unless (gen-zero) is passed (the result of the backpropagator is then also (gen-zero)).