Mixed-type Wafer Map Defect Dataset

import os
import pickle
from collections import Counter

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import tensorflow as tf
from sklearn.model_selection import train_test_split
from tensorflow.keras import Model, layers

data_path = "../data/mixedtype-wafer-defect-datasets/Wafer_Map_Datasets.npz"

data = np.load(data_path)
images = data["arr_0"]
labels = data["arr_1"]

Data Exploration

Labels

# Basic summary of labels
print("Labels shape:", labels.shape)
print("Total entries:", labels.shape[0])

# Unique values across entire labels array
unique_vals = np.unique(labels)
print("Unique values:", unique_vals)
print("Number of unique values:", unique_vals.size)

# Missing data check
missing_count = np.isnan(labels).sum()
print("Missing values (NaN) count:", missing_count)

# Show a few example label rows
examples_df = pd.DataFrame(labels[:5])
print("Example label rows:")
display(examples_df)

# Unique row patterns count
unique_rows = np.unique(labels, axis=0)
print("Number of unique label patterns:", unique_rows.shape[0])
Labels shape: (38015, 8)
Total entries: 38015
Unique values: [0 1]
Number of unique values: 2
Missing values (NaN) count: 0
Example label rows:
0 1 2 3 4 5 6 7
0 1 0 1 0 0 0 1 0
1 1 0 1 0 0 0 1 0
2 1 0 1 0 0 0 1 0
3 1 0 1 0 0 0 1 0
4 1 0 1 0 0 0 1 0
Number of unique label patterns: 38
label_mapping = {
    "00000000": "Normal",
    "10000000": "Center",
    "01000000": "Donut",
    "00100000": "Edge_Loc",
    "00010000": "Edge_Ring",
    "00001000": "Loc",
    "00000100": "Near_Full",
    "00000010": "Scratch",
    "00000001": "Random",
    "10100000": "Center+Edge_Loc",
    "10010000": "Center+Edge_Ring",
    "10001000": "Center+Loc",
    "10000010": "Center+Scratch",
    "01100000": "Donut+Edge_Loc",
    "01010000": "Donut+Edge_Ring",
    "01001000": "Donut+Loc",
    "01000010": "Donut+Scratch",
    "00101000": "Edge_Loc+Loc",
    "00100010": "Edge_Loc+Scratch",
    "00011000": "Edge_Ring+Loc",
    "00010010": "Edge_Ring+Scratch",
    "00001010": "Loc+Scratch",
    "10101000": "Center+Edge_Loc+Loc",
    "10100010": "Center+Edge_Loc+Scratch",
    "10011000": "Center+Edge_Ring+Loc",
    "10010010": "Center+Edge_Ring+Scratch",
    "10001010": "Center+Loc+Scratch",
    "01101000": "Donut+Edge_Loc+Loc",
    "01100010": "Donut+Edge_Loc+Scratch",
    "01011000": "Donut+Edge_Ring+Loc",
    "01010010": "Donut+Edge_Ring+Scratch",
    "01001010": "Donut+Loc+Scratch",
    "00101010": "Edge_Loc+Loc+Scratch",
    "00011010": "Edge_Ring+Loc+Scratch",
    "10101010": "Center+Edge_Loc+Loc+Scratch",
    "10011010": "Center+Edge_Ring+Loc+Scratch",
    "01101010": "Donut+Edge_Loc+Loc+Scratch",
    "01011010": "Donut+Edge_Ring+Loc+Scratch",
}

# Convert one-hot vectors to string keys
label_str = ["".join(map(str, map(int, row))) for row in labels]
label_names = [label_mapping.get(key, "Unknown") for key in label_str]

# Count frequency
label_counts = Counter(label_str)

# Create DataFrame with label names
label_df = (
    pd
    .DataFrame([
        {"OneHot": k, "Count": v, "LabelName": label_mapping.get(k, "Unknown")} for k, v in label_counts.items()
    ])
    .sort_values("OneHot")
    .reset_index(drop=True)
)

# plot label distribution with seaborn
sns.set_theme(style="whitegrid")
plt.figure(figsize=(12, 7))
sns.barplot(data=label_df, x="Count", y="LabelName", order=label_df.sort_values("Count", ascending=False)["LabelName"])
plt.xlabel("Frequency")
plt.ylabel("Defect Type")
plt.title("Distribution of Wafer Defect Types")
plt.tight_layout()
plt.show()

From the plot above, we can see that for most of the defect labels we have 1,000 images including the normal wafers. But for Center+Edge_loc+Scratch we have 2,000 images, for Random around 800 and for Near_Full an under representation of ~200 images.

Images

# Basic summary of images array
print("Images shape:", images.shape)
print("Total entries:", images.shape[0])
print("Single image shape:", images[0].shape)
print("Dtype:", images.dtype)

# Missing data check
missing_count_images = np.isnan(images).sum()
print("Missing values (NaN) count:", missing_count_images)

# Unique values across entire images array
unique_vals_images = np.unique(images)
print("Unique values:", unique_vals_images)
print("Number of unique values:", unique_vals_images.size)

# Count images with at least one pixel value of 3
images_with_3 = np.sum(np.any(images == 3, axis=(1, 2)))
print(f"Images with at least one pixel value of 3: {images_with_3}")

# in the images that have at least one pixel value of 3, tell me the average of many pixels have the value 3
images_with_3_pixels = images[np.any(images == 3, axis=(1, 2))]
count_3_pixels = np.sum(images_with_3_pixels == 3)
average_3_pixels = count_3_pixels / images_with_3
print(f"Average number of pixels with value 3 in those images: {average_3_pixels}")

# Show a few example images (as arrays)
sample_images = images[:3]
print("Sample image arrays:")
for i, img in enumerate(sample_images, start=1):
    print(f"Image {i} shape: {img.shape}")
    print(img)
Images shape: (38015, 52, 52)
Total entries: 38015
Single image shape: (52, 52)
Dtype: int32
Missing values (NaN) count: 0
Unique values: [0 1 2 3]
Number of unique values: 4
Images with at least one pixel value of 3: 105
Average number of pixels with value 3 in those images: 2.038095238095238
Sample image arrays:
Image 1 shape: (52, 52)
[[0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]]
Image 2 shape: (52, 52)
[[0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]]
Image 3 shape: (52, 52)
[[0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]]
# make pixels with value 3 be 0
images[images == 3] = 0
def convert_to_rgb(image):
    height, width = image.shape
    rgb_image = np.zeros((height, width, 3), dtype=np.uint8)
    rgb_image[image == 0] = [0, 63, 92]
    rgb_image[image == 1] = [255, 166, 0]
    rgb_image[image == 2] = [188, 80, 144]
    return rgb_image


# Separate labels by number of defects (number of '+' in label name)
unique_labels = label_df["LabelName"].tolist()

labels_by_defect_count = {0: [], 1: [], 2: [], 3: []}

for label_name in unique_labels:
    plus_count = label_name.count("+")
    labels_by_defect_count[plus_count].append(label_name)

# Plot each group separately
sns.set_theme(style="white")

for defect_count, labels_group in labels_by_defect_count.items():
    if not labels_group:
        continue

    n_labels = len(labels_group)
    n_cols = min(5, n_labels)
    n_rows = (n_labels + n_cols - 1) // n_cols

    fig, axes = plt.subplots(n_rows, n_cols, figsize=(16, n_rows * 3))
    axes = [axes] if n_labels == 1 else axes.flatten() if n_rows > 1 else axes

    for idx, label_name in enumerate(labels_group):
        label_idx = label_names.index(label_name)
        rgb_image = convert_to_rgb(images[label_idx])

        axes[idx].imshow(rgb_image)
        axes[idx].set_title(label_name, fontsize=10, fontweight="bold", color="#333333")
        axes[idx].axis("off")
        axes[idx].spines["top"].set_visible(False)
        axes[idx].spines["right"].set_visible(False)
        axes[idx].spines["bottom"].set_visible(False)
        axes[idx].spines["left"].set_visible(False)

    # Hide unused subplots
    for idx in range(n_labels, len(axes)):
        axes[idx].axis("off")

    fig.patch.set_facecolor("white")
    plt.axis("off")
    plt.tight_layout()

    # Add a suptitle for the group
    if defect_count == 0:
        fig.suptitle("Single Defect Types", fontsize=16, fontweight="bold", y=1.02)
    else:
        fig.suptitle(f"Defect Types with {defect_count} Combined Defects", fontsize=16, fontweight="bold", y=1.02)

    plt.show()

Step 1 — Data Preprocessing

Convert each (52, 52) image to (52, 52, 3) via one-hot encoding of the 3 pixel states: - Channel 0: blank (pixel == 0) - Channel 1: normal die (pixel == 1) - Channel 2: broken die (pixel == 2)

This preserves the categorical nature of the data — these are discrete states, not intensity values.

# One-hot encode pixel values: (N, 52, 52) -> (N, 52, 52, 3)
# Channel 0 = blank (0), Channel 1 = normal (1), Channel 2 = broken (2).
images_onehot = np.eye(3, dtype=np.float32)[images.astype(int)]

print("Original images shape:", images.shape)
print("One-hot encoded shape:", images_onehot.shape)
print("Labels shape:", labels.shape)
print("Dtype:", images_onehot.dtype)

# Verify encoding: show channel sums for first image
sample = images_onehot[0]
print("\nVerification — channel sums for first image:")
for ch in range(3):
    print(f"  Channel {ch} (pixel={ch}): {sample[:, :, ch].sum():.0f} pixels")
Original images shape: (38015, 52, 52)
One-hot encoded shape: (38015, 52, 52, 3)
Labels shape: (38015, 8)
Dtype: float32

Verification — channel sums for first image:
  Channel 0 (pixel=0): 643 pixels
  Channel 1 (pixel=1): 1441 pixels
  Channel 2 (pixel=2): 620 pixels

Step 2 — Train / Validation / Test Split

Using iterative stratification for multi-label data. Since sklearn’s train_test_split only supports single-label stratification, we convert the 8-dim one-hot vectors to string keys and stratify on those to ensure all 38 unique patterns are represented in every split.

Split ratio: 70% train / 15% validation / 15% test

# Create string keys for stratification (each unique label pattern = one stratum)
label_str_arr = np.array(["".join(map(str, map(int, row))) for row in labels])

# First split: 70% train, 30% temp
X_train, X_temp, y_train, y_temp, str_train, str_temp = train_test_split(
    images_onehot,
    labels,
    label_str_arr,
    test_size=0.30,
    random_state=42,
    stratify=label_str_arr,
)

# Second split: 50/50 of temp -> 15% val, 15% test
X_val, X_test, y_val, y_test, str_val, str_test = train_test_split(
    X_temp,
    y_temp,
    str_temp,
    test_size=0.50,
    random_state=42,
    stratify=str_temp,
)

print(f"Train: {X_train.shape[0]} samples")
print(f"Val:   {X_val.shape[0]} samples")
print(f"Test:  {X_test.shape[0]} samples")

# Verify: check that all 38 patterns are present in each split
for name, strs in [("Train", str_train), ("Val", str_val), ("Test", str_test)]:
    unique_patterns = len(set(strs))
    print(f"  {name} — unique patterns: {unique_patterns}")

# Show per-label distribution across splits
base_labels = ["Center", "Donut", "Edge_Loc", "Edge_Ring", "Loc", "Near_Full", "Scratch", "Random"]
print("\nPer-label sample counts:")
header = f"{'Label':<12} {'Train':>6} {'Val':>6} {'Test':>6}"
print(header)
for i, name in enumerate(base_labels):
    tr = y_train[:, i].sum()
    va = y_val[:, i].sum()
    te = y_test[:, i].sum()
    print(f"{name:<12} {tr:>6.0f} {va:>6.0f} {te:>6.0f}")
Train: 26610 samples
Val:   5702 samples
Test:  5703 samples
  Train — unique patterns: 38
  Val — unique patterns: 38
  Test — unique patterns: 38

Per-label sample counts:
Label         Train    Val   Test
Center         9100   1950   1950
Donut          8400   1800   1800
Edge_Loc       9100   1950   1950
Edge_Ring      8400   1800   1800
Loc           12600   2700   2700
Near_Full       104     22     23
Scratch       13300   2850   2850
Random          606    130    130

Step 3 — Data Augmentation

Apply geometrically valid augmentations (rotations, flips, shifts) since wafer maps are rotationally symmetric. We create a tf.data pipeline with augmentation applied only during training.

We avoid elastic deformations and color jitter — the pixel values are discrete categorical states, not continuous intensities.

BATCH_SIZE = 32
AUTOTUNE = tf.data.AUTOTUNE


def augment(image, label):
    """Apply random geometric augmentations to a single image.

    Since pixel values are one-hot encoded categorical states,
    we only use spatial transforms (rotations, flips) that preserve
    the discrete channel structure.
    """
    # Random 90-degree rotations (k=0,1,2,3)
    k = tf.random.uniform([], 0, 4, dtype=tf.int32)
    image = tf.image.rot90(image, k=k)

    # Random horizontal and vertical flips
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_flip_up_down(image)

    return image, label


def build_dataset(x, y, *, shuffle=False, augment_fn=None, batch_size=BATCH_SIZE):
    """Create a tf.data.Dataset with optional shuffling and augmentation."""
    ds = tf.data.Dataset.from_tensor_slices((x, y))
    if shuffle:
        ds = ds.shuffle(buffer_size=len(x), seed=42)
    if augment_fn is not None:
        ds = ds.map(augment_fn, num_parallel_calls=AUTOTUNE)
    ds = ds.batch(batch_size).prefetch(AUTOTUNE)
    return ds


# Build datasets
train_ds = build_dataset(X_train, y_train, shuffle=True, augment_fn=augment)
val_ds = build_dataset(X_val, y_val)
test_ds = build_dataset(X_test, y_test)

# Verify shapes
for batch_images, batch_labels in train_ds.take(1):
    print("Batch images shape:", batch_images.shape)
    print("Batch labels shape:", batch_labels.shape)
    print("Image value range:", batch_images.numpy().min(), "to", batch_images.numpy().max())
Batch images shape: (32, 52, 52, 3)
Batch labels shape: (32, 8)
Image value range: 0.0 to 1.0
2026-02-25 23:11:51.790128: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence

Step 4 — Focal Loss & Class Weights

Focal Loss down-weights easy/frequent examples and focuses learning on hard/rare ones. Combined with per-label weights (inversely proportional to frequency), this handles the severe imbalance between Normal (~1000) and Near_Full (~200).

