Эффективное вычисление медианы

У меня есть массив A длины n. Пусть B будет массивом (который мы никогда не хотим хранить отдельно — это просто для облегчения объяснения), содержащим каждый k-й элемент A. Я хочу найти медиану B, и я хочу переместить этот элемент A в этаж(n/2)-я позиция в A.

Как я могу сделать это эффективно? Я думаю попытаться сделать один вызов std::nth_element, передав указатель на A. Однако мне нужно, чтобы этот указатель увеличивался на k элементов A. Как мне это сделать? По сути:

A2 = (kFloat *)A;
std::nth_element(A2, A2 + (n/k)/2, A2 + (n/k));
swap(A[ ((n/k)/2)*k ], A[n/2]); // This might be redundant

где kFloat будет структурой, которая действует как число с плавающей запятой, но когда вы увеличиваете указатель, он перемещает k*sizeof(float) в памяти.

Примечание. Мне не требуется истинная медиана (среднее значение двух средних, когда n четно).

Редактировать: Другой способ сказать, что я хочу (не компилируется, потому что k не является константой):

std::nth_element((float[k] * )A, ((float[k] * ) A)[(n / k) / 2], ((float[k] * ) A)[n / k]);

Редактировать 2: я изменяю алгоритм.cc, поэтому я не хочу вводить зависимости от такой библиотеки, как Boost. Я хотел бы использовать только основные функции С++ 11 + std.


person PThomasCS    schedule 14.12.2013    source источник
comment
Я думаю, вы могли бы использовать собственный итератор; вероятно, уже есть реализация либо в библиотеке диапазонов, либо в итераторах повышения. Например. см. stackoverflow.com/q/12726466/420683   -  person dyp    schedule 15.12.2013
comment
Вы пробовали boost::accumulators для статистической оценки?   -  person Jepessen    schedule 15.12.2013
comment
Это сработало бы, однако я изменяю алгоритм.cc, поэтому я не могу добавить зависимости, такие как Boost или другую библиотеку. Я должен использовать только основные функции + стандарт.   -  person PThomasCS    schedule 15.12.2013


Ответы (2)


Однажды я реализовал бинарный поиск по пользовательским итераторам — вы можете проверить это на https://gist.github.com/IvanVergiliev/6048716 . Это для конкретной проблемы, но общая идея та же: напишите класс итератора для вашей последовательности и реализуйте необходимые операторы (++, --, +=) для перемещения k позиций за раз.

person Ivan Vergiliev    schedule 15.12.2013

Для всех, кто столкнется с этой проблемой в будущем, я изменил некоторые функции из алгоритма.cc, чтобы включить параметр шага. Многие из них предполагают, что _First и _Last кратны вашему шагу, поэтому я не рекомендую их вызывать. Однако вы можете вызвать следующую функцию:

// Same as _Nth_element, but increments pointers by strides of k
// Takes n, rather than last (needed to avoid confusion about what last should be [see line that computes _Last to see why]
// _First = pointer to start of the array
// _Nth = pointer to the position that we want to find the element for (if it were sorted).
//          This position should be = _First + k*x, for some integer x. That is, it should be a multiple of k.
// n = Length of array, _First, in primitive type (not length / k).
// _Pred = comparison operator. Typically use less<>()
// k = integer specifying the stride. If k = 10, we consider elements 0, 10, 20... only.
template<class _RanIt, class intType, class _Pr> inline
void _Nth_element_strided(_RanIt _First, _RanIt _Nth, intType n, _Pr _Pred, intType k);

Чтобы вызвать эту функцию, вам нужно включить этот заголовок:

#ifndef _NTH_ELEMENT_STRIDED_H_
#define _NTH_ELEMENT_STRIDED_H_

