Object Detection Circuit Extraction¶
This notebook extracts and visualizes the minimal computational circuit for pot detection in a fine-tuned Faster R-CNN model using activation patching and co-activation analysis.
Approach¶
- Layer 4 only (2048 channels, focused analysis)
- Channel-level granularity (interpretable, ~50-200 critical channels)
- Activation patching (systematic ablation to identify critical channels)
- Co-activation analysis (find functional relationships between channels)
- Semantic role assignment (categorize as edge/texture/shape/semantic)
- Interactive Plotly visualization (zoom, pan, hover to explore dense clusters)
Expected Runtime¶
~5-10 minutes (mainly ablation study testing 2048 channels)
Outputs¶
circuit_diagram.html- Interactive Plotly visualization (zoom/pan/hover)channel_importance.npy- Ablation impact scores for all channelscircuit.npy- Circuit structure (critical channels, edges, roles)
Section 1: Setup & Installation¶
# Install required packages
!pip install -q torch torchvision rasterio pillow networkx matplotlib seaborn plotly tqdm
# Imports
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import networkx as nx
import plotly.graph_objects as go
from tqdm import tqdm
from PIL import Image
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
import torchvision.transforms as T
# Mount Google Drive (if needed)
from google.colab import drive
drive.mount('/content/drive')
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"CUDA device: {torch.cuda.get_device_name(0)}")
Mounted at /content/drive PyTorch version: 2.9.0+cu126 CUDA available: True CUDA device: Tesla T4
Section 2: Model Definition¶
def get_model(num_classes=2, pretrained=False):
"""
Get Faster R-CNN model with ResNet50-FPN backbone.
Args:
num_classes: Number of classes (including background)
pretrained: Use COCO pre-trained weights
Returns:
Faster R-CNN model
"""
model = fasterrcnn_resnet50_fpn(pretrained=pretrained)
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
return model
print("✓ Model definition loaded")
✓ Model definition loaded
Section 3: Image Loading Helper¶
def load_image(image_path, crop_size=1024, crop_center=True, crop_offset=(0, 0)):
"""
Load image from file and prepare for Faster R-CNN.
IMPORTANT:
- The model was trained on 1024x1024 tiles
- Large images MUST be cropped to 1024x1024 for inference
- This function automatically crops to the specified size
Args:
image_path: Path to image file
crop_size: Size to crop (default 1024 to match training)
crop_center: If True, crop from center. If False, use crop_offset
crop_offset: (y, x) offset for crop start position if not centering
Returns:
torch.Tensor: (3, crop_size, crop_size) normalized image in [0, 1]
"""
if image_path.lower().endswith(('.tif', '.tiff')):
# GeoTIFF loading
import rasterio
with rasterio.open(image_path) as src:
# Read RGB bands (1, 2, 3) and transpose to (H, W, C)
img_array = src.read([1, 2, 3]).transpose(1, 2, 0) # (H, W, 3)
print(f" GeoTIFF info:")
print(f" Bands: {src.count} total, using bands 1-3 (RGB)")
print(f" Dtype: {img_array.dtype}")
print(f" Full shape: {img_array.shape}")
print(f" Value range: [{img_array.min()}, {img_array.max()}]")
# CRITICAL: Crop to model's expected size (1024x1024)
h, w = img_array.shape[:2]
if h > crop_size or w > crop_size:
if crop_center:
# Crop from center
start_y = (h - crop_size) // 2
start_x = (w - crop_size) // 2
else:
# Crop from specified offset
start_y, start_x = crop_offset
start_y = max(0, min(start_y, h - crop_size))
start_x = max(0, min(start_x, w - crop_size))
img_array = img_array[start_y:start_y+crop_size, start_x:start_x+crop_size]
print(f" → Cropped to {img_array.shape} (model expects ~{crop_size}x{crop_size})")
print(f" Crop region: y=[{start_y}:{start_y+crop_size}], x=[{start_x}:{start_x+crop_size}]")
# Normalize based on dtype
if img_array.dtype == np.uint16:
print(f" → Detected uint16, normalizing from [0, 65535] to [0, 1]")
img_array = img_array.astype(np.float32) / 65535.0
elif img_array.dtype == np.uint8:
print(f" → Detected uint8, normalizing from [0, 255] to [0, 1]")
img_array = img_array.astype(np.float32) / 255.0
else:
print(f" → Converting {img_array.dtype} to float and normalizing")
img_array = img_array.astype(np.float32)
if img_array.max() > 1.0:
img_array = img_array / img_array.max()
img_tensor = torch.from_numpy(img_array).permute(2, 0, 1)
print(f" Final tensor: {img_tensor.shape}, range: [{img_tensor.min():.3f}, {img_tensor.max():.3f}]")
else:
# Standard image loading (JPG, PNG)
img = Image.open(image_path).convert('RGB')
img_array = np.array(img) # (H, W, 3), uint8
print(f" Standard image info:")
print(f" Dtype: {img_array.dtype}, Shape: {img_array.shape}")
# Crop if needed
h, w = img_array.shape[:2]
if h > crop_size or w > crop_size:
if crop_center:
start_y = (h - crop_size) // 2
start_x = (w - crop_size) // 2
else:
start_y, start_x = crop_offset
start_y = max(0, min(start_y, h - crop_size))
start_x = max(0, min(start_x, w - crop_size))
img_array = img_array[start_y:start_y+crop_size, start_x:start_x+crop_size]
print(f" → Cropped to {img_array.shape}")
print(f" Value range: [{img_array.min()}, {img_array.max()}]")
# Normalize
img_tensor = torch.from_numpy(img_array).permute(2, 0, 1).float() / 255.0
print(f" Final tensor: {img_tensor.shape}, range: [{img_tensor.min():.3f}, {img_tensor.max():.3f}]")
return img_tensor
print("✓ Image loading helper loaded (FIXED: crops to 1024x1024)")
✓ Image loading helper loaded (FIXED: crops to 1024x1024)
Section 4: Activation Capture System¶
class ActivationCapture:
"""
Capture intermediate activations at specified layers using forward hooks.
"""
def __init__(self, model, layer_paths):
"""
Args:
model: PyTorch model
layer_paths: Dict mapping layer names to module paths
e.g., {'layer4': 'backbone.body.layer4'}
"""
self.activations = {}
self.hooks = []
for name, path in layer_paths.items():
layer = self._get_layer(model, path)
hook = layer.register_forward_hook(self._make_hook(name))
self.hooks.append(hook)
def _get_layer(self, model, path):
"""Navigate to a layer by dot-separated path."""
parts = path.split('.')
layer = model
for part in parts:
layer = getattr(layer, part)
return layer
def _make_hook(self, name):
"""Create hook function that captures activations."""
def hook(module, input, output):
self.activations[name] = output.detach().clone()
return hook
def remove_hooks(self):
"""Remove all registered hooks."""
for hook in self.hooks:
hook.remove()
self.hooks = []
print("✓ ActivationCapture class loaded")
✓ ActivationCapture class loaded
Section 5: Activation Patching Engine¶
class ActivationPatcher:
"""
Patch (ablate) specific channels and measure impact on detection.
"""
def _get_layer(self, model, path):
"""Navigate to a layer by dot-separated path."""
parts = path.split('.')
layer = model
for part in parts:
layer = getattr(layer, part)
return layer
def patch_channel(self, model, layer_path, channel_idx, patch_value=0):
"""
Zero out a specific channel during forward pass.
Args:
model: PyTorch model
layer_path: Path to layer (e.g., 'backbone.body.layer4')
channel_idx: Channel index to ablate
patch_value: Value to set (default 0 for ablation)
Returns:
Hook handle (call .remove() to cleanup)
"""
def patch_hook(module, input, output):
output[:, channel_idx, :, :] = patch_value
return output
layer = self._get_layer(model, layer_path)
hook = layer.register_forward_hook(patch_hook)
return hook
def measure_impact(self, model, image, baseline_output):
"""
Measure how detection degrades when a channel is ablated.
Args:
model: Model with patched channel
image: Input image tensor
baseline_output: Output from unpatched model
Returns:
impact_score: Change in detection confidence (0-1)
"""
# Run with patched activations
with torch.no_grad():
patched_output = model([image])
# Compare detection scores
baseline_score = baseline_output[0]['scores'].max().item() if len(baseline_output[0]['scores']) > 0 else 0
patched_score = patched_output[0]['scores'].max().item() if len(patched_output[0]['scores']) > 0 else 0
impact = baseline_score - patched_score # Positive = channel was important
return impact
print("✓ ActivationPatcher class loaded")
✓ ActivationPatcher class loaded
Section 6: Ablation Study¶
def ablation_study(model, image, layer_path='backbone.body.layer4', num_channels=2048):
"""
Test importance of each channel by ablating it and measuring impact on detection.
Args:
model: Fine-tuned Faster R-CNN model
image: Test image tensor (3, H, W)
layer_path: Path to layer to ablate
num_channels: Number of channels in the layer
Returns:
channel_importance: (num_channels,) array of impact scores
"""
# Get baseline output (no ablation)
with torch.no_grad():
baseline_output = model([image])
channel_importance = np.zeros(num_channels)
patcher = ActivationPatcher()
# Test each channel
for ch_idx in tqdm(range(num_channels), desc="Ablating channels"):
# Patch (zero out) this channel
hook = patcher.patch_channel(model, layer_path, ch_idx, patch_value=0)
# Measure impact
impact = patcher.measure_impact(model, image, baseline_output)
channel_importance[ch_idx] = impact
# Remove hook
hook.remove()
return channel_importance
print("✓ Ablation study function loaded")
✓ Ablation study function loaded
Section 7: Circuit Extraction¶
def assign_channel_roles(activations, channel_indices):
"""
Use spatial frequency analysis to categorize channels.
Args:
activations: Layer activations (C, H, W)
channel_indices: Indices of critical channels to categorize
Returns:
Dict mapping channel_idx -> role ('edge', 'texture', 'shape', 'semantic')
"""
roles = {}
for ch_idx in channel_indices:
ch_data = activations[ch_idx].cpu().numpy()
# Compute edge strength (high-frequency)
grad_y, grad_x = np.gradient(ch_data)
edge_strength = np.mean(np.sqrt(grad_x**2 + grad_y**2))
# Normalize edge strength
edge_strength_norm = edge_strength / (ch_data.std() + 1e-8)
# Compute spatial autocorrelation (localization)
flat = ch_data.flatten()
if len(flat) > 1:
autocorr = np.corrcoef(flat[:-1], flat[1:])[0, 1]
else:
autocorr = 0
# Categorize based on heuristics
if edge_strength_norm > 0.5:
roles[ch_idx] = 'edge'
elif autocorr > 0.7:
roles[ch_idx] = 'semantic'
elif autocorr < 0.3:
roles[ch_idx] = 'texture'
else:
roles[ch_idx] = 'shape'
return roles
def extract_circuit(channel_importance, activations, threshold_percentile=90):
"""
Extract minimal circuit: channels + connections that are critical.
For Layer 4 only, edges represent functional relationships between channels:
- Co-activation: Channels that activate together on pot regions
- Functional grouping: Channels that contribute to the same detections
Args:
channel_importance: (2048,) array of ablation impact scores
activations: Dict with 'layer4' -> (1, 2048, H, W) activations
threshold_percentile: Percentile for selecting critical channels
Returns:
circuit: Dict with:
- 'critical_channels': Array of critical channel indices
- 'edges': List of (ch_i, ch_j, correlation) tuples
- 'channel_roles': Dict mapping channel -> semantic role
- 'importance': Importance scores for critical channels
"""
# Step 1: Identify critical channels (top percentile by ablation impact)
threshold = np.percentile(channel_importance, threshold_percentile)
critical_indices = np.where(channel_importance >= threshold)[0]
print(f"Critical channels: {len(critical_indices)} / {len(channel_importance)}")
# Step 2: Compute co-activation between critical channels
layer4_activations = activations['layer4'][0] # (2048, H, W)
C, H, W = layer4_activations.shape
# For each critical channel, flatten spatial dims
activation_vectors = []
for ch_idx in critical_indices:
activation_vectors.append(layer4_activations[ch_idx].cpu().numpy().flatten())
# Compute pairwise correlations
activation_matrix = np.stack(activation_vectors) # (num_critical, H*W)
correlation_matrix = np.corrcoef(activation_matrix) # (num_critical, num_critical)
# Create edges for strong correlations (functional relationships)
edges = []
for i, ch_i in enumerate(critical_indices):
for j, ch_j in enumerate(critical_indices):
if i < j: # Avoid duplicates
corr = correlation_matrix[i, j]
if abs(corr) > 0.5: # Threshold for "strong" relationship
edges.append((int(ch_i), int(ch_j), float(abs(corr))))
# Step 3: Assign semantic roles
channel_roles = assign_channel_roles(layer4_activations, critical_indices)
circuit = {
'critical_channels': critical_indices,
'edges': edges,
'channel_roles': channel_roles,
'importance': channel_importance[critical_indices]
}
return circuit
print("✓ Circuit extraction functions loaded")
✓ Circuit extraction functions loaded
Section 8: Circuit Visualization¶
def visualize_circuit(circuit, save_path='circuit_diagram.html'):
"""
Create interactive Plotly visualization of the minimal circuit.
Nodes = critical channels, colored by role (edge/texture/shape/semantic)
Edges = co-activation patterns (correlation > 0.5)
Node size = proportional to ablation importance
Features:
- Zoom/pan to explore dense clusters
- Hover to see channel details
- Interactive legend to filter by role
- Export as HTML or PNG
Args:
circuit: Circuit dict from extract_circuit()
save_path: Path to save visualization (.html or .png)
"""
G = nx.Graph() # Undirected for co-activation
# Define color map for channel roles
role_colors = {
'edge': '#FF6B6B', # Red
'texture': '#4ECDC4', # Cyan
'shape': '#45B7D1', # Blue
'semantic': '#FFA07A' # Orange
}
# Add nodes (critical channels)
for idx, ch_idx in enumerate(circuit['critical_channels']):
role = circuit['channel_roles'].get(ch_idx, 'unknown')
importance = circuit['importance'][idx]
G.add_node(f'Ch{ch_idx}',
channel_idx=int(ch_idx),
role=role,
importance=float(importance),
color=role_colors.get(role, '#CCCCCC'))
# Add edges (co-activation patterns)
for ch_i, ch_j, weight in circuit['edges']:
G.add_edge(f'Ch{ch_i}', f'Ch{ch_j}', weight=weight)
# Compute layout using NetworkX
pos = nx.spring_layout(G, k=1.5, iterations=100, seed=42)
# Prepare edge traces
edge_traces = []
for edge in G.edges():
x0, y0 = pos[edge[0]]
x1, y1 = pos[edge[1]]
weight = G[edge[0]][edge[1]]['weight']
edge_trace = go.Scatter(
x=[x0, x1, None],
y=[y0, y1, None],
mode='lines',
line=dict(
width=weight * 3, # Scale by correlation
color='rgba(125, 125, 125, 0.4)'
),
hoverinfo='none',
showlegend=False
)
edge_traces.append(edge_trace)
# Prepare node traces (one trace per role for legend)
node_traces_by_role = {}
for node in G.nodes():
role = G.nodes[node]['role']
channel_idx = G.nodes[node]['channel_idx']
importance = G.nodes[node]['importance']
color = G.nodes[node]['color']
x, y = pos[node]
# Create trace for this role if not exists
if role not in node_traces_by_role:
node_traces_by_role[role] = {
'x': [],
'y': [],
'text': [],
'marker_size': [],
'color': color,
'name': role.capitalize()
}
# Add node to its role trace
node_traces_by_role[role]['x'].append(x)
node_traces_by_role[role]['y'].append(y)
node_traces_by_role[role]['text'].append(
f"<b>Channel {channel_idx}</b><br>" +
f"Role: {role.capitalize()}<br>" +
f"Importance: {importance:.4f}<br>" +
f"<i>Click to zoom</i>"
)
node_traces_by_role[role]['marker_size'].append(importance * 30 + 10) # Scale size
# Create plotly figure
fig = go.Figure()
# Add edge traces
for edge_trace in edge_traces:
fig.add_trace(edge_trace)
# Add node traces (grouped by role for legend)
for role, trace_data in node_traces_by_role.items():
fig.add_trace(go.Scatter(
x=trace_data['x'],
y=trace_data['y'],
mode='markers+text',
name=trace_data['name'],
text=[f"Ch{circuit['critical_channels'][i]}" for i in range(len(trace_data['x']))],
textposition="top center",
textfont=dict(size=6),
hovertext=trace_data['text'],
hoverinfo='text',
marker=dict(
size=trace_data['marker_size'],
color=trace_data['color'],
line=dict(width=1, color='white'),
opacity=0.9
)
))
# Layout configuration
fig.update_layout(
title=dict(
text='Minimal Computational Circuit for Pot Detection (Layer 4)<br>' +
'<sub>Interactive: Zoom/Pan to explore | Hover for details | Click legend to filter</sub>',
x=0.5,
xanchor='center',
font=dict(size=20)
),
showlegend=True,
legend=dict(
title=dict(text='Channel Roles'),
yanchor="top",
y=0.99,
xanchor="right",
x=0.99,
bgcolor='rgba(255,255,255,0.8)',
bordercolor='gray',
borderwidth=1
),
hovermode='closest',
margin=dict(b=20, l=20, r=20, t=80),
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
plot_bgcolor='white',
width=1400,
height=1000
)
# Save and display
if save_path.endswith('.html'):
fig.write_html(save_path)
print(f"✓ Interactive circuit saved to: {save_path}")
print(f" → Open in browser to explore")
elif save_path.endswith('.png'):
fig.write_image(save_path, width=1400, height=1000)
print(f"✓ Static circuit saved to: {save_path}")
else:
# Default to HTML
html_path = save_path.replace('.png', '.html') if '.png' in save_path else f"{save_path}.html"
fig.write_html(html_path)
print(f"✓ Interactive circuit saved to: {html_path}")
# Show in notebook
fig.show()
# Print summary
print(f"\n=== CIRCUIT SUMMARY ===")
print(f"Total nodes: {G.number_of_nodes()}")
print(f"Total edges: {G.number_of_edges()}")
print(f"Sparsity: {len(circuit['critical_channels'])} / 2048 ({len(circuit['critical_channels'])/2048*100:.1f}%)")
# Role breakdown
role_counts = {}
for ch_idx in circuit['critical_channels']:
role = circuit['channel_roles'].get(ch_idx, 'unknown')
role_counts[role] = role_counts.get(role, 0) + 1
print(f"\n=== ROLE BREAKDOWN ===")
for role, count in sorted(role_counts.items()):
print(f" {role.capitalize()}: {count} channels ({count/len(circuit['critical_channels'])*100:.1f}%)")
print("✓ Visualization function loaded (Plotly interactive)")
✓ Visualization function loaded (Plotly interactive)
Section 9: (Optional) Multi-Layer Analysis¶
Note: This section is optional and not part of the main workflow. It's included for future hierarchical circuit analysis across multiple ResNet layers.
def extract_hierarchical_circuit(model, image):
"""
[OPTIONAL] Extract circuit across multiple layers for future analysis.
Warning: This will take significantly longer (~30-40 minutes) as it tests
all channels across layers 1-4 (256 + 512 + 1024 + 2048 = 3840 channels).
"""
print("\n⚠️ This is an optional extended analysis (not part of main workflow)")
print(" Focus is on Layer 4 only per user requirements.\n")
layers_to_analyze = {
'layer1': ('backbone.body.layer1', 256),
'layer2': ('backbone.body.layer2', 512),
'layer3': ('backbone.body.layer3', 1024),
'layer4': ('backbone.body.layer4', 2048),
}
hierarchical_circuit = {}
for layer_name, (layer_path, num_channels) in layers_to_analyze.items():
print(f"Analyzing {layer_name} ({num_channels} channels)...")
# Ablation study for this layer
importance = ablation_study(model, image, layer_path, num_channels)
# Extract critical channels (top 5%)
threshold = np.percentile(importance, 95)
critical = np.where(importance >= threshold)[0]
hierarchical_circuit[layer_name] = {
'importance': importance,
'critical_channels': critical
}
print(f" → {len(critical)} critical channels\n")
return hierarchical_circuit
print("✓ (Optional) Multi-layer analysis function loaded")
✓ (Optional) Multi-layer analysis function loaded
Section 10: Main Pipeline¶
This orchestrates all the steps: load model → run ablation → extract circuit → visualize
def main_circuit_extraction_pipeline(checkpoint_path, test_image_path):
"""
End-to-end pipeline: load model → run ablation → extract circuit → visualize.
Args:
checkpoint_path: Path to fine-tuned model checkpoint
(e.g., "/content/drive/MyDrive/Orthos/fior-pot-detector.pth")
test_image_path: Path to test image with pot detections
Returns:
circuit: Extracted circuit structure
channel_importance: (2048,) array of ablation scores
"""
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}\n")
# === STEP 1: Load model ===
print("=" * 60)
print("STEP 1: Loading model")
print("=" * 60)
model = get_model(num_classes=2)
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
model = model.to(device)
print("✓ Model loaded successfully")
print(f" Epoch: {checkpoint.get('epoch', 'unknown')}")
if 'loss' in checkpoint:
print(f" Loss: {checkpoint['loss']:.4f}")
print()
# === STEP 2: Load test image ===
print("=" * 60)
print("STEP 2: Loading test image")
print("=" * 60)
image = load_image(test_image_path)
print(f"✓ Image loaded and preprocessed: {image.shape}")
print(f" Value range: [{image.min():.3f}, {image.max():.3f}]")
print(f" Expected: [0.0, 1.0] (normalized from uint8)")
image = image.to(device)
print()
# === STEP 3: Verify detection works ===
print("=" * 60)
print("STEP 3: Running baseline detection")
print("=" * 60)
# Run detection with lower threshold to see all detections
original_thresh = model.roi_heads.score_thresh
model.roi_heads.score_thresh = 0.95 # Lower threshold to see more detections
with torch.no_grad():
baseline_output = model([image])
num_detections = len(baseline_output[0]['boxes'])
print(f"✓ Detected {num_detections} objects (score threshold: {model.roi_heads.score_thresh})")
if num_detections > 0:
# Show top detections
scores = baseline_output[0]['scores'].cpu().numpy()
boxes = baseline_output[0]['boxes'].cpu().numpy()
print(f"\n Top detections:")
for i in range(min(5, num_detections)):
print(f" {i+1}. Score: {scores[i]:.3f}, Box: [{boxes[i][0]:.0f}, {boxes[i][1]:.0f}, {boxes[i][2]:.0f}, {boxes[i][3]:.0f}]")
# Reset to original threshold for ablation
model.roi_heads.score_thresh = original_thresh
# Re-run with normal threshold for baseline
with torch.no_grad():
baseline_output = model([image])
num_baseline = len(baseline_output[0]['boxes'])
print(f"\n Detections at threshold {original_thresh}: {num_baseline}")
else:
print(f"\n ⚠️ WARNING: No detections found!")
print(f" This could mean:")
print(f" 1. The test image doesn't contain pots")
print(f" 2. The model wasn't trained properly")
print(f" 3. Image preprocessing mismatch (FIXED in this version)")
print(f"\n Continuing anyway for demonstration...")
print(f" Note: Ablation scores will all be 0 or very small")
print()
# === STEP 4: Capture activations ===
print("=" * 60)
print("STEP 4: Capturing Layer 4 activations")
print("=" * 60)
capturer = ActivationCapture(model, {'layer4': 'backbone.body.layer4'})
with torch.no_grad():
_ = model([image])
activations = capturer.activations
print(f"✓ Captured activations: {activations['layer4'].shape}")
print(f" Activation range: [{activations['layer4'].min():.3f}, {activations['layer4'].max():.3f}]")
print()
capturer.remove_hooks()
# === STEP 5: Run ablation study ===
print("=" * 60)
print("STEP 5: Running ablation study on Layer 4 (2048 channels)")
print("=" * 60)
print("⚠️ This will take ~5-10 minutes (testing each channel)")
if num_detections == 0:
print("\n⚠️ WARNING: No baseline detections - ablation results will be minimal")
print("Consider using a different test image with visible pots\n")
channel_importance = ablation_study(model, image, 'backbone.body.layer4', 2048)
print(f"✓ Ablation complete")
print(f" Impact range: [{channel_importance.min():.4f}, {channel_importance.max():.4f}]")
print(f" Mean impact: {channel_importance.mean():.4f}")
print(f" Non-zero impacts: {np.count_nonzero(channel_importance)} / 2048")
print()
# === STEP 6: Extract circuit ===
print("=" * 60)
print("STEP 6: Extracting minimal circuit")
print("=" * 60)
circuit = extract_circuit(channel_importance, activations, threshold_percentile=90)
print(f"✓ Circuit extracted\n")
# === STEP 7: Visualize ===
print("=" * 60)
print("STEP 7: Generating interactive circuit diagram")
print("=" * 60)
visualize_circuit(circuit, save_path='circuit_diagram.html')
# === STEP 8: Save results ===
print("\n" + "=" * 60)
print("STEP 8: Saving results")
print("=" * 60)
np.save('channel_importance.npy', channel_importance)
np.save('circuit.npy', circuit, allow_pickle=True)
print("✓ Saved channel_importance.npy")
print("✓ Saved circuit.npy")
print("✓ Saved circuit_diagram.html (interactive)")
print("\n" + "=" * 60)
print("✅ PIPELINE COMPLETE!")
print("=" * 60)
if num_detections == 0:
print("\n⚠️ NOTE: No detections were found in the baseline.")
print("The circuit will be based on minimal activations.")
print("For best results, use an image where the model detects pots.")
return circuit, channel_importance
print("✓ Main pipeline function loaded")
✓ Main pipeline function loaded
Section 11: Execution¶
Instructions: Update the paths below with your checkpoint and test image locations, then run the cell.
IMPORTANT: Image Size¶
Your model was trained on 1024x1024 tiles. Large images (like 46074x2654) must be cropped!
The load_image() function now automatically:
- Crops to 1024x1024 (matches training tile size)
- By default, crops from center of the image
- You can change crop position by editing the function call below
Options for cropping:
Center crop (default):
load_image(test_image_path) # Crops from center
Custom crop position:
# Crop from pixel coordinates (y=10000, x=500) load_image(test_image_path, crop_center=False, crop_offset=(10000, 500))
Different crop size:
load_image(test_image_path, crop_size=512) # Crop to 512x512
Pro tip: If center crop has no detections, try different offsets to find a region with pots!
# ============================================================
# OPTIONAL: DEBUG CELL - Run this first to test your setup
# ============================================================
# Uncomment and run this cell to visualize image and test detection
# checkpoint_path = "/content/drive/MyDrive/Orthos/fior-pot-detector.pth"
# test_image_path = "/content/drive/MyDrive/Orthos/test_image.tif"
# # Load and visualize image
# test_img = load_image(test_image_path)
# print(f"Image shape: {test_img.shape}")
# print(f"Value range: [{test_img.min():.3f}, {test_img.max():.3f}]")
# # Show image
# plt.figure(figsize=(10, 10))
# plt.imshow(test_img.permute(1, 2, 0).cpu().numpy())
# plt.title("Test Image")
# plt.axis('off')
# plt.show()
# # Test detection
# model = get_model(num_classes=2)
# checkpoint = torch.load(checkpoint_path)
# model.load_state_dict(checkpoint['model_state_dict'])
# model.eval()
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
# model = model.to(device)
# test_img = test_img.to(device)
# model.roi_heads.score_thresh = 0.05
# with torch.no_grad():
# output = model([test_img])
# print(f"\nDetections: {len(output[0]['boxes'])}")
# if len(output[0]['boxes']) > 0:
# print("Top 5 scores:", output[0]['scores'][:5].cpu().numpy())
# else:
# print("⚠️ No detections! Try a different image.")
# ============================================================
# USER CONFIGURATION: Update these paths
# ============================================================
# Path to your fine-tuned Faster R-CNN checkpoint
checkpoint_path = "/content/drive/MyDrive/Orthos/fior-pot-detector.pth"
# Path to test image (should have pot detections)
# Supports: .tif, .tiff, .jpg, .jpeg, .png
test_image_path = "/content/drive/MyDrive/Orthos/35.tif"
# ============================================================
# RUN PIPELINE
# ============================================================
circuit, channel_importance = main_circuit_extraction_pipeline(
checkpoint_path=checkpoint_path,
test_image_path=test_image_path
)
print("\n🎉 Circuit extraction complete!")
print("\nOutputs:")
print(" - circuit_diagram.html (interactive visualization - zoom/pan/hover)")
print(" - channel_importance.npy (ablation scores)")
print(" - circuit.npy (circuit structure)")
print("\n💡 Tip: Download circuit_diagram.html and open in browser for best experience!")
Using device: cuda ============================================================ STEP 1: Loading model ============================================================
/usr/local/lib/python3.12/dist-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead. warnings.warn( /usr/local/lib/python3.12/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=None`. warnings.warn(msg)
Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 189MB/s]
✓ Model loaded successfully
Epoch: 50
Loss: 0.0314
============================================================
STEP 2: Loading test image
============================================================
GeoTIFF info:
Bands: 4 total, using bands 1-3 (RGB)
Dtype: uint8
Full shape: (46074, 2654, 3)
Value range: [0, 255]
→ Cropped to (1024, 1024, 3) (model expects ~1024x1024)
Crop region: y=[22525:23549], x=[815:1839]
→ Detected uint8, normalizing from [0, 255] to [0, 1]
Final tensor: torch.Size([3, 1024, 1024]), range: [0.055, 1.000]
✓ Image loaded and preprocessed: torch.Size([3, 1024, 1024])
Value range: [0.055, 1.000]
Expected: [0.0, 1.0] (normalized from uint8)
============================================================
STEP 3: Running baseline detection
============================================================
✓ Detected 56 objects (score threshold: 0.95)
Top detections:
1. Score: 1.000, Box: [444, 996, 504, 1024]
2. Score: 1.000, Box: [714, 994, 775, 1024]
3. Score: 1.000, Box: [175, 998, 239, 1024]
4. Score: 1.000, Box: [852, 0, 912, 27]
5. Score: 1.000, Box: [982, 990, 1024, 1024]
Detections at threshold 0.05: 56
============================================================
STEP 4: Capturing Layer 4 activations
============================================================
✓ Captured activations: torch.Size([1, 2048, 25, 25])
Activation range: [0.000, 5.029]
============================================================
STEP 5: Running ablation study on Layer 4 (2048 channels)
============================================================
⚠️ This will take ~5-10 minutes (testing each channel)
Ablating channels: 100%|██████████| 2048/2048 [02:54<00:00, 11.76it/s] /usr/local/lib/python3.12/dist-packages/numpy/lib/_function_base_impl.py:2922: RuntimeWarning: invalid value encountered in divide c /= stddev[:, None] /usr/local/lib/python3.12/dist-packages/numpy/lib/_function_base_impl.py:2923: RuntimeWarning: invalid value encountered in divide c /= stddev[None, :]
✓ Ablation complete Impact range: [-0.0000, 0.0000] Mean impact: -0.0000 Non-zero impacts: 44 / 2048 ============================================================ STEP 6: Extracting minimal circuit ============================================================ Critical channels: 2023 / 2048 ✓ Circuit extracted ============================================================ STEP 7: Generating interactive circuit diagram ============================================================ ✓ Interactive circuit saved to: circuit_diagram.html → Open in browser to explore
=== CIRCUIT SUMMARY === Total nodes: 2023 Total edges: 5099 Sparsity: 2023 / 2048 (98.8%) === ROLE BREAKDOWN === Edge: 162 channels (8.0%) Shape: 1606 channels (79.4%) Texture: 255 channels (12.6%) ============================================================ STEP 8: Saving results ============================================================ ✓ Saved channel_importance.npy ✓ Saved circuit.npy ✓ Saved circuit_diagram.html (interactive) ============================================================ ✅ PIPELINE COMPLETE! ============================================================ 🎉 Circuit extraction complete! Outputs: - circuit_diagram.html (interactive visualization - zoom/pan/hover) - channel_importance.npy (ablation scores) - circuit.npy (circuit structure) 💡 Tip: Open circuit_diagram.html in browser for best experience!
Interactive Visualization Tips¶
Using the Plotly Circuit Diagram¶
The circuit diagram is now fully interactive with the following features:
Zoom In/Out:
- Scroll wheel to zoom
- Double-click to reset view
- Click and drag box to zoom to region
Pan:
- Click and drag to move around
- Explore dense clusters in the center
Hover:
- Hover over nodes to see:
- Channel number
- Role (edge/texture/shape/semantic)
- Importance score
- Hover over nodes to see:
Filter by Role:
- Click legend items to show/hide roles
- Double-click legend to isolate one role
Export:
- Camera icon (top right) → Download as PNG
- Share the HTML file for interactive viewing
Interpretation Guide¶
- Node Size = Ablation importance (bigger = more critical)
- Node Color = Semantic role:
- Red = Edge detectors
- Cyan = Texture detectors
- Blue = Shape detectors
- Orange = Semantic features
- Edge Thickness = Co-activation strength (correlation)
- Clusters = Functionally related channel groups