Home CLIP image encoder wrapper
Post
Cancel

CLIP image encoder wrapper

This is a wrapper for CLIP image encoder provided in the python package openai-clip.
The CLIP encoder take a transformed tensor as input. However, if you want to use CLIP as a component in your netwoks, the input feed by your dataloader is likely to be different from the transform used in CLIP:

1
2
3
4
5
6
import torchvision.transforms as transforms

CLIP_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
    ])

Suppose your transform is:

1
2
3
4
your_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

So, we need to write a inverse transform that turn the tensor back into pixel space and then apply the CLIP transform for CLIP to process.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
from PIL import Image
try:
    BICUBIC = transforms.InterpolationMode.BICUBIC
except ImportError:
    BICUBIC = Image.BICUBIC


def _transform(n_px, center=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010)):
    return Compose([
        transforms.Normalize(mean=[-center[0] / std[0], -center[1] / std[1], -center[2] / std[2]],
                  std=[1 / std[0], 1 / std[1], 1 / std[2]]),
        transforms.Resize(n_px, interpolation=BICUBIC),
        transforms.CenterCrop(n_px),
        transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
    ])

Here is the CLIP wrapper that take a transformed tensor as input:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import clip
import torch
import torch.nn as nn

class clip_img_wrap(nn.Module):
    def __init__(self, clip_model='ViT-L/14', device='cpu', center=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010)):
        super().__init__()

        self.model, self.preprocess = clip.load(clip_model, device)
        self.name = '-'.join(clip_model.split('/'))
        self.device = device
        self.dim = self.model.text_projection.shape[1]
        self.inv_normalize = _transform(self.model.visual.input_resolution, center, std)

    def forward(self, image):

        # this is a freezed encoder.
        image = self.inv_normalize(image)
        with torch.no_grad():
            image_features = self.model.encode_image(image)

        return image_features.float()
This post is licensed under CC BY 4.0 by the author.
Trending Tags