\[FL(p_t) = -\alpha_t (1 - p_t)^\gamma \log(p_t)\]

  • \(\alpha_t\): per-label weight (higher for rare classes)
  • \(\gamma\): focusing parameter (typically 2.0)
# Compute per-label positive weights for the 8 base defect types
# Weight = (total_neg / total_pos) for each label column
n_samples = y_train.shape[0]
pos_counts = y_train.sum(axis=0)  # shape (8,)
neg_counts = n_samples - pos_counts

# Avoid division by zero; labels with 0 positives get weight 1.0
pos_weights = np.where(pos_counts > 0, neg_counts / pos_counts, 1.0)
pos_weights_tensor = tf.constant(pos_weights, dtype=tf.float32)

print("Per-label positive weights:")
base_labels_names = ["Center", "Donut", "Edge_Loc", "Edge_Ring", "Loc", "Near_Full", "Scratch", "Random"]
for name, w, pc in zip(base_labels_names, pos_weights, pos_counts, strict=True):
    print(f"  {name:<12}  weight={w:>6.2f}  (pos_count={pc:.0f})")


class BinaryFocalLoss(tf.keras.losses.Loss):
    """Binary focal loss for multi-label classification.

    Applies focal modulation per label to down-weight easy examples
    and focus on hard/rare ones. Supports per-label positive weights.
    """

    def __init__(self, gamma=2.0, pos_weight=None, **kwargs):
        super().__init__(**kwargs)
        self.gamma = gamma
        self.pos_weight = pos_weight  # shape (num_labels,)

    def call(self, y_true, y_pred):
        y_true = tf.cast(y_true, tf.float32)
        y_pred = tf.clip_by_value(y_pred, 1e-7, 1.0 - 1e-7)

        # Standard binary cross-entropy components
        bce_pos = -y_true * tf.math.log(y_pred)
        bce_neg = -(1.0 - y_true) * tf.math.log(1.0 - y_pred)

        # Focal modulation: (1 - p_t)^gamma
        p_t = y_true * y_pred + (1.0 - y_true) * (1.0 - y_pred)
        focal_weight = tf.pow(1.0 - p_t, self.gamma)

        # Apply per-label positive weights
        alpha = y_true * self.pos_weight + (1.0 - y_true) * 1.0 if self.pos_weight is not None else 1.0

        loss = focal_weight * alpha * (bce_pos + bce_neg)
        return tf.reduce_mean(loss, axis=-1)


focal_loss = BinaryFocalLoss(gamma=2.0, pos_weight=pos_weights_tensor)
print("\nFocal Loss initialized with gamma=2.0")
Per-label positive weights:
  Center        weight=  1.92  (pos_count=9100)
  Donut         weight=  2.17  (pos_count=8400)
  Edge_Loc      weight=  1.92  (pos_count=9100)
  Edge_Ring     weight=  2.17  (pos_count=8400)
  Loc           weight=  1.11  (pos_count=12600)
  Near_Full     weight=254.87  (pos_count=104)
  Scratch       weight=  1.00  (pos_count=13300)
  Random        weight= 42.91  (pos_count=606)

Focal Loss initialized with gamma=2.0

Step 5 — Model 1A: Custom CNN

A lightweight CNN designed for the (52, 52, 3) input size: - 4 Conv2D blocks (32 → 64 → 128 → 256 filters), each with BatchNormalization + ReLU + MaxPooling - GlobalAveragePooling2D (reduces overfitting vs Flatten) - Dense(128) → Dropout(0.5) → Dense(8, sigmoid)

Output: 8 independent sigmoid activations for multi-label classification.

def build_custom_cnn(input_shape=(52, 52, 3), num_labels=8):
    """Build a lightweight CNN for multi-label wafer defect classification.

    Architecture:
        4 Conv blocks (32→64→128→256) with BatchNorm + ReLU + MaxPool
        GlobalAveragePooling → Dense(128) → Dropout → Dense(num_labels, sigmoid)
    """
    inputs = layers.Input(shape=input_shape, name="input_image")

    x = inputs
    for filters in [32, 64, 128, 256]:
        x = layers.Conv2D(filters, (3, 3), padding="same")(x)
        x = layers.BatchNormalization()(x)
        x = layers.Activation("relu")(x)
        x = layers.MaxPooling2D((2, 2))(x)

    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dense(128, activation="relu")(x)
    x = layers.Dropout(0.5)(x)
    outputs = layers.Dense(num_labels, activation="sigmoid", name="output")(x)

    return Model(inputs, outputs, name="custom_cnn")


model_cnn = build_custom_cnn()
model_cnn.summary()
Model: "custom_cnn"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Layer (type)                     Output Shape                  Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ input_image (InputLayer)        │ (None, 52, 52, 3)      │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d (Conv2D)                 │ (None, 52, 52, 32)     │           896 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ batch_normalization             │ (None, 52, 52, 32)     │           128 │
│ (BatchNormalization)            │                        │               │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ activation (Activation)         │ (None, 52, 52, 32)     │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ max_pooling2d (MaxPooling2D)    │ (None, 26, 26, 32)     │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d_1 (Conv2D)               │ (None, 26, 26, 64)     │        18,496 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ batch_normalization_1           │ (None, 26, 26, 64)     │           256 │
│ (BatchNormalization)            │                        │               │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ activation_1 (Activation)       │ (None, 26, 26, 64)     │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ max_pooling2d_1 (MaxPooling2D)  │ (None, 13, 13, 64)     │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d_2 (Conv2D)               │ (None, 13, 13, 128)    │        73,856 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ batch_normalization_2           │ (None, 13, 13, 128)    │           512 │
│ (BatchNormalization)            │                        │               │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ activation_2 (Activation)       │ (None, 13, 13, 128)    │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ max_pooling2d_2 (MaxPooling2D)  │ (None, 6, 6, 128)      │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d_3 (Conv2D)               │ (None, 6, 6, 256)      │       295,168 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ batch_normalization_3           │ (None, 6, 6, 256)      │         1,024 │
│ (BatchNormalization)            │                        │               │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ activation_3 (Activation)       │ (None, 6, 6, 256)      │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ max_pooling2d_3 (MaxPooling2D)  │ (None, 3, 3, 256)      │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ global_average_pooling2d        │ (None, 256)            │             0 │
│ (GlobalAveragePooling2D)        │                        │               │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense (Dense)                   │ (None, 128)            │        32,896 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dropout (Dropout)               │ (None, 128)            │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ output (Dense)                  │ (None, 8)              │         1,032 │
└─────────────────────────────────┴────────────────────────┴───────────────┘
 Total params: 424,264 (1.62 MB)
 Trainable params: 423,304 (1.61 MB)
 Non-trainable params: 960 (3.75 KB)

Step 6 — Model 1B: Transfer Learning (MobileNetV2)

Project (52, 52, 4)(52, 52, 3) with Conv2D, then resize to (224, 224, 3) to match MobileNetV2’s expected input. Use the pre-trained ImageNet backbone as a frozen feature extractor, with a custom multi-label classification head.

Note: ImageNet features (edges, textures) should still transfer reasonably well to wafer map patterns.

TARGET_SIZE = (224, 224)


def build_mobilenet_model(input_shape=(52, 52, 3), target_size=TARGET_SIZE, num_labels=8):
    """Build a MobileNetV2 transfer learning model for multi-label classification.

    Projects (52,52,4) inputs to (52,52,3), then resizes to (224,224,3),
    then passes through a frozen MobileNetV2 backbone with a custom head.
    """
    inputs = layers.Input(shape=input_shape, name="input_image")

    # Project 4 channels to 3 for MobileNetV2 compatibility
    x = layers.Conv2D(3, (1, 1), padding="same", activation="relu")(inputs)

    # Resize to MobileNetV2-compatible dimensions
    x = layers.Resizing(target_size[0], target_size[1], interpolation="bilinear")(x)

    # Frozen MobileNetV2 backbone
    base_model = tf.keras.applications.MobileNetV2(
        input_shape=(target_size[0], target_size[1], 3),
        include_top=False,
        weights="imagenet",
    )
    base_model.trainable = False

    x = base_model(x, training=False)
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dense(256, activation="relu")(x)
    x = layers.Dropout(0.5)(x)
    outputs = layers.Dense(num_labels, activation="sigmoid", name="output")(x)

    return Model(inputs, outputs, name="mobilenet_transfer"), base_model


model_tl, base_model_tl = build_mobilenet_model()
model_tl.summary()

# Verify frozen layers

trainable = sum(1 for layer in base_model_tl.layers if layer.trainable)
total = len(base_model_tl.layers)
print(f"\nMobileNetV2 base: {trainable}/{total} trainable layers (should be 0/{total})")
Model: "mobilenet_transfer"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Layer (type)                     Output Shape                  Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ input_image (InputLayer)        │ (None, 52, 52, 3)      │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d_4 (Conv2D)               │ (None, 52, 52, 3)      │            12 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ resizing (Resizing)             │ (None, 224, 224, 3)    │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ mobilenetv2_1.00_224            │ (None, 7, 7, 1280)     │     2,257,984 │
│ (Functional)                    │                        │               │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ global_average_pooling2d_1      │ (None, 1280)           │             0 │
│ (GlobalAveragePooling2D)        │                        │               │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense_1 (Dense)                 │ (None, 256)            │       327,936 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dropout_1 (Dropout)             │ (None, 256)            │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ output (Dense)                  │ (None, 8)              │         2,056 │
└─────────────────────────────────┴────────────────────────┴───────────────┘
 Total params: 2,587,988 (9.87 MB)
 Trainable params: 330,004 (1.26 MB)
 Non-trainable params: 2,257,984 (8.61 MB)

MobileNetV2 base: 0/154 trainable layers (should be 0/154)

Step 7 — Training

Both models are trained with: - Loss: Binary Focal Loss (gamma=2.0, per-label pos_weights) - Optimizer: Adam (lr=1e-3) - Callbacks: ReduceLROnPlateau(patience=5), EarlyStopping(patience=10, restore_best_weights=True) - Metrics: AUC (per-label average), binary accuracy - Epochs: up to 100 (early stopping will cut it short)

EPOCHS = 100

callbacks = [
    tf.keras.callbacks.ReduceLROnPlateau(
        monitor="val_loss",
        factor=0.5,
        patience=5,
        min_lr=1e-6,
        verbose=1,
    ),
    tf.keras.callbacks.EarlyStopping(
        monitor="val_loss",
        patience=10,
        restore_best_weights=True,
        verbose=1,
    ),
]

metrics = [
    tf.keras.metrics.AUC(name="auc", multi_label=True),
    tf.keras.metrics.BinaryAccuracy(name="binary_acc"),
]

Train Model 1A — Custom CNN

model_cnn.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
    loss=focal_loss,
    metrics=metrics,
)

history_cnn = model_cnn.fit(
    train_ds,
    validation_data=val_ds,
    epochs=EPOCHS,
    callbacks=callbacks,
    verbose=1,
)
Epoch 1/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 28s 31ms/step - auc: 0.8536 - binary_acc: 0.7960 - loss: 0.1320 - val_auc: 0.8754 - val_binary_acc: 0.7561 - val_loss: 0.5910 - learning_rate: 0.0010

Epoch 2/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 26s 31ms/step - auc: 0.9065 - binary_acc: 0.8614 - loss: 0.0924 - val_auc: 0.9388 - val_binary_acc: 0.8665 - val_loss: 0.0668 - learning_rate: 0.0010

Epoch 3/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 25s 31ms/step - auc: 0.9287 - binary_acc: 0.8939 - loss: 0.0646 - val_auc: 0.9453 - val_binary_acc: 0.9082 - val_loss: 0.0699 - learning_rate: 0.0010

Epoch 4/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 26s 31ms/step - auc: 0.9368 - binary_acc: 0.9064 - loss: 0.0573 - val_auc: 0.9541 - val_binary_acc: 0.9303 - val_loss: 0.0417 - learning_rate: 0.0010

Epoch 5/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 25s 31ms/step - auc: 0.9429 - binary_acc: 0.9156 - loss: 0.0515 - val_auc: 0.8084 - val_binary_acc: 0.7108 - val_loss: 0.9099 - learning_rate: 0.0010

Epoch 6/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 25s 30ms/step - auc: 0.9495 - binary_acc: 0.9237 - loss: 0.0468 - val_auc: 0.9728 - val_binary_acc: 0.9419 - val_loss: 0.0419 - learning_rate: 0.0010

Epoch 7/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 25s 30ms/step - auc: 0.9410 - binary_acc: 0.9124 - loss: 0.0564 - val_auc: 0.9432 - val_binary_acc: 0.9027 - val_loss: 0.0579 - learning_rate: 0.0010

Epoch 8/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 25s 30ms/step - auc: 0.9459 - binary_acc: 0.9214 - loss: 0.0465 - val_auc: 0.9612 - val_binary_acc: 0.9277 - val_loss: 0.0577 - learning_rate: 0.0010

Epoch 9/100

831/832 ━━━━━━━━━━━━━━━━━━━ 0s 29ms/step - auc: 0.9423 - binary_acc: 0.9186 - loss: 0.0469

Epoch 9: ReduceLROnPlateau reducing learning rate to 0.0005000000237487257.

832/832 ━━━━━━━━━━━━━━━━━━━━ 25s 30ms/step - auc: 0.9482 - binary_acc: 0.9227 - loss: 0.0457 - val_auc: 0.9564 - val_binary_acc: 0.8997 - val_loss: 0.1008 - learning_rate: 0.0010

Epoch 10/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 25s 30ms/step - auc: 0.9532 - binary_acc: 0.9294 - loss: 0.0427 - val_auc: 0.9664 - val_binary_acc: 0.9436 - val_loss: 0.0291 - learning_rate: 5.0000e-04

Epoch 11/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 25s 30ms/step - auc: 0.9576 - binary_acc: 0.9337 - loss: 0.0418 - val_auc: 0.9809 - val_binary_acc: 0.9603 - val_loss: 0.0280 - learning_rate: 5.0000e-04

Epoch 12/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 25s 30ms/step - auc: 0.9632 - binary_acc: 0.9415 - loss: 0.0364 - val_auc: 0.9762 - val_binary_acc: 0.9445 - val_loss: 0.0325 - learning_rate: 5.0000e-04

Epoch 13/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 25s 30ms/step - auc: 0.9656 - binary_acc: 0.9427 - loss: 0.0398 - val_auc: 0.9858 - val_binary_acc: 0.9697 - val_loss: 0.0256 - learning_rate: 5.0000e-04

