Thay thế nhanh chóng cho numpy.median.reduceat


12

Liên quan đến câu trả lời này , có cách nào nhanh chóng để tính trung bình trên một mảng có các nhóm có số phần tử không bằng nhau không ?

Ví dụ:

data =  [1.00, 1.05, 1.30, 1.20, 1.06, 1.54, 1.33, 1.87, 1.67, ... ]
index = [0,    0,    1,    1,    1,    1,    2,    3,    3,    ... ]

Và sau đó tôi muốn tính toán sự khác biệt giữa số lượng và trung vị cho mỗi nhóm (ví dụ: trung vị của nhóm 01.025kết quả đầu tiên là 1.00 - 1.025 = -0.025). Vì vậy, đối với mảng trên, kết quả sẽ xuất hiện như sau:

result = [-0.025, 0.025, 0.05, -0.05, -0.19, 0.29, 0.00, 0.10, -0.10, ...]

np.median.reduceatkhông tồn tại (chưa), có cách nào khác nhanh chóng để đạt được điều này không? Mảng của tôi sẽ chứa hàng triệu hàng nên tốc độ rất quan trọng!

Các chỉ số có thể được coi là tiếp giáp và ra lệnh (thật dễ dàng để chuyển đổi chúng nếu chúng không có).


Dữ liệu mẫu để so sánh hiệu suất:

import numpy as np

np.random.seed(0)
rows = 10000
cols = 500
ngroup = 100

# Create random data and groups (unique per column)
data = np.random.rand(rows,cols)
groups = np.random.randint(ngroup, size=(rows,cols)) + 10*np.tile(np.arange(cols),(rows,1))

# Flatten
data = data.ravel()
groups = groups.ravel()

# Sort by group
idx_sort = groups.argsort()
data = data[idx_sort]
groups = groups[idx_sort]

Bạn đã có thời gian scipy.ndimage.mediangợi ý trong câu trả lời liên kết? Dường như với tôi rằng nó cần một số lượng bằng nhau cho mỗi nhãn. Hay tôi đã bỏ lỡ điều gì?
Andras Deak

Vì vậy, khi bạn nói hàng triệu hàng, tập dữ liệu thực tế của bạn có phải là mảng 2D không và bạn có đang thực hiện thao tác này trên mỗi hàng đó không?
Divakar

@Divakar Xem chỉnh sửa câu hỏi để kiểm tra dữ liệu
Jean-Paul

Bạn đã cho điểm chuẩn trong dữ liệu ban đầu, tôi đã thổi phồng nó để giữ định dạng giống nhau. Tất cả mọi thứ được điểm chuẩn so với tập dữ liệu tăng cao của tôi. Không hợp lý để thay đổi ngay bây giờ
roganjosh

Câu trả lời:


7

Đôi khi bạn cần phải viết mã numpy không thành ngữ nếu bạn thực sự muốn tăng tốc tính toán mà bạn không thể làm với numpy bản địa.

numbabiên dịch mã python của bạn ở mức độ thấp C. Vì rất nhiều numpy thường nhanh như C, nên điều này chủ yếu là hữu ích nếu vấn đề của bạn không cho vay để vector hóa tự nhiên với numpy. Đây là một ví dụ (trong đó tôi giả sử rằng các chỉ số tiếp giáp và được sắp xếp, cũng được phản ánh trong dữ liệu mẫu):

import numpy as np
import numba

# use the inflated example of roganjosh https://stackoverflow.com/a/58788534
data =  [1.00, 1.05, 1.30, 1.20, 1.06, 1.54, 1.33, 1.87, 1.67]
index = [0,    0,    1,    1,    1,    1,    2,    3,    3] 

data = np.array(data * 500) # using arrays is important for numba!
index = np.sort(np.random.randint(0, 30, 4500))               

# jit-decorate; original is available as .py_func attribute
@numba.njit('f8[:](f8[:], i8[:])') # explicit signature implies ahead-of-time compile
def diffmedian_jit(data, index): 
    res = np.empty_like(data) 
    i_start = 0 
    for i in range(1, index.size): 
        if index[i] == index[i_start]: 
            continue 

        # here: i is the first _next_ index 
        inds = slice(i_start, i)  # i_start:i slice 
        res[inds] = data[inds] - np.median(data[inds]) 

        i_start = i 

    # also fix last label 
    res[i_start:] = data[i_start:] - np.median(data[i_start:])

    return res

