NSFW画像分類Pythonスクリプト

プログラム

import os
import shutil
from transformers import AutoImageProcessor, AutoModelForImageClassification
from PIL import Image
import torch
import torch.nn.functional as F

INPUT_DIR = r"Y:\output_folder1"
NSFW_DIR  = r"Y:\output_folder5"
SAFE_DIR  = r"Y:\output_folder6"

os.makedirs(NSFW_DIR, exist_ok=True)
os.makedirs(SAFE_DIR, exist_ok=True)

# 🔥 NSFW 閾値(0〜1)
NSFW_THRESHOLD = 0.70

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用デバイス: {device}")

processor = AutoImageProcessor.from_pretrained("Falconsai/nsfw_image_detection")
model = AutoModelForImageClassification.from_pretrained("Falconsai/nsfw_image_detection").to(device)

def classify_image(image_path):
    image = Image.open(image_path).convert("RGB")
    inputs = processor(images=image, return_tensors="pt").to(device)

    with torch.no_grad():
        outputs = model(**inputs)

    logits = outputs.logits
    probs = F.softmax(logits, dim=-1)[0]

    id2label = model.config.id2label
    label2id = model.config.label2id

    # 🔥 2クラスなので単純
    nsfw_score = probs[int(label2id["nsfw"])].item()
    normal_score = probs[int(label2id["normal"])].item()

    label = "nsfw" if nsfw_score >= NSFW_THRESHOLD else "normal"

    return label, nsfw_score, normal_score


def main():
    print(f"NSFW 閾値: {NSFW_THRESHOLD}")
    print("--------------------------------------------------")

    for filename in os.listdir(INPUT_DIR):
        if not filename.lower().endswith((".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp")):
            continue

        src_path = os.path.join(INPUT_DIR, filename)

        print(f"\n=== {filename} ===")
        try:
            label, nsfw_score, normal_score = classify_image(src_path)

            print(f"normal: {normal_score:.4f}")
            print(f"nsfw  : {nsfw_score:.4f}")
            print(f"判定結果: {label.upper()}")

            if label == "nsfw":
                dst_path = os.path.join(NSFW_DIR, filename)
                print("→ 移動先: output_folder5 (NSFW)")
            else:
                dst_path = os.path.join(SAFE_DIR, filename)
                print("→ 移動先: output_folder6 (SAFE)")

            shutil.move(src_path, dst_path)
            print("移動完了")

        except Exception as e:
            print(f"⚠️ エラー発生: {filename}")
            print(f"内容: {e}")

        print("--------------------------------------------------")

    print("\n=== 全処理完了 ===")


if __name__ == "__main__":
    main()

Python,画像

Posted by eightban