Epoch 14/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 25s 30ms/step - auc: 0.9678 - binary_acc: 0.9433 - loss: 0.0415 - val_auc: 0.9889 - val_binary_acc: 0.9719 - val_loss: 0.0245 - learning_rate: 5.0000e-04

Epoch 15/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 25s 30ms/step - auc: 0.9745 - binary_acc: 0.9515 - loss: 0.0348 - val_auc: 0.9884 - val_binary_acc: 0.9712 - val_loss: 0.0254 - learning_rate: 5.0000e-04

Epoch 16/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 25s 30ms/step - auc: 0.9797 - binary_acc: 0.9567 - loss: 0.0327 - val_auc: 0.9944 - val_binary_acc: 0.9838 - val_loss: 0.0175 - learning_rate: 5.0000e-04

Epoch 17/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 25s 30ms/step - auc: 0.9815 - binary_acc: 0.9572 - loss: 0.0346 - val_auc: 0.9951 - val_binary_acc: 0.9833 - val_loss: 0.0176 - learning_rate: 5.0000e-04

Epoch 18/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 25s 30ms/step - auc: 0.9850 - binary_acc: 0.9627 - loss: 0.0297 - val_auc: 0.9963 - val_binary_acc: 0.9855 - val_loss: 0.0164 - learning_rate: 5.0000e-04

Epoch 19/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 25s 30ms/step - auc: 0.9860 - binary_acc: 0.9648 - loss: 0.0300 - val_auc: 0.9952 - val_binary_acc: 0.9831 - val_loss: 0.0173 - learning_rate: 5.0000e-04

Epoch 20/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 25s 30ms/step - auc: 0.9878 - binary_acc: 0.9671 - loss: 0.0278 - val_auc: 0.9978 - val_binary_acc: 0.9895 - val_loss: 0.0150 - learning_rate: 5.0000e-04

Epoch 21/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 25s 30ms/step - auc: 0.9895 - binary_acc: 0.9698 - loss: 0.0261 - val_auc: 0.9977 - val_binary_acc: 0.9891 - val_loss: 0.0147 - learning_rate: 5.0000e-04

Epoch 22/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 25s 30ms/step - auc: 0.9879 - binary_acc: 0.9662 - loss: 0.0314 - val_auc: 0.9976 - val_binary_acc: 0.9898 - val_loss: 0.0152 - learning_rate: 5.0000e-04

Epoch 23/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 25s 30ms/step - auc: 0.9900 - binary_acc: 0.9702 - loss: 0.0271 - val_auc: 0.9977 - val_binary_acc: 0.9887 - val_loss: 0.0142 - learning_rate: 5.0000e-04

Epoch 24/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 25s 30ms/step - auc: 0.9901 - binary_acc: 0.9702 - loss: 0.0282 - val_auc: 0.9977 - val_binary_acc: 0.9898 - val_loss: 0.0130 - learning_rate: 5.0000e-04

Epoch 25/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 25s 30ms/step - auc: 0.9919 - binary_acc: 0.9727 - loss: 0.0261 - val_auc: 0.9972 - val_binary_acc: 0.9817 - val_loss: 0.0172 - learning_rate: 5.0000e-04

Epoch 26/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 25s 30ms/step - auc: 0.9921 - binary_acc: 0.9742 - loss: 0.0246 - val_auc: 0.9982 - val_binary_acc: 0.9819 - val_loss: 0.0209 - learning_rate: 5.0000e-04

Epoch 27/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 25s 30ms/step - auc: 0.9928 - binary_acc: 0.9751 - loss: 0.0229 - val_auc: 0.9985 - val_binary_acc: 0.9861 - val_loss: 0.0174 - learning_rate: 5.0000e-04

Epoch 28/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 25s 31ms/step - auc: 0.9933 - binary_acc: 0.9758 - loss: 0.0234 - val_auc: 0.9977 - val_binary_acc: 0.9883 - val_loss: 0.0136 - learning_rate: 5.0000e-04

Epoch 29/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 25s 30ms/step - auc: 0.9935 - binary_acc: 0.9771 - loss: 0.0223 - val_auc: 0.9984 - val_binary_acc: 0.9923 - val_loss: 0.0126 - learning_rate: 5.0000e-04

Epoch 30/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 25s 30ms/step - auc: 0.9933 - binary_acc: 0.9765 - loss: 0.0230 - val_auc: 0.9987 - val_binary_acc: 0.9927 - val_loss: 0.0113 - learning_rate: 5.0000e-04

Epoch 31/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 25s 30ms/step - auc: 0.9940 - binary_acc: 0.9776 - loss: 0.0221 - val_auc: 0.9985 - val_binary_acc: 0.9930 - val_loss: 0.0125 - learning_rate: 5.0000e-04

Epoch 32/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 25s 30ms/step - auc: 0.9921 - binary_acc: 0.9744 - loss: 0.0255 - val_auc: 0.9988 - val_binary_acc: 0.9939 - val_loss: 0.0104 - learning_rate: 5.0000e-04

Epoch 33/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 25s 30ms/step - auc: 0.9934 - binary_acc: 0.9765 - loss: 0.0242 - val_auc: 0.9987 - val_binary_acc: 0.9941 - val_loss: 0.0102 - learning_rate: 5.0000e-04

Epoch 34/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 25s 30ms/step - auc: 0.9949 - binary_acc: 0.9797 - loss: 0.0203 - val_auc: 0.9988 - val_binary_acc: 0.9937 - val_loss: 0.0098 - learning_rate: 5.0000e-04

Epoch 35/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 25s 30ms/step - auc: 0.9952 - binary_acc: 0.9803 - loss: 0.0201 - val_auc: 0.9985 - val_binary_acc: 0.9891 - val_loss: 0.0150 - learning_rate: 5.0000e-04

Epoch 36/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 25s 30ms/step - auc: 0.9947 - binary_acc: 0.9791 - loss: 0.0229 - val_auc: 0.9933 - val_binary_acc: 0.9546 - val_loss: 0.0719 - learning_rate: 5.0000e-04

Epoch 37/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 25s 30ms/step - auc: 0.9951 - binary_acc: 0.9801 - loss: 0.0205 - val_auc: 0.9987 - val_binary_acc: 0.9904 - val_loss: 0.0136 - learning_rate: 5.0000e-04

Epoch 38/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 25s 30ms/step - auc: 0.9948 - binary_acc: 0.9793 - loss: 0.0216 - val_auc: 0.9978 - val_binary_acc: 0.9831 - val_loss: 0.0155 - learning_rate: 5.0000e-04

Epoch 39/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 25s 30ms/step - auc: 0.9953 - binary_acc: 0.9810 - loss: 0.0196 - val_auc: 0.9990 - val_binary_acc: 0.9943 - val_loss: 0.0084 - learning_rate: 5.0000e-04

Epoch 40/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 25s 30ms/step - auc: 0.9958 - binary_acc: 0.9819 - loss: 0.0183 - val_auc: 0.9989 - val_binary_acc: 0.9947 - val_loss: 0.0097 - learning_rate: 5.0000e-04

Epoch 41/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 25s 30ms/step - auc: 0.9955 - binary_acc: 0.9815 - loss: 0.0198 - val_auc: 0.9947 - val_binary_acc: 0.9675 - val_loss: 0.0264 - learning_rate: 5.0000e-04

Epoch 42/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 25s 30ms/step - auc: 0.9958 - binary_acc: 0.9818 - loss: 0.0185 - val_auc: 0.9990 - val_binary_acc: 0.9930 - val_loss: 0.0108 - learning_rate: 5.0000e-04

Epoch 43/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 25s 30ms/step - auc: 0.9960 - binary_acc: 0.9821 - loss: 0.0181 - val_auc: 0.9989 - val_binary_acc: 0.9935 - val_loss: 0.0102 - learning_rate: 5.0000e-04

Epoch 44/100

830/832 ━━━━━━━━━━━━━━━━━━━ 0s 28ms/step - auc: 0.9953 - binary_acc: 0.9821 - loss: 0.0184

Epoch 44: ReduceLROnPlateau reducing learning rate to 0.0002500000118743628.

832/832 ━━━━━━━━━━━━━━━━━━━━ 25s 30ms/step - auc: 0.9959 - binary_acc: 0.9818 - loss: 0.0189 - val_auc: 0.9991 - val_binary_acc: 0.9955 - val_loss: 0.0087 - learning_rate: 5.0000e-04

Epoch 45/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 25s 30ms/step - auc: 0.9966 - binary_acc: 0.9837 - loss: 0.0167 - val_auc: 0.9991 - val_binary_acc: 0.9954 - val_loss: 0.0088 - learning_rate: 2.5000e-04

Epoch 46/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 25s 30ms/step - auc: 0.9966 - binary_acc: 0.9836 - loss: 0.0167 - val_auc: 0.9991 - val_binary_acc: 0.9941 - val_loss: 0.0089 - learning_rate: 2.5000e-04

Epoch 47/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 25s 30ms/step - auc: 0.9967 - binary_acc: 0.9841 - loss: 0.0161 - val_auc: 0.9991 - val_binary_acc: 0.9949 - val_loss: 0.0102 - learning_rate: 2.5000e-04

Epoch 48/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 25s 30ms/step - auc: 0.9967 - binary_acc: 0.9845 - loss: 0.0158 - val_auc: 0.9991 - val_binary_acc: 0.9935 - val_loss: 0.0086 - learning_rate: 2.5000e-04

Epoch 49/100

831/832 ━━━━━━━━━━━━━━━━━━━ 0s 28ms/step - auc: 0.9942 - binary_acc: 0.9852 - loss: 0.0149

Epoch 49: ReduceLROnPlateau reducing learning rate to 0.0001250000059371814.

832/832 ━━━━━━━━━━━━━━━━━━━━ 25s 30ms/step - auc: 0.9968 - binary_acc: 0.9846 - loss: 0.0158 - val_auc: 0.9991 - val_binary_acc: 0.9956 - val_loss: 0.0085 - learning_rate: 2.5000e-04

Epoch 49: early stopping

Restoring model weights from the end of the best epoch: 39.

Train Model 1B — MobileNetV2 Transfer Learning

# Need fresh metrics instances (Keras metrics are stateful)
metrics_tl = [
    tf.keras.metrics.AUC(name="auc", multi_label=True),
    tf.keras.metrics.BinaryAccuracy(name="binary_acc"),
]

model_tl.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
    loss=BinaryFocalLoss(gamma=2.0, pos_weight=pos_weights_tensor),
    metrics=metrics_tl,
)

# Fresh callbacks (they carry state from previous training)
callbacks_tl = [
    tf.keras.callbacks.ReduceLROnPlateau(
        monitor="val_loss",
        factor=0.5,
        patience=5,
        min_lr=1e-6,
        verbose=1,
    ),
    tf.keras.callbacks.EarlyStopping(
        monitor="val_loss",
        patience=10,
        restore_best_weights=True,
        verbose=1,
    ),
]

history_tl = model_tl.fit(
    train_ds,
    validation_data=val_ds,
    epochs=EPOCHS,
    callbacks=callbacks_tl,
    verbose=1,
)
Epoch 1/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 200s 237ms/step - auc: 0.8527 - binary_acc: 0.8056 - loss: 0.1289 - val_auc: 0.9175 - val_binary_acc: 0.8686 - val_loss: 0.0841 - learning_rate: 0.0010

Epoch 2/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 195s 234ms/step - auc: 0.8791 - binary_acc: 0.8255 - loss: 0.1177 - val_auc: 0.9345 - val_binary_acc: 0.8660 - val_loss: 0.0771 - learning_rate: 0.0010

Epoch 3/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 196s 235ms/step - auc: 0.8902 - binary_acc: 0.8374 - loss: 0.1176 - val_auc: 0.9456 - val_binary_acc: 0.8943 - val_loss: 0.0790 - learning_rate: 0.0010

Epoch 4/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 196s 235ms/step - auc: 0.8890 - binary_acc: 0.8339 - loss: 0.1045 - val_auc: 0.9419 - val_binary_acc: 0.8859 - val_loss: 0.0827 - learning_rate: 0.0010

Epoch 5/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 195s 234ms/step - auc: 0.8899 - binary_acc: 0.8352 - loss: 0.1123 - val_auc: 0.9414 - val_binary_acc: 0.8934 - val_loss: 0.0831 - learning_rate: 0.0010

Epoch 6/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 197s 237ms/step - auc: 0.8891 - binary_acc: 0.8374 - loss: 0.1006 - val_auc: 0.9497 - val_binary_acc: 0.8972 - val_loss: 0.0742 - learning_rate: 0.0010

Epoch 7/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 196s 236ms/step - auc: 0.8989 - binary_acc: 0.8456 - loss: 0.0965 - val_auc: 0.9539 - val_binary_acc: 0.9046 - val_loss: 0.0756 - learning_rate: 0.0010

Epoch 8/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 196s 235ms/step - auc: 0.8871 - binary_acc: 0.8331 - loss: 0.1009 - val_auc: 0.9494 - val_binary_acc: 0.9078 - val_loss: 0.0748 - learning_rate: 0.0010

Epoch 9/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 195s 235ms/step - auc: 0.8932 - binary_acc: 0.8436 - loss: 0.0898 - val_auc: 0.9516 - val_binary_acc: 0.9001 - val_loss: 0.0727 - learning_rate: 0.0010

Epoch 10/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 196s 235ms/step - auc: 0.8835 - binary_acc: 0.8316 - loss: 0.1015 - val_auc: 0.9463 - val_binary_acc: 0.8953 - val_loss: 0.0829 - learning_rate: 0.0010

Epoch 11/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 196s 236ms/step - auc: 0.8871 - binary_acc: 0.8325 - loss: 0.0982 - val_auc: 0.9504 - val_binary_acc: 0.8992 - val_loss: 0.0716 - learning_rate: 0.0010

Epoch 12/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 196s 236ms/step - auc: 0.8952 - binary_acc: 0.8430 - loss: 0.0898 - val_auc: 0.9603 - val_binary_acc: 0.9114 - val_loss: 0.0694 - learning_rate: 0.0010

Epoch 13/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 198s 238ms/step - auc: 0.9003 - binary_acc: 0.8505 - loss: 0.0863 - val_auc: 0.9581 - val_binary_acc: 0.9147 - val_loss: 0.0684 - learning_rate: 0.0010

Epoch 14/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 196s 235ms/step - auc: 0.8974 - binary_acc: 0.8484 - loss: 0.0946 - val_auc: 0.9607 - val_binary_acc: 0.9128 - val_loss: 0.0717 - learning_rate: 0.0010

