Я пытаюсь реализовать алгоритм быстрой сортировки с использованием numba в Python.
Кажется, это намного медленнее, чем функция сортировки numpy.
Как я могу улучшить его? Мой код здесь:
import numba as nb
@nb.autojit
def quick_sort(list_):
"""
Iterative version of quick sort
"""
#temp_stack = []
#temp_stack.append((left,right))
max_depth = 1000
left = 0
right = list_.shape[0]-1
i_stack_pos = 0
a_temp_stack = np.ndarray( ( max_depth, 2), dtype=np.int32 )
a_temp_stack[i_stack_pos,0] = left
a_temp_stack[i_stack_pos,1] = right
i_stack_pos+=1
#Main loop to pop and push items until stack is empty
while i_stack_pos>0:
i_stack_pos-=1
right = a_temp_stack[ i_stack_pos, 1 ]
left = a_temp_stack[ i_stack_pos, 0 ]
piv = partition(list_,left,right)
#If items in the left of the pivot push them to the stack
if piv-1 > left:
#temp_stack.append((left,piv-1))
a_temp_stack[ i_stack_pos, 0 ] = left
a_temp_stack[ i_stack_pos, 1 ] = piv-1
i_stack_pos+=1
#If items in the right of the pivot push them to the stack
if piv+1 < right:
a_temp_stack[ i_stack_pos, 0 ] = piv+1
a_temp_stack[ i_stack_pos, 1 ] = right
i_stack_pos+=1
@nb.autojit( nopython=True )
def partition(list_, left, right):
"""
Partition method
"""
#Pivot first element in the array
piv = list_[left]
i = left + 1
j = right
while 1:
while i <= j and list_[i] <= piv:
i +=1
while j >= i and list_[j] >= piv:
j -=1
if j <= i:
break
#Exchange items
list_[i], list_[j] = list_[j], list_[i]
#Exchange pivot to the right position
list_[left], list_[j] = list_[j], list_[left]
return j
Мой тестовый код здесь:
x = np.random.random_integers(0,1000,1000000)
y = x.copy()
quick_sort( y )
z = np.sort(x)
np.testing.assert_array_equal( z, y )
y = x.copy()
with Timer( 'nb' ):
numba_fns.quick_sort( y )
with Timer( 'np' ):
x = np.sort(x)
ОБНОВИТЬ:
Я переписал функцию, чтобы зацикленная часть кода выполнялась в режиме nopython. Цикл while, по-видимому, не вызывает сбой nopython. Однако я не получил никакого улучшения производительности:
@nb.autojit
def quick_sort2(list_):
"""
Iterative version of quick sort
"""
max_depth = 1000
left = 0
right = list_.shape[0]-1
i_stack_pos = 0
a_temp_stack = np.ndarray( ( max_depth, 2), dtype=np.int32 )
a_temp_stack[i_stack_pos,0] = left
a_temp_stack[i_stack_pos,1] = right
i_stack_pos+=1
#Main loop to pop and push items until stack is empty
return _quick_sort2( list_, a_temp_stack, left, right )
@nb.autojit( nopython=True )
def _quick_sort2( list_, a_temp_stack, left, right ):
i_stack_pos = 1
while i_stack_pos>0:
i_stack_pos-=1
right = a_temp_stack[ i_stack_pos, 1 ]
left = a_temp_stack[ i_stack_pos, 0 ]
piv = partition(list_,left,right)
#If items in the left of the pivot push them to the stack
if piv-1 > left:
a_temp_stack[ i_stack_pos, 0 ] = left
a_temp_stack[ i_stack_pos, 1 ] = piv-1
i_stack_pos+=1
if piv+1 < right:
a_temp_stack[ i_stack_pos, 0 ] = piv+1
a_temp_stack[ i_stack_pos, 1 ] = right
i_stack_pos+=1
@nb.autojit( nopython=True )
def partition(list_, left, right):
"""
Partition method
"""
#Pivot first element in the array
piv = list_[left]
i = left + 1
j = right
while 1:
while i <= j and list_[i] <= piv:
i +=1
while j >= i and list_[j] >= piv:
j -=1
if j <= i:
break
#Exchange items
list_[i], list_[j] = list_[j], list_[i]
#Exchange pivot to the right position
list_[left], list_[j] = list_[j], list_[left]
return j
numba
или вам нужна быстрая сортировка? - person hpaulj   schedule 23.03.2015for ignore_me in range(1):
). Он будет выполнять циклы JIT в режиме nopython, и, по-видимому, будут небольшие накладные расходы для входа и выхода из режима nopython. Однако оставьте окончательный оператор return вне цикла. - person DavidW   schedule 24.03.2015np.sort
, если глупые циклы выглядят слишком глупо. - person DavidW   schedule 24.03.2015