template<class _RanIt, class intType, class _Pr> inline
void _Median_strided(_RanIt _First, _RanIt _Mid, _RanIt _Last, _Pr _Pred, intType k) {
    // sort median element to middle
    if (40 < (_Last - _First)/k) {
        // median of nine
        size_t _Step = k * ((_Last - _First + k) / (k*8));
        _Med3(_First, _First + _Step, _First + 2 * _Step, _Pred);
        _Med3(_Mid - _Step, _Mid, _Mid + _Step, _Pred);
        _Med3(_Last - 2 * _Step, _Last - _Step, _Last, _Pred);
        _Med3(_First + _Step, _Mid, _Last - _Step, _Pred);
    }
    else
        _Med3(_First, _Mid, _Last, _Pred);
}

// Same as _Unguarded_partition, except it increments pointers by k.
template<class _RanIt, class _Pr, class intType> inline
pair<_RanIt, _RanIt> _Unguarded_partition_strided(_RanIt _First, _RanIt _Last, _Pr _Pred, intType k) {
    // partition [_First, _Last), using _Pred
    _RanIt _Mid = _First + (((_Last - _First)/k) / 2)*k;
    _Median_strided(_First, _Mid, _Last - k, _Pred, k);
    _RanIt _Pfirst = _Mid;
    _RanIt _Plast = _Pfirst + k;

    while (_First < _Pfirst
        && !_DEBUG_LT_PRED(_Pred, *(_Pfirst - k), *_Pfirst)
        && !_Pred(*_Pfirst, *(_Pfirst - k)))
        _Pfirst -= k;
    while (_Plast < _Last
        && !_DEBUG_LT_PRED(_Pred, *_Plast, *_Pfirst)
        && !_Pred(*_Pfirst, *_Plast))
        _Plast += k;

    _RanIt _Gfirst = _Plast;
    _RanIt _Glast = _Pfirst;

    for (;;) {
        // partition
        for (; _Gfirst < _Last; _Gfirst += k) {
            if (_DEBUG_LT_PRED(_Pred, *_Pfirst, *_Gfirst))
                ;
            else if (_Pred(*_Gfirst, *_Pfirst))
                break;
            else if (_Plast != _Gfirst) {
                _STD iter_swap(_Plast, _Gfirst);
                _Plast += k;
            }
            else
                _Plast += k;
        }
        for (; _First < _Glast; _Glast -= k) {
            if (_DEBUG_LT_PRED(_Pred, *(_Glast - k), *_Pfirst))
                ;
            else if (_Pred(*_Pfirst, *(_Glast - k)))
                break;
            else {
                _Pfirst -= k;
                if (_Pfirst != _Glast - k)
                    _STD iter_swap(_Pfirst, _Glast - k);
            }
        }

        if (_Glast == _First && _Gfirst == _Last)
            return (pair<_RanIt, _RanIt>(_Pfirst, _Plast));

        if (_Glast == _First) {
            // no room at bottom, rotate pivot upward
            if (_Plast != _Gfirst)
                _STD iter_swap(_Pfirst, _Plast);
            _Plast += k;
            _STD iter_swap(_Pfirst, _Gfirst);
            _Pfirst += k;
            _Gfirst += k;
        }
        else if (_Gfirst == _Last) {
            // no room at top, rotate pivot downward
            _Glast -= k;
            _Pfirst -= k;
            if (_Glast != _Pfirst)
                _STD iter_swap(_Glast, _Pfirst);
            _Plast -= k;
            _STD iter_swap(_Pfirst, _Plast);
        }
        else {
            _Glast -= k;
            _STD iter_swap(_Gfirst, _Glast);
            _Gfirst += k;
        }
    }
}

// TEMPLATE FUNCTION move_backward
template<class _BidIt1, class _BidIt2, class intType> inline
_BidIt2 _Move_backward_strided(_BidIt1 _First, _BidIt1 _Last, _BidIt2 _Dest, intType k) {
    // move [_First, _Last) backwards to [..., _Dest), arbitrary iterators
    while (_First != _Last) {
        _Dest -= k;
        _Last -= k;
        *_Dest = _STD move(*_Last);
    }
    return (_Dest);
}

