anindya-hf-2002 commited on
Commit
634fc83
1 Parent(s): 8e3c016

upload 3 files

Browse files
Files changed (3) hide show
  1. src/dataset.py +65 -0
  2. src/generate_images.py +96 -0
  3. src/train.py +124 -0
src/dataset.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from torch.utils.data import Dataset
3
+ from torchvision import transforms
4
+ import os
5
+
6
+ class ClassifierDataset(Dataset):
7
+ def __init__(self, root_dir, transform=None):
8
+ self.root_dir = root_dir
9
+ self.transform = transform
10
+
11
+ self.classes = ['0', '1']
12
+ self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}
13
+
14
+ self.samples = self._make_dataset()
15
+
16
+ def _make_dataset(self):
17
+ samples = []
18
+ for class_name in self.classes:
19
+ class_dir = os.path.join(self.root_dir, class_name)
20
+ for img_name in os.listdir(class_dir):
21
+ img_path = os.path.join(class_dir, img_name)
22
+ samples.append((img_path, self.class_to_idx[class_name]))
23
+ return samples
24
+
25
+ def __len__(self):
26
+ return len(self.samples)
27
+
28
+ def __getitem__(self, idx):
29
+ img_path, label = self.samples[idx]
30
+ img = Image.open(img_path).convert('L') # Convert to grayscale
31
+ if self.transform:
32
+ img = self.transform(img)
33
+ return img, label
34
+
35
+
36
+
37
+ class CustomDataset(Dataset):
38
+ def __init__(self, root_dir, train_N, train_P, img_res):
39
+ self.root_dir = root_dir
40
+ self.train_N = train_N
41
+ self.train_P = train_P
42
+ self.img_res = img_res
43
+ self.transforms = transforms.Compose([
44
+ transforms.Resize(img_res),
45
+ transforms.ToTensor(),
46
+ transforms.Normalize(mean=[0.5], std=[0.5]) # Assuming grayscale images
47
+ ])
48
+
49
+ def __len__(self):
50
+ return min(len(os.listdir(os.path.join(self.root_dir, self.train_N))),
51
+ len(os.listdir(os.path.join(self.root_dir, self.train_P))))
52
+
53
+ def __getitem__(self, idx):
54
+ normal_path = os.path.join(self.root_dir, self.train_N, os.listdir(os.path.join(self.root_dir, self.train_N))[idx])
55
+ pneumo_path = os.path.join(self.root_dir, self.train_P, os.listdir(os.path.join(self.root_dir, self.train_P))[idx])
56
+
57
+ normal_img = Image.open(normal_path).convert("L") # Load as grayscale
58
+ pneumo_img = Image.open(pneumo_path).convert("L") # Load as grayscale
59
+
60
+ normal_img = self.transforms(normal_img)
61
+ pneumo_img = self.transforms(pneumo_img)
62
+
63
+ return normal_img, pneumo_img
64
+
65
+
src/generate_images.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from PIL import Image
4
+ import numpy as np
5
+ from torch.utils.data import DataLoader, Dataset
6
+ from torchvision import transforms
7
+ from tqdm import tqdm
8
+
9
+ from src.models import ResUNetGenerator
10
+
11
+ # Custom Dataset
12
+ class ImageDataset(Dataset):
13
+ def __init__(self, image_paths, transform=None):
14
+ self.image_paths = image_paths
15
+ self.transform = transform
16
+
17
+ def __len__(self):
18
+ return len(self.image_paths)
19
+
20
+ def __getitem__(self, idx):
21
+ img_path = self.image_paths[idx]
22
+ image = Image.open(img_path).convert('L')
23
+ if self.transform:
24
+ image = self.transform(image)
25
+ return image, img_path
26
+
27
+ # Function to save image
28
+ def save_image(tensor, path):
29
+ if tensor.is_cuda:
30
+ tensor = tensor.cpu()
31
+
32
+ array = tensor.permute(1, 2, 0).detach().numpy()
33
+ array = (array * 0.5 + 0.5) * 255
34
+ array = array.astype(np.uint8)
35
+ if array.shape[2] == 1:
36
+ array = array.squeeze(2)
37
+ image = Image.fromarray(array, mode='L')
38
+ else:
39
+ image = Image.fromarray(array)
40
+ image.save(path)
41
+
42
+ # Function to load model
43
+ def load_model(checkpoint_path, model_class, device):
44
+ model = model_class().to(device)
45
+ model.load_state_dict(torch.load(checkpoint_path, map_location=torch.device('cpu')))
46
+ model.eval()
47
+ return model
48
+
49
+ def generate_images(image_folder, g_NP_checkpoint, g_PN_checkpoint, output_dir='data/translated_images', batch_size=16):
50
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
51
+
52
+ # Load models
53
+ g_NP = load_model(g_NP_checkpoint, lambda: ResUNetGenerator(gf=32, channels=1), device)
54
+ g_PN = load_model(g_PN_checkpoint, lambda: ResUNetGenerator(gf=32, channels=1), device)
55
+
56
+ # Create output directories
57
+ os.makedirs(os.path.join(output_dir, '0'), exist_ok=True)
58
+ os.makedirs(os.path.join(output_dir, '1'), exist_ok=True)
59
+
60
+ # Collect image paths
61
+ image_paths_0 = [os.path.join(image_folder, '0', fname) for fname in os.listdir(os.path.join(image_folder, '0')) if fname.endswith(('.png', '.jpg', '.jpeg'))]
62
+ image_paths_1 = [os.path.join(image_folder, '1', fname) for fname in os.listdir(os.path.join(image_folder, '1')) if fname.endswith(('.png', '.jpg', '.jpeg'))]
63
+
64
+ # Prepare dataset and dataloader
65
+ transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.485], std=[0.229])])
66
+ dataset_0 = ImageDataset(image_paths_0, transform)
67
+ dataset_1 = ImageDataset(image_paths_1, transform)
68
+ dataloader_0 = DataLoader(dataset_0, batch_size=batch_size, shuffle=False)
69
+ dataloader_1 = DataLoader(dataset_1, batch_size=batch_size, shuffle=False)
70
+
71
+ # Process images from negative (0) to positive (1)
72
+ with torch.no_grad():
73
+ for batch, paths in tqdm(dataloader_0, desc="Converting N to P: "):
74
+ batch = batch.to(device)
75
+ translated_images = g_NP(batch)
76
+ translated_images = g_PN(translated_images)
77
+ for img, path in zip(translated_images, paths):
78
+ save_path = os.path.join(output_dir, '1', os.path.basename(path))
79
+ save_image(img, save_path)
80
+
81
+ # Process images from positive (1) to negative (0)
82
+ for batch, paths in tqdm(dataloader_1, desc="Converting P to N: "):
83
+ batch = batch.to(device)
84
+ translated_images = g_PN(batch)
85
+ translated_images = g_NP(translated_images)
86
+ for img, path in zip(translated_images, paths):
87
+ save_path = os.path.join(output_dir, '0', os.path.basename(path))
88
+ save_image(img, save_path)
89
+
90
+ if __name__ == '__main__':
91
+ image_folder = r'data\rsna-pneumonia-dataset\train'
92
+ g_NP_checkpoint = 'models\g_NP_best.ckpt'
93
+ g_PN_checkpoint = 'models\g_PN_best.ckpt'
94
+
95
+
96
+ generate_images(image_folder, g_NP_checkpoint, g_PN_checkpoint)
src/train.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchvision import transforms
2
+ from torch.utils.data import DataLoader
3
+ from lightning.pytorch.loggers.wandb import WandbLogger
4
+ from lightning.pytorch.callbacks import ModelCheckpoint
5
+ import lightning as pl
6
+ import wandb
7
+
8
+ from src.dataset import ClassifierDataset, CustomDataset
9
+ from src.classifier import Classifier
10
+ from src.models import CycleGAN
11
+ from src.config import CFG
12
+
13
+ def train_classifier(image_size,
14
+ batch_size,
15
+ epochs,
16
+ resume_ckpt_path,
17
+ train_dir,
18
+ val_dir,
19
+ checkpoint_dir,
20
+ project,
21
+ job_name):
22
+
23
+ clf_wandb_logger = WandbLogger(project=project, name=job_name, log_model="all")
24
+
25
+ transform = transforms.Compose([
26
+ transforms.Resize((image_size, image_size)), # Resize image to 512x512
27
+ transforms.ToTensor(),
28
+ transforms.Normalize(mean=[0.485], std=[0.229]) # Normalize image
29
+ ])
30
+
31
+ # Define dataset paths
32
+ # train_dir = "/kaggle/working/CycleGan-CFE/train-data/train"
33
+ # val_dir = "/kaggle/working/CycleGan-CFE/train-data/val"
34
+
35
+ # Create datasets
36
+ train_dataset = ClassifierDataset(root_dir=train_dir, transform=transform)
37
+ val_dataset = ClassifierDataset(root_dir=val_dir, transform=transform)
38
+ print("Total Training Images: ",len(train_dataset))
39
+ print("Total Validation Images: ",len(val_dataset))
40
+
41
+ # Create data loaders
42
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=4)
43
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=4)
44
+ # Instantiate the classifier model
45
+ clf = Classifier(transfer=True)
46
+
47
+ checkpoint_callback = ModelCheckpoint(
48
+ monitor='val_loss',
49
+ dirpath=checkpoint_dir,
50
+ filename='efficientnet_b2-epoch{epoch:02d}-val_loss{val_loss:.2f}',
51
+ auto_insert_metric_name=False,
52
+ save_weights_only=False,
53
+ save_top_k=3,
54
+ mode='min'
55
+ )
56
+ # Set up PyTorch Lightning Trainer with multiple GPUs and tqdm progress bar
57
+ trainer = pl.Trainer(
58
+ devices="auto",
59
+ precision="16-mixed",
60
+ accelerator="auto",
61
+ max_epochs=epochs,
62
+ accumulate_grad_batches=10,
63
+ log_every_n_steps=1,
64
+ check_val_every_n_epoch=1,
65
+ benchmark=True,
66
+ logger=clf_wandb_logger,
67
+ callbacks=[checkpoint_callback],
68
+ )
69
+
70
+ # Train the classifier
71
+ trainer.fit(clf, train_loader, val_loader, ckpt_path=resume_ckpt_path)
72
+ wandb.finish()
73
+
74
+
75
+ def train_cyclegan(image_size,
76
+ batch_size,
77
+ epochs,
78
+ classifier_path,
79
+ resume_ckpt_path,
80
+ train_dir,
81
+ val_dir,
82
+ test_dir,
83
+ checkpoint_dir,
84
+ project,
85
+ job_name,
86
+ ):
87
+
88
+
89
+ testdata_dir = test_dir
90
+ train_N = "0"
91
+ train_P = "1"
92
+ img_res = (image_size, image_size)
93
+
94
+ test_dataset = CustomDataset(root_dir=testdata_dir, train_N=train_N, train_P=train_P, img_res=img_res)
95
+ test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
96
+
97
+ wandb_logger = WandbLogger(project=project, name=job_name, log_model="all")
98
+ print(classifier_path)
99
+ cyclegan = CycleGAN(train_dir=train_dir, val_dir=val_dir, test_dataloader=test_dataloader, classifier_path=classifier_path, checkpoint_dir=checkpoint_dir, gf=CFG.GAN_FILTERS, df=CFG.DIS_FILTERS)
100
+
101
+ gan_checkpoint_callback = ModelCheckpoint(dirpath=checkpoint_dir,
102
+ filename='cyclegan-epoch_{epoch}-vloss_{val_generator_loss:.2f}',
103
+ monitor='val_generator_loss',
104
+ save_top_k=3,
105
+ save_last=True,
106
+ save_weights_only=False,
107
+ verbose=True,
108
+ mode='min')
109
+
110
+
111
+ # Create the trainer
112
+ trainer = pl.Trainer(
113
+ accelerator="auto",
114
+ precision="16-mixed",
115
+ max_epochs=epochs,
116
+ log_every_n_steps=1,
117
+ benchmark=True,
118
+ devices="auto",
119
+ logger=wandb_logger,
120
+ callbacks= [gan_checkpoint_callback]
121
+ )
122
+
123
+ # Train the CycleGAN model
124
+ trainer.fit(cyclegan, ckpt_path=resume_ckpt_path)