This is a wrapper for CLIP image encoder provided in the python package openai-clip.
The CLIP encoder take a transformed tensor as input. However, if you want to use CLIP as a component in your netwoks, the input feed by your dataloader is likely to be different from the transform used in CLIP:
1
2
3
4
5
6
import torchvision.transforms as transforms
CLIP_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
])
Suppose your transform is:
1
2
3
4
your_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
So, we need to write a inverse transform that turn the tensor back into pixel space and then apply the CLIP transform for CLIP to process.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
from PIL import Image
try:
BICUBIC = transforms.InterpolationMode.BICUBIC
except ImportError:
BICUBIC = Image.BICUBIC
def _transform(n_px, center=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010)):
return Compose([
transforms.Normalize(mean=[-center[0] / std[0], -center[1] / std[1], -center[2] / std[2]],
std=[1 / std[0], 1 / std[1], 1 / std[2]]),
transforms.Resize(n_px, interpolation=BICUBIC),
transforms.CenterCrop(n_px),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
])
Here is the CLIP wrapper that take a transformed tensor as input:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import clip
import torch
import torch.nn as nn
class clip_img_wrap(nn.Module):
def __init__(self, clip_model='ViT-L/14', device='cpu', center=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010)):
super().__init__()
self.model, self.preprocess = clip.load(clip_model, device)
self.name = '-'.join(clip_model.split('/'))
self.device = device
self.dim = self.model.text_projection.shape[1]
self.inv_normalize = _transform(self.model.visual.input_resolution, center, std)
def forward(self, image):
# this is a freezed encoder.
image = self.inv_normalize(image)
with torch.no_grad():
image_features = self.model.encode_image(image)
return image_features.float()