template<class _BidIt, class _Pr, class intType, class _Ty> inline
void _Insertion_sort1_strided(_BidIt _First, _BidIt _Last, _Pr _Pred, _Ty *, intType k) {
    // insertion sort [_First, _Last), using _Pred
    if (_First != _Last) {
        for (_BidIt _Next = _First + k; _Next != _Last;) {
            // order next element
            _BidIt _Next1 = _Next;
            _Ty _Val = _Move(*_Next);

            if (_DEBUG_LT_PRED(_Pred, _Val, *_First)) {
                // found new earliest element, move to front
                _Next1 += k;
                _Move_backward_strided(_First, _Next, _Next1, k);
                *_First = _Move(_Val);
            }
            else {
                for (_BidIt _First1 = _Next1 - k; _DEBUG_LT_PRED(_Pred, _Val, *_First1);) {
                    *_Next1 = _Move(*_First1);  // move hole down

                    _Next1 = _First1;
                    _First1 -= k;
                }
                *_Next1 = _Move(_Val);  // insert element in hole
            }
            _Next += k;
        }
    }
}

// _Last should point to the last element being considered (the last k'th element), plus k.
template<class _BidIt, class intType, class _Pr> inline
void _Insertion_sort_strided(_BidIt _First, _BidIt _Last, _Pr _Pred, intType k) {
    // insertion sort [_First, _Last), using _Pred
    _Insertion_sort1_strided(_First,_Last, _Pred, _Val_type(_First), k);
}

// Same as _Nth_element, but increments pointers by strides of k
// Takes n, rather than last (needed to avoid confusion about what last should be [see first line below]
// _First = pointer to start of the array
// _Nth = pointer to the position that we want to find the element for (if it were sorted).
//          This position should be = _First + k*x, for some integer x. That is, it should be a multiple of k.
// n = Length of array, _First, in primitive type (not length / k).
// _Pred = comparison operator. Typically use less<>()
// k = integer specifying the stride. If k = 10, we consider elements 0, 10, 20... only.
template<class _RanIt, class intType, class _Pr> inline
void _Nth_element_strided(_RanIt _First, _RanIt _Nth, intType n, _Pr _Pred, intType k) {

    _RanIt _Last = (n % k == 0 ? _First + n : _First + (n / k + 1)*k);
    // order Nth element, using _Pred
    for (; _ISORT_MAX < (_Last - _First) / k;) {
        // divide and conquer, ordering partition containing Nth
        pair<_RanIt, _RanIt> _Mid = _Unguarded_partition_strided(_First, _Last, _Pred, k);

        if (_Mid.second <= _Nth)
            _First = _Mid.second;
        else if (_Mid.first <= _Nth)
            return; // Nth inside fat pivot, done
        else
            _Last = _Mid.first;
    }

    _Insertion_sort_strided(_First, _Last, _Pred, k);   // sort any remainder
}

#endif

Пример использования этой функции:

    for (int counter = 0; true; counter++) {
        // Test strided methods
        int n = (rand() % 10000) + 1;
        int k = (rand() % n) + 1;
        int * a = new int[n];
        int bLen = (n % k == 0 ? n / k : n / k + 1);
        int * b = new int[bLen];
        for (int i = 0; i < n; i++) // Initialize randomly
            a[i] = rand() % 100;
        for (int i = 0; i < bLen; i++)
            b[i] = a[i*k];

        int index = rand() % (bLen);    // Random index!
        _Nth_element(b, b + index, b + bLen, less<>());
        _Nth_element_strided(a, a + index*k, n, less<>(), k);

        if (b[index] != a[index*k]) {
            cout << "Not equal!" << endl;
            cout << b[index] << '\t' << a[index*k] << endl;
            getchar();
        }
        else
            cout << counter << endl;
    }
person PThomasCS    schedule 15.12.2013