Đây là cách tiếp cận O (max (x) + len (x)) bằng cách sử dụng scipy.sparse
:
import numpy as np
from scipy import sparse
x = np.array("1 2 2 0 0 1 3 5".split(),int)
x
# array([1, 2, 2, 0, 0, 1, 3, 5])
M,N = x.max()+1,x.size
sparse.csc_matrix((x,x,np.arange(N+1)),(M,N)).tolil().rows.tolist()
# [[3, 4], [0, 5], [1, 2], [6], [], [7]]
Điều này hoạt động bằng cách tạo một ma trận thưa thớt với các mục tại các vị trí (x [0], 0), (x [1], 1), ... Sử dụng CSC
định dạng (cột thưa nén) điều này khá đơn giản. Ma trận sau đó được chuyển đổi sang LIL
định dạng (danh sách liên kết). Định dạng này lưu trữ các chỉ mục cột cho mỗi hàng dưới dạng một danh sách trong rows
thuộc tính của nó , vì vậy tất cả những gì chúng ta cần làm là lấy đó và chuyển đổi nó thành danh sách.
Lưu ý rằng đối với các argsort
giải pháp dựa trên mảng nhỏ có thể nhanh hơn nhưng ở một số kích thước không lớn, điều này sẽ vượt qua.
BIÊN TẬP:
argsort
numpy
giải pháp dựa trên cơ sở :
np.split(x.argsort(kind="stable"),np.bincount(x)[:-1].cumsum())
# [array([3, 4]), array([0, 5]), array([1, 2]), array([6]), array([], dtype=int64), array([7])]
Nếu thứ tự các chỉ số trong các nhóm không thành vấn đề, bạn cũng có thể thử argpartition
(điều này xảy ra không có sự khác biệt trong ví dụ nhỏ này nhưng nói chung điều này không được đảm bảo):
bb = np.bincount(x)[:-1].cumsum()
np.split(x.argpartition(bb),bb)
# [array([3, 4]), array([0, 5]), array([1, 2]), array([6]), array([], dtype=int64), array([7])]
BIÊN TẬP:
@Divakar khuyến cáo không nên sử dụng np.split
. Thay vào đó, một vòng lặp có thể nhanh hơn:
A = x.argsort(kind="stable")
B = np.bincount(x+1).cumsum()
[A[B[i-1]:B[i]] for i in range(1,len(B))]
Hoặc bạn có thể sử dụng toán tử walrus hoàn toàn mới (Python3.8 +):
A = x.argsort(kind="stable")
B = np.bincount(x)
L = 0
[A[L:(L:=L+b)] for b in B.tolist()]
EDIT (EDITED):
(Không thuần túy numpy): Thay thế cho numba (xem bài đăng của @ senderle) chúng ta cũng có thể sử dụng pythran.
Biên dịch với pythran -O3 <filename.py>
import numpy as np
#pythran export sort_to_bins(int[:],int)
def sort_to_bins(idx, mx):
if mx==-1:
mx = idx.max() + 1
cnts = np.zeros(mx + 2, int)
for i in range(idx.size):
cnts[idx[i] + 2] += 1
for i in range(3, cnts.size):
cnts[i] += cnts[i-1]
res = np.empty_like(idx)
for i in range(idx.size):
res[cnts[idx[i]+1]] = i
cnts[idx[i]+1] += 1
return [res[cnts[i]:cnts[i+1]] for i in range(mx)]
Ở đây numba
chiến thắng bởi một hiệu suất khôn ngoan:
repeat(lambda:enum_bins_numba_buffer(x),number=10)
# [0.6235917090671137, 0.6071486569708213, 0.6096088469494134]
repeat(lambda:sort_to_bins(x,-1),number=10)
# [0.6235359431011602, 0.6264424560358748, 0.6217901279451326]
Những thứ cũ hơn:
import numpy as np
#pythran export bincollect(int[:])
def bincollect(a):
o = [[] for _ in range(a.max()+1)]
for i,j in enumerate(a):
o[j].append(i)
return o
Thời gian so với numba (cũ)
timeit(lambda:bincollect(x),number=10)
# 3.5732191529823467
timeit(lambda:enumerate_bins(x),number=10)
# 6.7462647299980745
np.argsort([1, 2, 2, 0, 0, 1, 3, 5])
choarray([3, 4, 0, 5, 1, 2, 6, 7], dtype=int64)
. sau đó bạn chỉ có thể so sánh các yếu tố tiếp theo.