NumPy chọn chỉ mục cột cụ thể cho mỗi hàng bằng cách sử dụng danh sách chỉ mục


90

Tôi đang đấu tranh để chọn các cột cụ thể trên mỗi hàng của NumPyma trận.

Giả sử tôi có ma trận sau đây mà tôi sẽ gọi X:

[1, 2, 3]
[4, 5, 6]
[7, 8, 9]

Tôi cũng có một listsố chỉ mục cột trên mỗi hàng mà tôi sẽ gọi Y:

[1, 0, 2]

Tôi cần lấy các giá trị:

[2]
[4]
[9]

Thay vì a listvới các chỉ mục Y, tôi cũng có thể tạo một ma trận có hình dạng giống như Xtrong đó mọi cột là a bool/ inttrong phạm vi giá trị 0-1, cho biết đây có phải là cột bắt buộc hay không.

[0, 1, 0]
[1, 0, 0]
[0, 0, 1]

Tôi biết điều này có thể được thực hiện bằng cách lặp lại mảng và chọn các giá trị cột mà tôi cần. Tuy nhiên, điều này sẽ được thực thi thường xuyên trên các mảng dữ liệu lớn và đó là lý do tại sao nó phải chạy nhanh nhất có thể.

Vì vậy, tôi đã tự hỏi liệu có một giải pháp tốt hơn không?

Cảm ơn bạn.


Câu trả lời là tốt hơn cho bạn? stackoverflow.com/a/17081678/5046896
GoingMyWay

Câu trả lời:


102

Nếu bạn có một mảng boolean, bạn có thể thực hiện lựa chọn trực tiếp dựa trên như vậy:

>>> a = np.array([True, True, True, False, False])
>>> b = np.array([1,2,3,4,5])
>>> b[a]
array([1, 2, 3])

Để đi cùng với ví dụ ban đầu của bạn, bạn có thể làm như sau:

>>> a = np.array([[1,2,3], [4,5,6], [7,8,9]])
>>> b = np.array([[False,True,False],[True,False,False],[False,False,True]])
>>> a[b]
array([2, 4, 9])

Bạn cũng có thể thêm arangevà thực hiện lựa chọn trực tiếp trên đó, mặc dù tùy thuộc vào cách bạn đang tạo mảng boolean của mình và mã của bạn trông giống như YMMV.

>>> a = np.array([[1,2,3], [4,5,6], [7,8,9]])
>>> a[np.arange(len(a)), [1,0,2]]
array([2, 4, 9])

Hy vọng điều đó sẽ hữu ích, hãy cho tôi biết nếu bạn có thêm bất kỳ câu hỏi nào.


11
+1 cho ví dụ bằng cách sử dụng arange. Điều này đặc biệt hữu ích với tôi để lấy khối khác nhau từ nhiều ma trận (vì vậy về cơ bản là trường hợp 3D của ví dụ này)
Griddo

1
Xin chào, bạn có thể giải thích lý do tại sao chúng tôi phải sử dụng arangethay vì :? Tôi biết cách của bạn hiệu quả và cách của tôi thì không, nhưng tôi muốn hiểu tại sao.
marcotama

@tamzord bởi vì nó là một mảng numpy và không phải là một danh sách vanilla python, vì vậy :cú pháp không hoạt động theo cùng một cách.
Slater Victoroff

1
@SlaterTyranus, cảm ơn bạn đã phản hồi. Sự hiểu biết của tôi, sau một số bài đọc, là kết hợp :với lập chỉ mục nâng cao có nghĩa là: "đối với mọi không gian con cùng :, hãy áp dụng lập chỉ mục nâng cao đã cho". Tôi hiểu có đúng không?
marcotama

@tamzord giải thích những gì bạn có ý nghĩa bởi "tiểu vũ trụ"
Slater Victoroff

35

Bạn có thể làm điều gì đó như sau:

In [7]: a = np.array([[1, 2, 3],
   ...: [4, 5, 6],
   ...: [7, 8, 9]])

In [8]: lst = [1, 0, 2]

In [9]: a[np.arange(len(a)), lst]
Out[9]: array([2, 4, 9])

Thông tin thêm về lập chỉ mục mảng đa chiều: http://docs.scipy.org/doc/numpy/user/basics.indexing.html#indexing-multi-dimensional-arrays


1
đấu tranh để hiểu tại sao lại cần arange thay vì đơn giản là ':' hoặc range.
MadmanLee

@MadmanLee Xin chào, việc sử dụng :sẽ xuất ra nhiều len(a)lần kết quả, thay vào đó, chỉ số của mỗi hàng sẽ in ra kết quả dự đoán.
GoingMyWay

