Hiểu về einsum của NumPy


190

Tôi đang đấu tranh để hiểu chính xác làm thế nào einsumhoạt động. Tôi đã xem tài liệu và một vài ví dụ, nhưng nó dường như không dính.

Đây là một ví dụ chúng tôi đã đi qua trong lớp:

C = np.einsum("ij,jk->ki", A, B)

cho hai mảng AB

Tôi nghĩ rằng điều này sẽ mất A^T * B, nhưng tôi không chắc chắn (nó đang chuyển vị của một trong số họ phải không?). Bất cứ ai cũng có thể hướng dẫn tôi chính xác những gì đang xảy ra ở đây (và nói chung khi sử dụng einsum)?


7
Trên thực tế nó sẽ (A * B)^T, hoặc tương đương B^T * A^T.
Tigran Saluev

20
Tôi đã viết một bài blog ngắn về những điều cơ bản einsum ở đây . (Tôi rất vui khi ghép các bit có liên quan nhất vào câu trả lời trên Stack Overflow nếu hữu ích).
Alex Riley

1
@ajcr - Liên kết đẹp. Cảm ơn. Các numpytài liệu không đầy đủ khi giải thích các chi tiết.
rayryeng

Cảm ơn bạn đã bỏ phiếu tín nhiệm! Được biết, tôi đã đóng góp một câu trả lời dưới đây .
Alex Riley

Lưu ý rằng trong Python *không phải là phép nhân ma trận mà là phép nhân theo phần tử. Coi chừng!
ComputerSellectist

Câu trả lời:


368

(Lưu ý: câu trả lời này dựa trên một bài đăng blog ngắn về einsumtôi đã viết cách đây một thời gian.)

Không gì einsumlàm gì?

Hãy tưởng tượng rằng chúng ta có hai mảng đa chiều, AB. Bây giờ hãy giả sử rằng chúng ta muốn ...

  • nhân A với Bmột cách cụ thể để tạo ra các sản phẩm mới; và sau đó có thể
  • tổng hợp mảng mới này dọc theo các trục cụ thể; và sau đó có thể
  • hoán vị các trục của mảng mới theo thứ tự cụ thể.

Có một cơ hội tốt einsumsẽ giúp chúng ta thực hiện việc này nhanh hơn và hiệu quả hơn về bộ nhớ mà sự kết hợp của các hàm NumPy như multiply,sumtransposesẽ cho phép.

Làm thế nào einsum làm việc?

Đây là một ví dụ đơn giản (nhưng không hoàn toàn tầm thường). Lấy hai mảng sau:

A = np.array([0, 1, 2])

B = np.array([[ 0,  1,  2,  3],
              [ 4,  5,  6,  7],
              [ 8,  9, 10, 11]])

Chúng tôi sẽ nhân ABphần tử khôn ngoan và sau đó tổng hợp dọc theo các hàng của mảng mới. Trong NumPy "bình thường" chúng tôi sẽ viết:

>>> (A[:, np.newaxis] * B).sum(axis=1)
array([ 0, 22, 76])

Vì vậy, ở đây, hoạt động lập chỉ mục trên Acác đường thẳng lên các trục đầu tiên của hai mảng để phép nhân có thể được phát. Các hàng của mảng sản phẩm sau đó được tổng hợp để trả về câu trả lời.

Bây giờ nếu chúng ta muốn sử dụng einsumthay thế, chúng ta có thể viết:

>>> np.einsum('i,ij->i', A, B)
array([ 0, 22, 76])

Các chữ ký chuỗi 'i,ij->i'là chìa khóa ở đây và cần một chút giải thích. Bạn có thể nghĩ về nó trong hai nửa. Ở phía bên trái (bên trái của ->) chúng tôi đã gắn nhãn hai mảng đầu vào. Ở bên phải ->, chúng tôi đã gắn nhãn mảng mà chúng tôi muốn kết thúc.

