← Back to Blog

Extracting Features from Vision Model Backbones

Understanding what vision models learn requires looking inside them. Both SAM's ViT encoder and Faster R-CNN's ResNet50 backbone produce rich intermediate representations—2048 channels for ResNet layer4, 1024 channels for ViT. This post covers the practical mechanics of extracting, processing, and visualizing these features from large geospatial imagery.

Why Extract Features?

Feature extraction serves multiple purposes. For interpretability, we want to understand which channels respond to which visual patterns. For downstream tasks, frozen features can train lightweight classifiers without expensive full-model fine-tuning. For quality analysis, comparing fine-tuned vs. pretrained features reveals what adaptation changes.

The key insight: intermediate features often contain more information than final outputs. A bounding box tells you where an object is; the features that produced it tell you how the model recognized it.

Tiled Extraction for Large Images

Drone orthomosaics exceed 10,000 pixels per side—far too large to process in one pass. The solution mirrors inference: tile the image, extract features from each tile, and stitch results back together. With overlapping tiles, we average features in overlap regions to avoid boundary artifacts.

def extract_features_tiled(model, image, tile_size=1024, overlap=128):
    """Extract features with tile-based processing."""
    h, w = image.shape[:2]
    step = tile_size - overlap

    # Initialize feature array (downsampled for memory)
    downsample = 4
    feat_h, feat_w = h // downsample, w // downsample
    full_features = np.zeros((num_channels, feat_h, feat_w))
    count_map = np.zeros((feat_h, feat_w))

    for y in range(0, h, step):
        for x in range(0, w, step):
            tile = image[y:y+tile_size, x:x+tile_size]
            features = extract_single_tile(model, tile)

            # Upsample to downsampled target and accumulate
            # ... (handle coordinates and overlap averaging)

    # Average overlapping regions
    full_features /= count_map
    return full_features

Memory management matters. Full-resolution features for a 15,000x12,000 image with 2048 channels would require ~1.4 TB. Downsampling by 4x reduces this to ~22 GB—still large but manageable. For analysis, this resolution captures the essential spatial structure.

Forward Hooks for Feature Capture

PyTorch's forward hooks let us intercept activations at any layer without modifying the model. Register a hook, run inference, and the hook captures the output tensor before it flows to the next layer.

class FeatureExtractor:
    def __init__(self, model, layer_path):
        self.features = {}

        # Navigate to target layer
        layer = model
        for part in layer_path.split('.'):
            layer = getattr(layer, part)

        # Register hook
        def hook(module, input, output):
            self.features['output'] = output.detach()

        layer.register_forward_hook(hook)

    def extract(self, image_tensor):
        self.features = {}
        with torch.no_grad():
            _ = model(image_tensor)
        return self.features['output']

For Faster R-CNN, target backbone.body.layer4 for the deepest ResNet features (2048 channels). For SAM, hook the image_encoder output for ViT embeddings (1024 channels). Different layers capture different abstraction levels—earlier layers see edges and textures, later layers see semantic content.

Identifying Selective Channels

Not all channels are equally relevant. Some respond strongly to target objects (pots), others to background. Computing channel selectivity requires detection results to define positive and negative regions.

def compute_channel_selectivity(features, detection_boxes):
    """Identify channels that discriminate objects from background."""
    C, H, W = features.shape

    # Create mask from detections
    pot_mask = np.zeros((H, W), dtype=bool)
    for box in detection_boxes:
        # Scale box to feature resolution
        pot_mask[y1:y2, x1:x2] = True

    background_mask = ~pot_mask

    # Compute selectivity per channel
    selectivity = []
    for ch in range(C):
        pot_mean = features[ch][pot_mask].mean()
        bg_mean = features[ch][background_mask].mean()
        selectivity.append(pot_mean - bg_mean)

    return np.array(selectivity)

Positive selectivity means the channel activates more on pots; negative means it prefers background. In our experiments, about 15% of ResNet layer4 channels show strong selectivity (|score| > 0.1).

PCA Visualization

