Einsum

文章目录
  1. 1. 爱因斯坦求和约定
  2. 2. 约定
  3. 3. 例子
  4. 4. 缺点

爱因斯坦求和约定

矩阵运算中,Einstein summation convention 可以让表达式非常简洁,并且对于我们理解矩阵运算非常有用。一个例子如下

1
2
3
4
5
6
7
A = np.arange(3)
B = np.arange(12).reshape(3, 4)
# Goal: multiply A & B element wise, then sum along axis 1.
# 【 一般矩阵运算如下 】
(A[:, None] * B).sum(axis=1)
# 【 Or using Einstein Summation convention 】
np.einsum('i,ij->i', A, B)

约定

爱因斯坦求和约定如下

  1. 输入数组间重复的字母对应的轴,我们会沿着这个轴进行逐元素(element wise)的相乘
  2. ‘->’符号右侧中所忽略掉的字母对应的轴,我们会沿着这个轴进行求和
  3. ‘->’符号右侧输出的数组的shape可以是任意的

当忽略掉 ‘->’符号及它右侧的东西的时候,会将左侧只出现一次的字母按照字母序作为结果的shape

例子

Call signature NumPy equivalent Description
('i', A) A returns a view of A
('i->', A) sum(A) sums the values of A
('i,i->i', A, B) A * B element-wise multiplication of A and B
('i,i', A, B) inner(A, B) inner product of A and B
('i,j->ij', A, B) outer(A, B) outer product of A and B

Now let A and B be two 2D arrays with compatible shapes:

Call signature NumPy equivalent Description
('ij', A) A returns a view of A
('ji', A) A.T view transpose of A
('ii->i', A) diag(A) view main diagonal of A
('ii', A) trace(A) sums main diagonal of A
('ij->', A) sum(A) sums the values of A
('ij->j', A) sum(A, axis=0) sum down the columns of A (across rows)
('ij->i', A) sum(A, axis=1) sum horizontally along the rows of A
('ij,ij->ij', A, B) A * B element-wise multiplication of A and B
('ij,ji->ij', A, B) A * B.T element-wise multiplication of A and B.T
('ij,jk', A, B) dot(A, B) matrix multiplication of A and B
('ij,kj->ik', A, B) inner(A, B) inner product of A and B
('ij,kj->ikj', A, B) A[:, None] * B each row of A multiplied by B
('ij,kl->ijkl', A, B) A[:, :, None, None] * B each value of A multiplied by B

When working with larger numbers of dimensions, keep in mind that einsum allows the ellipses syntax '...'. This provides a convenient way to label the axes we’re not particularly interested in, e.g. np.einsum('...ij,ji->...', a, b) would multiply just the last two axes of a with the 2D array b. There are more examples in the documentation.

缺点

  1. 速度

    einsum虽然表达方式简洁,但是速度可能不令人满意,尤其是当矩阵的数量较多时。

    Numpy提供了一些方式进行加速,可以参考官网。

Reference:

  1. https://ajcr.net/Basic-guide-to-einsum/
  2. numpy.einsum docs