Как ускорить быструю сортировку с помощью numba?

Я пытаюсь реализовать алгоритм быстрой сортировки с использованием numba в Python.

Кажется, это намного медленнее, чем функция сортировки numpy.

Как я могу улучшить его? Мой код здесь:

import numba as nb

@nb.autojit
def quick_sort(list_):
    """
    Iterative version of quick sort
    """
    #temp_stack = []
    #temp_stack.append((left,right))

    max_depth = 1000

    left = 0
    right = list_.shape[0]-1

    i_stack_pos = 0
    a_temp_stack = np.ndarray( ( max_depth, 2), dtype=np.int32 )
    a_temp_stack[i_stack_pos,0] = left
    a_temp_stack[i_stack_pos,1] = right
    i_stack_pos+=1
    #Main loop to pop and push items until stack is empty

    while i_stack_pos>0:

        i_stack_pos-=1
        right = a_temp_stack[ i_stack_pos, 1 ]
        left  = a_temp_stack[ i_stack_pos, 0 ]

        piv = partition(list_,left,right)
        #If items in the left of the pivot push them to the stack
        if piv-1 > left:
            #temp_stack.append((left,piv-1))

            a_temp_stack[ i_stack_pos, 0 ] = left
            a_temp_stack[ i_stack_pos, 1 ] = piv-1
            i_stack_pos+=1
        #If items in the right of the pivot push them to the stack
        if piv+1 < right:
            a_temp_stack[ i_stack_pos, 0 ] = piv+1
            a_temp_stack[ i_stack_pos, 1 ] = right
            i_stack_pos+=1

@nb.autojit( nopython=True )
def partition(list_, left, right):
    """
    Partition method
    """
    #Pivot first element in the array
    piv = list_[left]
    i = left + 1
    j = right

    while 1:
        while i <= j  and list_[i] <= piv:
            i +=1
        while j >= i and list_[j] >= piv:
            j -=1
        if j <= i:
            break
        #Exchange items
        list_[i], list_[j] = list_[j], list_[i]
    #Exchange pivot to the right position
    list_[left], list_[j] = list_[j], list_[left]
    return j

Мой тестовый код здесь:

    x = np.random.random_integers(0,1000,1000000)
    y = x.copy()

    quick_sort( y )

    z = np.sort(x)

    np.testing.assert_array_equal( z, y )

    y = x.copy()
    with Timer( 'nb' ):
        numba_fns.quick_sort( y )

    with Timer( 'np' ):
        x = np.sort(x) 

ОБНОВИТЬ:

Я переписал функцию, чтобы зацикленная часть кода выполнялась в режиме nopython. Цикл while, по-видимому, не вызывает сбой nopython. Однако я не получил никакого улучшения производительности:

@nb.autojit
def quick_sort2(list_):
    """
    Iterative version of quick sort
    """

    max_depth = 1000

    left        = 0
    right       = list_.shape[0]-1

    i_stack_pos = 0
    a_temp_stack = np.ndarray( ( max_depth, 2), dtype=np.int32 )
    a_temp_stack[i_stack_pos,0] = left
    a_temp_stack[i_stack_pos,1] = right
    i_stack_pos+=1
    #Main loop to pop and push items until stack is empty

    return _quick_sort2( list_, a_temp_stack, left, right )

@nb.autojit( nopython=True )
def _quick_sort2( list_, a_temp_stack, left, right ):

    i_stack_pos = 1
    while i_stack_pos>0:

        i_stack_pos-=1
        right = a_temp_stack[ i_stack_pos, 1 ]
        left  = a_temp_stack[ i_stack_pos, 0 ]

        piv = partition(list_,left,right)
        #If items in the left of the pivot push them to the stack
        if piv-1 > left:            
            a_temp_stack[ i_stack_pos, 0 ] = left
            a_temp_stack[ i_stack_pos, 1 ] = piv-1
            i_stack_pos+=1
        if piv+1 < right:
            a_temp_stack[ i_stack_pos, 0 ] = piv+1
            a_temp_stack[ i_stack_pos, 1 ] = right
            i_stack_pos+=1

@nb.autojit( nopython=True )
def partition(list_, left, right):
    """
    Partition method
    """
    #Pivot first element in the array
    piv = list_[left]
    i = left + 1
    j = right

    while 1:
        while i <= j  and list_[i] <= piv:
            i +=1
        while j >= i and list_[j] >= piv:
            j -=1
        if j <= i:
            break
        #Exchange items
        list_[i], list_[j] = list_[j], list_[i]
    #Exchange pivot to the right position
    list_[left], list_[j] = list_[j], list_[left]
    return j

