Object Detection Circuit Extraction¶

This notebook extracts and visualizes the minimal computational circuit for pot detection in a fine-tuned Faster R-CNN model using activation patching and co-activation analysis.

Approach¶

  • Layer 4 only (2048 channels, focused analysis)
  • Channel-level granularity (interpretable, ~50-200 critical channels)
  • Activation patching (systematic ablation to identify critical channels)
  • Co-activation analysis (find functional relationships between channels)
  • Semantic role assignment (categorize as edge/texture/shape/semantic)
  • Interactive Plotly visualization (zoom, pan, hover to explore dense clusters)

Expected Runtime¶

~5-10 minutes (mainly ablation study testing 2048 channels)

Outputs¶

  1. circuit_diagram.html - Interactive Plotly visualization (zoom/pan/hover)
  2. channel_importance.npy - Ablation impact scores for all channels
  3. circuit.npy - Circuit structure (critical channels, edges, roles)

Section 1: Setup & Installation¶

In [1]:
# Install required packages
!pip install -q torch torchvision rasterio pillow networkx matplotlib seaborn plotly tqdm
In [2]:
# Imports
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import networkx as nx
import plotly.graph_objects as go
from tqdm import tqdm
from PIL import Image
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
import torchvision.transforms as T

# Mount Google Drive (if needed)
from google.colab import drive
drive.mount('/content/drive')

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
Mounted at /content/drive
PyTorch version: 2.9.0+cu126
CUDA available: True
CUDA device: Tesla T4

Section 2: Model Definition¶

In [3]:
def get_model(num_classes=2, pretrained=False):
    """
    Get Faster R-CNN model with ResNet50-FPN backbone.

    Args:
        num_classes: Number of classes (including background)
        pretrained: Use COCO pre-trained weights

    Returns:
        Faster R-CNN model
    """
    model = fasterrcnn_resnet50_fpn(pretrained=pretrained)
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    return model

print("✓ Model definition loaded")
✓ Model definition loaded

Section 3: Image Loading Helper¶

In [4]:
def load_image(image_path, crop_size=1024, crop_center=True, crop_offset=(0, 0)):
    """
    Load image from file and prepare for Faster R-CNN.

    IMPORTANT:
    - The model was trained on 1024x1024 tiles
    - Large images MUST be cropped to 1024x1024 for inference
    - This function automatically crops to the specified size

    Args:
        image_path: Path to image file
        crop_size: Size to crop (default 1024 to match training)
        crop_center: If True, crop from center. If False, use crop_offset
        crop_offset: (y, x) offset for crop start position if not centering

    Returns:
        torch.Tensor: (3, crop_size, crop_size) normalized image in [0, 1]
    """
    if image_path.lower().endswith(('.tif', '.tiff')):
        # GeoTIFF loading
        import rasterio
        with rasterio.open(image_path) as src:
            # Read RGB bands (1, 2, 3) and transpose to (H, W, C)
            img_array = src.read([1, 2, 3]).transpose(1, 2, 0)  # (H, W, 3)

            print(f"  GeoTIFF info:")
            print(f"    Bands: {src.count} total, using bands 1-3 (RGB)")
            print(f"    Dtype: {img_array.dtype}")
            print(f"    Full shape: {img_array.shape}")
            print(f"    Value range: [{img_array.min()}, {img_array.max()}]")

            # CRITICAL: Crop to model's expected size (1024x1024)
            h, w = img_array.shape[:2]
            if h > crop_size or w > crop_size:
                if crop_center:
                    # Crop from center
                    start_y = (h - crop_size) // 2
                    start_x = (w - crop_size) // 2
                else:
                    # Crop from specified offset
                    start_y, start_x = crop_offset

                start_y = max(0, min(start_y, h - crop_size))
                start_x = max(0, min(start_x, w - crop_size))

                img_array = img_array[start_y:start_y+crop_size, start_x:start_x+crop_size]
                print(f"    → Cropped to {img_array.shape} (model expects ~{crop_size}x{crop_size})")
                print(f"       Crop region: y=[{start_y}:{start_y+crop_size}], x=[{start_x}:{start_x+crop_size}]")

            # Normalize based on dtype
            if img_array.dtype == np.uint16:
                print(f"    → Detected uint16, normalizing from [0, 65535] to [0, 1]")
                img_array = img_array.astype(np.float32) / 65535.0
            elif img_array.dtype == np.uint8:
                print(f"    → Detected uint8, normalizing from [0, 255] to [0, 1]")
                img_array = img_array.astype(np.float32) / 255.0
            else:
                print(f"    → Converting {img_array.dtype} to float and normalizing")
                img_array = img_array.astype(np.float32)
                if img_array.max() > 1.0:
                    img_array = img_array / img_array.max()

            img_tensor = torch.from_numpy(img_array).permute(2, 0, 1)
            print(f"    Final tensor: {img_tensor.shape}, range: [{img_tensor.min():.3f}, {img_tensor.max():.3f}]")
    else:
        # Standard image loading (JPG, PNG)
        img = Image.open(image_path).convert('RGB')
        img_array = np.array(img)  # (H, W, 3), uint8

        print(f"  Standard image info:")
        print(f"    Dtype: {img_array.dtype}, Shape: {img_array.shape}")

        # Crop if needed
        h, w = img_array.shape[:2]
        if h > crop_size or w > crop_size:
            if crop_center:
                start_y = (h - crop_size) // 2
                start_x = (w - crop_size) // 2
            else:
                start_y, start_x = crop_offset

            start_y = max(0, min(start_y, h - crop_size))
            start_x = max(0, min(start_x, w - crop_size))
            img_array = img_array[start_y:start_y+crop_size, start_x:start_x+crop_size]
            print(f"    → Cropped to {img_array.shape}")

        print(f"    Value range: [{img_array.min()}, {img_array.max()}]")

        # Normalize
        img_tensor = torch.from_numpy(img_array).permute(2, 0, 1).float() / 255.0
        print(f"    Final tensor: {img_tensor.shape}, range: [{img_tensor.min():.3f}, {img_tensor.max():.3f}]")

    return img_tensor