Epoch 15/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 195s 235ms/step - auc: 0.8842 - binary_acc: 0.8297 - loss: 0.1024 - val_auc: 0.9475 - val_binary_acc: 0.8982 - val_loss: 0.0806 - learning_rate: 0.0010

Epoch 16/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 197s 237ms/step - auc: 0.8805 - binary_acc: 0.8272 - loss: 0.0945 - val_auc: 0.9211 - val_binary_acc: 0.8801 - val_loss: 0.0782 - learning_rate: 0.0010

Epoch 17/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 196s 236ms/step - auc: 0.8781 - binary_acc: 0.8258 - loss: 0.0964 - val_auc: 0.9386 - val_binary_acc: 0.8855 - val_loss: 0.0795 - learning_rate: 0.0010

Epoch 18/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 0s 215ms/step - auc: 0.8826 - binary_acc: 0.8314 - loss: 0.0955

Epoch 18: ReduceLROnPlateau reducing learning rate to 0.0005000000237487257.

832/832 ━━━━━━━━━━━━━━━━━━━━ 195s 235ms/step - auc: 0.8848 - binary_acc: 0.8322 - loss: 0.0927 - val_auc: 0.9366 - val_binary_acc: 0.8798 - val_loss: 0.0779 - learning_rate: 0.0010

Epoch 19/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 196s 235ms/step - auc: 0.8928 - binary_acc: 0.8414 - loss: 0.0902 - val_auc: 0.9542 - val_binary_acc: 0.9056 - val_loss: 0.0728 - learning_rate: 5.0000e-04

Epoch 20/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 196s 236ms/step - auc: 0.9012 - binary_acc: 0.8509 - loss: 0.0848 - val_auc: 0.9526 - val_binary_acc: 0.9077 - val_loss: 0.0702 - learning_rate: 5.0000e-04

Epoch 21/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 197s 237ms/step - auc: 0.9047 - binary_acc: 0.8546 - loss: 0.0867 - val_auc: 0.9595 - val_binary_acc: 0.9111 - val_loss: 0.0697 - learning_rate: 5.0000e-04

Epoch 22/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 197s 237ms/step - auc: 0.9070 - binary_acc: 0.8568 - loss: 0.0866 - val_auc: 0.9599 - val_binary_acc: 0.9057 - val_loss: 0.0737 - learning_rate: 5.0000e-04

Epoch 23/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 196s 235ms/step - auc: 0.9107 - binary_acc: 0.8592 - loss: 0.0811 - val_auc: 0.9597 - val_binary_acc: 0.9225 - val_loss: 0.0648 - learning_rate: 5.0000e-04

Epoch 24/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 195s 235ms/step - auc: 0.9147 - binary_acc: 0.8638 - loss: 0.0815 - val_auc: 0.9652 - val_binary_acc: 0.9189 - val_loss: 0.0644 - learning_rate: 5.0000e-04

Epoch 25/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 196s 236ms/step - auc: 0.9208 - binary_acc: 0.8710 - loss: 0.0774 - val_auc: 0.9645 - val_binary_acc: 0.9223 - val_loss: 0.0598 - learning_rate: 5.0000e-04

Epoch 26/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 197s 236ms/step - auc: 0.9165 - binary_acc: 0.8671 - loss: 0.0925 - val_auc: 0.9585 - val_binary_acc: 0.9116 - val_loss: 0.0704 - learning_rate: 5.0000e-04

Epoch 27/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 196s 235ms/step - auc: 0.9096 - binary_acc: 0.8584 - loss: 0.0878 - val_auc: 0.9581 - val_binary_acc: 0.9107 - val_loss: 0.0666 - learning_rate: 5.0000e-04

Epoch 28/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 196s 235ms/step - auc: 0.9157 - binary_acc: 0.8640 - loss: 0.0789 - val_auc: 0.9644 - val_binary_acc: 0.9192 - val_loss: 0.0637 - learning_rate: 5.0000e-04

Epoch 29/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 196s 235ms/step - auc: 0.9202 - binary_acc: 0.8692 - loss: 0.0780 - val_auc: 0.9648 - val_binary_acc: 0.9175 - val_loss: 0.0639 - learning_rate: 5.0000e-04

Epoch 30/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 0s 215ms/step - auc: 0.9213 - binary_acc: 0.8696 - loss: 0.0777

Epoch 30: ReduceLROnPlateau reducing learning rate to 0.0002500000118743628.

832/832 ━━━━━━━━━━━━━━━━━━━━ 196s 235ms/step - auc: 0.9206 - binary_acc: 0.8690 - loss: 0.0777 - val_auc: 0.9661 - val_binary_acc: 0.9247 - val_loss: 0.0618 - learning_rate: 5.0000e-04

Epoch 31/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 197s 237ms/step - auc: 0.9234 - binary_acc: 0.8729 - loss: 0.0745 - val_auc: 0.9650 - val_binary_acc: 0.9239 - val_loss: 0.0621 - learning_rate: 2.5000e-04

Epoch 32/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 196s 235ms/step - auc: 0.9240 - binary_acc: 0.8725 - loss: 0.0761 - val_auc: 0.9664 - val_binary_acc: 0.9224 - val_loss: 0.0620 - learning_rate: 2.5000e-04

Epoch 33/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 195s 235ms/step - auc: 0.9250 - binary_acc: 0.8739 - loss: 0.0737 - val_auc: 0.9655 - val_binary_acc: 0.9230 - val_loss: 0.0619 - learning_rate: 2.5000e-04

Epoch 34/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 196s 235ms/step - auc: 0.9248 - binary_acc: 0.8738 - loss: 0.0732 - val_auc: 0.9672 - val_binary_acc: 0.9222 - val_loss: 0.0590 - learning_rate: 2.5000e-04

Epoch 35/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 196s 235ms/step - auc: 0.9258 - binary_acc: 0.8743 - loss: 0.0740 - val_auc: 0.9669 - val_binary_acc: 0.9204 - val_loss: 0.0602 - learning_rate: 2.5000e-04

Epoch 36/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 195s 235ms/step - auc: 0.9251 - binary_acc: 0.8724 - loss: 0.0760 - val_auc: 0.9687 - val_binary_acc: 0.9304 - val_loss: 0.0571 - learning_rate: 2.5000e-04

Epoch 37/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 196s 236ms/step - auc: 0.9263 - binary_acc: 0.8744 - loss: 0.0735 - val_auc: 0.9685 - val_binary_acc: 0.9241 - val_loss: 0.0578 - learning_rate: 2.5000e-04

Epoch 38/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 196s 236ms/step - auc: 0.9285 - binary_acc: 0.8784 - loss: 0.0717 - val_auc: 0.9674 - val_binary_acc: 0.9264 - val_loss: 0.0562 - learning_rate: 2.5000e-04

Epoch 39/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 196s 235ms/step - auc: 0.9257 - binary_acc: 0.8750 - loss: 0.0764 - val_auc: 0.9679 - val_binary_acc: 0.9265 - val_loss: 0.0584 - learning_rate: 2.5000e-04

Epoch 40/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 195s 235ms/step - auc: 0.9297 - binary_acc: 0.8781 - loss: 0.0740 - val_auc: 0.9677 - val_binary_acc: 0.9219 - val_loss: 0.0570 - learning_rate: 2.5000e-04

Epoch 41/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 195s 235ms/step - auc: 0.9295 - binary_acc: 0.8793 - loss: 0.0722 - val_auc: 0.9672 - val_binary_acc: 0.9265 - val_loss: 0.0573 - learning_rate: 2.5000e-04

Epoch 42/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 195s 234ms/step - auc: 0.9301 - binary_acc: 0.8788 - loss: 0.0723 - val_auc: 0.9705 - val_binary_acc: 0.9298 - val_loss: 0.0569 - learning_rate: 2.5000e-04

Epoch 43/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 196s 235ms/step - auc: 0.9317 - binary_acc: 0.8819 - loss: 0.0707 - val_auc: 0.9688 - val_binary_acc: 0.9310 - val_loss: 0.0555 - learning_rate: 2.5000e-04

Epoch 44/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 195s 235ms/step - auc: 0.9297 - binary_acc: 0.8794 - loss: 0.0721 - val_auc: 0.9710 - val_binary_acc: 0.9328 - val_loss: 0.0573 - learning_rate: 2.5000e-04

Epoch 45/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 195s 234ms/step - auc: 0.9309 - binary_acc: 0.8809 - loss: 0.0710 - val_auc: 0.9703 - val_binary_acc: 0.9278 - val_loss: 0.0581 - learning_rate: 2.5000e-04

Epoch 46/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 195s 235ms/step - auc: 0.9323 - binary_acc: 0.8817 - loss: 0.0712 - val_auc: 0.9707 - val_binary_acc: 0.9289 - val_loss: 0.0572 - learning_rate: 2.5000e-04

Epoch 47/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 197s 236ms/step - auc: 0.9314 - binary_acc: 0.8811 - loss: 0.0708 - val_auc: 0.9714 - val_binary_acc: 0.9279 - val_loss: 0.0557 - learning_rate: 2.5000e-04

Epoch 48/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 0s 215ms/step - auc: 0.9311 - binary_acc: 0.8805 - loss: 0.0781

Epoch 48: ReduceLROnPlateau reducing learning rate to 0.0001250000059371814.

832/832 ━━━━━━━━━━━━━━━━━━━━ 195s 235ms/step - auc: 0.9319 - binary_acc: 0.8809 - loss: 0.0740 - val_auc: 0.9715 - val_binary_acc: 0.9299 - val_loss: 0.0563 - learning_rate: 2.5000e-04

Epoch 49/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 195s 235ms/step - auc: 0.9328 - binary_acc: 0.8828 - loss: 0.0706 - val_auc: 0.9715 - val_binary_acc: 0.9314 - val_loss: 0.0557 - learning_rate: 1.2500e-04

Epoch 50/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 196s 236ms/step - auc: 0.9334 - binary_acc: 0.8824 - loss: 0.0698 - val_auc: 0.9716 - val_binary_acc: 0.9320 - val_loss: 0.0563 - learning_rate: 1.2500e-04

Epoch 51/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 195s 235ms/step - auc: 0.9331 - binary_acc: 0.8824 - loss: 0.0697 - val_auc: 0.9706 - val_binary_acc: 0.9276 - val_loss: 0.0560 - learning_rate: 1.2500e-04

Epoch 52/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 196s 236ms/step - auc: 0.9344 - binary_acc: 0.8842 - loss: 0.0689 - val_auc: 0.9710 - val_binary_acc: 0.9301 - val_loss: 0.0571 - learning_rate: 1.2500e-04

Epoch 53/100

832/832 ━━━━━━━━━━━━━━━━━━━━ 0s 216ms/step - auc: 0.9328 - binary_acc: 0.8841 - loss: 0.0693

Epoch 53: ReduceLROnPlateau reducing learning rate to 6.25000029685907e-05.

832/832 ━━━━━━━━━━━━━━━━━━━━ 196s 235ms/step - auc: 0.9337 - binary_acc: 0.8842 - loss: 0.0694 - val_auc: 0.9717 - val_binary_acc: 0.9306 - val_loss: 0.0556 - learning_rate: 1.2500e-04

Epoch 53: early stopping

Restoring model weights from the end of the best epoch: 43.

Step 8 — Phase 1 Evaluation & Comparison

Compare both models on the test set using: - Training history plots (loss & AUC) - Per-label Precision, Recall, F1 - Macro F1 summary

def plot_training_history(history, model_name):
    """Plot loss and AUC curves for training and validation."""
    _fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    # Loss
    axes[0].plot(history.history["loss"], label="Train Loss")
    axes[0].plot(history.history["val_loss"], label="Val Loss")
    axes[0].set_title(f"{model_name} — Loss")
    axes[0].set_xlabel("Epoch")
    axes[0].set_ylabel("Focal Loss")
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)

    # AUC
    axes[1].plot(history.history["auc"], label="Train AUC")
    axes[1].plot(history.history["val_auc"], label="Val AUC")
    axes[1].set_title(f"{model_name} — AUC")
    axes[1].set_xlabel("Epoch")
    axes[1].set_ylabel("AUC")
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()


plot_training_history(history_cnn, "Custom CNN")
plot_training_history(history_tl, "MobileNetV2 Transfer")

from sklearn.metrics import classification_report, f1_score

THRESHOLD = 0.5
base_labels_all = ["Center", "Donut", "Edge_Loc", "Edge_Ring", "Loc", "Near_Full", "Scratch", "Random"]


def evaluate_multilabel(model, dataset, y_true, model_name):
    """Evaluate a multi-label model: per-label precision/recall/F1 + macro F1."""
    y_pred_prob = model.predict(dataset, verbose=0)
    y_pred = (y_pred_prob >= THRESHOLD).astype(int)

    separator = "=" * 60
    print(f"\n{separator}")
    print(f"  {model_name} — Test Set Results (threshold={THRESHOLD})")
    print(separator)

    report = classification_report(
        y_true,
        y_pred,
        target_names=base_labels_all,
        zero_division=0,
        output_dict=True,
    )
    print(
        classification_report(
            y_true,
            y_pred,
            target_names=base_labels_all,
            zero_division=0,
        )
    )

    macro_f1 = f1_score(y_true, y_pred, average="macro", zero_division=0)
    weighted_f1 = f1_score(y_true, y_pred, average="weighted", zero_division=0)
    print(f"Macro F1:    {macro_f1:.4f}")
    print(f"Weighted F1: {weighted_f1:.4f}")

    return y_pred, y_pred_prob, report


y_pred_cnn, y_prob_cnn, report_cnn = evaluate_multilabel(model_cnn, test_ds, y_test, "Custom CNN")
y_pred_tl, y_prob_tl, report_tl = evaluate_multilabel(model_tl, test_ds, y_test, "MobileNetV2 Transfer")

============================================================
  Custom CNN — Test Set Results (threshold=0.5)
============================================================
              precision    recall  f1-score   support

      Center       1.00      1.00      1.00      1950
       Donut       1.00      1.00      1.00      1800
    Edge_Loc       0.98      0.99      0.99      1950
   Edge_Ring       0.99      0.99      0.99      1800
         Loc       1.00      0.99      0.99      2700
   Near_Full       0.82      1.00      0.90        23
     Scratch       0.97      0.98      0.98      2850
      Random       0.94      1.00      0.97       130

   micro avg       0.99      0.99      0.99     13203
   macro avg       0.96      0.99      0.98     13203
weighted avg       0.99      0.99      0.99     13203
 samples avg       0.96      0.97      0.96     13203