person Ginger    schedule 22.03.2015    source источник
comment
Даже с JIT-компилятором маловероятно, что вы превзойдете алгоритм, реализованный в прямой C. Также может случиться так, что ваш код возвращается к режим объекта.   -  person Seth    schedule 23.03.2015
comment
Вы делаете это упражнение numba или вам нужна быстрая сортировка?   -  person hpaulj    schedule 23.03.2015
comment
Это не упражнение. Мне нужна быстрая сортировка в функции numba, которую я пишу. Функция вызывает быструю сортировку несколько раз. Чтобы работать в режиме nopython, функция не может использовать функцию сортировки numpy, поэтому мне нужно написать свою собственную.   -  person Ginger    schedule 23.03.2015
comment
Альтернатива, которая работает в простом тестовом примере, я пытался поместить материал до и после вызова np.sort в бессмысленных циклах (for ignore_me in range(1):). Он будет выполнять циклы JIT в режиме nopython, и, по-видимому, будут небольшие накладные расходы для входа и выхода из режима nopython. Однако оставьте окончательный оператор return вне цикла.   -  person DavidW    schedule 24.03.2015
comment
Вы также можете использовать вызовы функций (для функций jit'd с nopython=True) по обе стороны от np.sort, если глупые циклы выглядят слишком глупо.   -  person DavidW    schedule 24.03.2015


Ответы (3)


Одно небольшое предложение, которое может помочь (но, как вам правильно сказали в комментариях к вашему вопросу, вы будете изо всех сил пытаться победить чистую реализацию C):

вы хотите убедиться, что большая часть этого выполняется в режиме «nopython» (@jit(nopython=True)). Добавьте это перед вашими функциями и посмотрите, где это сломается. Также вызовите inspect_types() в своей функции и посмотрите, сможет ли она правильно их идентифицировать.

Единственная вещь в вашем коде, которая, скорее всего, заставит его перейти в объектный режим (в отличие от режима nopython), — это выделение массива numpy. Хотя numba может компилировать циклы отдельно в режиме nopython, я не знаю, может ли он сделать это для циклов while. Звонок inspect_types скажет вам.

Мой обычный обходной путь для создания массивов numpy, гарантируя, что все остальное находится в режиме nopython, — это создать функцию-оболочку.

@nb.jit(nopython=True) # make sure it can be done in nopython mode
def _quick_sort_impl(list_,output_array):
   ...most of your code goes here...

@nb.jit
def quick_sort(list_):
   # this code won't compile in nopython mode, but it's
   # short and isolated
   max_depth = 1000
   a_temp_stack = np.ndarray( ( max_depth, 2), dtype=np.int32 )
   _quick_sort_impl(list_,a_temp_stack)
person DavidW    schedule 23.03.2015
comment
Итак... быстрый тест показывает, что он выполняет ваш цикл while в режиме nopython, поэтому мое улучшение не имеет ощутимой разницы. Единственное, что я вижу, это то, что он в конечном итоге компилирует (и вынужден выбирать между) 4 разные версии partion (на основе типа целого числа left и right). Я не могу поверить, что это поможет слишком много, хотя. - person DavidW; 23.03.2015

В общем, если вы не форсируете режим nopython, у вас есть высокие шансы не получить повышения производительности. Цитата из документов о режиме nopython:

Режим [nopython] создает код с наивысшей производительностью, но требует, чтобы можно было вывести собственные типы всех значений в функции и чтобы не выделялись новые объекты.

Поэтому ваш вызов np.ndarray запускает объектный режим и, следовательно, замедляет код. Попробуйте выделить рабочий массив вне функции, например:

def quick_sort(list_):

    max_depth = 1000
    temp_stack_ = np.array( ( max_depth, 2), dtype=np.int32 )

    _quick_sort(list_, temp_stack_)

...

@numba.jit(nopython=True)
def _quick_sort(list_, temp_stack_):
    ...
person astrojuanlu    schedule 23.03.2015

Что бы это ни стоило, numba реализовала как общую функцию sorted, так и метод .sort() numpy-массива, начиная с (я думаю) версии 0.22. Ура!

http://numba.pydata.org/numba-doc/dev/reference/pysupported.html#built-in-functions http://numba.pydata.org/numba-doc/dev/reference/numpysupported.html#other-methods

person Peter M    schedule 10.11.2015