Đây là những gì xảy ra tiếp theo:

  • Acó một trục; chúng tôi đã dán nhãn nó i. Và Bcó hai trục; chúng tôi đã gắn nhãn trục 0 là ivà trục 1 là j.

  • Bằng cách lặp lại nhãn itrong cả hai mảng đầu vào, chúng tôi đang nói einsumrằng hai trục này sẽ được nhân với nhau. Nói cách khác, chúng ta đang nhân mảng Avới từng cột của mảng B, giống như A[:, np.newaxis] * Bvậy.

  • Lưu ý rằng jkhông xuất hiện dưới dạng nhãn trong đầu ra mong muốn của chúng tôi; chúng tôi vừa mới sử dụng i(chúng tôi muốn kết thúc với mảng 1D). Bằng cách bỏ nhãn, chúng ta đang nói einsumđến tổng hợp dọc theo trục này. Nói cách khác, chúng tôi đang tổng hợp các hàng của sản phẩm, giống như .sum(axis=1)vậy.

Về cơ bản đó là tất cả những gì bạn cần biết để sử dụng einsum. Nó giúp chơi về một chút; nếu chúng ta để cả hai nhãn ở đầu ra 'i,ij->ij', chúng ta sẽ lấy lại một mảng sản phẩm 2D (giống như A[:, np.newaxis] * B). Nếu chúng tôi nói không có nhãn đầu ra 'i,ij->, chúng tôi sẽ lấy lại một số duy nhất (giống như làm (A[:, np.newaxis] * B).sum()).

einsumTuy nhiên, điều tuyệt vời là không xây dựng một mảng sản phẩm tạm thời trước; Nó chỉ tổng hợp các sản phẩm khi nó đi. Điều này có thể dẫn đến tiết kiệm lớn trong việc sử dụng bộ nhớ.

Một ví dụ lớn hơn một chút

Để giải thích sản phẩm chấm, đây là hai mảng mới:

A = array([[1, 1, 1],
           [2, 2, 2],
           [5, 5, 5]])

B = array([[0, 1, 0],
           [1, 1, 0],
           [1, 1, 1]])

Chúng tôi sẽ tính toán sản phẩm chấm bằng cách sử dụng np.einsum('ij,jk->ik', A, B). Dưới đây là hình ảnh hiển thị nhãn của ABvà mảng đầu ra mà chúng ta nhận được từ hàm:

nhập mô tả hình ảnh ở đây

Bạn có thể thấy nhãn đó jđược lặp lại - điều này có nghĩa là chúng tôi nhân các hàng Avới các cột của B. Hơn nữa, nhãnj không được bao gồm trong đầu ra - chúng tôi đang tổng hợp các sản phẩm này. Nhãn ik được giữ cho đầu ra, vì vậy chúng tôi lấy lại một mảng 2D.

Nó có thể là thậm chí rõ ràng hơn để so sánh kết quả này với các mảng nơi nhãn jkhông tóm gọn. Bên dưới, bên trái, bạn có thể thấy mảng 3D kết quả từ việc viết np.einsum('ij,jk->ijk', A, B)(tức là chúng tôi đã giữ nhãn j):

nhập mô tả hình ảnh ở đây

Trục tổng hợp j cho sản phẩm chấm mong đợi, hiển thị bên phải.

Một số bài tập

Để có thêm cảm giác einsum, có thể hữu ích khi triển khai các hoạt động mảng NumPy quen thuộc bằng cách sử dụng ký hiệu đăng ký. Bất cứ điều gì liên quan đến sự kết hợp của các trục nhân và tổng có thể được viết bằng cách sử dụng einsum .

Đặt A và B là hai mảng 1D có cùng độ dài. Ví dụ, A = np.arange(10)B = np.arange(5, 15).

  • Tổng của Acó thể được viết:

    np.einsum('i->', A)
  • Phép nhân phần tử A * B, có thể được viết:

    np.einsum('i,i->i', A, B)
  • Sản phẩm bên trong hoặc sản phẩm chấm, np.inner(A, B)hoặc np.dot(A, B), có thể được viết:

    np.einsum('i,i->', A, B) # or just use 'i,i'
  • Các sản phẩm bên ngoài np.outer(A, B), có thể được viết:

    np.einsum('i,j->ij', A, B)

