Numba @jit не может оптимизировать простую функцию

У меня есть довольно простая функция, которая использует массивы Numpy и циклы for, но добавление декоратора Numba @jit не дает абсолютно никакого ускорения:

# @jit(float64[:](int32,float64,float64,float64,int32))
@jit
def Ising_model_1D(N=200,J=1,T=1e-2,H=0,n_iter=1e6):
    beta = 1/T
    s = randn(N,1) > 10  
    s[N-1] = s[0]
    mag = zeros((n_iter,1))
    aux_idx =  randint(low=0,high=N,size=(n_iter,1))

    for i1 in arange(n_iter):
        rnd_idx = aux_idx[i1]
        s_1 = s[rnd_idx]*2 - 1
        s_2 = s[(rnd_idx+1)%(N)]*2 - 1
        s_3 = s[(rnd_idx-1)%(N)]*2 - 1
        delta_E = 2.0*J*(s_2+s_3)*s_1 + 2.0*H*s_1

        if(delta_E < 0):
            s[rnd_idx] = np.logical_not(s[rnd_idx]) 
        elif(np.exp(-1*beta*delta_E) >= rand()):
            s[rnd_idx] = np.logical_not(s[rnd_idx])
        s[N-1] = s[0]
        mag[i1] = (s*2-1).sum()*1.0/N 
    return mag

С другой стороны, MATLAB выполняет это менее чем за 0,5 секунды! Почему Numba не хватает чего-то столь простого?


person KartMan    schedule 18.11.2015    source источник
comment
Вы вызываете функции NumPy для скалярных значений в теле цикла. Эти функции предназначены для эффективной работы с большими массивами, а не с отдельными значениями. Эти вызовы функций не могут быть оптимизированы с помощью numba. Короче говоря, вам нужно векторизовать код, а не JIT-компилировать его.   -  person Alex Riley    schedule 19.11.2015
comment
@ajcr Я думаю, что некоторые из них на самом деле могут любить rand() и ndarray.sum() (по крайней мере, они могут в последней версии numba).   -  person jme    schedule 19.11.2015
comment
@jme: ах, спасибо, я не знал, что это так. Я думал, что повторный вызов np.logical_not (и других скомпилированных функций) замедлит цикл. Я должен копнуть немного глубже в документы numba.   -  person Alex Riley    schedule 19.11.2015


Ответы (1)


Вот переработка вашего кода, который выполняется примерно за 0,4 секунды на моей машине:

def ising_model_1d(N=200,J=1,T=1e-2,H=0,n_iter=1e6):
    n_iter = int(n_iter)
    beta = 1/T
    s = randn(N) > 10
    s[N-1] = s[0]

    mag = zeros(n_iter)
    aux_idx =  randint(low=0,high=N,size=n_iter)

    pre_rand = rand(n_iter)

    _ising_jitted(n_iter, aux_idx, s, J, N, H, beta, pre_rand, mag)

    return mag


@jit(nopython=True)
def _ising_jitted(n_iter, aux_idx, s, J, N, H, beta, pre_rand, mag):
    for i1 in range(n_iter):
        rnd_idx = aux_idx[i1]
        s_1 = s[rnd_idx*2] - 1
        s_2 = s[(rnd_idx+1)%(N)]*2 - 1
        s_3 = s[(rnd_idx-1)%(N)]*2 - 1
        delta_E = 2.0*J*(s_2+s_3)*s_1 + 2.0*H*s_1
        t = rand()
        if delta_E < 0:
            s[rnd_idx] = not s[rnd_idx]
        elif np.exp(-1*beta*delta_E) >= pre_rand[i1]:
            s[rnd_idx] = not s[rnd_idx]

        s[N-1] = s[0]
        mag[i1] = (s*2-1).sum()*1.0/N

Пожалуйста, убедитесь, что результаты соответствуют ожиданиям! Я изменил многое из того, что у вас было, и не могу гарантировать правильность расчетов!

Работа с numba требует осторожности. Функции Python, как и большинство функций numpy, не могут быть оптимизированы компилятором. Одна вещь, которую я считаю полезной, это использовать опцию nopython для @jit. Это означает, что компилятор будет жаловаться всякий раз, когда вы даете ему код, который он не может оптимизировать. Затем вы можете просмотреть сообщение об ошибке и найти строку, которая, вероятно, замедлит ваш код.

Я считаю, что хитрость заключается в том, чтобы написать функцию «шлюза» на Python, которая выполняет как можно больше работы, используя numpy и его векторизованные функции. Он должен создать пустые массивы, в которых вам нужно будет хранить результаты. Он должен упаковать все данные, которые вам понадобятся во время вычислений. Затем он должен передать все это вашей jit-функции в одном большом длинном списке аргументов.

Показательный пример: обратите внимание, как я обрабатываю генерацию случайных чисел в джиттинг-коде. В исходном коде вы вызвали rand():

elif(np.exp(-1*beta*delta_E) >= rand()):

Но rand() нельзя оптимизировать с помощью numba (по крайней мере, в старых версиях numba. В более новых версиях можно, при условии, что rand вызывается без аргументов). Наблюдение состоит в том, что вам нужно одно случайное число для каждой из n_iter итераций. Итак, мы просто создаем случайный массив, используя numpy в нашей функции-оболочке, а затем передаем этот случайный массив функции jitted. Получить случайное число так же просто, как индексировать этот массив.

Наконец, список numpy функций, которые можно оптимизировать с помощью последней версии компилятора, см. на странице здесь. В моей переработке вашего кода я был агрессивен в удалении вызовов функций numpy, чтобы код работал на большем количестве версий numba.

person jme    schedule 18.11.2015
comment
Превосходно! Спасибо за подробный ответ. Документы Numba немного скудны в деталях. - person KartMan; 19.11.2015