Macro F1:    0.9779
Weighted F1: 0.9905

============================================================
  MobileNetV2 Transfer — Test Set Results (threshold=0.5)
============================================================
              precision    recall  f1-score   support

      Center       0.97      0.97      0.97      1950
       Donut       1.00      1.00      1.00      1800
    Edge_Loc       0.67      0.83      0.75      1950
   Edge_Ring       0.80      0.97      0.88      1800
         Loc       0.87      0.91      0.89      2700
   Near_Full       0.79      1.00      0.88        23
     Scratch       0.84      0.85      0.84      2850
      Random       0.98      1.00      0.99       130

   micro avg       0.85      0.91      0.88     13203
   macro avg       0.87      0.94      0.90     13203
weighted avg       0.86      0.91      0.88     13203
 samples avg       0.85      0.89      0.85     13203

Macro F1:    0.8998
Weighted F1: 0.8847
# Side-by-side F1 comparison per label
comparison_data = []
for label_name in base_labels_all:
    comparison_data.append({
        "Defect Type": label_name,
        "Custom CNN F1": report_cnn[label_name]["f1-score"],
        "MobileNetV2 F1": report_tl[label_name]["f1-score"],
    })

comparison_df = pd.DataFrame(comparison_data)

# Add macro averages
comparison_df = pd.concat(
    [
        comparison_df,
        pd.DataFrame([
            {
                "Defect Type": "MACRO AVG",
                "Custom CNN F1": report_cnn["macro avg"]["f1-score"],
                "MobileNetV2 F1": report_tl["macro avg"]["f1-score"],
            }
        ]),
    ],
    ignore_index=True,
)

display(
    comparison_df.style.format({"Custom CNN F1": "{:.4f}", "MobileNetV2 F1": "{:.4f}"}).highlight_max(
        subset=["Custom CNN F1", "MobileNetV2 F1"],
        axis=1,
        color="lightgreen",
    )
)

# Bar chart comparison
_fig, ax = plt.subplots(figsize=(12, 6))
x = np.arange(len(base_labels_all))
width = 0.35
ax.bar(
    x - width / 2, [report_cnn[lbl]["f1-score"] for lbl in base_labels_all], width, label="Custom CNN", color="#4C78A8"
)
ax.bar(
    x + width / 2, [report_tl[lbl]["f1-score"] for lbl in base_labels_all], width, label="MobileNetV2", color="#F58518"
)
ax.set_xlabel("Defect Type")
ax.set_ylabel("F1 Score")
ax.set_title("Phase 1 — Per-Label F1 Score Comparison")
ax.set_xticks(x)
ax.set_xticklabels(base_labels_all, rotation=45, ha="right")
ax.legend()
ax.grid(True, alpha=0.3, axis="y")
plt.tight_layout()
plt.show()
  Defect Type Custom CNN F1 MobileNetV2 F1
0 Center 1.0000 0.9704
1 Donut 1.0000 0.9964
2 Edge_Loc 0.9870 0.7456
3 Edge_Ring 0.9936 0.8773
4 Loc 0.9941 0.8927
5 Near_Full 0.9020 0.8846
6 Scratch 0.9768 0.8430
7 Random 0.9701 0.9886
8 MACRO AVG 0.9779 0.8998

Phase 2 — Multi-Class Classification (All 38 Defect Patterns)

Reuse the best-performing Phase 1 convolutional backbone as a frozen feature extractor and replace the classification head with a 38-class softmax output.

Strategy: 1. Select the best Phase 1 model based on Macro F1 2. Freeze its conv layers, replace the head 3. Two-stage training: head-only (lr=1e-3) → fine-tune last conv blocks (lr=1e-5)

Step 9 — Encode All 38 Pattern Labels

Convert each 8-dim one-hot vector to a single integer class ID (0–37) for multi-class softmax classification. Compute class weights using sklearn to handle the severe imbalance.

from sklearn.utils.class_weight import compute_class_weight

# Build a mapping from one-hot string key → integer class ID
unique_patterns = sorted(label_mapping.keys())
pattern_to_id = {pattern: idx for idx, pattern in enumerate(unique_patterns)}
id_to_pattern = {idx: label_mapping[pattern] for pattern, idx in pattern_to_id.items()}
NUM_CLASSES = len(unique_patterns)

print(f"Number of unique defect patterns: {NUM_CLASSES}")
print("\nClass ID → Pattern Name:")
for cid, name in id_to_pattern.items():
    print(f"  {cid:>2}: {name}")

# Convert all labels to integer class IDs
y_class_all = np.array([pattern_to_id["".join(map(str, map(int, row)))] for row in labels])

# Split using the same indices — recreate from label_str_arr
y_class_train = np.array([pattern_to_id[s] for s in str_train])
y_class_val = np.array([pattern_to_id[s] for s in str_val])
y_class_test = np.array([pattern_to_id[s] for s in str_test])

print(f"\nTrain class IDs shape: {y_class_train.shape}")
print(f"Val class IDs shape:   {y_class_val.shape}")
print(f"Test class IDs shape:  {y_class_test.shape}")

# Compute class weights to handle imbalance
class_weights_arr = compute_class_weight(
    class_weight="balanced",
    classes=np.arange(NUM_CLASSES),
    y=y_class_train,
)
class_weights_dict = dict(enumerate(class_weights_arr))

print("\nClass weights (top 5 highest):")
for cid, w in sorted(class_weights_dict.items(), key=lambda x: x[1], reverse=True)[:5]:
    print(f"  {id_to_pattern[cid]:<35} weight={w:.3f}")
Number of unique defect patterns: 38

Class ID → Pattern Name:
   0: Normal
   1: Random
   2: Scratch
   3: Near_Full
   4: Loc
   5: Loc+Scratch
   6: Edge_Ring
   7: Edge_Ring+Scratch
   8: Edge_Ring+Loc
   9: Edge_Ring+Loc+Scratch
  10: Edge_Loc
  11: Edge_Loc+Scratch
  12: Edge_Loc+Loc
  13: Edge_Loc+Loc+Scratch
  14: Donut
  15: Donut+Scratch
  16: Donut+Loc
  17: Donut+Loc+Scratch
  18: Donut+Edge_Ring
  19: Donut+Edge_Ring+Scratch
  20: Donut+Edge_Ring+Loc
  21: Donut+Edge_Ring+Loc+Scratch
  22: Donut+Edge_Loc
  23: Donut+Edge_Loc+Scratch
  24: Donut+Edge_Loc+Loc
  25: Donut+Edge_Loc+Loc+Scratch
  26: Center
  27: Center+Scratch
  28: Center+Loc
  29: Center+Loc+Scratch
  30: Center+Edge_Ring
  31: Center+Edge_Ring+Scratch
  32: Center+Edge_Ring+Loc
  33: Center+Edge_Ring+Loc+Scratch
  34: Center+Edge_Loc
  35: Center+Edge_Loc+Scratch
  36: Center+Edge_Loc+Loc
  37: Center+Edge_Loc+Loc+Scratch

Train class IDs shape: (26610,)
Val class IDs shape:   (5702,)
Test class IDs shape:  (5703,)

Class weights (top 5 highest):
  Near_Full                           weight=6.733
  Random                              weight=1.156
  Normal                              weight=1.000
  Scratch                             weight=1.000
  Loc                                 weight=1.000

Step 10 — Phase 2 Models: Backbone Reuse

Build Phase 2 models by reusing the convolutional backbones from both Phase 1 models: 1. Freeze all conv layers from the Phase 1 model 2. Replace the Dense(8, sigmoid) head with Dense(128, relu) → Dropout → Dense(38, softmax) 3. Build separate datasets with integer class labels for sparse_categorical_crossentropy

# Build Phase 2 datasets with integer class labels
train_ds_p2 = build_dataset(X_train, y_class_train, shuffle=True, augment_fn=augment)
val_ds_p2 = build_dataset(X_val, y_class_val)
test_ds_p2 = build_dataset(X_test, y_class_test)

# Verify shapes
for batch_images, batch_labels in train_ds_p2.take(1):
    print("Phase 2 batch images shape:", batch_images.shape)
    print("Phase 2 batch labels shape:", batch_labels.shape)
    print("Sample labels:", batch_labels[:5].numpy())


def build_phase2_from_backbone(phase1_model, num_classes=NUM_CLASSES, model_name="phase2"):
    """Reuse Phase 1 conv backbone, freeze it, and add a new 38-class softmax head.

    Extracts the feature extractor (everything up to GlobalAveragePooling2D output)
    from the Phase 1 model and attaches a new classification head.
    """
    # Find the GlobalAveragePooling2D layer output in the Phase 1 model
    gap_layer = None
    for layer in phase1_model.layers:
        if isinstance(layer, layers.GlobalAveragePooling2D):
            gap_layer = layer
            break

    if gap_layer is None:
        msg = "Could not find GlobalAveragePooling2D in Phase 1 model"
        raise ValueError(msg)

    # Build feature extractor: input → GAP output
    feature_extractor = Model(
        inputs=phase1_model.input,
        outputs=gap_layer.output,
        name=f"{model_name}_backbone",
    )

    # Freeze all backbone layers
    for layer in feature_extractor.layers:
        layer.trainable = False

    # New classification head
    inputs = layers.Input(shape=phase1_model.input_shape[1:], name="input_image")
    x = feature_extractor(inputs, training=False)
    x = layers.Dense(128, activation="relu")(x)
    x = layers.Dropout(0.5)(x)
    outputs = layers.Dense(num_classes, activation="softmax", name="output_38")(x)

    return Model(inputs, outputs, name=model_name), feature_extractor


# Phase 2A: Custom CNN backbone → 38-class head
model_p2_cnn, backbone_cnn = build_phase2_from_backbone(model_cnn, model_name="phase2_custom_cnn")
print("Phase 2A — Custom CNN backbone:")
model_p2_cnn.summary()

# Phase 2B: MobileNetV2 backbone → 38-class head
model_p2_tl, backbone_tl = build_phase2_from_backbone(model_tl, model_name="phase2_mobilenet")
print("\nPhase 2B — MobileNetV2 backbone:")
model_p2_tl.summary()

# Verify frozen layers
for name, bb in [("CNN", backbone_cnn), ("MobileNet", backbone_tl)]:
    n_trainable = sum(1 for layer in bb.layers if layer.trainable)
    total = len(bb.layers)
    print(f"\n{name} backbone: {n_trainable}/{total} trainable layers")
Phase 2 batch images shape: (32, 52, 52, 3)
Phase 2 batch labels shape: (32,)
Sample labels: [ 6 12 28 27 24]
Phase 2A — Custom CNN backbone:
2026-02-26 02:25:54.004939: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
Model: "phase2_custom_cnn"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Layer (type)                     Output Shape                  Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ input_image (InputLayer)        │ (None, 52, 52, 3)      │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ phase2_custom_cnn_backbone      │ (None, 256)            │       390,336 │
│ (Functional)                    │                        │               │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense_2 (Dense)                 │ (None, 128)            │        32,896 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dropout_2 (Dropout)             │ (None, 128)            │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ output_38 (Dense)               │ (None, 38)             │         4,902 │
└─────────────────────────────────┴────────────────────────┴───────────────┘
 Total params: 428,134 (1.63 MB)
 Trainable params: 37,798 (147.65 KB)
 Non-trainable params: 390,336 (1.49 MB)

Phase 2B — MobileNetV2 backbone:
Model: "phase2_mobilenet"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Layer (type)                     Output Shape                  Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ input_image (InputLayer)        │ (None, 52, 52, 3)      │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ phase2_mobilenet_backbone       │ (None, 1280)           │     2,257,996 │
│ (Functional)                    │                        │               │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dense_3 (Dense)                 │ (None, 128)            │       163,968 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dropout_3 (Dropout)             │ (None, 128)            │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ output_38 (Dense)               │ (None, 38)             │         4,902 │
└─────────────────────────────────┴────────────────────────┴───────────────┘
 Total params: 2,426,866 (9.26 MB)
 Trainable params: 168,870 (659.65 KB)
 Non-trainable params: 2,257,996 (8.61 MB)

CNN backbone: 0/18 trainable layers

MobileNet backbone: 0/5 trainable layers

Step 11 — Phase 2 Training (Two-Stage)

Stage A: Train only the new head (backbone frozen), lr=1e-3, ~30 epochs Stage B: Unfreeze last 1-2 conv blocks, fine-tune end-to-end, lr=1e-5, ~30 epochs

Loss: sparse_categorical_crossentropy with class weights.

def train_phase2_two_stage(
    model, backbone, train_ds, val_ds, class_weights, stage_a_epochs=30, stage_b_epochs=30, model_name="Phase2"
):
    """Two-stage Phase 2 training: head-only then fine-tune.

    Stage A: Backbone frozen, train only the new head (lr=1e-3).
    Stage B: Unfreeze last 2 conv blocks in backbone, fine-tune (lr=1e-5).
    """
    # ── Stage A: Head-only training ──
    separator = "=" * 60
    print(f"\n{separator}")
    print(f"  {model_name} — Stage A: Head-only training")
    print(separator)

    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
        loss="sparse_categorical_crossentropy",
        metrics=["accuracy"],
    )

    callbacks_a = [
        tf.keras.callbacks.ReduceLROnPlateau(monitor="val_loss", factor=0.5, patience=5, min_lr=1e-6, verbose=1),
        tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=10, restore_best_weights=True, verbose=1),
    ]

    history_a = model.fit(
        train_ds,
        validation_data=val_ds,
        epochs=stage_a_epochs,
        callbacks=callbacks_a,
        class_weight=class_weights,
        verbose=1,
    )

    print(f"\n{separator}")
    print(f"\n{'=' * 60}")
    print(separator)
    print(f"{'=' * 60}")

    # Unfreeze last ~25% of backbone layers
    backbone_layers = backbone.layers
    unfreeze_from = int(len(backbone_layers) * 0.75)
    for layer in backbone_layers[unfreeze_from:]:
        if not isinstance(layer, layers.BatchNormalization):
            layer.trainable = True

    trainable_count = sum(1 for layer in backbone_layers if layer.trainable)
    print(f"  Unfroze layers from index {unfreeze_from}: {trainable_count}/{len(backbone_layers)} trainable")

    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5),
        loss="sparse_categorical_crossentropy",
        metrics=["accuracy"],
    )

    callbacks_b = [
        tf.keras.callbacks.ReduceLROnPlateau(monitor="val_loss", factor=0.5, patience=5, min_lr=1e-7, verbose=1),
        tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=10, restore_best_weights=True, verbose=1),
    ]

    history_b = model.fit(
        train_ds,
        validation_data=val_ds,
        epochs=stage_b_epochs,
        callbacks=callbacks_b,
        class_weight=class_weights,
        verbose=1,
    )

    return history_a, history_b