Đối với mảng 2D CD, với điều kiện là các trục có độ dài tương thích (cả hai chiều dài giống nhau hoặc một trong số chúng có độ dài 1), đây là một vài ví dụ:

  • Dấu vết của C(tổng đường chéo chính) np.trace(C), có thể được viết:

    np.einsum('ii', C)
  • Nhân tố khôn ngoan của Cvà transpose của D, C * D.T, có thể được viết:

    np.einsum('ij,ji->ij', C, D)
  • Nhân từng phần tử của Cmảng D(để tạo mảng 4D) C[:, :, None, None] * D, có thể được viết:

    np.einsum('ij,kl->ijkl', C, D)  

1
Giải thích rất hay, cảm ơn. "Lưu ý rằng tôi không xuất hiện dưới dạng nhãn trong đầu ra mong muốn của chúng tôi" - không phải sao?
Ian Hincks

Cảm ơn @IanHincks! Điều đó trông giống như một lỗi đánh máy; Tôi đã sửa nó ngay bây giờ.
Alex Riley

1
Câu trả lời rất hay. Cũng đáng lưu ý rằng nó ij,jkcó thể tự hoạt động (không có mũi tên) để tạo thành phép nhân ma trận. Nhưng có vẻ như để rõ ràng, tốt nhất là đặt mũi tên và sau đó là kích thước đầu ra. Đó là trong bài viết trên blog.
ComputerSellectist

1
@Peaceful: đây là một trong những dịp khó chọn từ đúng! Tôi cảm thấy "cột" phù hợp hơn một chút ở đây vì Acó độ dài 3, giống như chiều dài của các cột trong B(trong khi các hàng Bcó độ dài 4 và không thể nhân số phần tử theo A).
Alex Riley

1
Lưu ý rằng việc bỏ qua ->ảnh hưởng đến ngữ nghĩa: "Trong chế độ ẩn, các chỉ số được chọn rất quan trọng vì các trục của đầu ra được sắp xếp lại theo thứ tự bảng chữ cái. Điều này có nghĩa là np.einsum('ij', a)không ảnh hưởng đến mảng 2D, trong khi np.einsum('ji', a)chuyển vị của nó."
BallpointBen

40

Nắm bắt ý tưởng numpy.einsum()là rất dễ dàng nếu bạn hiểu nó bằng trực giác. Ví dụ, hãy bắt đầu với một mô tả đơn giản liên quan đến nhân ma trận .


Để sử dụng numpy.einsum(), tất cả những gì bạn phải làm là chuyển chuỗi được gọi là chuỗi đăng ký dưới dạng đối số, theo sau là mảng đầu vào của bạn .

Hãy nói rằng bạn có hai 2D mảng, AB, và bạn muốn làm nhân ma trận. Bạn cũng vậy:

np.einsum("ij, jk -> ik", A, B)

Ở đây, chuỗi ký tự ij tương ứng với mảng Atrong khi chuỗi ký tự jk tương ứng với mảng B. Ngoài ra, điều quan trọng nhất cần lưu ý ở đây là số lượng ký tự trong mỗi chuỗi ký tự phải phù hợp với kích thước của mảng. (tức là hai ký tự cho mảng 2D, ba ký tự cho mảng 3D, v.v.) Và nếu bạn lặp lại ký tự giữa các chuỗi ký tự ( jtrong trường hợp của chúng tôi), thì điều đó có nghĩa là bạn muốn eintổng số xảy ra dọc theo các kích thước đó. Vì vậy, họ sẽ được giảm tổng. (tức là kích thước đó sẽ biến mất )

Chuỗi đăng ký sau này ->, sẽ là mảng kết quả của chúng tôi. Nếu bạn để trống, mọi thứ sẽ được tính tổng và kết quả là giá trị vô hướng được trả về. Khác mảng kết quả sẽ có kích thước theo chuỗi con . Trong ví dụ của chúng tôi, nó sẽ là ik. Điều này là trực quan bởi vì chúng ta biết rằng để nhân ma trận, số lượng cột trong mảng Aphải khớp với số lượng hàng trong mảng Bđang xảy ra ở đây (tức là chúng ta mã hóa kiến ​​thức này bằng cách lặp lại char jtrong chuỗi đăng ký )


