본문 바로가기
컴퓨터비전(CV)

3. YOLO를 활요한 안전모 탐지

by 곽정우 2024. 7. 14.

!pip install ultralytics
import os
import random
import shutil
import yaml
import cv2
import xml.etree.ElementTree as ET
import numpy as np
import pandas as pd
import glob
import ultralytics
from ultralytics import YOLO
from torchvision import transforms
from tqdm import tqdm
import matplotlib.pyplot as plt
ultralytics.checks()

!kaggle datasets download -d andrewmvd/hard-hat-detection
!unzip -q /content/hard-hat-detection.zip
# helmet_detection 디렉토리를 생성
data_root = '/content/helmet_detection'
# helmet_detection/data 디렉토리를 생성
file_root = f'{data_root}/data'
project_name = 'shd'

train_root = f'{data_root}/{project_name}/train'
valid_root = f'{data_root}/{project_name}/valid'
test_root = f'{data_root}/{project_name}/test'

for folder in [train_root, valid_root, test_root]:
    if not os.path.exists(folder):
        os.makedirs(folder)
    for s in ['images', 'labels']:
        s_folder = f'{folder}/{s}'
        if not os.path.exists(s_folder):
            os.makedirs(s_folder)
# helmet_detection/data 디렉토리에 images, annotations를 넣어줌
file_list = glob.glob(f'{file_root}/images/*.png')
len(file_list)

def xml_to_yolo_bbox(bbox, w, h):
    # xmin, ymin, xmax, ymax
    x_center = ((bbox[2] + bbox[0]) / 2) / w
    y_center = ((bbox[3] + bbox[1]) / 2) / h
    width = (bbox[2] - bbox[0]) / w
    height = (bbox[3] - bbox[1]) / h
    return [x_center, y_center, width, height]
# /content/helmet_detection/data/labels 디렉토리 생성

classes = []
for file in tqdm(file_list):
    file_name = file.split('/')[-1].replace('png', 'xml')
    save_name = file_name.replace('xml', 'txt')
    file_path = f'{file_root}/annotations/{file_name}'
    save_path = f'{file_root}/labels/{save_name}'

    result = list()
    tree = ET.parse(file_path)
    root = tree.getroot()
    width = int(root.find('size').find('width').text)
    height = int(root.find('size').find('height').text)
    for obj in root.findall('object'):
        label = obj.find('name').text
        if label not in classes:
            classes.append(label)
        index = classes.index(label)
        pil_bbox = [int(x.text) for x in obj.find('bndbox')]
        yolo_bbox = xml_to_yolo_bbox(pil_bbox, width, height)
        bbox_string = ' '.join([str(x) for x in yolo_bbox])
        result.append(f'{index} {bbox_string}')
    if result:
        with open(save_path, 'w', encoding='utf-8') as f:
            f.write('\n'.join(result))

classes

class_list = ['head', 'helmet', 'person']
random.seed(2024)
random.shuffle(file_list)
test_ratio = 0.1
num_file = len(file_list)

test_list = file_list[:int(num_file*test_ratio)]
valid_list = file_list[int(num_file*test_ratio):int(num_file*test_ratio)*2]
train_list = file_list[int(num_file*test_ratio)*2:]
for i in test_list:
    label_name = i.split('/')[-1].replace('png', 'txt')
    label_path = f'{file_root}/labels/{label_name}'
    shutil.copyfile(label_path, f'{test_root}/labels/{label_name}')
    img_name = i.split('/')[-1]
    shutil.copyfile(i, f'{test_root}/images/{img_name}')

for i in valid_list:
    label_name = i.split('/')[-1].replace('png', 'txt')
    label_path = f'{file_root}/labels/{label_name}'
    shutil.copyfile(label_path, f'{valid_root}/labels/{label_name}')
    img_name = i.split('/')[-1]
    shutil.copyfile(i, f'{valid_root}/images/{img_name}')

for i in train_list:
    label_name = i.split('/')[-1].replace('png', 'txt')
    label_path = f'{file_root}/labels/{label_name}'
    shutil.copyfile(label_path, f'{train_root}/labels/{label_name}')
    img_name = i.split('/')[-1]
    shutil.copyfile(i, f'{train_root}/images/{img_name}')
project_root = '/content/helmet_detection'
data = dict()

data['train'] = train_root
data['val'] = valid_root
data['test'] = test_root
data['nc'] = len(cls_list)
data['names'] = cls_list

with open(f'{project_root}/safety_helmet.yaml', 'w') as f:
    yaml.dump(data, f)
%cd /content/helmet_detection
model = YOLO('yolov8n.pt')
results = model.train(data='safety_helmet.yaml', epochs=2, batch=8, imgsz=224, device=0, workers=4, amp=False, patience=30, name='safety_n')

