Я строю матрицу перехода из массива n1 x n2 x ... x nN x nN
. Для конкретности пусть N = 3
, например,
import numpy as np
# example with N = 3
n1, n2, n3 = 3, 2, 5
dim = (n1, n2, n3)
arr = np.random.random_sample(dim + (n3,))
Здесь arr
содержит вероятности перехода между 2 состояниями, где исходное состояние индексируется по первым 3 измерениям, а состояние до индексируется по первым 2 и последнему измерению. Я хочу построить матрицу перехода, которая выражает эти вероятности, сведенные в разреженную матрицу (n1*n2*n3) x (n1*n2*n3
.
Чтобы уточнить, позвольте мне представить мой текущий подход, который делает то, что я хочу сделать. К сожалению, это медленно и не работает, когда N
и n1, n2, ...
большие. Поэтому я ищу более эффективный способ сделать то же самое, который лучше масштабируется для более крупных задач.
Мой подход
import numpy as np
from scipy import sparse as sparse
## step 1: get the index correponding to each dimension of the from and to state
# ravel axes 1 to 3 into single axis and make sparse
spmat = sparse.coo_matrix(arr.reshape(np.prod(dim), -1))
data = spmat.data
row = spmat.row
col = spmat.col
# use unravel to get idx for
row_unravel = np.array(np.unravel_index(row, dim))
col_unravel = np.array(np.unravel_index(col, n3))
## step 2: combine "to" index with rows 1 and 2 of "from"-index to get "to"-coordinates in full state space
row_unravel[-1, :] = col_unravel # first 2 dimensions of state do not change
colnew = np.ravel_multi_index(row_unravel, dim) # ravel back to 1d
## step 3: assemble transition matrix
out = sparse.coo_matrix((data, (row, colnew)), shape=(np.prod(dim), np.prod(dim)))
Заключительная мысль
Я буду запускать этот код много раз. В ходе итераций данные arr
могут измениться, но размеры останутся прежними. Поэтому я могу сохранить и загрузить row
и colnew
из файла, пропустив все между определением data
(строка 2) и последней строкой моего кода. Как вы думаете, это был бы лучший подход?
Редактировать: одна проблема, которую я вижу с этой стратегией, заключается в том, что если некоторые элементы arr
равны нулю (что возможно), то размер data
будет меняться между итерациями.