Train Phase 2A — Custom CNN Backbone

hist_p2_cnn_a, hist_p2_cnn_b = train_phase2_two_stage(
    model_p2_cnn,
    backbone_cnn,
    train_ds_p2,
    val_ds_p2,
    class_weights_dict,
    model_name="Phase2 Custom CNN",
)
============================================================

  Phase2 Custom CNN — Stage A: Head-only training

============================================================

Epoch 1/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 9s 10ms/step - accuracy: 0.6489 - loss: 1.1442 - val_accuracy: 0.9318 - val_loss: 0.2510 - learning_rate: 0.0010

Epoch 2/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 8s 9ms/step - accuracy: 0.8913 - loss: 0.3612 - val_accuracy: 0.9558 - val_loss: 0.1583 - learning_rate: 0.0010

Epoch 3/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 8s 9ms/step - accuracy: 0.9196 - loss: 0.2634 - val_accuracy: 0.9611 - val_loss: 0.1393 - learning_rate: 0.0010

Epoch 4/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 7s 9ms/step - accuracy: 0.9318 - loss: 0.2301 - val_accuracy: 0.9616 - val_loss: 0.1260 - learning_rate: 0.0010

Epoch 5/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 7s 9ms/step - accuracy: 0.9386 - loss: 0.2087 - val_accuracy: 0.9637 - val_loss: 0.1261 - learning_rate: 0.0010

Epoch 6/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 7s 9ms/step - accuracy: 0.9410 - loss: 0.1934 - val_accuracy: 0.9633 - val_loss: 0.1255 - learning_rate: 0.0010

Epoch 7/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 8s 9ms/step - accuracy: 0.9447 - loss: 0.1866 - val_accuracy: 0.9655 - val_loss: 0.1242 - learning_rate: 0.0010

Epoch 8/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 7s 9ms/step - accuracy: 0.9463 - loss: 0.1783 - val_accuracy: 0.9625 - val_loss: 0.1398 - learning_rate: 0.0010

Epoch 9/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 7s 9ms/step - accuracy: 0.9469 - loss: 0.1746 - val_accuracy: 0.9647 - val_loss: 0.1267 - learning_rate: 0.0010

Epoch 10/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 7s 9ms/step - accuracy: 0.9483 - loss: 0.1687 - val_accuracy: 0.9653 - val_loss: 0.1292 - learning_rate: 0.0010

Epoch 11/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 7s 9ms/step - accuracy: 0.9548 - loss: 0.1603 - val_accuracy: 0.9676 - val_loss: 0.1259 - learning_rate: 0.0010

Epoch 12/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 7s 9ms/step - accuracy: 0.9528 - loss: 0.1614 - val_accuracy: 0.9670 - val_loss: 0.1197 - learning_rate: 0.0010

Epoch 13/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 8s 9ms/step - accuracy: 0.9540 - loss: 0.1613 - val_accuracy: 0.9686 - val_loss: 0.1325 - learning_rate: 0.0010

Epoch 14/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 7s 9ms/step - accuracy: 0.9528 - loss: 0.1564 - val_accuracy: 0.9660 - val_loss: 0.1319 - learning_rate: 0.0010

Epoch 15/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 7s 9ms/step - accuracy: 0.9525 - loss: 0.1565 - val_accuracy: 0.9670 - val_loss: 0.1335 - learning_rate: 0.0010

Epoch 16/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 7s 9ms/step - accuracy: 0.9550 - loss: 0.1549 - val_accuracy: 0.9656 - val_loss: 0.1385 - learning_rate: 0.0010

Epoch 17/30

830/832 ━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.9524 - loss: 0.1560

Epoch 17: ReduceLROnPlateau reducing learning rate to 0.0005000000237487257.

832/832 ━━━━━━━━━━━━━━━━━━━━ 7s 9ms/step - accuracy: 0.9542 - loss: 0.1498 - val_accuracy: 0.9662 - val_loss: 0.1284 - learning_rate: 0.0010

Epoch 18/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 7s 9ms/step - accuracy: 0.9596 - loss: 0.1377 - val_accuracy: 0.9700 - val_loss: 0.1210 - learning_rate: 5.0000e-04

Epoch 19/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 8s 9ms/step - accuracy: 0.9615 - loss: 0.1330 - val_accuracy: 0.9681 - val_loss: 0.1303 - learning_rate: 5.0000e-04

Epoch 20/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 7s 9ms/step - accuracy: 0.9625 - loss: 0.1325 - val_accuracy: 0.9693 - val_loss: 0.1244 - learning_rate: 5.0000e-04

Epoch 21/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 7s 9ms/step - accuracy: 0.9626 - loss: 0.1310 - val_accuracy: 0.9698 - val_loss: 0.1218 - learning_rate: 5.0000e-04

Epoch 22/30

829/832 ━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.9616 - loss: 0.1321

Epoch 22: ReduceLROnPlateau reducing learning rate to 0.0002500000118743628.

832/832 ━━━━━━━━━━━━━━━━━━━━ 7s 9ms/step - accuracy: 0.9619 - loss: 0.1341 - val_accuracy: 0.9704 - val_loss: 0.1225 - learning_rate: 5.0000e-04

Epoch 22: early stopping

Restoring model weights from the end of the best epoch: 12.



============================================================



============================================================

============================================================

============================================================

  Unfroze layers from index 13: 4/18 trainable

Epoch 1/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 11s 12ms/step - accuracy: 0.9565 - loss: 0.1454 - val_accuracy: 0.9683 - val_loss: 0.1171 - learning_rate: 1.0000e-05

Epoch 2/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 10s 12ms/step - accuracy: 0.9580 - loss: 0.1353 - val_accuracy: 0.9695 - val_loss: 0.1138 - learning_rate: 1.0000e-05

Epoch 3/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 10s 12ms/step - accuracy: 0.9598 - loss: 0.1360 - val_accuracy: 0.9705 - val_loss: 0.1133 - learning_rate: 1.0000e-05

Epoch 4/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 10s 12ms/step - accuracy: 0.9607 - loss: 0.1339 - val_accuracy: 0.9702 - val_loss: 0.1139 - learning_rate: 1.0000e-05

Epoch 5/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 10s 11ms/step - accuracy: 0.9624 - loss: 0.1353 - val_accuracy: 0.9705 - val_loss: 0.1116 - learning_rate: 1.0000e-05

Epoch 6/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 10s 12ms/step - accuracy: 0.9635 - loss: 0.1271 - val_accuracy: 0.9700 - val_loss: 0.1127 - learning_rate: 1.0000e-05

Epoch 7/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 10s 11ms/step - accuracy: 0.9643 - loss: 0.1265 - val_accuracy: 0.9712 - val_loss: 0.1112 - learning_rate: 1.0000e-05

Epoch 8/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 10s 12ms/step - accuracy: 0.9639 - loss: 0.1257 - val_accuracy: 0.9723 - val_loss: 0.1095 - learning_rate: 1.0000e-05

Epoch 9/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 10s 12ms/step - accuracy: 0.9643 - loss: 0.1276 - val_accuracy: 0.9721 - val_loss: 0.1101 - learning_rate: 1.0000e-05

Epoch 10/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 10s 11ms/step - accuracy: 0.9622 - loss: 0.1294 - val_accuracy: 0.9721 - val_loss: 0.1098 - learning_rate: 1.0000e-05

Epoch 11/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 10s 12ms/step - accuracy: 0.9646 - loss: 0.1252 - val_accuracy: 0.9721 - val_loss: 0.1097 - learning_rate: 1.0000e-05

Epoch 12/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 10s 12ms/step - accuracy: 0.9629 - loss: 0.1274 - val_accuracy: 0.9716 - val_loss: 0.1080 - learning_rate: 1.0000e-05

Epoch 13/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 10s 12ms/step - accuracy: 0.9658 - loss: 0.1192 - val_accuracy: 0.9725 - val_loss: 0.1093 - learning_rate: 1.0000e-05

Epoch 14/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 10s 12ms/step - accuracy: 0.9658 - loss: 0.1251 - val_accuracy: 0.9719 - val_loss: 0.1097 - learning_rate: 1.0000e-05

Epoch 15/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 10s 11ms/step - accuracy: 0.9652 - loss: 0.1249 - val_accuracy: 0.9719 - val_loss: 0.1089 - learning_rate: 1.0000e-05

Epoch 16/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 10s 11ms/step - accuracy: 0.9649 - loss: 0.1217 - val_accuracy: 0.9719 - val_loss: 0.1083 - learning_rate: 1.0000e-05

Epoch 17/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 0s 10ms/step - accuracy: 0.9653 - loss: 0.1246

Epoch 17: ReduceLROnPlateau reducing learning rate to 4.999999873689376e-06.

832/832 ━━━━━━━━━━━━━━━━━━━━ 10s 11ms/step - accuracy: 0.9646 - loss: 0.1243 - val_accuracy: 0.9716 - val_loss: 0.1091 - learning_rate: 1.0000e-05

Epoch 18/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 10s 12ms/step - accuracy: 0.9651 - loss: 0.1230 - val_accuracy: 0.9721 - val_loss: 0.1083 - learning_rate: 5.0000e-06

Epoch 19/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 10s 12ms/step - accuracy: 0.9678 - loss: 0.1214 - val_accuracy: 0.9721 - val_loss: 0.1080 - learning_rate: 5.0000e-06

Epoch 20/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 10s 11ms/step - accuracy: 0.9672 - loss: 0.1194 - val_accuracy: 0.9723 - val_loss: 0.1085 - learning_rate: 5.0000e-06

Epoch 21/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 10s 11ms/step - accuracy: 0.9650 - loss: 0.1216 - val_accuracy: 0.9723 - val_loss: 0.1084 - learning_rate: 5.0000e-06

Epoch 22/30

827/832 ━━━━━━━━━━━━━━━━━━━ 0s 10ms/step - accuracy: 0.9652 - loss: 0.1215

Epoch 22: ReduceLROnPlateau reducing learning rate to 2.499999936844688e-06.

832/832 ━━━━━━━━━━━━━━━━━━━━ 10s 12ms/step - accuracy: 0.9662 - loss: 0.1177 - val_accuracy: 0.9726 - val_loss: 0.1082 - learning_rate: 5.0000e-06

Epoch 22: early stopping

Restoring model weights from the end of the best epoch: 12.

Train Phase 2B — MobileNetV2 Backbone

hist_p2_tl_a, hist_p2_tl_b = train_phase2_two_stage(
    model_p2_tl,
    backbone_tl,
    train_ds_p2,
    val_ds_p2,
    class_weights_dict,
    model_name="Phase2 MobileNetV2",
)
============================================================

  Phase2 MobileNetV2 — Stage A: Head-only training

============================================================

Epoch 1/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 99s 115ms/step - accuracy: 0.2201 - loss: 2.4609 - val_accuracy: 0.4660 - val_loss: 1.6635 - learning_rate: 0.0010

Epoch 2/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 94s 113ms/step - accuracy: 0.3773 - loss: 1.7501 - val_accuracy: 0.5558 - val_loss: 1.3109 - learning_rate: 0.0010

Epoch 3/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 94s 113ms/step - accuracy: 0.4292 - loss: 1.5843 - val_accuracy: 0.5680 - val_loss: 1.2359 - learning_rate: 0.0010

Epoch 4/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 94s 113ms/step - accuracy: 0.4573 - loss: 1.4787 - val_accuracy: 0.5896 - val_loss: 1.1411 - learning_rate: 0.0010

Epoch 5/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 95s 114ms/step - accuracy: 0.4760 - loss: 1.4259 - val_accuracy: 0.6001 - val_loss: 1.1324 - learning_rate: 0.0010

Epoch 6/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 94s 113ms/step - accuracy: 0.4892 - loss: 1.3843 - val_accuracy: 0.6093 - val_loss: 1.0835 - learning_rate: 0.0010

Epoch 7/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 95s 114ms/step - accuracy: 0.4985 - loss: 1.3570 - val_accuracy: 0.6114 - val_loss: 1.0606 - learning_rate: 0.0010

Epoch 8/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 95s 115ms/step - accuracy: 0.5113 - loss: 1.3226 - val_accuracy: 0.6142 - val_loss: 1.0373 - learning_rate: 0.0010

Epoch 9/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 94s 113ms/step - accuracy: 0.5134 - loss: 1.3194 - val_accuracy: 0.6363 - val_loss: 1.0381 - learning_rate: 0.0010

Epoch 10/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 95s 114ms/step - accuracy: 0.5212 - loss: 1.2836 - val_accuracy: 0.6506 - val_loss: 0.9742 - learning_rate: 0.0010

Epoch 11/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 94s 113ms/step - accuracy: 0.5246 - loss: 1.2738 - val_accuracy: 0.6312 - val_loss: 1.0046 - learning_rate: 0.0010

Epoch 12/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 94s 114ms/step - accuracy: 0.5305 - loss: 1.2639 - val_accuracy: 0.6478 - val_loss: 0.9646 - learning_rate: 0.0010

Epoch 13/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 94s 113ms/step - accuracy: 0.5356 - loss: 1.2431 - val_accuracy: 0.6452 - val_loss: 0.9677 - learning_rate: 0.0010

Epoch 14/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 95s 114ms/step - accuracy: 0.5372 - loss: 1.2391 - val_accuracy: 0.6512 - val_loss: 0.9617 - learning_rate: 0.0010

Epoch 15/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 94s 113ms/step - accuracy: 0.5381 - loss: 1.2335 - val_accuracy: 0.6445 - val_loss: 0.9929 - learning_rate: 0.0010

Epoch 16/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 94s 113ms/step - accuracy: 0.5478 - loss: 1.2089 - val_accuracy: 0.6626 - val_loss: 0.9524 - learning_rate: 0.0010

Epoch 17/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 94s 113ms/step - accuracy: 0.5483 - loss: 1.2176 - val_accuracy: 0.6640 - val_loss: 0.9187 - learning_rate: 0.0010

Epoch 18/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 95s 114ms/step - accuracy: 0.5484 - loss: 1.2098 - val_accuracy: 0.6661 - val_loss: 0.9082 - learning_rate: 0.0010

Epoch 19/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 94s 113ms/step - accuracy: 0.5531 - loss: 1.2042 - val_accuracy: 0.6575 - val_loss: 0.9296 - learning_rate: 0.0010