print("✓ Image loading helper loaded (FIXED: crops to 1024x1024)")
✓ Image loading helper loaded (FIXED: crops to 1024x1024)

Section 4: Activation Capture System¶

In [5]:
class ActivationCapture:
    """
    Capture intermediate activations at specified layers using forward hooks.
    """
    def __init__(self, model, layer_paths):
        """
        Args:
            model: PyTorch model
            layer_paths: Dict mapping layer names to module paths
                        e.g., {'layer4': 'backbone.body.layer4'}
        """
        self.activations = {}
        self.hooks = []

        for name, path in layer_paths.items():
            layer = self._get_layer(model, path)
            hook = layer.register_forward_hook(self._make_hook(name))
            self.hooks.append(hook)

    def _get_layer(self, model, path):
        """Navigate to a layer by dot-separated path."""
        parts = path.split('.')
        layer = model
        for part in parts:
            layer = getattr(layer, part)
        return layer

    def _make_hook(self, name):
        """Create hook function that captures activations."""
        def hook(module, input, output):
            self.activations[name] = output.detach().clone()
        return hook

    def remove_hooks(self):
        """Remove all registered hooks."""
        for hook in self.hooks:
            hook.remove()
        self.hooks = []

print("✓ ActivationCapture class loaded")
✓ ActivationCapture class loaded

Section 5: Activation Patching Engine¶

In [6]:
class ActivationPatcher:
    """
    Patch (ablate) specific channels and measure impact on detection.
    """
    def _get_layer(self, model, path):
        """Navigate to a layer by dot-separated path."""
        parts = path.split('.')
        layer = model
        for part in parts:
            layer = getattr(layer, part)
        return layer

    def patch_channel(self, model, layer_path, channel_idx, patch_value=0):
        """
        Zero out a specific channel during forward pass.

        Args:
            model: PyTorch model
            layer_path: Path to layer (e.g., 'backbone.body.layer4')
            channel_idx: Channel index to ablate
            patch_value: Value to set (default 0 for ablation)

        Returns:
            Hook handle (call .remove() to cleanup)
        """
        def patch_hook(module, input, output):
            output[:, channel_idx, :, :] = patch_value
            return output

        layer = self._get_layer(model, layer_path)
        hook = layer.register_forward_hook(patch_hook)
        return hook

    def measure_impact(self, model, image, baseline_output):
        """
        Measure how detection degrades when a channel is ablated.

        Args:
            model: Model with patched channel
            image: Input image tensor
            baseline_output: Output from unpatched model

        Returns:
            impact_score: Change in detection confidence (0-1)
        """
        # Run with patched activations
        with torch.no_grad():
            patched_output = model([image])

        # Compare detection scores
        baseline_score = baseline_output[0]['scores'].max().item() if len(baseline_output[0]['scores']) > 0 else 0
        patched_score = patched_output[0]['scores'].max().item() if len(patched_output[0]['scores']) > 0 else 0

        impact = baseline_score - patched_score  # Positive = channel was important
        return impact