1
Tôi nghĩ đây là cách chính xác và thanh lịch để giải quyết vấn đề này.
GoingMyWay

6

Một cách đơn giản có thể giống như sau:

In [1]: a = np.array([[1, 2, 3],
   ...: [4, 5, 6],
   ...: [7, 8, 9]])

In [2]: y = [1, 0, 2]  #list of indices we want to select from matrix 'a'

range(a.shape[0]) sẽ trở lại array([0, 1, 2])

In [3]: a[range(a.shape[0]), y] #we're selecting y indices from every row
Out[3]: array([2, 4, 9])

1
Vui lòng xem xét thêm các giải thích.
souki

@souki Tôi đã thêm giải thích ngay bây giờ. Cảm ơn
Dhaval Mayatra

6

Các numpyphiên bản gần đây đã thêm một take_along_axis(và put_along_axis) thực hiện việc lập chỉ mục này một cách rõ ràng.

In [101]: a = np.arange(1,10).reshape(3,3)                                                             
In [102]: b = np.array([1,0,2])                                                                        
In [103]: np.take_along_axis(a, b[:,None], axis=1)                                                     
Out[103]: 
array([[2],
       [4],
       [9]])

Nó hoạt động theo cách tương tự như:

In [104]: a[np.arange(3), b]                                                                           
Out[104]: array([2, 4, 9])

nhưng với cách xử lý trục khác nhau. Nó đặc biệt nhằm vào việc áp dụng các kết quả của argsortargmax.


3

Bạn có thể làm điều đó bằng cách sử dụng trình lặp. Như thế này:

np.fromiter((row[index] for row, index in zip(X, Y)), dtype=int)

Thời gian:

N = 1000
X = np.zeros(shape=(N, N))
Y = np.arange(N)

#@Aशwini चhaudhary
%timeit X[np.arange(len(X)), Y]
10000 loops, best of 3: 30.7 us per loop

#mine
%timeit np.fromiter((row[index] for row, index in zip(X, Y)), dtype=int)
1000 loops, best of 3: 1.15 ms per loop

#mine
%timeit np.diag(X.T[Y])
10 loops, best of 3: 20.8 ms per loop

1
OP đã đề cập rằng nó sẽ chạy nhanh trên các mảng lớn , vì vậy điểm chuẩn của bạn không đại diện cho lắm. Tôi tò mò muốn biết phương thức cuối cùng của bạn hoạt động như thế nào đối với (nhiều) mảng lớn hơn!

@moarningsun: Đã cập nhật. np.diag(X.T[Y])quá chậm ... Nhưng np.diag(X.T)quá nhanh (10us). Tôi không biết tại sao.
Kei Minagawa

0

Một cách thông minh khác là chuyển đổi mảng đầu tiên và lập chỉ mục nó sau đó. Cuối cùng, lấy đường chéo, nó luôn là câu trả lời đúng.

X = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
Y = np.array([1, 0, 2, 2])

np.diag(X.T[Y])

Từng bước một:

Mảng ban đầu:

>>> X
array([[ 1,  2,  3],
       [ 4,  5,  6],
       [ 7,  8,  9],
       [10, 11, 12]])

>>> Y
array([1, 0, 2, 2])

Transpose để có thể lập chỉ mục đúng.

>>> X.T
array([[ 1,  4,  7, 10],
       [ 2,  5,  8, 11],
       [ 3,  6,  9, 12]])

Nhận các hàng theo thứ tự Y.

>>> X.T[Y]
array([[ 2,  5,  8, 11],
       [ 1,  4,  7, 10],
       [ 3,  6,  9, 12],
       [ 3,  6,  9, 12]])

Đường chéo bây giờ sẽ trở nên rõ ràng.

>>> np.diag(X.T[Y])
array([ 2,  4,  9, 12]

1
Kỹ thuật này hoạt động và trông rất thanh lịch. Tuy nhiên, tôi thấy rằng cách tiếp cận này hoàn toàn bùng nổ khi bạn đang xử lý các mảng lớn. Trong trường hợp của tôi, NumPy đã nuốt 30GB swap và làm đầy ổ SSD của tôi. Tôi khuyên bạn nên sử dụng phương pháp lập chỉ mục nâng cao để thay thế.
5nefarious
Khi sử dụng trang web của chúng tôi, bạn xác nhận rằng bạn đã đọc và hiểu Chính sách cookieChính sách bảo mật của chúng tôi.
Licensed under cc by-sa 3.0 with attribution required.