With thousands of channels, dimensionality reduction helps visualization. PCA projects the high-dimensional feature space to three components, which we map to RGB for display.

def visualize_features_pca(features, n_components=3):
    """Project features to RGB for visualization."""
    C, H, W = features.shape

    # Reshape: (C, H, W) -> (H*W, C)
    X = features.reshape(C, -1).T

    # Subsample for PCA fitting (memory)
    sample_idx = np.random.choice(X.shape[0], 200000, replace=False)

    # Fit and transform
    scaler = StandardScaler()
    pca = PCA(n_components=n_components)
    pca.fit(scaler.fit_transform(X[sample_idx]))

    X_pca = pca.transform(scaler.transform(X))
    pca_image = X_pca.reshape(H, W, n_components)

    # Normalize to [0, 1] for display
    for c in range(n_components):
        pca_image[..., c] = normalize_percentile(pca_image[..., c])

    return pca_image

The resulting RGB image shows feature similarity spatially. Pixels with similar colors have similar high-dimensional representations. For well-trained models, objects of the same type cluster together visually.

Comparing SAM vs. Faster R-CNN Features

The two architectures produce qualitatively different features. ResNet features (2048 channels) show strong spatial correlation—nearby pixels have similar representations due to convolutional locality. ViT features (1024 channels) exhibit more semantic grouping; similar objects share representations even when distant.

In PCA visualizations, ResNet features form smooth gradients across space. ViT features show more discrete clustering by object type. Both are valid representations; they capture different aspects of the visual structure.

Variance explained by the top 3 PCs is typically higher for ResNet (~50%) than ViT (~30%), reflecting the more distributed nature of transformer representations.

Fine-Tuned vs. Pretrained

Comparing features before and after fine-tuning reveals what adaptation changes. For Faster R-CNN trained on pot detection, the most selective channels show the largest differences from pretrained COCO weights.

# Compare same channel across models
ch_pretrained = pretrained_features[channel_idx]
ch_finetuned = finetuned_features[channel_idx]

diff = np.abs(ch_finetuned - ch_pretrained)
print(f"Mean absolute difference: {diff.mean():.4f}")

Changes concentrate in channels that become pot-selective. Background-preferring channels and general-purpose feature detectors (edges, textures) remain relatively stable.

Saving Features as GeoTIFFs

For GIS integration, we can save features as multi-band GeoTIFFs that preserve georeferencing. Each band represents one channel; the spatial extent matches the source imagery.

# Save top selective channels
with rasterio.open(output_path, 'w',
                   driver='GTiff',
                   height=height, width=width,
                   count=n_channels,
                   dtype='float32',
                   crs=source_crs,
                   transform=source_transform) as dst:
    for i, ch_idx in enumerate(top_channels):
        dst.write(features[ch_idx], i + 1)
        dst.set_band_description(i + 1, f'Channel_{ch_idx}')

These feature rasters can be loaded in QGIS for spatial analysis, overlaid with detection results, or used as inputs to downstream classifiers.

Practical Considerations

Memory: Full-resolution features are huge. Always downsample for storage. 4x downsampling preserves enough structure for most analyses while reducing size by 16x.

Overlap averaging: Tiles must overlap to avoid boundary artifacts. Average features in overlap regions rather than taking max or keeping one version.

GPU memory: Process one tile at a time and move features to CPU immediately. Call torch.cuda.empty_cache() between tiles to prevent accumulation.

Coordinate systems: Track the relationship between feature resolution, tile resolution, and original image resolution carefully. Off-by-one errors compound across tiles.

The extracted features form the foundation for deeper interpretability analysis—sparse probing, circuit extraction, and mechanistic understanding of what these models actually learn.

Related

SAM vs Faster R-CNN: A Practical Comparison

Comparing the two architectures for aerial object detection—speed, accuracy, and when to use each.

Read more →
Related

Sparse Linear Probing for Efficient Detection

Using L1-regularized probes to find minimal feature subsets sufficient for pot detection.

Read more →