print("✓ ActivationPatcher class loaded")
✓ ActivationPatcher class loaded

Section 6: Ablation Study¶

In [7]:
def ablation_study(model, image, layer_path='backbone.body.layer4', num_channels=2048):
    """
    Test importance of each channel by ablating it and measuring impact on detection.

    Args:
        model: Fine-tuned Faster R-CNN model
        image: Test image tensor (3, H, W)
        layer_path: Path to layer to ablate
        num_channels: Number of channels in the layer

    Returns:
        channel_importance: (num_channels,) array of impact scores
    """
    # Get baseline output (no ablation)
    with torch.no_grad():
        baseline_output = model([image])

    channel_importance = np.zeros(num_channels)
    patcher = ActivationPatcher()

    # Test each channel
    for ch_idx in tqdm(range(num_channels), desc="Ablating channels"):
        # Patch (zero out) this channel
        hook = patcher.patch_channel(model, layer_path, ch_idx, patch_value=0)

        # Measure impact
        impact = patcher.measure_impact(model, image, baseline_output)
        channel_importance[ch_idx] = impact

        # Remove hook
        hook.remove()

    return channel_importance

print("✓ Ablation study function loaded")
✓ Ablation study function loaded

Section 7: Circuit Extraction¶

In [8]:
def assign_channel_roles(activations, channel_indices):
    """
    Use spatial frequency analysis to categorize channels.

    Args:
        activations: Layer activations (C, H, W)
        channel_indices: Indices of critical channels to categorize

    Returns:
        Dict mapping channel_idx -> role ('edge', 'texture', 'shape', 'semantic')
    """
    roles = {}

    for ch_idx in channel_indices:
        ch_data = activations[ch_idx].cpu().numpy()

        # Compute edge strength (high-frequency)
        grad_y, grad_x = np.gradient(ch_data)
        edge_strength = np.mean(np.sqrt(grad_x**2 + grad_y**2))

        # Normalize edge strength
        edge_strength_norm = edge_strength / (ch_data.std() + 1e-8)

        # Compute spatial autocorrelation (localization)
        flat = ch_data.flatten()
        if len(flat) > 1:
            autocorr = np.corrcoef(flat[:-1], flat[1:])[0, 1]
        else:
            autocorr = 0

        # Categorize based on heuristics
        if edge_strength_norm > 0.5:
            roles[ch_idx] = 'edge'
        elif autocorr > 0.7:
            roles[ch_idx] = 'semantic'
        elif autocorr < 0.3:
            roles[ch_idx] = 'texture'
        else:
            roles[ch_idx] = 'shape'

    return roles


def extract_circuit(channel_importance, activations, threshold_percentile=90):
    """
    Extract minimal circuit: channels + connections that are critical.

    For Layer 4 only, edges represent functional relationships between channels:
    - Co-activation: Channels that activate together on pot regions
    - Functional grouping: Channels that contribute to the same detections

    Args:
        channel_importance: (2048,) array of ablation impact scores
        activations: Dict with 'layer4' -> (1, 2048, H, W) activations
        threshold_percentile: Percentile for selecting critical channels

    Returns:
        circuit: Dict with:
            - 'critical_channels': Array of critical channel indices
            - 'edges': List of (ch_i, ch_j, correlation) tuples
            - 'channel_roles': Dict mapping channel -> semantic role
            - 'importance': Importance scores for critical channels
    """
    # Step 1: Identify critical channels (top percentile by ablation impact)
    threshold = np.percentile(channel_importance, threshold_percentile)
    critical_indices = np.where(channel_importance >= threshold)[0]

    print(f"Critical channels: {len(critical_indices)} / {len(channel_importance)}")

    # Step 2: Compute co-activation between critical channels
    layer4_activations = activations['layer4'][0]  # (2048, H, W)
    C, H, W = layer4_activations.shape

    # For each critical channel, flatten spatial dims
    activation_vectors = []
    for ch_idx in critical_indices:
        activation_vectors.append(layer4_activations[ch_idx].cpu().numpy().flatten())

    # Compute pairwise correlations
    activation_matrix = np.stack(activation_vectors)  # (num_critical, H*W)
    correlation_matrix = np.corrcoef(activation_matrix)  # (num_critical, num_critical)

    # Create edges for strong correlations (functional relationships)
    edges = []
    for i, ch_i in enumerate(critical_indices):
        for j, ch_j in enumerate(critical_indices):
            if i < j:  # Avoid duplicates
                corr = correlation_matrix[i, j]
                if abs(corr) > 0.5:  # Threshold for "strong" relationship
                    edges.append((int(ch_i), int(ch_j), float(abs(corr))))

    # Step 3: Assign semantic roles
    channel_roles = assign_channel_roles(layer4_activations, critical_indices)

    circuit = {
        'critical_channels': critical_indices,
        'edges': edges,
        'channel_roles': channel_roles,
        'importance': channel_importance[critical_indices]
    }

    return circuit