Epoch 20/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 94s 113ms/step - accuracy: 0.5555 - loss: 1.1974 - val_accuracy: 0.6596 - val_loss: 0.9247 - learning_rate: 0.0010

Epoch 21/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 95s 114ms/step - accuracy: 0.5581 - loss: 1.1962 - val_accuracy: 0.6410 - val_loss: 0.9619 - learning_rate: 0.0010

Epoch 22/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 94s 113ms/step - accuracy: 0.5613 - loss: 1.1805 - val_accuracy: 0.6663 - val_loss: 0.9199 - learning_rate: 0.0010

Epoch 23/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 0s 93ms/step - accuracy: 0.5605 - loss: 1.1922

Epoch 23: ReduceLROnPlateau reducing learning rate to 0.0005000000237487257.

832/832 ━━━━━━━━━━━━━━━━━━━━ 94s 113ms/step - accuracy: 0.5535 - loss: 1.1934 - val_accuracy: 0.6578 - val_loss: 0.9575 - learning_rate: 0.0010

Epoch 24/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 94s 113ms/step - accuracy: 0.5810 - loss: 1.1299 - val_accuracy: 0.6791 - val_loss: 0.8745 - learning_rate: 5.0000e-04

Epoch 25/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 94s 113ms/step - accuracy: 0.5842 - loss: 1.1182 - val_accuracy: 0.6864 - val_loss: 0.8703 - learning_rate: 5.0000e-04

Epoch 26/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 95s 114ms/step - accuracy: 0.5908 - loss: 1.1094 - val_accuracy: 0.6926 - val_loss: 0.8568 - learning_rate: 5.0000e-04

Epoch 27/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 95s 114ms/step - accuracy: 0.5904 - loss: 1.1006 - val_accuracy: 0.6878 - val_loss: 0.8455 - learning_rate: 5.0000e-04

Epoch 28/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 94s 113ms/step - accuracy: 0.5914 - loss: 1.1067 - val_accuracy: 0.6848 - val_loss: 0.8505 - learning_rate: 5.0000e-04

Epoch 29/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 94s 113ms/step - accuracy: 0.5959 - loss: 1.1009 - val_accuracy: 0.6787 - val_loss: 0.8737 - learning_rate: 5.0000e-04

Epoch 30/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 94s 113ms/step - accuracy: 0.5910 - loss: 1.0997 - val_accuracy: 0.6829 - val_loss: 0.8633 - learning_rate: 5.0000e-04

Restoring model weights from the end of the best epoch: 27.



============================================================



============================================================

============================================================

============================================================

  Unfroze layers from index 3: 2/5 trainable

Epoch 1/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 530s 619ms/step - accuracy: 0.2226 - loss: 8.0492 - val_accuracy: 0.0659 - val_loss: 33.8210 - learning_rate: 1.0000e-05

Epoch 2/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 514s 618ms/step - accuracy: 0.3278 - loss: 1.9544 - val_accuracy: 0.0680 - val_loss: 23.8821 - learning_rate: 1.0000e-05

Epoch 3/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 521s 626ms/step - accuracy: 0.4084 - loss: 1.6009 - val_accuracy: 0.0944 - val_loss: 15.0651 - learning_rate: 1.0000e-05

Epoch 4/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 513s 617ms/step - accuracy: 0.4989 - loss: 1.3403 - val_accuracy: 0.4863 - val_loss: 1.5419 - learning_rate: 1.0000e-05

Epoch 5/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 517s 621ms/step - accuracy: 0.5785 - loss: 1.1375 - val_accuracy: 0.6794 - val_loss: 0.8602 - learning_rate: 1.0000e-05

Epoch 6/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 513s 617ms/step - accuracy: 0.6296 - loss: 0.9978 - val_accuracy: 0.7238 - val_loss: 0.7043 - learning_rate: 1.0000e-05

Epoch 7/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 512s 615ms/step - accuracy: 0.6718 - loss: 0.8882 - val_accuracy: 0.7811 - val_loss: 0.5720 - learning_rate: 1.0000e-05

Epoch 8/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 513s 617ms/step - accuracy: 0.7015 - loss: 0.8086 - val_accuracy: 0.8057 - val_loss: 0.5145 - learning_rate: 1.0000e-05

Epoch 9/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 515s 618ms/step - accuracy: 0.7303 - loss: 0.7397 - val_accuracy: 0.8390 - val_loss: 0.4451 - learning_rate: 1.0000e-05

Epoch 10/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 513s 617ms/step - accuracy: 0.7483 - loss: 0.6877 - val_accuracy: 0.8576 - val_loss: 0.3895 - learning_rate: 1.0000e-05

Epoch 11/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 512s 616ms/step - accuracy: 0.7668 - loss: 0.6391 - val_accuracy: 0.8788 - val_loss: 0.3511 - learning_rate: 1.0000e-05

Epoch 12/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 511s 614ms/step - accuracy: 0.7868 - loss: 0.5932 - val_accuracy: 0.9046 - val_loss: 0.2964 - learning_rate: 1.0000e-05

Epoch 13/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 513s 616ms/step - accuracy: 0.8002 - loss: 0.5562 - val_accuracy: 0.9060 - val_loss: 0.2848 - learning_rate: 1.0000e-05

Epoch 14/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 517s 621ms/step - accuracy: 0.8107 - loss: 0.5324 - val_accuracy: 0.9158 - val_loss: 0.2634 - learning_rate: 1.0000e-05

Epoch 15/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 514s 618ms/step - accuracy: 0.8114 - loss: 0.5225 - val_accuracy: 0.9265 - val_loss: 0.2450 - learning_rate: 1.0000e-05

Epoch 16/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 514s 618ms/step - accuracy: 0.8254 - loss: 0.4827 - val_accuracy: 0.9342 - val_loss: 0.2189 - learning_rate: 1.0000e-05

Epoch 17/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 515s 619ms/step - accuracy: 0.8370 - loss: 0.4582 - val_accuracy: 0.9398 - val_loss: 0.2004 - learning_rate: 1.0000e-05

Epoch 18/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 515s 619ms/step - accuracy: 0.8403 - loss: 0.4523 - val_accuracy: 0.9460 - val_loss: 0.1828 - learning_rate: 1.0000e-05

Epoch 19/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 513s 617ms/step - accuracy: 0.8419 - loss: 0.4457 - val_accuracy: 0.9511 - val_loss: 0.1792 - learning_rate: 1.0000e-05

Epoch 20/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 513s 616ms/step - accuracy: 0.8500 - loss: 0.4260 - val_accuracy: 0.9519 - val_loss: 0.1700 - learning_rate: 1.0000e-05

Epoch 21/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 515s 619ms/step - accuracy: 0.8554 - loss: 0.4157 - val_accuracy: 0.9498 - val_loss: 0.1735 - learning_rate: 1.0000e-05

Epoch 22/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 513s 617ms/step - accuracy: 0.8610 - loss: 0.3910 - val_accuracy: 0.9516 - val_loss: 0.1638 - learning_rate: 1.0000e-05

Epoch 23/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 511s 614ms/step - accuracy: 0.8627 - loss: 0.3956 - val_accuracy: 0.9569 - val_loss: 0.1535 - learning_rate: 1.0000e-05

Epoch 24/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 516s 620ms/step - accuracy: 0.8655 - loss: 0.3892 - val_accuracy: 0.9532 - val_loss: 0.1572 - learning_rate: 1.0000e-05

Epoch 25/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 515s 619ms/step - accuracy: 0.8717 - loss: 0.3680 - val_accuracy: 0.9579 - val_loss: 0.1503 - learning_rate: 1.0000e-05

Epoch 26/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 513s 616ms/step - accuracy: 0.8763 - loss: 0.3616 - val_accuracy: 0.9591 - val_loss: 0.1459 - learning_rate: 1.0000e-05

Epoch 27/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 512s 616ms/step - accuracy: 0.8776 - loss: 0.3544 - val_accuracy: 0.9586 - val_loss: 0.1474 - learning_rate: 1.0000e-05

Epoch 28/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 514s 617ms/step - accuracy: 0.8826 - loss: 0.3386 - val_accuracy: 0.9609 - val_loss: 0.1369 - learning_rate: 1.0000e-05

Epoch 29/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 513s 616ms/step - accuracy: 0.8876 - loss: 0.3253 - val_accuracy: 0.9628 - val_loss: 0.1325 - learning_rate: 1.0000e-05

Epoch 30/30

832/832 ━━━━━━━━━━━━━━━━━━━━ 515s 619ms/step - accuracy: 0.8873 - loss: 0.3217 - val_accuracy: 0.9632 - val_loss: 0.1285 - learning_rate: 1.0000e-05

Restoring model weights from the end of the best epoch: 30.

Save models

Step 12 — Phase 2 Evaluation & Final Comparison

Evaluate both Phase 2 models on the test set with per-class F1 scores, grouped by defect count (single, double, triple, quadruple). Then produce a final summary table comparing all 4 models.

from sklearn.metrics import confusion_matrix

# All 38 pattern names in class ID order
all_pattern_names = [id_to_pattern[i] for i in range(NUM_CLASSES)]


def evaluate_phase2(model, dataset, y_true, model_name):
    """Evaluate a 38-class Phase 2 model with per-class and grouped metrics."""
    y_pred = model.predict(dataset, verbose=0).argmax(axis=1)

    separator = "=" * 60
    print(f"\n{separator}")
    print(f"  {model_name} — Test Set Results (38 classes)")
    print(separator)

    report = classification_report(
        y_true,
        y_pred,
        target_names=all_pattern_names,
        zero_division=0,
        output_dict=True,
    )
    print(
        classification_report(
            y_true,
            y_pred,
            target_names=all_pattern_names,
            zero_division=0,
        )
    )

    macro_f1 = f1_score(y_true, y_pred, average="macro", zero_division=0)
    weighted_f1 = f1_score(y_true, y_pred, average="weighted", zero_division=0)
    accuracy = np.mean(y_pred == y_true)
    print(f"Accuracy:    {accuracy:.4f}")
    print(f"Macro F1:    {macro_f1:.4f}")
    print(f"Weighted F1: {weighted_f1:.4f}")

    print("\nF1 by defect combination count:")
    for n_plus in range(4):
        group_labels = [name for name in all_pattern_names if name.count("+") == n_plus]
        group_f1s = [report[name]["f1-score"] for name in group_labels if name in report]
        if group_f1s:
            group_name = ["Single/Normal", "Double", "Triple", "Quadruple"][n_plus]
            print(f"  {group_name:<12} avg F1={np.mean(group_f1s):.4f}  (n={len(group_f1s)} classes)")

    return y_pred, report, macro_f1, weighted_f1


y_pred_p2_cnn, report_p2_cnn, mf1_p2_cnn, wf1_p2_cnn = evaluate_phase2(
    model_p2_cnn,
    test_ds_p2,
    y_class_test,
    "Phase2 Custom CNN",
)
y_pred_p2_tl, report_p2_tl, mf1_p2_tl, wf1_p2_tl = evaluate_phase2(
    model_p2_tl,
    test_ds_p2,
    y_class_test,
    "Phase2 MobileNetV2",
)

============================================================
  Phase2 Custom CNN — Test Set Results (38 classes)
============================================================
                              precision    recall  f1-score   support

                      Normal       1.00      1.00      1.00       150
                      Random       1.00      0.99      1.00       130
                     Scratch       0.99      1.00      0.99       150
                   Near_Full       0.96      1.00      0.98        23
                         Loc       0.96      1.00      0.98       150
                 Loc+Scratch       1.00      0.95      0.98       150
                   Edge_Ring       0.98      0.99      0.98       150
           Edge_Ring+Scratch       0.99      0.97      0.98       150
               Edge_Ring+Loc       0.92      0.97      0.94       150
       Edge_Ring+Loc+Scratch       0.97      0.91      0.94       150
                    Edge_Loc       0.99      0.98      0.98       150
            Edge_Loc+Scratch       0.97      0.99      0.98       150
                Edge_Loc+Loc       0.96      0.99      0.97       150
        Edge_Loc+Loc+Scratch       0.99      0.95      0.97       150
                       Donut       0.95      0.99      0.97       150
               Donut+Scratch       0.97      0.98      0.98       150
                   Donut+Loc       0.94      0.95      0.95       150
           Donut+Loc+Scratch       0.97      0.94      0.96       150
             Donut+Edge_Ring       0.97      0.98      0.98       150
     Donut+Edge_Ring+Scratch       0.95      0.97      0.96       150
         Donut+Edge_Ring+Loc       0.98      0.99      0.99       150
 Donut+Edge_Ring+Loc+Scratch       0.99      0.97      0.98       150
              Donut+Edge_Loc       0.98      0.93      0.96       150
      Donut+Edge_Loc+Scratch       0.94      0.95      0.95       150
          Donut+Edge_Loc+Loc       1.00      0.96      0.98       150
  Donut+Edge_Loc+Loc+Scratch       0.97      0.99      0.98       150
                      Center       1.00      0.99      1.00       150
              Center+Scratch       0.96      1.00      0.98       150
                  Center+Loc       0.97      0.99      0.98       150
          Center+Loc+Scratch       0.99      0.97      0.98       150
            Center+Edge_Ring       0.98      0.99      0.99       150
    Center+Edge_Ring+Scratch       0.97      0.97      0.97       150
        Center+Edge_Ring+Loc       0.94      0.98      0.96       150
Center+Edge_Ring+Loc+Scratch       0.97      0.95      0.96       150
             Center+Edge_Loc       0.99      0.99      0.99       150
     Center+Edge_Loc+Scratch       1.00      0.96      0.98       300
         Center+Edge_Loc+Loc       0.95      0.98      0.96       150
 Center+Edge_Loc+Loc+Scratch       0.99      0.95      0.97       150

                    accuracy                           0.97      5703
                   macro avg       0.97      0.97      0.97      5703
                weighted avg       0.97      0.97      0.97      5703

Accuracy:    0.9735
Macro F1:    0.9736
Weighted F1: 0.9735

F1 by defect combination count:
  Single/Normal avg F1=0.9866  (n=9 classes)
  Double       avg F1=0.9725  (n=13 classes)
  Triple       avg F1=0.9655  (n=12 classes)
  Quadruple    avg F1=0.9722  (n=4 classes)

============================================================
  Phase2 MobileNetV2 — Test Set Results (38 classes)