Và đây là một số thời gian sử dụng %timeitphép thuật của IPython :

>>> %timeit diffmedian_jit.py_func(data, index)  # non-jitted function
... %timeit diffmedian_jit(data, index)  # jitted function
...
4.27 ms ± 109 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
65.2 µs ± 1.01 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

Sử dụng dữ liệu ví dụ được cập nhật trong câu hỏi những con số này (tức là thời gian chạy của hàm python so với thời gian chạy của funcio được tăng tốc JIT) là

>>> %timeit diffmedian_jit.py_func(data, groups) 
... %timeit diffmedian_jit(data, groups)
2.45 s ± 34.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
93.6 ms ± 518 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

Điều này lên tới tốc độ tăng tốc 65 lần trong trường hợp nhỏ hơn và tăng tốc 26 lần trong trường hợp lớn hơn (tất nhiên là so với mã vòng lặp chậm) bằng cách sử dụng mã tăng tốc. Một nhược điểm khác là (không giống như vector hóa thông thường với numpy bản địa), chúng tôi không cần thêm bộ nhớ để đạt được tốc độ này, tất cả là về mã cấp thấp được tối ưu hóa và biên dịch cuối cùng được chạy.


Hàm trên giả định rằng mảng int numpy là int64 theo mặc định, đây không thực sự là trường hợp trên Windows. Vì vậy, một giải pháp thay thế là xóa chữ ký khỏi lệnh gọi đến numba.njit, kích hoạt quá trình biên dịch đúng lúc. Nhưng điều này có nghĩa là hàm sẽ được biên dịch trong lần thực hiện đầu tiên, có thể can thiệp vào kết quả thời gian (chúng ta có thể thực hiện chức năng một lần theo cách thủ công, sử dụng các loại dữ liệu đại diện hoặc chỉ chấp nhận rằng việc thực hiện thời gian đầu tiên sẽ chậm hơn nhiều, nên được bỏ qua). Đây chính xác là những gì tôi đã cố gắng ngăn chặn bằng cách chỉ định một chữ ký, kích hoạt việc biên dịch trước thời hạn.

Dù sao, trong trường hợp JIT đúng cách, người trang trí chúng ta cần chỉ là

@numba.njit
def diffmedian_jit(...):

Lưu ý rằng các thời gian trên mà tôi đã hiển thị cho chức năng biên dịch jit chỉ áp dụng khi chức năng đã được biên dịch. Điều này xảy ra ở định nghĩa (với phần tổng hợp háo hức, khi chữ ký rõ ràng được chuyển đếnnumba.njit ) hoặc trong khi gọi hàm đầu tiên (với biên dịch lười, khi không có chữ ký nào được chuyển đến numba.njit). Nếu hàm chỉ được thực hiện một lần thì thời gian biên dịch cũng cần được xem xét cho tốc độ của phương thức này. Nó thường chỉ có giá trị biên dịch các hàm nếu tổng thời gian biên dịch + thực thi ít hơn thời gian chạy không biên dịch (điều này thực sự đúng trong trường hợp trên, trong đó hàm python gốc rất chậm). Điều này chủ yếu xảy ra khi bạn đang gọi chức năng biên dịch của bạn rất nhiều lần.

Như max9111 lưu ý trong một nhận xét, một tính năng quan trọng của numbacachetừ khóa để jit. Chuyển cache=Trueđến numba.jitsẽ lưu trữ chức năng đã biên dịch vào đĩa, để trong lần thực hiện tiếp theo của mô-đun python đã cho, chức năng sẽ được tải từ đó thay vì được biên dịch lại, điều này một lần nữa có thể giúp bạn chạy trong thời gian dài.


@Divakar thực sự, nó giả sử các chỉ số tiếp giáp và sắp xếp, có vẻ như là một giả định trong dữ liệu của OP và cũng tự động được bao gồm trong indexdữ liệu của roganjosh . Tôi sẽ để lại một lưu ý về điều này, cảm ơn :)
Andras Deak

