今天介绍的这篇文章「Automatic Differentiation for Tensor Algebras」虽然引用不多,但比较清楚地解释了在多维Tensor的场景下,怎样产生symbolic的求导结果。并且详细说明了怎样利用Jacobian矩阵的稀疏进行优化。


Reverse Mode Automatic Differentiation

假设我们有一个函数\(f(x) = g(h(x))\), 那么\(f\)对\(x\)的导数可以通过链式法则得到,

\[\frac{df}{dx} = \frac{\partial{g}}{\partial{h}} \frac{\partial{h}}{\partial{x}}\]

那么对于有多个输入\(x_1, x_2, ...\)的情况,更一般地可以写成,

\[\frac{\partial{f}}{\partial{x_i}} = \sum_{d\in csmr(x_j)} \frac{\partial{f}}{\partial{f_d}} \frac{\partial{f_d}}{\partial{x_j}}\]

其中\(d\in csmr(x_j)\)表示\(f_d\)是\(x_j\)的consumer,即\(f_d\)的计算需要用到\(x_j\),这里面隐含了一张计算图,顺着producer-consumer的依赖关系,我们就能得到相应的求导顺序。

另一个重要的概念是\(\frac{\partial{f}}{\partial{x_i}}\)和\(\frac{\partial{f}}{\partial{f_d}}\)被称作adjoint,分别写作\(\bar{x_i}\)和\(\bar{f_d}\),adjoints是从consumer反向传给producer的。


向量求导和Jacobian

先看一维向量的情况,\(\mathbf{f}(\mathbf{x}) = \mathbf{g}(\mathbf{h}(\mathbf{x}))\),其中

\[\mathbf{h}(\mathbf{x}) = \left[ h_1(\mathbf{x}), h_2(\mathbf{x}), \dots, h_K(\mathbf{x}) \right]^T\]

因此向量对向量求导,得到的是一个矩阵。我们定义Jacobian为,

\[\left( \frac{d\mathbf{f}}{d\mathbf{x}} \right)_{ij} \triangleq \frac{\partial f_i}{\partial x_j} \\ \frac{\partial f_i}{\partial x_j} = \sum_{k=1}^K \frac{\partial g_i}{\partial h_k} \frac{\partial h_k}{\partial x_j}\]

稀疏的Jacobian

一般情况下,我们会假设\(f_i(\mathbf{x})\)用到了向量\(\mathbf{x}\)中的所有元素,但实际应用中,可能\(\mathbf{f}\)和\(\mathbf{x}\)的维度是相同的,并且\(f_i(\mathbf{x}) = f_i(x_i)\),这样的函数我们称作「elementwise-defined function」。这类函数的Jacobian有什么特点呢?显然,它是一个对角矩阵 - 只有对角线上有元素,其他位置都是零。

