Заполнитель Tensorflow в пользовательской целевой функции Keras

Мне нужно реализовать настраиваемую целевую функцию для Keras, где мне нужен дополнительный заполнитель тензорного потока для вычислений. В тензорном потоке у меня есть следующее:

pre_cost1 = tf.multiply((self.input_R - self.Decoder) , self.input_mask_R)
cost1 = tf.square(self.l2_norm(pre_cost1))

где input_mask_R - заполнитель тензорного потока. input_R и Decoder - это заполнители, соответствующие y_true и y_pred для функции потерь Keras соответственно. У меня есть функция потери Keras, реализованная как,

def custom_objective(y_true, y_pred):

    pre_cost1 = tf.multiply((y_true - y_pred))
    cost1 = tf.square(l2_norm(pre_cost1))

    return cost1

Мне нужно добавить дополнительную информацию для маски ввода в функцию потерь для keras. (Это должен быть заполнитель тензорного потока, поскольку это маска для ввода, которая отличается для каждой строки входных данных).


person Awais Jafar    schedule 29.04.2017    source источник
comment
Изменяется ли input_mask_R во время обучения? или он предопределен для каждого входного образца?   -  person Pedia    schedule 29.04.2017
comment
его предопределено для каждого образца   -  person Awais Jafar    schedule 29.04.2017
comment
Не могли бы вы выложить модель? хотя бы базовый, с которым вы тестируете?   -  person Pedia    schedule 29.04.2017


Ответы (1)


Используйте бэкэнд keras:

import keras.backend as K

Там есть большинство функций для тензоров, таких как:

input_mask_R = K.placeholder(shape=(yourshape))

Но, возможно, поскольку вам нужна предопределенная маска, вам понадобится:

input_mask_R = K.constant(arrayWithValues, shape=(yourshape))

И вы действительно можете умножать и возводить в квадрат с помощью K.multiply и K.square. Таким образом, если вы когда-нибудь подумаете о смене серверной части, все будет в порядке. (Также я не уверен, будет ли Keras обрабатывать прямые вызовы функций тензорного потока ...)

См. Документацию: https://keras.io/backend/

person Daniel Möller    schedule 05.05.2017
comment
Да, из моих тестов keras handle tf.square вместо K.square - person parsethis; 05.05.2017