Dưới đây là một số ví dụ minh họa việc sử dụng / sức mạnh của np.einsum()việc thực hiện một số hoạt động tenor hoặc nd-mảng phổ biến , ngắn gọn.

Đầu vào

# a vector
In [197]: vec
Out[197]: array([0, 1, 2, 3])

# an array
In [198]: A
Out[198]: 
array([[11, 12, 13, 14],
       [21, 22, 23, 24],
       [31, 32, 33, 34],
       [41, 42, 43, 44]])

# another array
In [199]: B
Out[199]: 
array([[1, 1, 1, 1],
       [2, 2, 2, 2],
       [3, 3, 3, 3],
       [4, 4, 4, 4]])

1) Phép nhân ma trận (tương tự np.matmul(arr1, arr2))

In [200]: np.einsum("ij, jk -> ik", A, B)
Out[200]: 
array([[130, 130, 130, 130],
       [230, 230, 230, 230],
       [330, 330, 330, 330],
       [430, 430, 430, 430]])

2) Trích xuất các phần tử dọc theo đường chéo chính (tương tự np.diag(arr))

In [202]: np.einsum("ii -> i", A)
Out[202]: array([11, 22, 33, 44])

3) Sản phẩm Hadamard (nghĩa là sản phẩm có yếu tố thông minh gồm hai mảng) (tương tự arr1 * arr2)

In [203]: np.einsum("ij, ij -> ij", A, B)
Out[203]: 
array([[ 11,  12,  13,  14],
       [ 42,  44,  46,  48],
       [ 93,  96,  99, 102],
       [164, 168, 172, 176]])

4) Bình phương phần tử khôn ngoan (tương tự np.square(arr)hoặc arr ** 2)

In [210]: np.einsum("ij, ij -> ij", B, B)
Out[210]: 
array([[ 1,  1,  1,  1],
       [ 4,  4,  4,  4],
       [ 9,  9,  9,  9],
       [16, 16, 16, 16]])

5) Dấu vết (tức là tổng các phần tử đường chéo chính) (tương tự np.trace(arr))

In [217]: np.einsum("ii -> ", A)
Out[217]: 110

6) Ma trận chuyển vị (tương tự np.transpose(arr))

In [221]: np.einsum("ij -> ji", A)
Out[221]: 
array([[11, 21, 31, 41],
       [12, 22, 32, 42],
       [13, 23, 33, 43],
       [14, 24, 34, 44]])

7) Sản phẩm bên ngoài (của vectơ) (tương tự np.outer(vec1, vec2))

In [255]: np.einsum("i, j -> ij", vec, vec)
Out[255]: 
array([[0, 0, 0, 0],
       [0, 1, 2, 3],
       [0, 2, 4, 6],
       [0, 3, 6, 9]])

8) Sản phẩm bên trong (của vectơ) (tương tự np.inner(vec1, vec2))

In [256]: np.einsum("i, i -> ", vec, vec)
Out[256]: 14

9) Tổng dọc trục 0 (tương tự np.sum(arr, axis=0))

In [260]: np.einsum("ij -> j", B)
Out[260]: array([10, 10, 10, 10])

10) Tổng dọc trục 1 (tương tự np.sum(arr, axis=1))

In [261]: np.einsum("ij -> i", B)
Out[261]: array([ 4,  8, 12, 16])

11) Phép nhân ma trận hàng loạt

In [287]: BM = np.stack((A, B), axis=0)

In [288]: BM
Out[288]: 
array([[[11, 12, 13, 14],
        [21, 22, 23, 24],
        [31, 32, 33, 34],
        [41, 42, 43, 44]],

       [[ 1,  1,  1,  1],
        [ 2,  2,  2,  2],
        [ 3,  3,  3,  3],
        [ 4,  4,  4,  4]]])

In [289]: BM.shape
Out[289]: (2, 4, 4)

# batch matrix multiply using einsum
In [292]: BMM = np.einsum("bij, bjk -> bik", BM, BM)

In [293]: BMM
Out[293]: 
array([[[1350, 1400, 1450, 1500],
        [2390, 2480, 2570, 2660],
        [3430, 3560, 3690, 3820],
        [4470, 4640, 4810, 4980]],

       [[  10,   10,   10,   10],
        [  20,   20,   20,   20],
        [  30,   30,   30,   30],
        [  40,   40,   40,   40]]])

