Боюсь, что нет простого способа обойти это: утилиты случайного преобразования Torchvision построены таким образом, что параметры преобразования будут дискретизированы при вызове. Это уникальные случайные преобразования в том смысле, что используемые параметры (1) недоступны для пользователя, а (2) - такое же случайное преобразование. не повторяется.
Начиная с Torchvision 0.8.0, случайные преобразования обычно создаются с двумя основными функциями:
get_params
: выборка будет основана на гиперпараметрах преобразования (то, что вы указали при инициализации оператора преобразования, а именно диапазон значений параметров)
forward
: функция, которая выполняется при применении преобразования. Важная часть состоит в том, что он получает свои параметры от get_params
, а затем применяет их к входу, используя связанную детерминированную функцию. Для RandomRotation
_ 5_. Точно так же RandomAffine
будет использовать _ 7_.
Одно из решений вашей проблемы - самостоятельно выбрать параметры из get_params
и вместо этого вызвать функциональный - детерминированный - API. Таким образом, вы бы не использовали ни RandomRotation
, RandomAffine
, ни какие-либо другие Random*
преобразования в этом отношении.
Например, давайте посмотрим на T.RandomRotation
( Я удалил комментарии для краткости).
class RandomRotation(torch.nn.Module):
def __init__(
self, degrees, interpolation=InterpolationMode.NEAREST, expand=False,
center=None, fill=None, resample=None):
# ...
@staticmethod
def get_params(degrees: List[float]) -> float:
angle = float(torch.empty(1).uniform_(float(degrees[0]), \
float(degrees[1])).item())
return angle
def forward(self, img):
fill = self.fill
if isinstance(img, Tensor):
if isinstance(fill, (int, float)):
fill = [float(fill)] * F._get_image_num_channels(img)
else:
fill = [float(f) for f in fill]
angle = self.get_params(self.degrees)
return F.rotate(img, angle, self.resample, self.expand, self.center, fill)
def __repr__(self):
# ...
Имея это в виду, вот возможное переопределение для изменения T.RandomRotation
:
class RandomRotation(T.RandomRotation):
def __init__(*args, **kwargs):
super(RandomRotation, self).__init__(*args, **kwargs) # let super do all the work
self.angle = self.get_params(self.degrees) # initialize your random parameters
def forward(self): # override T.RandomRotation's forward
fill = self.fill
if isinstance(img, Tensor):
if isinstance(fill, (int, float)):
fill = [float(fill)] * F._get_image_num_channels(img)
else:
fill = [float(f) for f in fill]
return F.rotate(img, self.angle, self.resample, self.expand, self.center, fill)
Я практически скопировал функцию T.RandomRotation
forward
, с той лишь разницей, что параметры выбираются в __init__
(т.е. один раз), а не внутри forward
(т.е. при каждом вызове). Реализация Torchvision охватывает все случаи, как правило, вам не нужно полностью копировать forward
. В некоторых случаях вы можете сразу вызвать функциональную версию. Например, если вам не нужно устанавливать параметры fill
, вы можете просто отказаться от этой части и использовать только:
class RandomRotation(T.RandomRotation):
def __init__(*args, **kwargs):
super(RandomRotation, self).__init__(*args, **kwargs) # let super do all the work
self.angle = self.get_params(self.degrees) # initialize your random parameters
def forward(self): # override T.RandomRotation's forward
return F.rotate(img, self.angle, self.resample, self.expand, self.center)
Если вы хотите переопределить другие случайные преобразования, вы можете посмотреть источник код. API довольно понятен, и у вас не должно возникнуть слишком много проблем с реализацией переопределения для каждого преобразования.
person
Ivan
schedule
26.01.2021