print("✓ Circuit extraction functions loaded")
✓ Circuit extraction functions loaded

Section 8: Circuit Visualization¶

In [9]:
def visualize_circuit(circuit, save_path='circuit_diagram.html'):
    """
    Create interactive Plotly visualization of the minimal circuit.

    Nodes = critical channels, colored by role (edge/texture/shape/semantic)
    Edges = co-activation patterns (correlation > 0.5)
    Node size = proportional to ablation importance

    Features:
    - Zoom/pan to explore dense clusters
    - Hover to see channel details
    - Interactive legend to filter by role
    - Export as HTML or PNG

    Args:
        circuit: Circuit dict from extract_circuit()
        save_path: Path to save visualization (.html or .png)
    """
    G = nx.Graph()  # Undirected for co-activation

    # Define color map for channel roles
    role_colors = {
        'edge': '#FF6B6B',      # Red
        'texture': '#4ECDC4',   # Cyan
        'shape': '#45B7D1',     # Blue
        'semantic': '#FFA07A'   # Orange
    }

    # Add nodes (critical channels)
    for idx, ch_idx in enumerate(circuit['critical_channels']):
        role = circuit['channel_roles'].get(ch_idx, 'unknown')
        importance = circuit['importance'][idx]

        G.add_node(f'Ch{ch_idx}',
                   channel_idx=int(ch_idx),
                   role=role,
                   importance=float(importance),
                   color=role_colors.get(role, '#CCCCCC'))

    # Add edges (co-activation patterns)
    for ch_i, ch_j, weight in circuit['edges']:
        G.add_edge(f'Ch{ch_i}', f'Ch{ch_j}', weight=weight)

    # Compute layout using NetworkX
    pos = nx.spring_layout(G, k=1.5, iterations=100, seed=42)

    # Prepare edge traces
    edge_traces = []
    for edge in G.edges():
        x0, y0 = pos[edge[0]]
        x1, y1 = pos[edge[1]]
        weight = G[edge[0]][edge[1]]['weight']

        edge_trace = go.Scatter(
            x=[x0, x1, None],
            y=[y0, y1, None],
            mode='lines',
            line=dict(
                width=weight * 3,  # Scale by correlation
                color='rgba(125, 125, 125, 0.4)'
            ),
            hoverinfo='none',
            showlegend=False
        )
        edge_traces.append(edge_trace)

    # Prepare node traces (one trace per role for legend)
    node_traces_by_role = {}

    for node in G.nodes():
        role = G.nodes[node]['role']
        channel_idx = G.nodes[node]['channel_idx']
        importance = G.nodes[node]['importance']
        color = G.nodes[node]['color']
        x, y = pos[node]

        # Create trace for this role if not exists
        if role not in node_traces_by_role:
            node_traces_by_role[role] = {
                'x': [],
                'y': [],
                'text': [],
                'marker_size': [],
                'color': color,
                'name': role.capitalize()
            }

        # Add node to its role trace
        node_traces_by_role[role]['x'].append(x)
        node_traces_by_role[role]['y'].append(y)
        node_traces_by_role[role]['text'].append(
            f"<b>Channel {channel_idx}</b><br>" +
            f"Role: {role.capitalize()}<br>" +
            f"Importance: {importance:.4f}<br>" +
            f"<i>Click to zoom</i>"
        )
        node_traces_by_role[role]['marker_size'].append(importance * 30 + 10)  # Scale size

    # Create plotly figure
    fig = go.Figure()

    # Add edge traces
    for edge_trace in edge_traces:
        fig.add_trace(edge_trace)

    # Add node traces (grouped by role for legend)
    for role, trace_data in node_traces_by_role.items():
        fig.add_trace(go.Scatter(
            x=trace_data['x'],
            y=trace_data['y'],
            mode='markers+text',
            name=trace_data['name'],
            text=[f"Ch{circuit['critical_channels'][i]}" for i in range(len(trace_data['x']))],
            textposition="top center",
            textfont=dict(size=6),
            hovertext=trace_data['text'],
            hoverinfo='text',
            marker=dict(
                size=trace_data['marker_size'],
                color=trace_data['color'],
                line=dict(width=1, color='white'),
                opacity=0.9
            )
        ))

    # Layout configuration
    fig.update_layout(
        title=dict(
            text='Minimal Computational Circuit for Pot Detection (Layer 4)<br>' +
                 '<sub>Interactive: Zoom/Pan to explore | Hover for details | Click legend to filter</sub>',
            x=0.5,
            xanchor='center',
            font=dict(size=20)
        ),
        showlegend=True,
        legend=dict(
            title=dict(text='Channel Roles'),
            yanchor="top",
            y=0.99,
            xanchor="right",
            x=0.99,
            bgcolor='rgba(255,255,255,0.8)',
            bordercolor='gray',
            borderwidth=1
        ),
        hovermode='closest',
        margin=dict(b=20, l=20, r=20, t=80),
        xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
        yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
        plot_bgcolor='white',
        width=1400,
        height=1000
    )

    # Save and display
    if save_path.endswith('.html'):
        fig.write_html(save_path)
        print(f"✓ Interactive circuit saved to: {save_path}")
        print(f"  → Open in browser to explore")
    elif save_path.endswith('.png'):
        fig.write_image(save_path, width=1400, height=1000)
        print(f"✓ Static circuit saved to: {save_path}")
    else:
        # Default to HTML
        html_path = save_path.replace('.png', '.html') if '.png' in save_path else f"{save_path}.html"
        fig.write_html(html_path)
        print(f"✓ Interactive circuit saved to: {html_path}")

    # Show in notebook
    fig.show()

    # Print summary
    print(f"\n=== CIRCUIT SUMMARY ===")
    print(f"Total nodes: {G.number_of_nodes()}")
    print(f"Total edges: {G.number_of_edges()}")
    print(f"Sparsity: {len(circuit['critical_channels'])} / 2048 ({len(circuit['critical_channels'])/2048*100:.1f}%)")

    # Role breakdown
    role_counts = {}
    for ch_idx in circuit['critical_channels']:
        role = circuit['channel_roles'].get(ch_idx, 'unknown')
        role_counts[role] = role_counts.get(role, 0) + 1

    print(f"\n=== ROLE BREAKDOWN ===")
    for role, count in sorted(role_counts.items()):
        print(f"  {role.capitalize()}: {count} channels ({count/len(circuit['critical_channels'])*100:.1f}%)")

