вычислить матричное умножение вектора с помощью python в cuda

Я пытаюсь использовать numbapro, чтобы написать простое умножение векторов матрицы ниже:

from numbapro import cuda
from numba import *
import numpy as np
import math
from timeit import default_timer as time

m = 100000 
n = 100

@cuda.jit('void(f4[:,:], f4[:], f4[:])')
def cu_matrix_vector(A, b, c):
    row = cuda.grid(1)
    if (row < m):
        sum = 0

        for i in range(n):
            sum += A[row, i] * b[i]

        c[row] = sum

A = np.array(np.random.random((m, n)), dtype=np.float32)
B = np.array(np.random.random(m), dtype=np.float32)
C = np.empty_like(B)

s = time()
dA = cuda.to_device(A)
dB = cuda.to_device(B)
dC = cuda.to_device(C)

cu_matrix_vector[(m+511)/512, 512](dA, dB, dC)

dC.to_host()

print ( C)

Но когда я начинаю работать, я получаю ошибку в функции ** cu_matrix_vector ** аргумент 2:: неправильный тип

cu_matrix_vector [(m + 511) / 512, 512] (dA, dB, dC) Файл "C: \ Anaconda3 \ lib \ site-packages \ numba \ cuda \ compiler.py", строка 359, в call < / strong> sharedmem = self.sharedmem) Файл "C: \ Anaconda3 \ lib \ site-packages \ numba \ cuda \ compiler.py", строка 433, в _kernel_call cu_func (* kernelargs) Файл "C: \ Anaconda3 \ lib \ site-packages \ numba \ cuda \ cudadrv \ driver.py ", строка 1116, в call self.sharedmem, streamhandle, args) Файл" C: \ Anaconda3 \ lib \ site-packages \ numba \ cuda \ cudadrv \ driver.py ", строка 1160, в launch_kernel Нет) Файл" C: \ Anaconda3 \ lib \ site-packages \ numba \ cuda \ cudadrv \ driver.py ", строка 221, в safe_cuda_api_call retcode = libfn (* args) ctypes.ArgumentError: аргумент 2:: неправильный тип


person ouamane zahra    schedule 27.04.2017    source источник


Ответы (1)


Проблема здесь:

cu_matrix_vector[(m+511)/512, 512](dA, dB, dC)

в Python 3 (m+511)/512 = 196.310546875. Передача значения с плавающей запятой в качестве параметра запуска недопустима, что является источником наблюдаемой ошибки конфликта типов. Вы хотите сделать:

cu_matrix_vector[(m+511)//512, 512](dA, dB, dC)

который выдаст целочисленное значение и должен позволить коду работать правильно.

person Community    schedule 27.04.2017