OK, sự tiếp giáp không được bao gồm tự động ... nhưng tôi khá chắc chắn rằng nó phải tiếp tục liên tục. Hmm ...
Andras Deak

1
@AndrasDeak Thật sự tốt khi giả sử các nhãn liền kề và được sắp xếp (sửa chúng nếu không dễ dàng)
Jean-Paul

1
@AndrasDeak Xem chỉnh sửa để đặt câu hỏi cho dữ liệu thử nghiệm (sao cho so sánh hiệu suất giữa các câu hỏi là nhất quán)
Jean-Paul

1
Bạn có thể đề cập đến từ khóa cache=Trueđể tránh biên dịch lại mỗi lần khởi động lại trình thông dịch.
max9111

5

Một cách tiếp cận sẽ được sử dụng Pandasở đây hoàn toàn để sử dụng groupby. Tôi đã tăng kích thước đầu vào một chút để hiểu rõ hơn về thời gian (vì có chi phí trong việc tạo DF).

import numpy as np
import pandas as pd

data =  [1.00, 1.05, 1.30, 1.20, 1.06, 1.54, 1.33, 1.87, 1.67]
index = [0,    0,    1,    1,    1,    1,    2,    3,    3]

data = data * 500
index = np.sort(np.random.randint(0, 30, 4500))

def df_approach(data, index):
    df = pd.DataFrame({'data': data, 'label': index})
    df['median'] = df.groupby('label')['data'].transform('median')
    df['result'] = df['data'] - df['median']

Cung cấp như sau timeit:

%timeit df_approach(data, index)
5.38 ms ± 50.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Với cùng cỡ mẫu, tôi nhận được cách tiếp cận chính tả của Aryerez là:

%timeit dict_approach(data, index)
8.12 ms ± 3.47 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Tuy nhiên, nếu chúng ta tăng đầu vào thêm 10 lần nữa, thời gian sẽ trở thành:

%timeit df_approach(data, index)
7.72 ms ± 85 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

%timeit dict_approach(data, index)
30.2 ms ± 10.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Tuy nhiên, với chi phí của một số khả năng, câu trả lời bởi Divakar bằng cách sử dụng numpy thuần túy được đưa ra tại:

%timeit bin_median_subtract(data, index)
573 µs ± 7.48 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

Trong ánh sáng của bộ dữ liệu mới (mà thực sự nên được đặt khi bắt đầu):

%timeit df_approach(data, groups)
472 ms ± 2.52 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit bin_median_subtract(data, groups) #https://stackoverflow.com/a/58788623/4799172
3.02 s ± 31.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit dict_approach(data, groups) #https://stackoverflow.com/a/58788199/4799172
<I gave up after 1 minute>

# jitted (using @numba.njit('f8[:](f8[:], i4[:]') on Windows) from  https://stackoverflow.com/a/58788635/4799172
%timeit diffmedian_jit(data, groups)
132 ms ± 3.12 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

Cảm ơn bạn cho câu trả lời này! Để thống nhất với các câu trả lời khác, bạn có thể kiểm tra các giải pháp của mình trên dữ liệu mẫu được cung cấp trong phần chỉnh sửa cho câu hỏi của tôi không?
Jean-Paul

@ Jean-Paul thời gian đã nhất quán rồi, phải không? Họ đã sử dụng dữ liệu điểm chuẩn ban đầu của tôi và trong trường hợp họ không làm như vậy, tôi đã cung cấp thời gian cho họ với cùng điểm chuẩn
roganjosh

Tôi đã bỏ qua bạn cũng đã thêm một tham chiếu đến câu trả lời của Divakar vì vậy câu trả lời của bạn thực sự đã tạo ra một so sánh tốt đẹp giữa các cách tiếp cận khác nhau, cảm ơn vì điều đó!
Jean-Paul

1
@ Jean-Paul Tôi đã thêm thời gian mới nhất ở phía dưới vì nó thực sự đã thay đổi mọi thứ khá quyết liệt
roganjosh