print("✓ Visualization function loaded (Plotly interactive)")
✓ Visualization function loaded (Plotly interactive)

Section 9: (Optional) Multi-Layer Analysis¶

Note: This section is optional and not part of the main workflow. It's included for future hierarchical circuit analysis across multiple ResNet layers.

In [10]:
def extract_hierarchical_circuit(model, image):
    """
    [OPTIONAL] Extract circuit across multiple layers for future analysis.

    Warning: This will take significantly longer (~30-40 minutes) as it tests
    all channels across layers 1-4 (256 + 512 + 1024 + 2048 = 3840 channels).
    """
    print("\n⚠️  This is an optional extended analysis (not part of main workflow)")
    print("   Focus is on Layer 4 only per user requirements.\n")

    layers_to_analyze = {
        'layer1': ('backbone.body.layer1', 256),
        'layer2': ('backbone.body.layer2', 512),
        'layer3': ('backbone.body.layer3', 1024),
        'layer4': ('backbone.body.layer4', 2048),
    }

    hierarchical_circuit = {}

    for layer_name, (layer_path, num_channels) in layers_to_analyze.items():
        print(f"Analyzing {layer_name} ({num_channels} channels)...")

        # Ablation study for this layer
        importance = ablation_study(model, image, layer_path, num_channels)

        # Extract critical channels (top 5%)
        threshold = np.percentile(importance, 95)
        critical = np.where(importance >= threshold)[0]

        hierarchical_circuit[layer_name] = {
            'importance': importance,
            'critical_channels': critical
        }

        print(f"  → {len(critical)} critical channels\n")

    return hierarchical_circuit

print("✓ (Optional) Multi-layer analysis function loaded")
✓ (Optional) Multi-layer analysis function loaded

Section 10: Main Pipeline¶

This orchestrates all the steps: load model → run ablation → extract circuit → visualize

