Почему выполняются обе ветки в tf.cond? И почему tf.while_loop завершает цикл, хотя условие остается верным?

Я использую keras уже некоторое время, но обычно мне не нужно использовать настраиваемые слои или выполнять более сложное управление потоком, поэтому я изо всех сил пытаюсь что-то понять.

Я моделирую нейронную сеть с настраиваемым слоем сверху. Этот настроенный слой вызывает другую функцию (search_sigma), и внутри этой функции я выполняю tf.while_loop, а внутри tf.while_loop я выполняю tf.cond.

Я не могу понять, почему условия не работают.

  • tf.while_loop останавливается, несмотря на то, что условие (l1) остается верным
  • tf.cond executes и f1, и f2 (вызываемые true_fn и false_fn)

Может ли кто-нибудь помочь мне понять, что мне не хватает?

Я уже пытался изменить условия tf.cond и tf.while_loop для истинных тензоров, просто чтобы посмотреть, что произойдет. Поведение (точно такие же ошибки) осталось прежним.

Я также пытался написать этот код без реализации класса (используя только функции). Ничего не изменилось.

Я пытался найти решения, просматривая документацию по тензорному потоку, другие сомнения в переполнении стека и веб-сайты, говорящие о tf.while_loop и tf.cond.

Я оставил несколько print() в теле кода, чтобы попытаться отследить, что происходит.

class find_sigma:
    
    def __init__ (self, t_inputs,  inputs,  expected_perp=10. ):       
        self.sigma, self.cluster = t_inputs
        self.inputs = inputs
        self.expected_perp = expected_perp
        self.min_sigma=tf.constant([0.01],tf.float32)
        self.max_sigma=tf.constant([50.],tf.float32)
 

    def search_sigma(self):

        
        def cond(s,sigma_not_found): return sigma_not_found


        def body(s,sigma_not_found):   

            print('loop')
            pi = K.exp( - K.sum( (K.expand_dims(self.inputs, axis=1) - self.cluster)**2, axis=2  )/(2*s**2) )        
            pi = pi / K.sum(pi)
            MACHINE_EPSILON = np.finfo(np.double).eps
            pi = K.maximum(pi, MACHINE_EPSILON)
            H = - K.sum ( pi*(K.log(pi)/K.log(2.)) , axis=0 )
            perp = 2**H

            print('0')

            l1 = tf.logical_and (tf.less(perp , self.expected_perp), tf.less(0.01, self.max_sigma-s))
            l2 = tf.logical_and (tf.less(  self.expected_perp , perp) , tf.less(0.01, s-self.min_sigma) )
    
            def f1():
                print('f1')
                self.min_sigma = s 
                s2 = (s+self.max_sigma)/2 
                return  [s2, tf.constant([True])]
                

            def f2(l2): 
                tf.cond( l2, true_fn=f3 , false_fn = f4)

            def f3(): 
                print('f3')
                self.max_sigma = s 
                s2 = (s+self.min_sigma)/2
                return [s2, tf.constant([True])]

            def f4(): 
                print('f4')
                return [s, tf.constant([False])]
            
            output = tf.cond( l1, f1 ,  f4 ) #colocar f2 no lugar de f4

            s, sigma_not_found = output
            
            print('sigma_not_found = ',sigma_not_found)
            return [s,sigma_not_found]

        print('01')

        sigma_not_found = tf.constant([True])

        new_sigma,sigma_not_found=sigma_not_found = tf.while_loop(
            cond , body, loop_vars=[self.sigma,sigma_not_found]
        )

        print('saiu')
        
        print(new_sigma)

        return new_sigma

Фрагмент кода, который вызывает приведенный выше код:

self.sigma = tf.map_fn(fn=lambda t: find_sigma(t,  inputs).search_sigma() , elems=(self.sigma,self.clusters), dtype=tf.float32)

'входы' - это тензор размера (None, 10)

'self.sigma' - это тензор размера (10,)

'self.clusters' - это тензор размера (N, 10)


person Luiza Ribeiro Marnet    schedule 04.02.2021    source источник


Ответы (1)


Во-первых, ваш первый вопрос был замечательным! Много информации!

tf.while_loop очень сбивает с толку, и это одна из причин, по которой tf перешел к нетерпеливому выполнению. Вам больше не нужно этого делать.

В любом случае, вернемся к вашим 2 вопросам. Ответ одинаков для обоих: вы никогда не выполняете свой график, вы просто его строите. При построении графика выполнения tensorflow необходимо отслеживать ваш код Python, поэтому вы думаете, что tf.conf запускает f1 и f2. Это своего рода работа, потому что ему нужно зайти внутрь, чтобы выяснить, какие тензоры/операции будут добавлены к графу.

То же самое относится и к вашему вопросу о tf.while_loop. Это никогда не выполняется.

Я рекомендую небольшое изменение, которое может помочь вам понять, о чем я говорю, а также решить вашу проблему. Удалите этот tf.while_loop из метода body. Создайте еще один метод, скажем, run() и переместите туда цикл. вроде этого

def run(self):
   out = tf.while_loop(cond, body, loop_vars)

Затем вызовите run(). Это заставит граф выполниться.

person CrazyBrazilian    schedule 04.02.2021