\[\frac{\partial f_i}{\partial x_{i'}} = 0 \quad \text{for} \quad i \neq i'\]

这有什么好处呢,举例来说,\(f_i(\mathbf{x}) = \sin x_i\),我们想计算adjoint \(\bar{x_{i'}}\),

\[\bar{x_{i'}} = \sum_i \frac{\partial g}{\partial f_i} \frac{\partial f_i}{\partial x_{i'}} = \sum_i \bar{f_i} \delta_{i, i'} \cos x_i = \bar{f_{i'}} \cos x_{i'}\]

可以看到求和被cancel掉了,这是一个非常重要的优化,意味着生成求导程序的时候,一个for循环能够直接被省略。

这个技巧甚至可以用来优化reduction运算,因为被reduce的axis不会出现在output tensor中,举例来说,

\[f_{ij} (\mathbf{x}, \mathbf{y}) = \sum_k x_{ik} y_{kj} \\ \frac{\partial f_{ij}}{\partial x_{i'k'}} = \delta_{i,i'} \sum_k \delta_{k,k'} y_{kj} = \delta_{i,i'} y_{k'j} \\ \frac{\partial f_{ij}}{\partial y_{k'j'}} = \delta_{j,j'} \sum_k \delta_{k,k'} x_{ik} = \delta_{j,j'} x_{ik'}\]

如果带上\(f_{ij}\)的consumer的话,和上面一样,可以将这些consumer带来的求和cancel掉了,

\[\bar{x_{i'k'}} = \sum_i \sum_j \bar{f_{ij}} \delta_{i,i'} y_{k'j} = \sum_j \bar{f_{i'j}} y_{k'j} \\ \bar{y_{k'j'}} = \sum_i \sum_j \bar{f_{ij}} \delta_{j,j'} x_{ik'} = \sum_i \bar{f_{ij'}} x_{ik'}\]

最后看一个例子,这里面输入下标是输出下标的线性组合,

\[f_{ij} (\mathbf{x}) = \exp x_{mi + j} \\ \frac{\partial f_{ij}}{\partial x_{i'}} = \delta_{mi+j,i'} \exp x_{mi+j} \\ \bar{x_{i'}} = \sum_i \sum_j \bar{f_{ij}} \delta_{mi+j,i'} \exp x_{mi+j}\]

通过求解\(mi + j = i'\),我们得到\(j = i'-mi\),从而把下标\(j\)对应的求和cancel掉,

\[\bar{x_{i'}} = \sum_i \overline{f_{i, i'-mi}} \exp x_{i'}\]

更一般的情况

如果更一般地扩展到多维,多个输入Tensor的情况,以下我们用大写字母表示多维Tensor(比如上面例子里\(x_{ij}, y_{ij}\)作为元素组成的Tensor),\(\vec{\alpha} = (\alpha_1, \alpha_2, \dots, \alpha_{D_f})\)表示下标(比如上面例子中的\(i, j, k\)),

\[F_{\vec{\alpha}} (X^1, X^2, \dots, X^P) = f(X_{A^1\vec{\alpha}}^1, X_{A^2\vec{\alpha}}^2, \dots, X_{A^P\vec{\alpha}}^P)\]

其中\(A^P\vec{\alpha}\)表示\(X\)的下标是从输出\(F\)的下标线性变换过来的;$F$是输出,它是一个\(N_1^f \times N_2^f \times \dots \times N_{D_f}^f\)的Tensor;\(f\)是一个函数,比如上面例子里的\(\sin\),\(x \cdot y\)之类;记\(\vec{N^f} \triangleq \left[ N_1^f, N_2^f, \dots, N_{D_f}^f \right]^T\),

现在我们求\(X_{\vec{\beta^p}}^p\)的导数,其中\(\vec{\beta^p} = (\beta_1^p, \beta_2^p, \dots, \beta_{D_p}^p)\)表示\(X^p\)的下标,

\[\begin{equation} \label{eq:gradient} \overline{X_{\vec{\beta^p}}^p} = \frac{\partial l}{\partial X_{\vec{\beta^p}}^p} = \sum_{\substack{\vec{1} \leq \vec{\alpha} \leq \vec{N^f} \\ A^p \vec{\alpha} = \vec{\beta^p} }} \frac{\partial l}{\partial F_{\vec{\alpha}}} \frac{\partial F_{\vec{\alpha}}}{\partial X_{\vec{\beta^p}}^p} = \sum_{\substack{\vec{1} \leq \vec{\alpha} \leq \vec{N^f} \\ A^p \vec{\alpha} = \vec{\beta^p} }} \overline{ F_{\vec{\alpha}} } \cdot \frac{\partial f}{\partial X_{\vec{\beta^p}}^p} \end{equation}\]

注意上面这个求和,已经运用了对\(\delta_{A^p \vec{\alpha}, \vec{\beta^p}}\)的cancel,其中\(A^p \vec{\alpha}\)是将输出\(F\)的下标\(\vec{\alpha}\)变换到对应它来自输入\(X^p\)的哪一个下标,然后令它和adjoint的下标\(\vec{\beta^p}\)相等就是Jacobian中非零的部分。

(\(\delta_{A^p \vec{\alpha}, \vec{\beta^p}}\)的下标是向量,因此实际是\(D_f\)个\(\delta\);类似的,上面的求和\(\sum\)也是\(D_f\)个求和。)


求解求和下标

接下来我们要求解\(A^p \vec{\alpha} = \vec{\beta^p}\),最显然的解是\(\vec{\alpha} = (A^p)^{-1} \vec{\beta^p}\),然而事实上,一个线性方程很少有情况那么凑巧能得到唯一解(也就是说,\(A^p\)不可逆)。上面的第一、第二个例子里,\(i = i'\),\(j = j'\)是显然的唯一解,\(\sum_{i=i'}\)退化成了单个表达式,因此看起来被「cancel」了;而第三个例子里,\(mi + j = i'\)是没有唯一解的,我们最多可以做到\(\sum_i \sum_{i'-mi}\),其中\(\sum_{i'-mi}\)在\(i\)给定的情况下退化了。

因此一般情况下,我们要得到\((\alpha_1, \alpha_2, \dots, \alpha_{D_f})\)的解,其中一些\(\alpha_i\)是固定值,或者可以用别的\(\alpha_{j\neq i}\)和\(\vec{\beta^p}\)来表示。

简洁起见,以下我们用\(A\)代替\(A^p\),\(\vec{\beta}\)代替\(\vec{\beta^p}\)。

首先对\(A\)做SVD分解,\(U^T S V^T = A\)其中\(U\)和\(V\)是正交矩阵,\(S = \rm{diag} (s_1, s_2, \dots, s_R, 0, \dots, 0)\)是对角矩阵,对角线的前\(R\)个为非零元素。

从SVD分解我们可以引出伪逆\(A^{\dagger} \triangleq VS^{\dagger}U\),其中\(S^{\dagger} = \rm{diag}(1/s_1, 1/s_2, \dots, 1/s_R, 0, \dots, 0)\),伪逆满足\(AA^{\dagger}A = A\)

通过

\[U^T S V^T \vec{\alpha} = \vec{\beta}\]

得到

\[\begin{equation} \label{eq:svdshift} S V^T \vec{\alpha} = U \vec{\beta} \end{equation}\]

因为\(S\)的对角线后半部分(也就是\(R+1\)到\(D_f\)的部分)为零,所以等式右边相应的部分也应该为零。我们定义\(C\)矩阵是\(U\)的下半部分,即\(C_{ij} = U_{R+i,j}\),可以得到,

\[C\vec{\beta} = \vec{0}\]

于是,从等式(\(\ref{eq:svdshift}\))我们可以看出一个解,

\[\begin{equation} \label{eq:alph} \begin{split} \vec{\alpha} &= VS^{\dagger} U\vec{\beta} \\ &= A^{\dagger} \vec{\beta} \end{split} \end{equation}\]

简单展开一下,把\(\vec{\alpha}\)的这个解代入式(\(\ref{eq:svdshift}\)),我们得到等式左边是\(S V^T VS^{\dagger} U\vec{\beta} = S S^{\dagger} U\vec{\beta} = \rm{diag}(\overbrace{1,\cdots, 1}^R, \overbrace{0, \cdots, 0}^{(D_f-R)}) U\vec{\beta}\),等式右边我们刚才已经说明了最后有\((D_f-R)\)个零,因此两边是相等的。

除此之外,还有没有别的解?

由SVD分解\(AV = U^T S\),等式右边的矩阵,靠右的\((D_f-R)\)列为零。因此相应的,等式左边的\(AV\)也有\((D_f-R)\)列为零。我们把\(V\)的右边\((D_f-R)\)列记为\(K\),\(K_{ij} = V_{i, R+j}\),则\(AK = \mathbf{0}\),从而,

\[AA^{\dagger}\vec{\beta} + AK\vec{z} = \vec{\beta} \\ \Rightarrow A (A^{\dagger}\vec{\beta} + K\vec{z}) = \vec{\beta}\]

其中\(\vec{z}\)是任意的整数向量。从而我们得到\(\vec{\alpha}\)有任意多个整数解,

\[\vec{\alpha} = A^{\dagger}\vec{\beta} + K\vec{z}\]

现在,我们求解的问题变成了求解\(\vec{z}\)的range,也就是(\(\ref{eq:gradient}\))中求和的界限,

\[\begin{equation} \label{set:beta} \Sigma(\vec{\beta}) = \{z \in \mathbb{Z}^{D_f-R} \mid \vec{1} \leq A^{\dagger}\vec{\beta} + K\vec{z} \leq \vec{N^f} \} \end{equation}\]

这里值得注意的几点,

  • \(\vec{z}\)的维数是\(D_f-R\),这一点是从\(K\)的大小确定的。这意味着(\(\ref{eq:gradient}\))中求和从\(D_f\)个减少到了\(D_f-R\)个。特殊情况下,矩阵满秩,\(\vec{\alpha}\)有唯一解,求和减少到\(D_f-D_f=0\)个;
  • 举例来说,我们有标量\(\beta = \alpha_1-2\alpha_2-2\alpha_3\),矩阵\(A=(1, -2, -2)\)的秩为\(1\),需要求和\(\sum_{\alpha_1}\)和\(\sum_{\alpha_2}\);
  • 由于\(\vec{z}\)是在\(\vec{\alpha}\)变换后的空间内,原本\(\vec{\alpha}\)的range是由\(\vec{N^f}\)定义的,到了\(\vec{z}\)这边,就会出现\(z_i\)的range依赖于\(z_{i+1\dots D_f-R}\)的情况。
  • 上面的集合(\ref{set:beta})其实我们跳过了两个条件:\(C\vec{\beta}=\vec{0}\)和\(A^{\dagger}\vec{\beta} \in \mathbb{Z}^{D_f}\),这两个条件不依赖于\(\vec{z}\),因此可以运行一次检查,一旦这两个条件不满足,我们就能直接判断\(\overline{X_{\vec{\beta^p}}^p} = 0\)了。

求解不等式

下面我们简单介绍一下怎么求解集合(\ref{set:beta})中的不等式,中心思想是迭代消除其中的变量。假设我们有如下的不等式要求解,

\[\begin{gather*} A_{11} x_1 + A_{12} x_2 + \cdots + A_{1M} x_M \geq b_1 \\ A_{21} x_1 + A_{22} x_2 + \cdots + A_{2M} x_M \geq b_2 \\ \vdots \\ A_{N1} x_1 + A_{N2} x_2 + \cdots + A_{NM} x_M \geq b_N \end{gather*}\]

每一行都除以\(\vert A_{i1}\vert\),然后重新排列一下,我们一定可以规约得到三个不等式,一个包含\(x_1\),一个包含\(-x_1\),最后一个完全不包含\(x_1\):

\[\begin{gather*} x_1 + \sum_{j=2}^M D_{hj}x_j \geq d_h, \quad h \in {1,\dots,H} \\ -x_1 + \sum_{j=2}^M E_{kj}x_j \geq e_k, \quad k \in {1,\dots,K} \\ \sum_{j=2}^M F_{lj} x_j \geq f_l, \quad l \in {1,\dots,L} \end{gather*}\]

其中\(H + K + L = N\),这样再把前两个式子加一下,就消掉了\(x_1\),这个过程一直进行下去,最终就可以得到每个\(x\)的范围,

\[\begin{gather*} \max(L^M \vec{b}) \leq x_M \leq \min(H^M\vec{b}) \\ \max(L^{M-1}\vec{b} + \hat{L}^{M-1}x_M) \leq x_{M-1} \leq \min(H^{M-1}\vec{b}+\hat{H}^{M-1}x_M) \\ \max(L^{M-2}\vec{b} + \hat{L}^{M-2}\vec{x}_{M-1\dots M} \leq x_{M-2} \leq \min(H^{M-2}\vec{b} + \hat{H}^{M-2}\vec{x}_{M-1\dots M}) \\ \vdots \\ \max(L^2\vec{b} + \hat{L}^2\vec{x}_{3\dots M}) \leq x_2 \leq \min(H^2\vec{b} + \hat{H}^2\vec{x}_{3\dots M}) \\ \max(L^1\vec{b} + \hat{L}^1\vec{x}_{2\dots M}) \leq x_1 \leq \min(H^1\vec{b} + \hat{H}^1\vec{x}_{2\dots M}) \end{gather*}\]

\(\vec{x}_{2\dots M}\)表示\(x_2, x_3, \dots, x_M\)组成的向量。

集合(\ref{set:beta})中的不等式自然也可以用这个方式来迭代求解。