In [294]: BMM.shape
Out[294]: (2, 4, 4)

12) Tổng dọc trục 2 (tương tự np.sum(arr, axis=2))

In [330]: np.einsum("ijk -> ij", BM)
Out[330]: 
array([[ 50,  90, 130, 170],
       [  4,   8,  12,  16]])

13) Tính tổng tất cả các phần tử trong mảng (tương tự np.sum(arr))

In [335]: np.einsum("ijk -> ", BM)
Out[335]: 480

14) Tính tổng trên nhiều trục (nghĩa là cận biên)
(tương tự np.sum(arr, axis=(axis0, axis1, axis2, axis3, axis4, axis6, axis7)))

# 8D array
In [354]: R = np.random.standard_normal((3,5,4,6,8,2,7,9))

# marginalize out axis 5 (i.e. "n" here)
In [363]: esum = np.einsum("ijklmnop -> n", R)

# marginalize out axis 5 (i.e. sum over rest of the axes)
In [364]: nsum = np.sum(R, axis=(0,1,2,3,4,6,7))

In [365]: np.allclose(esum, nsum)
Out[365]: True

15) Sản phẩm Double Dot (tương tự np.sum (sản phẩm hadamard) xem 3 )

In [772]: A
Out[772]: 
array([[1, 2, 3],
       [4, 2, 2],
       [2, 3, 4]])

In [773]: B
Out[773]: 
array([[1, 4, 7],
       [2, 5, 8],
       [3, 6, 9]])

In [774]: np.einsum("ij, ij -> ", A, B)
Out[774]: 124

16) Phép nhân mảng 2D và 3D

Phép nhân như vậy có thể rất hữu ích khi giải hệ phương trình tuyến tính ( Ax = b ) trong đó bạn muốn xác minh kết quả.

# inputs
In [115]: A = np.random.rand(3,3)
In [116]: b = np.random.rand(3, 4, 5)

# solve for x
In [117]: x = np.linalg.solve(A, b.reshape(b.shape[0], -1)).reshape(b.shape)

# 2D and 3D array multiplication :)
In [118]: Ax = np.einsum('ij, jkl', A, x)

# indeed the same!
In [119]: np.allclose(Ax, b)
Out[119]: True

Ngược lại, nếu người ta phải sử dụng np.matmul()để xác minh này, chúng tôi phải thực hiện một vài reshapethao tác để đạt được kết quả tương tự như:

# reshape 3D array `x` to 2D, perform matmul
# then reshape the resultant array to 3D
In [123]: Ax_matmul = np.matmul(A, x.reshape(x.shape[0], -1)).reshape(x.shape)

# indeed correct!
In [124]: np.allclose(Ax, Ax_matmul)
Out[124]: True

Phần thưởng : Đọc thêm toán ở đây: Einstein-Summation và chắc chắn ở đây: Tenor-Notation


7

Cho phép tạo 2 mảng, với các kích thước khác nhau nhưng tương thích để làm nổi bật khả năng tương tác của chúng

In [43]: A=np.arange(6).reshape(2,3)
Out[43]: 
array([[0, 1, 2],
       [3, 4, 5]])


In [44]: B=np.arange(12).reshape(3,4)
Out[44]: 
array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11]])

Tính toán của bạn, lấy một 'chấm' (tổng sản phẩm) của một (2,3) với (3,4) để tạo ra một mảng (4.2). ilà mờ thứ 1 của A, cuối cùng của C; kcuối cùng của B, 1 của C. jđược "tiêu thụ" bởi tổng kết.

In [45]: C=np.einsum('ij,jk->ki',A,B)
Out[45]: 
array([[20, 56],
       [23, 68],
       [26, 80],
       [29, 92]])

Điều này giống như np.dot(A,B).T- đó là đầu ra cuối cùng được chuyển đổi.

Để xem thêm về những gì xảy ra j, thay đổi các Cmục con thành ijk:

