본문 바로가기
컴퓨터비전(CV)

2. YOLOv8를 활용한 폐 질환 분류

by 곽정우 2024. 7. 14.

!pip install ultralytics
import os
import random
import shutil
import cv2
import glob
import yaml
import matplotlib.pyplot as plt
import ultralytics
import numpy as np
import torch
from torchvision import transforms
from tqdm import tqdm
from PIL import Image
from ultralytics import YOLO
ultralytics.checks()

random.seed(2024)
!kaggle datasets download -d hamdallak/the-iqothnccd-lung-cancer-dataset

!unzip -q /content/the-iqothnccd-lung-cancer-dataset.zip
data_root = '/content/The IQ-OTHNCCD lung cancer dataset'
file_root = f'{data_root}/data'
project_name = 'lung_cancer'

# 정리할 디렉토리 정의
train_file_root = f'{data_root}/{project_name}'
train_root = f'{data_root}/{project_name}/train'
valid_root =  f'{data_root}/{project_name}/valid'
test_root = f'{data_root}/{project_name}/test'

수동으로 data 폴더 만들어서 그 안으로 이동시킴

# 3종류의 클래스 디렉토리를 data 안으로 넣어줌
# file_root에 있는 모든 디렉토리와 파일을 리스트로 만듦
cls_list = os.listdir(file_root)
cls_list

for folder in [train_root, valid_root, test_root]:
    if not os.path.exists(folder):
        os.makedirs(folder)
    for cls in cls_list:
        cls_folder = f'{folder}/{cls}'
        if not os.path.exists(cls_folder):
            os.makedirs(cls_folder)
for cls in cls_list:
    file_list = os.listdir(f'{file_root}/{cls}')
    random.shuffle(file_list)
    test_ratio = 0.1
    num_file = len(file_list)
    test_list = file_list[:int(num_file*test_ratio)]
    valid_list = file_list[int(num_file*test_ratio):int(num_file*test_ratio)*2]
    train_list = file_list[int(num_file*test_ratio)*2:]
    # print(test_list)
    # print(valid_list)
    # print(train_list)
    for i in test_list:
        shutil.copyfile(f'{file_root}/{cls}/{i}', f'{test_root}/{cls}/{i}')

    for i in valid_list:
        shutil.copyfile(f'{file_root}/{cls}/{i}', f'{valid_root}/{cls}/{i}')

    for i in train_list:
        shutil.copyfile(f'{file_root}/{cls}/{i}', f'{train_root}/{cls}/{i}')
test_file_list = glob.glob(f'{test_root}/*/*')
random.shuffle(test_file_list)
# test_file_list

plt.figure(figsize=(20, 10))
for i in range(10):
    test_img_path = os.path.join(test_root, test_file_list[i])
    ori_img = Image.open(test_img_path).convert('RGB')
    plt.subplot(2, 5, (i+1))
    plt.title(test_file_list[i].split('/')[-2])
    plt.imshow(ori_img)

plt.show()

project_root = '/content/The IQ-OTHNCCD lung cancer dataset/lung_cancer'
data = dict()

data['train'] = train_root
data['val'] = valid_root
data['test'] = test_root
data['nc'] = len(cls_list)
data['names'] = cls_list

with open(f'{project_root}/lung_cancer.yaml', 'w') as f:
    yaml.dump(data, f)
%cd /content/The IQ-OTHNCCD lung cancer dataset/lung_cancer

model = YOLO('yolov8s-cls.pt')

results = model.train(data=f'{data_root}/{project_name}', epochs=50, batch=8, device=0, patience=30, name='lung_cancer_s')

 

s2인데 수동으로 s로 바꿈!!!(중복 저장 때문에 2로 자동으로 저장됨)

result_folder = f'{project_root}/runs/classify/lung_cancer_s'
model = YOLO(f'{result_folder}/weights/best.pt')
model

metrics = model.val(split='test')
metrics

print('top1 accuracy: ', metrics.top1)
print('top5 accuracy: ', metrics.top5)

IMG_SIZE = (512, 512)
test_data_transform = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    transforms.ToTensor()
])
img = Image.open(test_file_list[0]).convert('RGB')
img_src = test_data_transform(img)
print(img_src.shape)
x_tensor = img_src.unsqueeze(0)
print(x_tensor.shape)

result = model(x_tensor)[0]

gt = test_file_list[0].split('/')[-2]
pt = model.names[torch.argmax(result.probs.data).item()]
print(gt)
print(pt)

plt.figure(figsize=(3, 3))
plt.title(f'GT: {gt}, Predict: {pt}') 
plt.imshow(np.array(img))
plt.show()

plt.figure(figsize=(20, 5))

for idx in range(5):
  img = Image.open(test_file_list[idx]).convert('RGB')
  img_src = test_data_transform(img)
  x_tensor = img_src.unsqueeze(0)
  result = model(x_tensor)[0]
  gt = test_file_list[idx].split('/')[-2]
  pt = model.names[torch.argmax(result.probs.data).item()]
  plt.subplot(1, 5, (idx+1))
  plt.title(f'GT: {gt}, Predict: {pt}') 
  plt.imshow(np.array(img))
plt.show()

'컴퓨터비전(CV)' 카테고리의 다른 글

4. YOLO v8을 이용한 이상행동 탐지  (0) 2024.07.14
3. YOLO를 활요한 안전모 탐지  (0) 2024.07.14
1. YOLO  (0) 2024.07.14
8. OpenCV7  (0) 2024.07.14
7. OpenCV6  (0) 2024.07.14