In [11]:
def main_circuit_extraction_pipeline(checkpoint_path, test_image_path):
    """
    End-to-end pipeline: load model → run ablation → extract circuit → visualize.

    Args:
        checkpoint_path: Path to fine-tuned model checkpoint
                        (e.g., "/content/drive/MyDrive/Orthos/fior-pot-detector.pth")
        test_image_path: Path to test image with pot detections

    Returns:
        circuit: Extracted circuit structure
        channel_importance: (2048,) array of ablation scores
    """
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Using device: {device}\n")

    # === STEP 1: Load model ===
    print("=" * 60)
    print("STEP 1: Loading model")
    print("=" * 60)
    model = get_model(num_classes=2)
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    model = model.to(device)
    print("✓ Model loaded successfully")
    print(f"  Epoch: {checkpoint.get('epoch', 'unknown')}")
    if 'loss' in checkpoint:
        print(f"  Loss: {checkpoint['loss']:.4f}")
    print()

    # === STEP 2: Load test image ===
    print("=" * 60)
    print("STEP 2: Loading test image")
    print("=" * 60)
    image = load_image(test_image_path)
    print(f"✓ Image loaded and preprocessed: {image.shape}")
    print(f"  Value range: [{image.min():.3f}, {image.max():.3f}]")
    print(f"  Expected: [0.0, 1.0] (normalized from uint8)")

    image = image.to(device)
    print()

    # === STEP 3: Verify detection works ===
    print("=" * 60)
    print("STEP 3: Running baseline detection")
    print("=" * 60)

    # Run detection with lower threshold to see all detections
    original_thresh = model.roi_heads.score_thresh
    model.roi_heads.score_thresh = 0.95  # Lower threshold to see more detections

    with torch.no_grad():
        baseline_output = model([image])

    num_detections = len(baseline_output[0]['boxes'])
    print(f"✓ Detected {num_detections} objects (score threshold: {model.roi_heads.score_thresh})")

    if num_detections > 0:
        # Show top detections
        scores = baseline_output[0]['scores'].cpu().numpy()
        boxes = baseline_output[0]['boxes'].cpu().numpy()
        print(f"\n  Top detections:")
        for i in range(min(5, num_detections)):
            print(f"    {i+1}. Score: {scores[i]:.3f}, Box: [{boxes[i][0]:.0f}, {boxes[i][1]:.0f}, {boxes[i][2]:.0f}, {boxes[i][3]:.0f}]")

        # Reset to original threshold for ablation
        model.roi_heads.score_thresh = original_thresh

        # Re-run with normal threshold for baseline
        with torch.no_grad():
            baseline_output = model([image])
        num_baseline = len(baseline_output[0]['boxes'])
        print(f"\n  Detections at threshold {original_thresh}: {num_baseline}")
    else:
        print(f"\n  ⚠️  WARNING: No detections found!")
        print(f"  This could mean:")
        print(f"    1. The test image doesn't contain pots")
        print(f"    2. The model wasn't trained properly")
        print(f"    3. Image preprocessing mismatch (FIXED in this version)")
        print(f"\n  Continuing anyway for demonstration...")
        print(f"  Note: Ablation scores will all be 0 or very small")

    print()

    # === STEP 4: Capture activations ===
    print("=" * 60)
    print("STEP 4: Capturing Layer 4 activations")
    print("=" * 60)
    capturer = ActivationCapture(model, {'layer4': 'backbone.body.layer4'})
    with torch.no_grad():
        _ = model([image])
    activations = capturer.activations
    print(f"✓ Captured activations: {activations['layer4'].shape}")
    print(f"  Activation range: [{activations['layer4'].min():.3f}, {activations['layer4'].max():.3f}]")
    print()
    capturer.remove_hooks()

    # === STEP 5: Run ablation study ===
    print("=" * 60)
    print("STEP 5: Running ablation study on Layer 4 (2048 channels)")
    print("=" * 60)
    print("⚠️  This will take ~5-10 minutes (testing each channel)")

    if num_detections == 0:
        print("\n⚠️  WARNING: No baseline detections - ablation results will be minimal")
        print("Consider using a different test image with visible pots\n")

    channel_importance = ablation_study(model, image, 'backbone.body.layer4', 2048)
    print(f"✓ Ablation complete")
    print(f"  Impact range: [{channel_importance.min():.4f}, {channel_importance.max():.4f}]")
    print(f"  Mean impact: {channel_importance.mean():.4f}")
    print(f"  Non-zero impacts: {np.count_nonzero(channel_importance)} / 2048")
    print()

    # === STEP 6: Extract circuit ===
    print("=" * 60)
    print("STEP 6: Extracting minimal circuit")
    print("=" * 60)
    circuit = extract_circuit(channel_importance, activations, threshold_percentile=90)
    print(f"✓ Circuit extracted\n")

    # === STEP 7: Visualize ===
    print("=" * 60)
    print("STEP 7: Generating interactive circuit diagram")
    print("=" * 60)
    visualize_circuit(circuit, save_path='circuit_diagram.html')

    # === STEP 8: Save results ===
    print("\n" + "=" * 60)
    print("STEP 8: Saving results")
    print("=" * 60)
    np.save('channel_importance.npy', channel_importance)
    np.save('circuit.npy', circuit, allow_pickle=True)
    print("✓ Saved channel_importance.npy")
    print("✓ Saved circuit.npy")
    print("✓ Saved circuit_diagram.html (interactive)")

    print("\n" + "=" * 60)
    print("✅ PIPELINE COMPLETE!")
    print("=" * 60)

    if num_detections == 0:
        print("\n⚠️  NOTE: No detections were found in the baseline.")
        print("The circuit will be based on minimal activations.")
        print("For best results, use an image where the model detects pots.")

    return circuit, channel_importance

