Numpy dizisinin yalnızca sıfır içerip içermediğini test edin


94

Aşağıdaki gibi sıfırlarla uyuşmuş bir dizi başlatıyoruz:

np.zeros((N,N+1))

Ancak, belirli bir n * n numpy dizi matrisindeki tüm öğelerin sıfır olup olmadığını nasıl kontrol ederiz?
Tüm değerler gerçekten sıfırsa, yöntemin yalnızca True döndürmesi gerekir.

Yanıtlar:



168

Burada yayınlanan diğer yanıtlar işe yarayacaktır, ancak kullanılacak en net ve en verimli işlev şudur numpy.any():

>>> all_zeros = not np.any(a)

veya

>>> all_zeros = not a.any()
  • Bu, numpy.all(a==0)daha az RAM kullandığı için tercih edilir . ( a==0Terim tarafından oluşturulan geçici diziyi gerektirmez .)
  • Ayrıca, numpy.count_nonzero(a)sıfırdan farklı ilk eleman bulunduğunda hemen dönebileceğinden daha hızlıdır .
    • Düzenleme: @Rachel'in yorumlarda belirttiği gibi, np.any()artık "kısa devre" mantığını kullanmadığından, küçük diziler için bir hız avantajı görmeyeceksiniz.

3
Bir dakika önce itibariyle numpy en anyve alldo not kısa devre. Onların Şekeri olduğuna inanıyoruz logical_or.reduceve logical_and.reduce. Birbirimle ve kısa devre yapmamla karşılaştırın is_in: all_false = np.zeros(10**8) all_true = np.ones(10**8) %timeit np.any(all_false) 91.5 ms ± 1.82 ms per loop %timeit np.any(all_true) 93.7 ms ± 6.16 ms per loop %timeit is_in(1, all_true) 293 ns ± 1.65 ns per loop
Rachel

3
Bu harika bir nokta, teşekkürler. Bu kısa devreye benziyor kullanılan davranışı olmak, ama bu noktada kayboldu. Bu sorunun cevaplarında bazı ilginç tartışmalar var .
Stuart Berg

51

Bir dizi a'nız varsa burada np.all kullanırım:

>>> np.all(a==0)

3
Bu cevabın sıfır olmayan değerleri de kontrol etmesini seviyorum. Örneğin, bir dizideki tüm elemanların aynı olup olmadığını yaparak kontrol edebilir np.all(a==a[0]). Çok teşekkürler!
aignas

Bu çözüm aynı zamanda biraz daha etkilidir np.count_nonzero. % timeit num_of_non_zeros = np.count_nonzero (sıfır_vektör) 18,2 µs ± 386 ns döngü başına (ortalama ± standart dev. 7 çalıştırma, her biri 100000 döngü)% timeit num_of_non_zeros = np.all ((sıfır_vektör == 0)) 7,31 µs ± Döngü başına 41,6 ns (ortalama ± standart sapma 7 çalıştırma, her biri 100000 döngü)
IP

9

Başka bir yanıtın da söylediği gibi, 0dizinizdeki muhtemelen tek yanlış öğenin bu öğe olduğunu biliyorsanız, doğru / yanlış değerlendirmelerden yararlanabilirsiniz . Bir dizideki tüm öğeler yanlıştır, ancak içinde doğru öğeler yoksa. *

>>> a = np.zeros(10)
>>> not np.any(a)
True

Ancak cevap any, kısmen kısa devre nedeniyle diğer seçeneklerden daha hızlı olduğunu iddia etti . 2018 itibariyle, Numpy's allve any kısa devre yapmıyor .

Bu tür şeyleri sık sık yaparsanız, aşağıdakileri kullanarak kendi kısa devre sürümlerinizi oluşturmak çok kolaydır numba:

import numba as nb

# short-circuiting replacement for np.any()
@nb.jit(nopython=True)
def sc_any(array):
    for x in array.flat:
        if x:
            return True
    return False

# short-circuiting replacement for np.all()
@nb.jit(nopython=True)
def sc_all(array):
    for x in array.flat:
        if not x:
            return False
    return True

Bunlar, kısa devre olmasa bile Numpy'nin sürümlerinden daha hızlı olma eğilimindedir. count_nonzeroen yavaş olanıdır.

Performansı kontrol etmek için bazı girdiler:

import numpy as np

n = 10**8
middle = n//2
all_0 = np.zeros(n, dtype=int)
all_1 = np.ones(n, dtype=int)
mid_0 = np.ones(n, dtype=int)
mid_1 = np.zeros(n, dtype=int)
np.put(mid_0, middle, 0)
np.put(mid_1, middle, 1)
# mid_0 = [1 1 1 ... 1 0 1 ... 1 1 1]
# mid_1 = [0 0 0 ... 0 1 0 ... 0 0 0]

Kontrol:

## count_nonzero
%timeit np.count_nonzero(all_0) 
# 220 ms ± 8.73 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit np.count_nonzero(all_1)
# 150 ms ± 4.56 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

### all
# np.all
%timeit np.all(all_1)
%timeit np.all(mid_0)
%timeit np.all(all_0)
# 56.8 ms ± 3.41 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 57.4 ms ± 1.76 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 55.9 ms ± 2.13 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

# sc_all
%timeit sc_all(all_1)
%timeit sc_all(mid_0)
%timeit sc_all(all_0)
# 44.4 ms ± 2.49 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 22.7 ms ± 599 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 288 ns ± 6.36 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

### any
# np.any
%timeit np.any(all_0)
%timeit np.any(mid_1)
%timeit np.any(all_1)
# 60.7 ms ± 1.38 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 60 ms ± 287 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 57.7 ms ± 1.12 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

# sc_any
%timeit sc_any(all_0)
%timeit sc_any(mid_1)
%timeit sc_any(all_1)
# 41.7 ms ± 1.24 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 22.4 ms ± 1.51 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 287 ns ± 12.7 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

* Yararlı allve anyeşdeğerler:

np.all(a) == np.logical_not(np.any(np.logical_not(a)))
np.any(a) == np.logical_not(np.all(np.logical_not(a)))
not np.all(a) == np.any(np.logical_not(a))
not np.any(a) == np.all(np.logical_not(a))

-8

Başka bir numpy işlevinde bir uyarıdan kaçınmak için tüm sıfırları test ediyorsanız, satırı bir denemede kaydırmak, ilgilendiğiniz işlemden önce sıfırlar için testi yapmak zorunda kalmanızı sağlar, örn.

try: # removes output noise for empty slice 
    mean = np.mean(array)
except:
    mean = 0
Sitemizi kullandığınızda şunları okuyup anladığınızı kabul etmiş olursunuz: Çerez Politikası ve Gizlilik Politikası.
Licensed under cc by-sa 3.0 with attribution required.