Как получить лучшую производительность с помощью алгоритма Штрассена, чем наивный подход без точки отсечки?

Я пытаюсь проверить наивный метод и метод Штрассена для умножения матриц.

Однако алгоритм Штрассена работает намного медленнее, чем наивный подход. Для матрицы размером 1024 наивный подход завершается за 3542 мс, а метод Штрассена — за 83602 мс. (где Штрассен не использует отсеченный/наивный подход) Вот код Штрассена, который я использую. LEAF SIZE — это число, при котором он переключается на наивный подход:

int n = A.length;

if (n <= LEAF_SIZE) {
    return ikjAlgorithm(A, B);
} else {
    // initializing the new sub-matrices
    int newSize = n / 2;
    int[][] a11 = new int[newSize][newSize];
    int[][] a12 = new int[newSize][newSize];
    int[][] a21 = new int[newSize][newSize];
    int[][] a22 = new int[newSize][newSize];

    int[][] b11 = new int[newSize][newSize];
    int[][] b12 = new int[newSize][newSize];
    int[][] b21 = new int[newSize][newSize];
    int[][] b22 = new int[newSize][newSize];

    int[][] aResult = new int[newSize][newSize];
    int[][] bResult = new int[newSize][newSize];

    // dividing the matrices in 4 sub-matrices:
    for (int i = 0; i < newSize; i++) {
        for (int j = 0; j < newSize; j++) {
            a11[i][j] = A[i][j]; // top left
            a12[i][j] = A[i][j + newSize]; // top right
            a21[i][j] = A[i + newSize][j]; // bottom left
            a22[i][j] = A[i + newSize][j + newSize]; // bottom right

            b11[i][j] = B[i][j]; // top left
            b12[i][j] = B[i][j + newSize]; // top right
            b21[i][j] = B[i + newSize][j]; // bottom left
            b22[i][j] = B[i + newSize][j + newSize]; // bottom right
        }
    }

    // Calculating p1 to p7:
    aResult = add(a11, a22);
    bResult = add(b11, b22);
    int[][] p1 = strassenR(aResult, bResult);
    // p1 = (a11+a22) * (b11+b22)

    aResult = add(a21, a22); // a21 + a22
    int[][] p2 = strassenR(aResult, b11); // p2 = (a21+a22) * (b11)

    bResult = subtract(b12, b22); // b12 - b22
    int[][] p3 = strassenR(a11, bResult);
    // p3 = (a11) * (b12 - b22)

    bResult = subtract(b21, b11); // b21 - b11
    int[][] p4 = strassenR(a22, bResult);
    // p4 = (a22) * (b21 - b11)

    aResult = add(a11, a12); // a11 + a12
    int[][] p5 = strassenR(aResult, b22);
    // p5 = (a11+a12) * (b22)

    aResult = subtract(a21, a11); // a21 - a11
    bResult = add(b11, b12); // b11 + b12
    int[][] p6 = strassenR(aResult, bResult);
    // p6 = (a21-a11) * (b11+b12)

    aResult = subtract(a12, a22); // a12 - a22
    bResult = add(b21, b22); // b21 + b22
    int[][] p7 = strassenR(aResult, bResult);
    // p7 = (a12-a22) * (b21+b22)

    // calculating c21, c21, c11 e c22:
    int[][] c12 = add(p3, p5); // c12 = p3 + p5
    int[][] c21 = add(p2, p4); // c21 = p2 + p4

    aResult = add(p1, p4); // p1 + p4
    bResult = add(aResult, p7); // p1 + p4 + p7
    int[][] c11 = subtract(bResult, p5);
    // c11 = p1 + p4 - p5 + p7

    aResult = add(p1, p3); // p1 + p3
    bResult = add(aResult, p6); // p1 + p3 + p6
    int[][] c22 = subtract(bResult, p2);
    // c22 = p1 + p3 - p2 + p6

    // Grouping the results obtained in a single matrix:
    int[][] C = new int[n][n];
    for (int i = 0; i < newSize; i++) {
        for (int j = 0; j < newSize; j++) {
            C[i][j] = c11[i][j];
            C[i][j + newSize] = c12[i][j];
            C[i + newSize][j] = c21[i][j];
            C[i + newSize][j + newSize] = c22[i][j];
        }
    }
    return C;
}
private static int[][] add(int[][] A, int[][] B) {
    int n = A.length;
    int[][] C = new int[n][n];
    for (int i = 0; i < n; i++) {
        for (int j = 0; j < n; j++) {
            C[i][j] = A[i][j] + B[i][j];
        }
    }
    return C;
}

private static int[][] subtract(int[][] A, int[][] B) {
    int n = A.length;
    int[][] C = new int[n][n];
    for (int i = 0; i < n; i++) {
        for (int j = 0; j < n; j++) {
            C[i][j] = A[i][j] - B[i][j];
        }
    }
    return C;
}

С размером листа около 32 он действительно работает быстрее (это точка отсечки, где срабатывает наивный алгоритм).

Это на языке Java. Код взят из Интернета, однако более или менее все реализации похожи.

Разве нельзя победить наивного одним только Штрассеном без точки отсечки? Любые идеи были бы хорошы. Спасибо.

EDIT Добавлены методы сложения и вычитания.

EDIT2 Из кода самые большие накладные расходы связаны с созданием новых подматриц? Если да, то какой метод можно применить для максимально возможного устранения накладных расходов? Если ничего нельзя сделать на java, я не против использования c++.

EDIT3 Может ли кто-нибудь предложить способ уменьшения используемого здесь выделения памяти? Был бы признателен за предложения.


person Svajunas Kavaliauskas    schedule 11.05.2019    source источник
comment
Не по теме: это просто для развлечения или вы ожидаете какой-то выгоды в своей работе? Я не думаю, что какой-либо код общего назначения использует ненаивный матмул: все (включая все разрекламированные приложения для глубокого обучения) будут использовать какой-то BLAS-бэкенд на основе наивного матмул. Наверное, потому что его трудно заставить работать. См. также это. При попытке приблизиться к реальной производительности, вероятно, также не поможет попробовать java. Один только материал int[][] выглядит ужасно (массив объектов) с точки зрения кэширования.   -  person sascha    schedule 11.05.2019
comment
@sascha Это задание для сравнения Штрассена и наивного подхода.   -  person Svajunas Kavaliauskas    schedule 11.05.2019
comment
Вы можете найти некоторые практические измерения на этом сайте, посвященном быстрому умножению матриц.   -  person Axel Kemper    schedule 12.05.2019