In [46]: np.einsum('ij,jk->ijk',A,B)
Out[46]: 
array([[[ 0,  0,  0,  0],
        [ 4,  5,  6,  7],
        [16, 18, 20, 22]],

       [[ 0,  3,  6,  9],
        [16, 20, 24, 28],
        [40, 45, 50, 55]]])

Điều này cũng có thể được sản xuất với:

A[:,:,None]*B[None,:,:]

Nghĩa là, thêm một thứ knguyên vào cuối Avà một iphía trước B, dẫn đến một mảng (2,3,4).

0 + 4 + 16 = 20,, 9 + 28 + 55 = 92vv; Tổng hợp jvà chuyển đổi để có được kết quả trước đó:

np.sum(A[:,:,None] * B[None,:,:], axis=1).T

# C[k,i] = sum(j) A[i,j (,k) ] * B[(i,)  j,k]

6

Tôi tìm thấy NumPy: Các thủ thuật của giao dịch (Phần II) mang tính hướng dẫn

Chúng tôi sử dụng -> để chỉ ra thứ tự của mảng đầu ra. Vì vậy, hãy nghĩ về 'ij, i-> j' là có bên tay trái (LHS) và bên tay phải (RHS). Bất kỳ sự lặp lại của các nhãn trên LHS đều tính toán yếu tố sản phẩm một cách khôn ngoan và sau đó tính tổng. Bằng cách thay đổi nhãn ở phía RHS (đầu ra), chúng ta có thể xác định trục mà chúng ta muốn tiến hành đối với mảng đầu vào, tức là tính tổng theo trục 0, 1, v.v.

import numpy as np

>>> a
array([[1, 1, 1],
       [2, 2, 2],
       [3, 3, 3]])
>>> b
array([[0, 1, 2],
       [3, 4, 5],
       [6, 7, 8]])
>>> d = np.einsum('ij, jk->ki', a, b)

Lưu ý rằng có ba trục, i, j, k và j được lặp lại (ở phía bên trái). i,jđại diện cho các hàng và cột cho a. j,kchob .

Để tính toán sản phẩm và căn chỉnh jtrục chúng ta cần thêm một trục vào a. ( bsẽ được phát dọc theo (?) trục đầu tiên)

a[i, j, k]
   b[j, k]

>>> c = a[:,:,np.newaxis] * b
>>> c
array([[[ 0,  1,  2],
        [ 3,  4,  5],
        [ 6,  7,  8]],

       [[ 0,  2,  4],
        [ 6,  8, 10],
        [12, 14, 16]],

       [[ 0,  3,  6],
        [ 9, 12, 15],
        [18, 21, 24]]])

jkhông có ở phía bên tay phải nên chúng tôi tổng hợp jđó là trục thứ hai của mảng 3x3x3

>>> c = c.sum(1)
>>> c
array([[ 9, 12, 15],
       [18, 24, 30],
       [27, 36, 45]])

Cuối cùng, các chỉ số được (theo thứ tự chữ cái) đảo ngược ở phía bên tay phải để chúng ta hoán vị.

>>> c.T
array([[ 9, 18, 27],
       [12, 24, 36],
       [15, 30, 45]])

>>> np.einsum('ij, jk->ki', a, b)
array([[ 9, 18, 27],
       [12, 24, 36],
       [15, 30, 45]])
>>>

NumPy: Các mánh khóe trong giao dịch (Phần II) dường như cần có lời mời từ chủ sở hữu trang web cũng như tài khoản Wordpress
Tejas Shetty

... liên kết cập nhật, may mắn thay tôi đã tìm thấy nó với một tìm kiếm. - Thnx.
wwii

@TejasShetty Rất nhiều câu trả lời tốt hơn ở đây bây giờ - có lẽ tôi nên xóa câu trả lời này.
wwii

2
Xin đừng xóa câu trả lời của bạn.
Tejas Shetty

4

Khi đọc các phương trình einsum, tôi thấy thật hữu ích khi chỉ có thể làm sôi chúng xuống các phiên bản bắt buộc của chúng.

Hãy bắt đầu với câu lệnh (áp đặt) sau:

C = np.einsum('bhwi,bhwj->bij', A, B)