model = YOLO('yolov8s.pt')
results = model.train(data='safety_helmet.yaml', epochs=2, batch=8, imgsz=224, device=0, workers=4, amp=False, patience=30, name='safety_s')

model = YOLO('yolov8m.pt')
results = model.train(data='safety_helmet.yaml', epochs=2, batch=8, imgsz=224, device=0, workers=4, amp=False, patience=30, name='safety_m')

%cd /content/helmet_detection

result_folder = f'{project_root}/runs/detect'
model = YOLO(f'{result_folder}/safety_n/weights/best.pt')
metrics = model.val(split='test')
print('map50-95', metrics.box.map)
print('map50', metrics.box.map)

model = YOLO(f'{result_folder}/safety_s/weights/best.pt')
metrics = model.val(split='test')
print('map50-95', metrics.box.map)
print('map50', metrics.box.map)

model = YOLO(f'{result_folder}/safety_m/weights/best.pt')
metrics = model.val(split='test')
print('map50-95', metrics.box.map)
print('map50', metrics.box.map)

model = YOLO('yolov8n.pt')
results = model.train(data='safety_helmet.yaml', epochs=50, batch=8, imgsz=224, device=0, workers=4, amp=False, patience=30, name='safety')

model = YOLO(f'{result_folder}/safety/weights/best.pt')
metrics = model.val(split='test')
print('map50-95', metrics.box.map)
print('map50', metrics.box.map)

data_root = '/content/helmet_detection'
project_name = 'shd'
test_root = f'{data_root}/{project_name}/test'

test_file_list = glob.glob(f'{test_root}/images/*')
random.shuffle(test_file_list)
IMG_SIZE = (224, 224)
test_data_transform = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    transforms.ToTensor()
])
model.names

color_dict = {i: tuple([random.randint(0, 255) for _ in range(3)]) for i in range(len(model.names))}
color_dict

test_img = cv2.imread(test_file_list[3])
img_src = cv2.cvtColor(test_img, cv2.COLOR_BGR2RGB)
result = model(img_src)

result[0].boxes

plt.imshow(img_src)
plt.show()

color_dict = {
    0: (255, 0, 0),
    1: (0, 255, 0),
    2: (0, 0, 255)
}
num_head = 0
test_img = cv2.imread(test_file_list[1])
img_src = cv2.cvtColor(test_img, cv2.COLOR_BGR2RGB)
result = model(img_src)[0]

annotator = Annotator(img_src)
boxes = result.boxes

for box in boxes:
    b = box.xyxy[0]
    cls = box.cls
    if 'head' == model.names[int(cls)]:
        num_head += 1
    annotator.box_label(b, model.names[int(cls)], color_dict[int(cls)])
img_src = annotator.result()
if num_head > 0:
    cv2.rectangle(img_src, (0, 0), (300, 50), (255, 0, 0), -1, cv2.LINE_AA)
    cv2.putText(img_src, 'No Helmet!', (5, 30), cv2.FONT_HERSHEY_DUPLEX, 1, (255, 255, 255), thinkness=3, lineType=cv2.LINE_AA)
plt.imshow(img_src)
plt.show()

plt.figure(figsize=(20, 16))

for idx in range(20):
    num_head = 0
    test_img = cv2.imread(test_file_list[idx])
    img_src = cv2.cvtColor(test_img, cv2.COLOR_BGR2RGB)
    result = model(img_src)[0]

    annotator = Annotator(img_src)
    boxes = result.boxes

    for box in boxes:
        b = box.xyxy[0]
        cls = box.cls
        if 'head' == model.names[int(cls)]:
            num_head += 1
        annotator.box_label(b, model.names[int(cls)], color_dict[int(cls)])
    img_src = annotator.result()

    plt.subplot(5, 4, (idx+1))
    if num_head > 0:
        cv2.rectangle(img_src, (0, 0), (300, 50), (255, 0, 0), -1, cv2.LINE_AA)
        cv2.putText(img_src, 'No Helmet!', (5, 30), cv2.FONT_HERSHEY_DUPLEX, 1, (255, 255, 255), lineType=cv2.LINE_AA)
    plt.imshow(img_src)
plt.show()
len(test_file_list)

'컴퓨터비전(CV)' 카테고리의 다른 글

4. YOLO v8을 이용한 이상행동 탐지  (0) 2024.07.14
2. YOLOv8를 활용한 폐 질환 분류  (0) 2024.07.14
1. YOLO  (0) 2024.07.14
8. OpenCV7  (0) 2024.07.14
7. OpenCV6  (0) 2024.07.14