============================================================
                              precision    recall  f1-score   support

                      Normal       1.00      1.00      1.00       150
                      Random       1.00      0.99      1.00       130
                     Scratch       0.95      1.00      0.97       150
                   Near_Full       0.96      1.00      0.98        23
                         Loc       0.96      0.99      0.97       150
                 Loc+Scratch       0.97      0.95      0.96       150
                   Edge_Ring       0.96      1.00      0.98       150
           Edge_Ring+Scratch       0.94      1.00      0.97       150
               Edge_Ring+Loc       0.92      0.97      0.95       150
       Edge_Ring+Loc+Scratch       0.94      0.90      0.92       150
                    Edge_Loc       0.99      0.97      0.98       150
            Edge_Loc+Scratch       1.00      0.92      0.96       150
                Edge_Loc+Loc       0.97      0.99      0.98       150
        Edge_Loc+Loc+Scratch       1.00      0.89      0.94       150
                       Donut       0.96      1.00      0.98       150
               Donut+Scratch       0.96      0.99      0.98       150
                   Donut+Loc       0.97      0.98      0.97       150
           Donut+Loc+Scratch       0.95      0.97      0.96       150
             Donut+Edge_Ring       0.96      1.00      0.98       150
     Donut+Edge_Ring+Scratch       0.91      0.99      0.95       150
         Donut+Edge_Ring+Loc       0.96      1.00      0.98       150
 Donut+Edge_Ring+Loc+Scratch       0.97      0.96      0.96       150
              Donut+Edge_Loc       0.97      0.95      0.96       150
      Donut+Edge_Loc+Scratch       0.99      0.87      0.93       150
          Donut+Edge_Loc+Loc       1.00      0.96      0.98       150
  Donut+Edge_Loc+Loc+Scratch       0.99      0.91      0.95       150
                      Center       1.00      0.99      1.00       150
              Center+Scratch       0.93      1.00      0.96       150
                  Center+Loc       0.96      1.00      0.98       150
          Center+Loc+Scratch       0.98      0.97      0.98       150
            Center+Edge_Ring       0.96      1.00      0.98       150
    Center+Edge_Ring+Scratch       0.89      0.99      0.94       150
        Center+Edge_Ring+Loc       0.93      0.99      0.96       150
Center+Edge_Ring+Loc+Scratch       0.98      0.97      0.98       150
             Center+Edge_Loc       0.99      0.96      0.97       150
     Center+Edge_Loc+Scratch       0.99      0.91      0.95       300
         Center+Edge_Loc+Loc       0.97      0.92      0.95       150
 Center+Edge_Loc+Loc+Scratch       1.00      0.92      0.96       150

                    accuracy                           0.97      5703
                   macro avg       0.97      0.97      0.97      5703
                weighted avg       0.97      0.97      0.97      5703

Accuracy:    0.9660
Macro F1:    0.9667
Weighted F1: 0.9658

F1 by defect combination count:
  Single/Normal avg F1=0.9841  (n=9 classes)
  Double       avg F1=0.9690  (n=13 classes)
  Triple       avg F1=0.9527  (n=12 classes)
  Quadruple    avg F1=0.9624  (n=4 classes)
# Confusion matrix for the best Phase 2 model
best_p2_name = "Custom CNN" if mf1_p2_cnn >= mf1_p2_tl else "MobileNetV2"
best_p2_pred = y_pred_p2_cnn if mf1_p2_cnn >= mf1_p2_tl else y_pred_p2_tl
print(f"Plotting confusion matrix for best Phase 2 model: {best_p2_name}")

cm = confusion_matrix(y_class_test, best_p2_pred, labels=np.arange(NUM_CLASSES))

# Convert to percentages (normalize by row)
cm_percent = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis] * 100

fig, ax = plt.subplots(figsize=(20, 18))
sns.heatmap(
    cm_percent,
    annot=True,
    fmt=".1f",
    cmap="Blues",
    xticklabels=all_pattern_names,
    yticklabels=all_pattern_names,
    ax=ax,
    linewidths=0.5,
    cbar_kws={"label": "Percentage (%)"},
)
ax.set_xlabel("Predicted", fontsize=12)
ax.set_ylabel("True", fontsize=12)
ax.set_title(f"Phase 2 — {best_p2_name} Confusion Matrix (38 classes, %)", fontsize=14)
plt.xticks(rotation=90, fontsize=7)
plt.yticks(rotation=0, fontsize=7)
plt.tight_layout()
plt.show()

# Confusion matrices split by defect-combination count for the best Phase 2 model
best_p2_name = "Custom CNN" if mf1_p2_cnn >= mf1_p2_tl else "MobileNetV2"
best_p2_pred = y_pred_p2_cnn if mf1_p2_cnn >= mf1_p2_tl else y_pred_p2_tl
print(f"Plotting split confusion matrices for best Phase 2 model: {best_p2_name}")

group_defect_counts = {
    "Base patterns (single/normal)": 0,
    "Two-defect combinations": 1,
    "Three-defect combinations": 2,
    "Four-defect combinations": 3,
}


def plot_group_confusion_matrix(group_name, plus_count):
    """Plot confusion matrix for classes with a specific '+' count in the label name."""
    group_ids = [cid for cid, name in id_to_pattern.items() if name.count("+") == plus_count]
    group_names = [id_to_pattern[cid] for cid in group_ids]

    # Use only samples whose true class is in this group
    mask = np.isin(y_class_test, group_ids)
    y_true_group = y_class_test[mask]
    y_pred_group = best_p2_pred[mask]

    cm = confusion_matrix(y_true_group, y_pred_group, labels=group_ids)
    row_sums = cm.sum(axis=1, keepdims=True)
    cm_percent = np.divide(cm, row_sums, out=np.zeros_like(cm, dtype=float), where=row_sums != 0) * 100

    fig_w = max(8, 0.6 * len(group_names))
    fig_h = max(6, 0.6 * len(group_names))
    fig, ax = plt.subplots(figsize=(fig_w, fig_h))  # noqa: RUF059
    sns.heatmap(
        cm_percent,
        annot=True,
        fmt=".1f",
        cmap="Blues",
        xticklabels=group_names,
        yticklabels=group_names,
        ax=ax,
        linewidths=0.5,
        cbar_kws={"label": "Percentage (%)"},
    )
    ax.set_xlabel("Predicted", fontsize=11)
    ax.set_ylabel("True", fontsize=11)
    ax.set_title(f"Phase 2 — {best_p2_name} Confusion Matrix ({group_name}, %)", fontsize=13)
    plt.xticks(rotation=90, fontsize=7)
    plt.yticks(rotation=0, fontsize=7)
    plt.tight_layout()
    plt.show()


for group_name, plus_count in group_defect_counts.items():
    plot_group_confusion_matrix(group_name, plus_count)
Plotting confusion matrix for best Phase 2 model: Custom CNN

Plotting split confusion matrices for best Phase 2 model: Custom CNN

# ── Final Summary Table: All 4 Models ──
summary_data = {
    "Model": [
        "Phase 1 — Custom CNN (8-label)",
        "Phase 1 — MobileNetV2 (8-label)",
        "Phase 2 — Custom CNN (38-class)",
        "Phase 2 — MobileNetV2 (38-class)",
    ],
    "Macro F1": [
        report_cnn["macro avg"]["f1-score"],
        report_tl["macro avg"]["f1-score"],
        mf1_p2_cnn,
        mf1_p2_tl,
    ],
    "Weighted F1": [
        report_cnn["weighted avg"]["f1-score"],
        report_tl["weighted avg"]["f1-score"],
        wf1_p2_cnn,
        wf1_p2_tl,
    ],
}

summary_df = pd.DataFrame(summary_data)
display(
    summary_df.style
    .format({"Macro F1": "{:.4f}", "Weighted F1": "{:.4f}"})
    .highlight_max(subset=["Macro F1", "Weighted F1"], color="lightgreen")
    .set_caption("Final Model Comparison — All Phases")
)
Table 1: Final Model Comparison — All Phases
  Model Macro F1 Weighted F1
0 Phase 1 — Custom CNN (8-label) 0.9779 0.9905
1 Phase 1 — MobileNetV2 (8-label) 0.8998 0.8847
2 Phase 2 — Custom CNN (38-class) 0.9736 0.9735
3 Phase 2 — MobileNetV2 (38-class) 0.9667 0.9658

Save models

# Create a directory for Phase 2 models
MODELS_DIR = "phase2_models"
os.makedirs(MODELS_DIR, exist_ok=True)

# Save Phase 2 models
model_p2_cnn.save(f"{MODELS_DIR}/model_p2_cnn.keras")
model_p2_tl.save(f"{MODELS_DIR}/model_p2_tl.keras")

print(f"✓ Phase 2 models saved to '{MODELS_DIR}/':")
print("  - model_p2_cnn.keras")
print("  - model_p2_tl.keras")
✓ Phase 2 models saved to 'phase2_models/':
  - model_p2_cnn.keras
  - model_p2_tl.keras

Save Training Artifacts

# Create a directory for artifacts
ARTIFACT_DIR = "training_artifacts"
os.makedirs(ARTIFACT_DIR, exist_ok=True)

# Save all trained models
model_cnn.save(f"{ARTIFACT_DIR}/model_cnn.keras")
model_tl.save(f"{ARTIFACT_DIR}/model_tl.keras")
model_p2_cnn.save(f"{ARTIFACT_DIR}/model_p2_cnn.keras")
model_p2_tl.save(f"{ARTIFACT_DIR}/model_p2_tl.keras")

# Make a copy if running all notebooks sequentially to avoid overwriting
model_cnn2 = model_cnn
model_tl2 = model_tl
model_p2_cnn2 = model_p2_cnn
model_p2_tl2 = model_p2_tl

# Save everything else in pickle (keep original types)
artifacts = {
    "X_train": X_train,
    "X_val": X_val,
    "X_test": X_test,
    "y_train": y_train,
    "y_val": y_val,
    "y_test": y_test,
    "y_class_train": y_class_train,
    "y_class_val": y_class_val,
    "y_class_test": y_class_test,
    "label_mapping": label_mapping,
    "pattern_to_id": pattern_to_id,
    "id_to_pattern": id_to_pattern,
    "class_weights_dict": class_weights_dict,
    "pos_weights_tensor": pos_weights_tensor,
    "history_cnn": history_cnn,
    "history_tl": history_tl,
    "hist_p2_cnn_a": hist_p2_cnn_a,
    "hist_p2_cnn_b": hist_p2_cnn_b,
    "hist_p2_tl_a": hist_p2_tl_a,
    "hist_p2_tl_b": hist_p2_tl_b,
    "BinaryFocalLoss": BinaryFocalLoss,
    "focal_loss": focal_loss,
    "THRESHOLD": THRESHOLD,
    "NUM_CLASSES": NUM_CLASSES,
    "BATCH_SIZE": BATCH_SIZE,
    "EPOCHS": EPOCHS,
    "TARGET_SIZE": TARGET_SIZE,
    "AUTOTUNE": AUTOTUNE,
}

with open(f"{ARTIFACT_DIR}/artifacts.pkl", "wb") as f:
    pickle.dump(artifacts, f)

print(f"\n✓ All artifacts saved to '{ARTIFACT_DIR}/':")
print("  - model_cnn.keras")
print("  - model_tl.keras")
print("  - model_p2_cnn.keras")
print("  - model_p2_tl.keras")
print("  - artifacts.pkl")

✓ All artifacts saved to 'training_artifacts/':
  - model_cnn.keras
  - model_tl.keras
  - model_p2_cnn.keras
  - model_p2_tl.keras
  - artifacts.pkl

Load Training Artifacts

ARTIFACT_DIR = "training_artifacts"

# Load all models
print("Loading models...")
model_cnn = tf.keras.models.load_model(f"{ARTIFACT_DIR}/model_cnn.keras")
model_tl = tf.keras.models.load_model(f"{ARTIFACT_DIR}/model_tl.keras")
model_p2_cnn = tf.keras.models.load_model(f"{ARTIFACT_DIR}/model_p2_cnn.keras")
model_p2_tl = tf.keras.models.load_model(f"{ARTIFACT_DIR}/model_p2_tl.keras")
print("✓ All models loaded")

# Load all artifacts
print("Loading artifacts...")
with open(f"{ARTIFACT_DIR}/artifacts.pkl", "rb") as f:
    artifacts = pickle.load(f)

# Unpack everything
X_train = artifacts["X_train"]
X_val = artifacts["X_val"]
X_test = artifacts["X_test"]
y_train = artifacts["y_train"]
y_val = artifacts["y_val"]
y_test = artifacts["y_test"]
y_class_train = artifacts["y_class_train"]
y_class_val = artifacts["y_class_val"]
y_class_test = artifacts["y_class_test"]
label_mapping = artifacts["label_mapping"]
pattern_to_id = artifacts["pattern_to_id"]
id_to_pattern = artifacts["id_to_pattern"]
class_weights_dict = artifacts["class_weights_dict"]
pos_weights_tensor = artifacts["pos_weights_tensor"]
history_cnn = artifacts["history_cnn"]
history_tl = artifacts["history_tl"]
hist_p2_cnn_a = artifacts["hist_p2_cnn_a"]
hist_p2_cnn_b = artifacts["hist_p2_cnn_b"]
hist_p2_tl_a = artifacts["hist_p2_tl_a"]
hist_p2_tl_b = artifacts["hist_p2_tl_b"]
BinaryFocalLoss = artifacts["BinaryFocalLoss"]
focal_loss = artifacts["focal_loss"]
THRESHOLD = artifacts["THRESHOLD"]
NUM_CLASSES = artifacts["NUM_CLASSES"]
BATCH_SIZE = artifacts["BATCH_SIZE"]
EPOCHS = artifacts["EPOCHS"]
TARGET_SIZE = artifacts["TARGET_SIZE"]
AUTOTUNE = artifacts["AUTOTUNE"]
print("✓ All artifacts loaded")

# Verify everything loaded correctly
print("\n" + "=" * 60)
print("Verification:")
print("=" * 60)
print(f"X_train shape: {X_train.shape}")
print(f"y_class_test shape: {y_class_test.shape}")
print(f"NUM_CLASSES: {NUM_CLASSES}")
print(f"Models loaded: {model_cnn.name}, {model_tl.name}, {model_p2_cnn.name}, {model_p2_tl.name}")

# Quick prediction test
test_ds_p2 = build_dataset(X_test, y_class_test)
test_batch = next(iter(test_ds_p2))
pred_cnn_p2 = model_p2_cnn.predict(test_batch[0], verbose=0)
print(f"✓ Model prediction successful (shape: {pred_cnn_p2.shape})")
print("\nAll artifacts loaded successfully! ✓")