Nói một cách đơn giản, torch.Tensor.view()
được lấy cảm hứng từ numpy.ndarray.reshape()
hoặc numpy.reshape()
, tạo ra một cái nhìn mới về tenxơ, miễn là hình dạng mới tương thích với hình dạng của tenxơ ban đầu.
Hãy hiểu chi tiết điều này bằng một ví dụ cụ thể.
In [43]: t = torch.arange(18)
In [44]: t
Out[44]:
tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17])
Với tensor này t
hình dạng (18,)
, mới xem có thể chỉ được tạo ra cho các hình dạng sau:
(1, 18)
hoặc tương đương (1, -1)
hoặc hoặc tương đương hoặc hoặc tương đương hoặc hoặc tương đương hoặc hoặc tương đương hoặc hoặc tương đương hoặc(-1, 18)
(2, 9)
(2, -1)
(-1, 9)
(3, 6)
(3, -1)
(-1, 6)
(6, 3)
(6, -1)
(-1, 3)
(9, 2)
(9, -1)
(-1, 2)
(18, 1)
(18, -1)
(-1, 1)
Như chúng ta đã có thể quan sát từ các bộ hình dạng ở trên, phép nhân của các phần tử của bộ hình dạng (ví dụ 2*9
, 3*6
v.v.) phải luôn bằng tổng số phần tử trong thang đo ban đầu ( 18
trong ví dụ của chúng tôi).
Một điều khác để quan sát là chúng tôi đã sử dụng một -1
trong những vị trí trong mỗi bộ hình dạng. Bằng cách sử dụng a -1
, chúng ta sẽ lười biếng trong việc tự tính toán và giao nhiệm vụ cho PyTorch để tính toán giá trị đó cho hình dạng khi nó tạo ra chế độ xem mới . Một điều quan trọng cần lưu ý là chúng ta chỉ có thể sử dụng một -1
trong hình dạng tuple. Các giá trị còn lại phải được cung cấp rõ ràng bởi chúng tôi. Else PyTorch sẽ khiếu nại bằng cách ném RuntimeError
:
RuntimeError: chỉ có thể suy ra một chiều
Vì vậy, với tất cả các hình dạng được đề cập ở trên, PyTorch sẽ luôn trả lại một cái nhìn mới về tenxơ ban đầu t
. Điều này về cơ bản có nghĩa là nó chỉ thay đổi thông tin sải chân của tenor cho mỗi khung nhìn mới được yêu cầu.
Dưới đây là một số ví dụ minh họa cách các bước của thang đo được thay đổi với mỗi chế độ xem mới .
# stride of our original tensor `t`
In [53]: t.stride()
Out[53]: (1,)
Bây giờ, chúng ta sẽ thấy những bước tiến cho các quan điểm mới :
# shape (1, 18)
In [54]: t1 = t.view(1, -1)
# stride tensor `t1` with shape (1, 18)
In [55]: t1.stride()
Out[55]: (18, 1)
# shape (2, 9)
In [56]: t2 = t.view(2, -1)
# stride of tensor `t2` with shape (2, 9)
In [57]: t2.stride()
Out[57]: (9, 1)
# shape (3, 6)
In [59]: t3 = t.view(3, -1)
# stride of tensor `t3` with shape (3, 6)
In [60]: t3.stride()
Out[60]: (6, 1)
# shape (6, 3)
In [62]: t4 = t.view(6,-1)
# stride of tensor `t4` with shape (6, 3)
In [63]: t4.stride()
Out[63]: (3, 1)
# shape (9, 2)
In [65]: t5 = t.view(9, -1)
# stride of tensor `t5` with shape (9, 2)
In [66]: t5.stride()
Out[66]: (2, 1)
# shape (18, 1)
In [68]: t6 = t.view(18, -1)
# stride of tensor `t6` with shape (18, 1)
In [69]: t6.stride()
Out[69]: (1, 1)
Vì vậy, đó là sự kỳ diệu của view()
chức năng. Nó chỉ thay đổi các bước của tenxơ (ban đầu) cho mỗi chế độ xem mới , miễn là hình dạng của chế độ xem mới tương thích với hình dạng ban đầu.
Một điều thú vị khác mà người ta có thể quan sát từ các bộ sải chân là giá trị của phần tử ở vị trí thứ 0 bằng với giá trị của phần tử ở vị trí thứ 1 của hình dạng.
In [74]: t3.shape
Out[74]: torch.Size([3, 6])
|
In [75]: t3.stride() |
Out[75]: (6, 1) |
|_____________|
Điều này là do:
In [76]: t3
Out[76]:
tensor([[ 0, 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10, 11],
[12, 13, 14, 15, 16, 17]])
sải chân (6, 1)
nói rằng để đi từ một yếu tố này sang yếu tố tiếp theo dọc theo chiều thứ 0 , chúng ta phải nhảy hoặc thực hiện 6 bước. (tức là để đi từ 0
đến 6
, người ta phải thực hiện 6 bước.) Nhưng để đi từ yếu tố này sang yếu tố tiếp theo trong chiều thứ 1 , chúng ta chỉ cần một bước (ví dụ: đi từ 2
đến 3
).
Do đó, thông tin bước tiến là cốt lõi của cách các phần tử được truy cập từ bộ nhớ để thực hiện tính toán.
Hàm này sẽ trả về một khung nhìn và hoàn toàn giống như sử dụng torch.Tensor.view()
miễn là hình dạng mới tương thích với hình dạng của tenxơ ban đầu. Nếu không, nó sẽ trả về một bản sao.
Tuy nhiên, các ghi chú torch.reshape()
cảnh báo rằng:
đầu vào liền kề và đầu vào với các bước tương thích có thể được định hình lại mà không cần sao chép, nhưng người ta không nên phụ thuộc vào hành vi sao chép so với hành vi xem.
reshape
trong PyTorch?!