print("✓ Main pipeline function loaded")
✓ Main pipeline function loaded

Section 11: Execution¶

Instructions: Update the paths below with your checkpoint and test image locations, then run the cell.

IMPORTANT: Image Size¶

Your model was trained on 1024x1024 tiles. Large images (like 46074x2654) must be cropped!

The load_image() function now automatically:

  • Crops to 1024x1024 (matches training tile size)
  • By default, crops from center of the image
  • You can change crop position by editing the function call below

Options for cropping:

  1. Center crop (default):

    load_image(test_image_path)  # Crops from center
    
  2. Custom crop position:

    # Crop from pixel coordinates (y=10000, x=500)
    load_image(test_image_path, crop_center=False, crop_offset=(10000, 500))
    
  3. Different crop size:

    load_image(test_image_path, crop_size=512)  # Crop to 512x512
    

Pro tip: If center crop has no detections, try different offsets to find a region with pots!

In [12]:
# ============================================================
# OPTIONAL: DEBUG CELL - Run this first to test your setup
# ============================================================
# Uncomment and run this cell to visualize image and test detection

# checkpoint_path = "/content/drive/MyDrive/Orthos/fior-pot-detector.pth"
# test_image_path = "/content/drive/MyDrive/Orthos/test_image.tif"

# # Load and visualize image
# test_img = load_image(test_image_path)
# print(f"Image shape: {test_img.shape}")
# print(f"Value range: [{test_img.min():.3f}, {test_img.max():.3f}]")

# # Show image
# plt.figure(figsize=(10, 10))
# plt.imshow(test_img.permute(1, 2, 0).cpu().numpy())
# plt.title("Test Image")
# plt.axis('off')
# plt.show()

# # Test detection
# model = get_model(num_classes=2)
# checkpoint = torch.load(checkpoint_path)
# model.load_state_dict(checkpoint['model_state_dict'])
# model.eval()
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
# model = model.to(device)
# test_img = test_img.to(device)

# model.roi_heads.score_thresh = 0.05
# with torch.no_grad():
#     output = model([test_img])

# print(f"\nDetections: {len(output[0]['boxes'])}")
# if len(output[0]['boxes']) > 0:
#     print("Top 5 scores:", output[0]['scores'][:5].cpu().numpy())
# else:
#     print("⚠️ No detections! Try a different image.")
In [13]:
# ============================================================
# USER CONFIGURATION: Update these paths
# ============================================================

# Path to your fine-tuned Faster R-CNN checkpoint
checkpoint_path = "/content/drive/MyDrive/Orthos/fior-pot-detector.pth"

# Path to test image (should have pot detections)
# Supports: .tif, .tiff, .jpg, .jpeg, .png
test_image_path = "/content/drive/MyDrive/Orthos/35.tif"

# ============================================================
# RUN PIPELINE
# ============================================================

circuit, channel_importance = main_circuit_extraction_pipeline(
    checkpoint_path=checkpoint_path,
    test_image_path=test_image_path
)

print("\n🎉 Circuit extraction complete!")
print("\nOutputs:")
print("  - circuit_diagram.html (interactive visualization - zoom/pan/hover)")
print("  - channel_importance.npy (ablation scores)")
print("  - circuit.npy (circuit structure)")
print("\n💡 Tip: Download circuit_diagram.html and open in browser for best experience!")
Using device: cuda