1
Xin lỗi vì không thêm bộ kiểm tra khi đăng câu hỏi, đánh giá cao rằng bạn vẫn thêm kết quả kiểm tra ngay bây giờ! Cảm ơn!!!
Jean-Paul

4

Có thể bạn đã làm điều này, nhưng nếu không, hãy xem liệu điều đó có đủ nhanh không:

median_dict = {i: np.median(data[index == i]) for i in np.unique(index)}
def myFunc(my_dict, a): 
    return my_dict[a]
vect_func = np.vectorize(myFunc)
median_diff = data - vect_func(median_dict, index)
median_diff

Đầu ra:

array([-0.025,  0.025,  0.05 , -0.05 , -0.19 ,  0.29 ,  0.   ,  0.1  ,
   -0.1  ])

Có nguy cơ nêu rõ ràng, np.vectorizelà một gói rất mỏng cho một vòng lặp, vì vậy tôi không mong đợi phương pháp này sẽ đặc biệt nhanh.
Andras Deak

1
@AndrasDeak Tôi không đồng ý :) Tôi sẽ tiếp tục theo dõi và nếu ai đó đăng một giải pháp tốt hơn, tôi sẽ xóa nó.
Aryerez

1
Tôi không nghĩ rằng bạn phải xóa nó ngay cả khi các phương pháp tiếp cận nhanh hơn bật lên :)
Andras Deak

@roganjosh Đó có thể là vì bạn đã không xác định dataindexnhư np.arrays như trong câu hỏi.
Aryerez

1
@ Jean-Paul roganjosh đã làm một so sánh thời gian giữa tôi và phương pháp của anh ấy, và những người khác ở đây so sánh họ. Nó phụ thuộc vào phần cứng máy tính, vì vậy không có điểm nào mọi người kiểm tra phương pháp riêng của họ, nhưng có vẻ như tôi đã đưa ra giải pháp chậm nhất ở đây.
Aryerez

4

Đây là một cách tiếp cận dựa trên NumPy để có được giá trị trung bình cho các giá trị thùng / chỉ số dương -

def bin_median(a, i):
    sidx = np.lexsort((a,i))

    a = a[sidx]
    i = i[sidx]

    c = np.bincount(i)
    c = c[c!=0]

    s1 = c//2

    e = c.cumsum()
    s1[1:] += e[:-1]

    firstval = a[s1-1]
    secondval = a[s1]
    out = np.where(c%2,secondval,(firstval+secondval)/2.0)
    return out

Để giải quyết trường hợp cụ thể của chúng tôi về những người bị trừ -

def bin_median_subtract(a, i):
    sidx = np.lexsort((a,i))

    c = np.bincount(i)

    valid_mask = c!=0
    c = c[valid_mask]    

    e = c.cumsum()
    s1 = c//2
    s1[1:] += e[:-1]
    ssidx = sidx.argsort()
    starts = c%2+s1-1
    ends = s1

    starts_orgindx = sidx[np.searchsorted(sidx,starts,sorter=ssidx)]
    ends_orgindx  = sidx[np.searchsorted(sidx,ends,sorter=ssidx)]
    val = (a[starts_orgindx] + a[ends_orgindx])/2.
    out = a-np.repeat(val,c)
    return out

Câu trả lời rất hay! Bạn có bất kỳ dấu hiệu nào cho thấy sự cải thiện tốc độ hơn df.groupby('index').transform('median')không?
Jean-Paul

@ Jean-Paul Bạn có thể kiểm tra dữ liệu thực tế của hàng triệu người không?
Divakar

Xem chỉnh sửa để đặt câu hỏi cho dữ liệu thử nghiệm
Jean-Paul

@ Jean-Paul Chỉnh sửa giải pháp của tôi cho đơn giản hơn. Hãy chắc chắn để sử dụng cái này để thử nghiệm, nếu bạn là.
Divakar
Khi sử dụng trang web của chúng tôi, bạn xác nhận rằng bạn đã đọc và hiểu Chính sách cookieChính sách bảo mật của chúng tôi.
Licensed under cc by-sa 3.0 with attribution required.