Làm việc thông qua dấu câu đầu tiên, chúng ta thấy rằng chúng ta có hai đốm màu được phân tách bằng dấu phẩy - bhwibhwj, trước mũi tên và một đốm 3 chữ cái duy nhấtbij sau nó. Do đó, phương trình tạo ra kết quả tenxơ bậc 3 từ hai đầu vào tenxơ bậc 4.

Bây giờ, hãy để mỗi chữ cái trong mỗi blob là tên của một biến phạm vi. Vị trí mà chữ cái xuất hiện trong blob là chỉ số của trục mà nó nằm trong phạm vi đó. Do đó, tổng cộng bắt buộc tạo ra từng phần tử của C, phải bắt đầu bằng ba vòng lặp lồng nhau, một cho mỗi chỉ số của C.

for b in range(...):
    for i in range(...):
        for j in range(...):
            # the variables b, i and j index C in the order of their appearance in the equation
            C[b, i, j] = ...

Vì vậy, về cơ bản, bạn có một forvòng lặp cho mọi chỉ số đầu ra của C. Bây giờ chúng ta sẽ để các phạm vi không xác định.

Tiếp theo chúng ta nhìn vào mặt trái tay - được có bất kỳ biến phạm vi đó mà không xuất hiện trên cánh tay phải bên? Trong trường hợp của chúng tôi - có, hw. Thêm một forvòng lặp lồng bên trong cho mỗi biến như vậy:

for b in range(...):
    for i in range(...):
        for j in range(...):
            C[b, i, j] = 0
            for h in range(...):
                for w in range(...):
                    ...

Bên trong vòng lặp trong cùng, bây giờ chúng ta có tất cả các chỉ mục được xác định, vì vậy chúng ta có thể viết tổng kết thực tế và bản dịch hoàn tất:

# three nested for-loops that index the elements of C
for b in range(...):
    for i in range(...):
        for j in range(...):

            # prepare to sum
            C[b, i, j] = 0

            # two nested for-loops for the two indexes that don't appear on the right-hand side
            for h in range(...):
                for w in range(...):
                    # Sum! Compare the statement below with the original einsum formula
                    # 'bhwi,bhwj->bij'

                    C[b, i, j] += A[b, h, w, i] * B[b, h, w, j]

Nếu bạn đã có thể theo dõi mã cho đến nay, xin chúc mừng! Đây là tất cả những gì bạn cần để có thể đọc phương trình einsum. Đặc biệt lưu ý cách công thức einsum ban đầu ánh xạ tới câu lệnh tổng hợp cuối cùng trong đoạn trích ở trên. Các vòng lặp for và giới hạn phạm vi chỉ là lông tơ và tuyên bố cuối cùng đó là tất cả những gì bạn thực sự cần để hiểu những gì đang diễn ra.

Để hoàn thiện, hãy xem cách xác định phạm vi cho từng biến phạm vi. Chà, phạm vi của mỗi biến chỉ đơn giản là chiều dài của (các) thứ mà nó lập chỉ mục. Rõ ràng, nếu một biến chỉ mục nhiều hơn một chiều trong một hoặc nhiều thang đo, thì độ dài của mỗi kích thước đó phải bằng nhau. Đây là đoạn mã trên với phạm vi đầy đủ:

# C's shape is determined by the shapes of the inputs
# b indexes both A and B, so its range can come from either A.shape or B.shape
# i indexes only A, so its range can only come from A.shape, the same is true for j and B
assert A.shape[0] == B.shape[0]
assert A.shape[1] == B.shape[1]
assert A.shape[2] == B.shape[2]
C = np.zeros((A.shape[0], A.shape[3], B.shape[3]))
for b in range(A.shape[0]): # b indexes both A and B, or B.shape[0], which must be the same
    for i in range(A.shape[3]):
        for j in range(B.shape[3]):
            # h and w can come from either A or B
            for h in range(A.shape[1]):
                for w in range(A.shape[2]):
                    C[b, i, j] += A[b, h, w, i] * B[b, h, w, j]
Khi sử dụng trang web của chúng tôi, bạn xác nhận rằng bạn đã đọc và hiểu Chính sách cookieChính sách bảo mật của chúng tôi.
Licensed under cc by-sa 3.0 with attribution required.