Menu



Manage

Cord > Project_AI이미지 처리 전체 다운로드
Project_AI이미지 처리 > makenpz/dataset1.py Lines 125 | 5.7 KB
다운로드

                        import subprocess
import sys

try:
    import numpy as np
except ImportError:
    print("NumPy 라이브러리가 설치되어 있지 않습니다. 설치를 시작합니다.")
    subprocess.check_call([sys.executable, "-m", "pip", "install", "numpy"])
    import numpy as np
    
try:
    from PIL import Image
except ImportError:
    print("Pillow 라이브러리가 설치되어 있지 않습니다. 설치를 시작합니다.")
    subprocess.check_call([sys.executable, "-m", "pip", "install", "Pillow"])
    from PIL import Image

try:
    from sklearn.model_selection import train_test_split
except ImportError:
    print("scikit-learn 라이브러리가 설치되어 있지 않습니다. 설치를 시작합니다.")
    subprocess.check_call([sys.executable, "-m", "pip", "install", "scikit-learn"])
    from sklearn.model_selection import train_test_split

import os
import random
import numpy as np
from PIL import Image
from sklearn.model_selection import train_test_split

print("설치 작업이 끝났습니다.")

data_dir = r"C:\Users\remil\바탕 화면\productive\dataset"

print("3000개 리스트를 생성성합니다.")

#무작위의 파일 리스트 생성
image_files = []
mask_files = []

# clip 폴더 순회
for date_dir in os.listdir(os.path.join(data_dir, "clip")):
    if os.path.isdir(os.path.join(data_dir, "clip", date_dir)):
        for num_dir in os.listdir(os.path.join(data_dir, "clip", date_dir)):
            if os.path.isdir(os.path.join(data_dir, "clip", date_dir, num_dir)) and num_dir.startswith("clip_"): # 번호 폴더 이름 확인
                for img_file in os.listdir(os.path.join(data_dir, "clip", date_dir, num_dir)):
                    if img_file.endswith(".jpg"):
                        image_path = os.path.join("clip", date_dir, num_dir, img_file)
                        mask_file = img_file.replace(".jpg", ".png")
                        mask_path = os.path.join(data_dir, "matting", date_dir, num_dir.replace("clip_", "matting_"), mask_file) # 마스크 폴더 이름 변경
                        if os.path.exists(os.path.join(data_dir, mask_path)):
                            image_files.append(image_path)
                            mask_files.append(os.path.join("matting", date_dir, num_dir.replace("clip_", "matting_"), mask_file))

# matting_human_half 폴더 순회
for folder in ["clip_img", "matting"]:
    if folder == "clip_img":
        mask_folder = "matting"
    else :
        mask_folder = "clip_img"
    for date_dir in os.listdir(os.path.join(data_dir, "matting_human_half", folder)):
        if os.path.isdir(os.path.join(data_dir, "matting_human_half", folder, date_dir)):
            for num_dir in os.listdir(os.path.join(data_dir, "matting_human_half", folder, date_dir)):
                if os.path.isdir(os.path.join(data_dir, "matting_human_half", folder, date_dir, num_dir)) and num_dir.startswith(folder + "_"): # 번호 폴더 이름 확인
                    for img_file in os.listdir(os.path.join(data_dir, "matting_human_half", folder, date_dir, num_dir)):
                        if img_file.endswith(".jpg"):
                            image_path = os.path.join("matting_human_half", folder, date_dir, num_dir, img_file)
                            mask_file = img_file.replace(".jpg", ".png")
                            mask_path = os.path.join(data_dir, "matting_human_half", "matting", date_dir, num_dir.replace("clip_img_", "matting_"), mask_file) if folder == "clip_img" else os.path.join(data_dir, "matting_human_half", "clip_img", date_dir, num_dir.replace("matting_", "clip_img_"), mask_file)
                            if os.path.exists(os.path.join(data_dir, mask_path)):
                                image_files.append(image_path)
                                mask_files.append(os.path.join("matting_human_half", "matting", date_dir, num_dir.replace("clip_img_", "matting_"), mask_file)) if folder == "clip_img" else mask_files.append(os.path.join("matting_human_half", "clip_img", date_dir, num_dir.replace("matting_", "clip_img_"), mask_file))

print("3000개를 추출합니다.")

# 해당 3000장 추출
num_images = 3000
if num_images > len(image_files):
    print(f"오류: 추출하려는 이미지 개수({num_images})가 실제 이미지 개수({len(image_files)})보다 큽니다.")
    num_images = len(image_files) # num_images를 실제 이미지 개수로 수정
    print(f"추출할 이미지 개수를 {num_images}로 수정합니다.")

random_indices = random.sample(range(len(image_files)), num_images)
selected_image_files = [image_files[i] for i in random_indices]
selected_mask_files = [mask_files[i] for i in random_indices]

print("전처리를 시작합니다.")

# 데이터 로드와 전처리
def load_and_preprocess_image(filepath):
    img = Image.open(os.path.join(data_dir, filepath)).convert("RGB")
    img_array = np.array(img)
    return img_array

images = []
masks = []
for image_file, mask_file in zip(selected_image_files, selected_mask_files):
    image = load_and_preprocess_image(image_file)
    mask = load_and_preprocess_image(mask_file)

    images.append(image)
    masks.append(mask)

images = np.array(images)
masks = np.array(masks)

print("train과 test를 저장합니다.")

# train /test 분리
train_images, test_images, train_masks, test_masks = train_test_split(images, masks, test_size=500, random_state=42)

print("npz파일로 저장합니다.")

#npz로 저장
output_file = "aisegment_3000_train_test.npz"
np.savez(output_file, train_images=train_images, train_masks=train_masks, test_images=test_images, test_masks=test_masks)

print(f"{output_file} 파일에 train({len(train_images)}) 및 test({len(test_images)}) 데이터가 저장되었습니다.")