Kiểm tra xem mảng numpy chỉ chứa các số không


92

Chúng tôi khởi tạo một mảng numpy với các số không như sau:

np.zeros((N,N+1))

Nhưng làm cách nào để kiểm tra xem tất cả các phần tử trong ma trận mảng n * n đã cho có bằng không hay không.
Phương thức chỉ cần trả về giá trị True nếu tất cả các giá trị thực sự là 0.

Câu trả lời:



161

Các câu trả lời khác được đăng ở đây sẽ hoạt động, nhưng chức năng rõ ràng và hiệu quả nhất để sử dụng là numpy.any():

>>> all_zeros = not np.any(a)

hoặc là

>>> all_zeros = not a.any()
  • Điều này được ưu tiên hơn numpy.all(a==0)vì nó sử dụng ít RAM hơn. (Nó không yêu cầu mảng tạm thời được tạo bởi a==0thuật ngữ.)
  • Ngoài ra, nó nhanh hơn numpy.count_nonzero(a)vì nó có thể trả về ngay lập tức khi phần tử khác không đầu tiên được tìm thấy.
    • Chỉnh sửa: Như @Rachel đã chỉ ra trong các nhận xét, np.any()không còn sử dụng logic "ngắn mạch", vì vậy bạn sẽ không thấy lợi ích về tốc độ cho các mảng nhỏ.

2
Tính đến một phút trước, numpy của anyalllàm không ngắn mạch. Tôi tin rằng chúng là đường cho logical_or.reducelogical_and.reduce. Hãy so sánh với nhau và ngắn mạch của tôi is_in: all_false = np.zeros(10**8) all_true = np.ones(10**8) %timeit np.any(all_false) 91.5 ms ± 1.82 ms per loop %timeit np.any(all_true) 93.7 ms ± 6.16 ms per loop %timeit is_in(1, all_true) 293 ns ± 1.65 ns per loop
Rachel

2
Đó là một điểm tuyệt vời, cảm ơn. Có vẻ như đoản mạch từng là hành vi, nhưng điều đó đã mất đi vào một thời điểm nào đó. Có một số cuộc thảo luận thú vị trong các câu trả lời cho câu hỏi này .
Stuart Berg,

50

Tôi sẽ sử dụng np.all ở đây, nếu bạn có một mảng a:

>>> np.all(a==0)

3
Tôi thích rằng câu trả lời này cũng kiểm tra các giá trị khác 0. Ví dụ, người ta có thể kiểm tra xem tất cả các phần tử trong một mảng có giống nhau hay không bằng cách thực hiện np.all(a==a[0]). Cảm ơn rất nhiều!
aignas

9

Như một câu trả lời khác đã nói, bạn có thể tận dụng các đánh giá true / falsy nếu bạn biết đó 0là phần tử giả duy nhất có thể có trong mảng của bạn. Tất cả các phần tử trong một mảng là giả mạo, không có bất kỳ phần tử trung thực nào trong đó. *

>>> a = np.zeros(10)
>>> not np.any(a)
True

Tuy nhiên, câu trả lời cho rằng anynhanh hơn các phương án khác một phần là do chập mạch. Tính đến năm 2018, Numpy's allany không đoản mạch .

Nếu bạn làm điều này thường xuyên, bạn rất dễ tạo ra các phiên bản đoản mạch của riêng mình bằng cách sử dụng numba:

import numba as nb

# short-circuiting replacement for np.any()
@nb.jit(nopython=True)
def sc_any(array):
    for x in array.flat:
        if x:
            return True
    return False

# short-circuiting replacement for np.all()
@nb.jit(nopython=True)
def sc_all(array):
    for x in array.flat:
        if not x:
            return False
    return True

Các phiên bản này có xu hướng nhanh hơn các phiên bản của Numpy ngay cả khi không bị chập mạch. count_nonzerolà chậm nhất.

Một số đầu vào để kiểm tra hiệu suất:

import numpy as np

n = 10**8
middle = n//2
all_0 = np.zeros(n, dtype=int)
all_1 = np.ones(n, dtype=int)
mid_0 = np.ones(n, dtype=int)
mid_1 = np.zeros(n, dtype=int)
np.put(mid_0, middle, 0)
np.put(mid_1, middle, 1)
# mid_0 = [1 1 1 ... 1 0 1 ... 1 1 1]
# mid_1 = [0 0 0 ... 0 1 0 ... 0 0 0]

Kiểm tra:

## count_nonzero
%timeit np.count_nonzero(all_0) 
# 220 ms ± 8.73 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit np.count_nonzero(all_1)
# 150 ms ± 4.56 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

### all
# np.all
%timeit np.all(all_1)
%timeit np.all(mid_0)
%timeit np.all(all_0)
# 56.8 ms ± 3.41 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 57.4 ms ± 1.76 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 55.9 ms ± 2.13 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

# sc_all
%timeit sc_all(all_1)
%timeit sc_all(mid_0)
%timeit sc_all(all_0)
# 44.4 ms ± 2.49 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 22.7 ms ± 599 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 288 ns ± 6.36 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

### any
# np.any
%timeit np.any(all_0)
%timeit np.any(mid_1)
%timeit np.any(all_1)
# 60.7 ms ± 1.38 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 60 ms ± 287 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 57.7 ms ± 1.12 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

# sc_any
%timeit sc_any(all_0)
%timeit sc_any(mid_1)
%timeit sc_any(all_1)
# 41.7 ms ± 1.24 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 22.4 ms ± 1.51 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 287 ns ± 12.7 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

* Tính năng hữu ích allanytương đương:

np.all(a) == np.logical_not(np.any(np.logical_not(a)))
np.any(a) == np.logical_not(np.all(np.logical_not(a)))
not np.all(a) == np.any(np.logical_not(a))
not np.any(a) == np.all(np.logical_not(a))

-9

Nếu bạn đang kiểm tra tất cả các số không để tránh cảnh báo trên một hàm numpy khác thì gói dòng trong một lần thử, ngoại trừ khối sẽ tiết kiệm việc phải thực hiện kiểm tra các số không trước khi thực hiện thao tác bạn quan tâm.

try: # removes output noise for empty slice 
    mean = np.mean(array)
except:
    mean = 0
Khi sử dụng trang web của chúng tôi, bạn xác nhận rằng bạn đã đọc và hiểu Chính sách cookieChính sách bảo mật của chúng tôi.
Licensed under cc by-sa 3.0 with attribution required.