Загрузка изображений FITS с помощью PyTorch

Я пытаюсь создать CNN с помощью PyTorch, но мои изображения нужно импортировать из формата FITS, а не из обычных .png или .jpeg и т. Д.

Есть ли способ легко сделать это с помощью torch.utils.data.DataLoader или есть место в исходном коде, где я могу добавить пункт, который будет обрабатывать файлы FITS при загрузке?

Я просмотрел документацию, и самое важное, что я нашел, - это преобразователь ToPILImage, который преобразует тензор или ndarray в изображение PIL.

В настоящее время я использую процедуру загрузки изображений следующим образом:

import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision

batch_size = 4

transform = transforms.Compose(
                   [transforms.Resize((32,32)),
                    transforms.ToTensor(),
                    ])

trainset = dset.ImageFolder(root="Documents/Image_data",transform=transform)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,shuffle=True)

Astropy: http://www.astropy.org/

Pytorch: https://pytorch.org/

torch.utils: https://pytorch.org/docs/master/data.html

ОБНОВЛЕНИЕ: возможно, используя torchvision.datasets.DatasetFolder вместо DataLoader, вставка в мой собственный обработчик FITS будет работать?

При попытке использовать этот класс я получаю следующую ошибку:

AttributeError: module 'torchvision.datasets' has no attribute 'DatasetFolder'

Поддерживается ли DatasetFolder на данный момент torchvision?


person user8188120    schedule 08.05.2018    source источник


Ответы (3)


Прочитав некоторую комбинацию документов и кода, я не думаю, что вы обязательно хотите использовать ImageFolder, поскольку он ничего не знает о FITS.

Вместо этого вам следует попробовать использовать более общий класс DataSetFolder. (который фактически является родительским классом ImageFolder). Вы должны передать ему список расширений, которые он должен обрабатывать (т.е. ['.fits'] и функцию «загрузчик», которая принимает файл FITS и, по-видимому, должна возвращать PIL.Image.

Вы даже можете создать свой собственный подкласс по примеру ImageFolder . Например.

class FitsFolder(DatasetFolder):

    EXTENSIONS = ['.fits']

    def __init__(self, root, transform=None, target_transform=None,
                 loader=None):
        if loader is None:
            loader = self.__fits_loader

        super(FitsFolder, self).__init__(root, loader, self.EXTENSIONS,
                                         transform=transform,
                                         target_transform=target_transform)

    @staticmethod
    def __fits_loader(filename):
        data = fits.getdata(filename)
        return Image.fromarray(data)

Точные детали __fits_loader могут зависеть от деталей ваших файлов FITS. В этом базовом примере просто используется высокоуровневая функция fits.getdata(), которая возвращает первый массив изображений в файле FITS (некоторые файлы FITS могут иметь множество расширений с множеством изображений или иметь таблицы и т. Д.). Так что эта часть будет зависеть от вас.

person Iguananaut    schedule 09.05.2018
comment
Спасибо за ответ. Это наверняка хороший способ сделать это. Однако при попытке реализовать эту идею я получаю следующую ошибку: модуль torchvision.datasets не имеет атрибута DatasetFolder. - person user8188120; 10.05.2018
comment
На самом деле, похоже, что это было добавлено совсем недавно: github.com/pytorch/vision/pull / 444 Так что, если вы не можете использовать новейшую версию пакета, возможно, вам, к сожалению, придется немного заново изобрести колесо, но он, вероятно, по-прежнему будет выглядеть примерно так же (например, вы можете создать подкласс ImageFolder, хотя вам придется реализовать еще немного метода __init__). - person Iguananaut; 10.05.2018
comment
Ах, в этом есть смысл. Думаю, если бы я просто скопировал исходный код локально: github.com / pytorch / vision / blob / master / torchvision / datasets / Тогда я мог бы просто вызвать DatasetFolder и реализовать описанный выше метод? Так было бы проще. - person user8188120; 10.05.2018
comment
Конечно, вы могли бы сделать это в качестве временной меры. Возможно, с напоминанием о том, что его можно будет удалить при обновлении до будущей версии PyTorch. - person Iguananaut; 11.05.2018

Вы можете экспортировать изображение FITS в любой формат, поддерживаемый pyplot.imsave () с помощью этого метода:

from astropy.io import fits
import matplotlib.pyplot as plt

image_data = fits.getdata(r"/path/to/image.fits")
plt.imsave("/path/to/image.png", image_data, cmap="gray")
person MadeOfAir    schedule 08.05.2018
comment
Хорошая идея, к сожалению, мне нужно сохранять данные в формате FITS при архивировании, чтобы я мог быстро и легко использовать их в астрономических конвейерах. - person user8188120; 09.05.2018
comment
Я не уверен, в чем проблема. В этом ответе демонстрируется загрузка данных из файла FITS с последующей их записью в отдельный файл .png. Вы бы вообще не потеряли данные FITS. В противном случае я не знаком с PyTorch, но, возможно, есть способ расширить его для чтения файлов FITS. - person Iguananaut; 09.05.2018
comment
Приношу свои извинения, я хотел сказать, что вместо преобразования FITS в png и сохранения изображений для загрузки в PyTorch я больше сосредоточен на чтении в FITS непосредственно в PyTorch без необходимости в средней стадии, которая дублирует изображения в формат png - как я полагаю адрес вашего недавнего ответа. - person user8188120; 10.05.2018

Несколько недель назад я столкнулся с той же проблемой, что и @ user8188120. Использование ответа @Iguananaut отлично работает при чтении меток из структуры папок. Если кто-то наткнется на это и ему нужно прочитать из файла csv, это также может сработать:

labels = []
transform = transforms.Compose([
    # here go your transforms
    ])


class MyFitsDataset(data.Dataset):
    def __init__(self, csv_path):
        # Read the csv file
        self.data_info = pd.read_csv(csv_path, header=None)
        # First column contains the image paths
        self.image_arr = np.asarray(self.data_info.iloc[:, 0])
        # the rest contain the labels
        self.label_arr = np.asarray(self.data_info.iloc[:, 1:])  # for multi-label
        self.label_arr = np.asarray(self.data_info.iloc[:, 1])  # for single-label
        labels.append(self.label_arr)
        self.data_len = len(self.data_info.index)

    def __getitem__(self, index):
        single_image_name = self.image_arr[index]

        data = pyfits.open(single_image_name, axes=2)
        data = data[0].data.astype('float32')
        data = data.reshape(IMG_WIDTH, IMG_HEIGHT, CHANNELS)

        img = transform(data)

        # Get label(class) of the image based on the pandas column
        single_image_label = self.label_arr[index]

        return (img, single_image_label)

    def __len__(self):
        return self.data_len

Это также позволяет избежать использования класса DatasetFolder, который до сих пор недоступен в новейшей версии PyTorch. Я надеюсь, что это помогает кому-то.

person sara    schedule 03.07.2018