============================================================
STEP 1: Loading model
============================================================
/usr/local/lib/python3.12/dist-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
  warnings.warn(
/usr/local/lib/python3.12/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=None`.
  warnings.warn(msg)
Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 189MB/s]
✓ Model loaded successfully
  Epoch: 50
  Loss: 0.0314

============================================================
STEP 2: Loading test image
============================================================
  GeoTIFF info:
    Bands: 4 total, using bands 1-3 (RGB)
    Dtype: uint8
    Full shape: (46074, 2654, 3)
    Value range: [0, 255]
    → Cropped to (1024, 1024, 3) (model expects ~1024x1024)
       Crop region: y=[22525:23549], x=[815:1839]
    → Detected uint8, normalizing from [0, 255] to [0, 1]
    Final tensor: torch.Size([3, 1024, 1024]), range: [0.055, 1.000]
✓ Image loaded and preprocessed: torch.Size([3, 1024, 1024])
  Value range: [0.055, 1.000]
  Expected: [0.0, 1.0] (normalized from uint8)

============================================================
STEP 3: Running baseline detection
============================================================
✓ Detected 56 objects (score threshold: 0.95)

  Top detections:
    1. Score: 1.000, Box: [444, 996, 504, 1024]
    2. Score: 1.000, Box: [714, 994, 775, 1024]
    3. Score: 1.000, Box: [175, 998, 239, 1024]
    4. Score: 1.000, Box: [852, 0, 912, 27]
    5. Score: 1.000, Box: [982, 990, 1024, 1024]

  Detections at threshold 0.05: 56

============================================================
STEP 4: Capturing Layer 4 activations
============================================================
✓ Captured activations: torch.Size([1, 2048, 25, 25])
  Activation range: [0.000, 5.029]

============================================================
STEP 5: Running ablation study on Layer 4 (2048 channels)
============================================================
⚠️  This will take ~5-10 minutes (testing each channel)
Ablating channels: 100%|██████████| 2048/2048 [02:54<00:00, 11.76it/s]
/usr/local/lib/python3.12/dist-packages/numpy/lib/_function_base_impl.py:2922: RuntimeWarning: invalid value encountered in divide
  c /= stddev[:, None]
/usr/local/lib/python3.12/dist-packages/numpy/lib/_function_base_impl.py:2923: RuntimeWarning: invalid value encountered in divide
  c /= stddev[None, :]
✓ Ablation complete
  Impact range: [-0.0000, 0.0000]
  Mean impact: -0.0000
  Non-zero impacts: 44 / 2048

============================================================
STEP 6: Extracting minimal circuit
============================================================
Critical channels: 2023 / 2048
✓ Circuit extracted

============================================================
STEP 7: Generating interactive circuit diagram
============================================================
✓ Interactive circuit saved to: circuit_diagram.html
  → Open in browser to explore
=== CIRCUIT SUMMARY ===
Total nodes: 2023
Total edges: 5099
Sparsity: 2023 / 2048 (98.8%)

=== ROLE BREAKDOWN ===
  Edge: 162 channels (8.0%)
  Shape: 1606 channels (79.4%)
  Texture: 255 channels (12.6%)

============================================================
STEP 8: Saving results
============================================================
✓ Saved channel_importance.npy
✓ Saved circuit.npy
✓ Saved circuit_diagram.html (interactive)

============================================================
✅ PIPELINE COMPLETE!
============================================================

🎉 Circuit extraction complete!

Outputs:
  - circuit_diagram.html (interactive visualization - zoom/pan/hover)
  - channel_importance.npy (ablation scores)
  - circuit.npy (circuit structure)

💡 Tip: Open circuit_diagram.html in browser for best experience!

Interactive Visualization Tips¶

Using the Plotly Circuit Diagram¶

The circuit diagram is now fully interactive with the following features:

  1. Zoom In/Out:

    • Scroll wheel to zoom
    • Double-click to reset view
    • Click and drag box to zoom to region
  2. Pan:

    • Click and drag to move around
    • Explore dense clusters in the center
  3. Hover:

    • Hover over nodes to see:
      • Channel number
      • Role (edge/texture/shape/semantic)
      • Importance score
  4. Filter by Role:

    • Click legend items to show/hide roles
    • Double-click legend to isolate one role
  5. Export:

    • Camera icon (top right) → Download as PNG
    • Share the HTML file for interactive viewing

Interpretation Guide¶

  • Node Size = Ablation importance (bigger = more critical)
  • Node Color = Semantic role:
    • Red = Edge detectors
    • Cyan = Texture detectors
    • Blue = Shape detectors
    • Orange = Semantic features
  • Edge Thickness = Co-activation strength (correlation)
  • Clusters = Functionally related channel groups