mirror of
https://github.com/OpenBMB/MiniCPM-V.git
synced 2026-02-04 17:59:18 +08:00
First commit
This commit is contained in:
1
omnilmm/model/__init__.py
Normal file
1
omnilmm/model/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .omnilmm import OmniLMMForCausalLM
|
||||
457
omnilmm/model/omnilmm.py
Normal file
457
omnilmm/model/omnilmm.py
Normal file
@@ -0,0 +1,457 @@
|
||||
|
||||
import gc
|
||||
import math
|
||||
import timm
|
||||
import torch
|
||||
from torch import Tensor
|
||||
import torch.nn as nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
from transformers import AutoConfig, AutoModelForCausalLM
|
||||
from transformers import MistralForCausalLM, MistralModel, MistralConfig
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
|
||||
from omnilmm.model.utils import build_transform
|
||||
from omnilmm.model.resampler import Resampler
|
||||
|
||||
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
||||
DEFAULT_IM_START_TOKEN = "<im_start>"
|
||||
DEFAULT_IM_END_TOKEN = "<im_end>"
|
||||
|
||||
|
||||
class OmniLMMConfig(MistralConfig):
|
||||
model_type = "omnilmm"
|
||||
|
||||
|
||||
class Identity(torch.nn.Identity):
|
||||
def forward(self, input: Tensor, **kwargs) -> Tensor:
|
||||
return super().forward(input)
|
||||
|
||||
|
||||
def create_vision_module(config):
|
||||
vision_tower = timm.create_model('eva02_enormous_patch14_clip_224.laion2b_plus',
|
||||
pretrained=False,
|
||||
num_classes=0,
|
||||
dynamic_img_size=True,
|
||||
dynamic_img_pad=True)
|
||||
|
||||
if isinstance(vision_tower, timm.models.VisionTransformer):
|
||||
if vision_tower.attn_pool is not None:
|
||||
vision_tower.attn_pool = Identity()
|
||||
|
||||
# use 2nd last layer's output
|
||||
vision_tower.blocks[-1] = Identity()
|
||||
|
||||
embed_dim = config.hidden_size
|
||||
resampler = Resampler(
|
||||
grid_size=int(math.sqrt(config.num_query)),
|
||||
embed_dim=embed_dim,
|
||||
num_heads=embed_dim // 128,
|
||||
kv_dim=vision_tower.embed_dim,
|
||||
)
|
||||
return vision_tower, resampler
|
||||
|
||||
|
||||
class OmniLMMModel(MistralModel):
|
||||
config_class = OmniLMMConfig
|
||||
|
||||
def __init__(self, config: OmniLMMConfig, mm_vision_tower=None, mm_hidden_size=None, tune_clip=True):
|
||||
super(OmniLMMModel, self).__init__(config)
|
||||
|
||||
if hasattr(config, "mm_vision_tower"):
|
||||
vision_tower, resampler = create_vision_module(config)
|
||||
|
||||
print(__file__, 'skip loading vision tower weights')
|
||||
|
||||
# HACK: for FSDP
|
||||
self.vision_tower = [vision_tower]
|
||||
self.resampler = resampler
|
||||
if tune_clip:
|
||||
self.vision_tower = self.vision_tower[0]
|
||||
|
||||
self.vision_config = lambda x: None
|
||||
|
||||
def initialize_vision_modules(self, vision_tower, no_randaug, num_query, image_size, tune_clip=False):
|
||||
self.config.mm_vision_tower = vision_tower
|
||||
self.config.use_mm_proj = True
|
||||
self.config.num_query = num_query
|
||||
self.config.image_size = image_size
|
||||
|
||||
if not hasattr(self, 'vision_tower'):
|
||||
vision_tower, resampler = create_vision_module(self.config)
|
||||
state_dict = torch.load(
|
||||
'/tt/data/public/multimodal/multimodal_model_ckpts/timm/eva02_enormous_patch14_clip_224.laion2b_plus.pt')
|
||||
vision_tower.load_state_dict(state_dict, strict=False)
|
||||
del state_dict
|
||||
gc.collect()
|
||||
else:
|
||||
if isinstance(self.vision_tower, list):
|
||||
vision_tower = self.vision_tower[0]
|
||||
else:
|
||||
vision_tower = self.vision_tower
|
||||
resampler = self.resampler
|
||||
self.vision_tower = vision_tower if tune_clip else [vision_tower]
|
||||
self.resampler = resampler
|
||||
|
||||
train_img_transform = build_transform(
|
||||
is_train=True, randaug=not no_randaug, input_size=self.config.image_size, std_mode='OPENAI_CLIP')
|
||||
eval_img_transform = build_transform(
|
||||
is_train=False, input_size=self.config.image_size, std_mode='OPENAI_CLIP')
|
||||
|
||||
return dict(
|
||||
image_processor=(train_img_transform, eval_img_transform),
|
||||
image_token_len=num_query,
|
||||
vision_config=self.vision_config
|
||||
)
|
||||
|
||||
def get_vision_embedding(self, pixel_values):
|
||||
if isinstance(self.vision_tower, list):
|
||||
vision_tower = self.vision_tower[0] # HACK: for FSDP
|
||||
else:
|
||||
vision_tower = self.vision_tower
|
||||
|
||||
dtype = vision_tower.pos_embed.data.dtype
|
||||
vision_embedding = vision_tower.forward_features(
|
||||
pixel_values.type(dtype))
|
||||
if hasattr(vision_tower, 'num_prefix_tokens') and vision_tower.num_prefix_tokens > 0:
|
||||
vision_embedding = vision_embedding[:,
|
||||
vision_tower.num_prefix_tokens:]
|
||||
res = self.resampler(vision_embedding)
|
||||
return res
|
||||
|
||||
def get_vllm_embedding(self, data):
|
||||
|
||||
if 'vision_hidden_states' not in data:
|
||||
pixel_values_list = data['pixel_values']
|
||||
vision_hidden_states = []
|
||||
for pixel_values in pixel_values_list:
|
||||
if len(pixel_values) > 0:
|
||||
vision_hidden_states.append(self.get_vision_embedding(pixel_values.unsqueeze(0))[0])
|
||||
else:
|
||||
vision_hidden_states.append([])
|
||||
else:
|
||||
vision_hidden_states = data['vision_hidden_states']
|
||||
|
||||
#vllm_embedding = self.llm.model.embed_tokens(data['input_ids']) * self.llm.config.scale_emb
|
||||
inputs_embeds = self.embed_tokens(data['input_ids'])
|
||||
vision_hidden_states = [i.type(inputs_embeds.dtype)
|
||||
if isinstance(i, torch.Tensor) else i for i in vision_hidden_states
|
||||
]
|
||||
|
||||
|
||||
# HACK: replace back original embeddings for LLaVA pretraining
|
||||
orig_embeds_params = getattr(self, 'orig_embeds_params', None)
|
||||
|
||||
new_input_embeds = []
|
||||
cur_image_idx = 0
|
||||
for cur_input_ids, cur_input_embeds in zip(data['input_ids'], inputs_embeds):
|
||||
if (cur_input_ids == self.vision_config.im_patch_token).sum() == 0:
|
||||
# multimodal LLM, but the current sample is not multimodal
|
||||
cur_input_embeds = cur_input_embeds + (0. * dummy_image_features).sum()
|
||||
new_input_embeds.append(cur_input_embeds)
|
||||
continue
|
||||
|
||||
if self.vision_config.use_im_start_end:
|
||||
cur_image_features = vision_hidden_states[cur_image_idx]
|
||||
num_patches = cur_image_features.shape[0]
|
||||
if (cur_input_ids == self.vision_config.im_start_token).sum() != (cur_input_ids == self.vision_config.im_end_token).sum():
|
||||
raise ValueError(
|
||||
"The number of image start tokens and image end tokens should be the same.")
|
||||
image_start_tokens = torch.where(
|
||||
cur_input_ids == self.vision_config.im_start_token)[0]
|
||||
for image_start_token_pos in image_start_tokens:
|
||||
cur_image_features = vision_hidden_states[cur_image_idx].to(
|
||||
device=cur_input_embeds.device)
|
||||
num_patches = cur_image_features.shape[0]
|
||||
if cur_input_ids[image_start_token_pos + num_patches + 1] != self.vision_config.im_end_token:
|
||||
raise ValueError(
|
||||
"The image end token should follow the image start token.")
|
||||
if orig_embeds_params is not None:
|
||||
cur_new_input_embeds = torch.cat((cur_input_embeds[:image_start_token_pos].detach(), cur_input_embeds[image_start_token_pos:image_start_token_pos+1], cur_image_features,
|
||||
cur_input_embeds[image_start_token_pos + num_patches + 1:image_start_token_pos + num_patches + 2], cur_input_embeds[image_start_token_pos + num_patches + 2:].detach()), dim=0)
|
||||
else:
|
||||
cur_new_input_embeds = torch.cat(
|
||||
(cur_input_embeds[:image_start_token_pos+1], cur_image_features, cur_input_embeds[image_start_token_pos + num_patches + 1:]), dim=0)
|
||||
cur_image_idx += 1
|
||||
new_input_embeds.append(cur_new_input_embeds)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
inputs_embeds = torch.stack(new_input_embeds, dim=0)
|
||||
|
||||
return inputs_embeds, vision_hidden_states
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
images: Optional[torch.FloatTensor] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
**kwargs
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
|
||||
# HACK: replace back original embeddings for LLaVA pretraining
|
||||
orig_embeds_params = getattr(self, 'orig_embeds_params', None)
|
||||
|
||||
if inputs_embeds is None and past_key_values is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
vision_tower = getattr(self, 'vision_tower', None)
|
||||
if vision_tower is not None and (input_ids.shape[1] != 1 or self.training) and images is not None:
|
||||
|
||||
if type(images) is list:
|
||||
image_features = []
|
||||
for image in images:
|
||||
image_forward_out = self.get_vision_embedding(image.unsqueeze(0))[
|
||||
0]
|
||||
image_features.append(image_forward_out)
|
||||
else:
|
||||
image_features = self.get_vision_embedding(images)
|
||||
|
||||
dummy_image_features = torch.zeros(
|
||||
self.config.num_query,
|
||||
self.config.hidden_size,
|
||||
device=inputs_embeds.device,
|
||||
dtype=inputs_embeds.dtype)
|
||||
|
||||
new_input_embeds = []
|
||||
cur_image_idx = 0
|
||||
for cur_input_ids, cur_input_embeds in zip(input_ids, inputs_embeds):
|
||||
if (cur_input_ids == self.vision_config.im_patch_token).sum() == 0:
|
||||
# multimodal LLM, but the current sample is not multimodal
|
||||
cur_input_embeds = cur_input_embeds + \
|
||||
(0. * dummy_image_features).sum()
|
||||
new_input_embeds.append(cur_input_embeds)
|
||||
continue
|
||||
|
||||
if self.vision_config.use_im_start_end:
|
||||
cur_image_features = image_features[cur_image_idx]
|
||||
num_patches = cur_image_features.shape[0]
|
||||
if (cur_input_ids == self.vision_config.im_start_token).sum() != (cur_input_ids == self.vision_config.im_end_token).sum():
|
||||
raise ValueError(
|
||||
"The number of image start tokens and image end tokens should be the same.")
|
||||
image_start_tokens = torch.where(
|
||||
cur_input_ids == self.vision_config.im_start_token)[0]
|
||||
for image_start_token_pos in image_start_tokens:
|
||||
cur_image_features = image_features[cur_image_idx].to(
|
||||
device=cur_input_embeds.device)
|
||||
num_patches = cur_image_features.shape[0]
|
||||
if cur_input_ids[image_start_token_pos + num_patches + 1] != self.vision_config.im_end_token:
|
||||
raise ValueError(
|
||||
"The image end token should follow the image start token.")
|
||||
if orig_embeds_params is not None:
|
||||
cur_new_input_embeds = torch.cat((cur_input_embeds[:image_start_token_pos].detach(), cur_input_embeds[image_start_token_pos:image_start_token_pos+1], cur_image_features,
|
||||
cur_input_embeds[image_start_token_pos + num_patches + 1:image_start_token_pos + num_patches + 2], cur_input_embeds[image_start_token_pos + num_patches + 2:].detach()), dim=0)
|
||||
else:
|
||||
cur_new_input_embeds = torch.cat(
|
||||
(cur_input_embeds[:image_start_token_pos+1], cur_image_features, cur_input_embeds[image_start_token_pos + num_patches + 1:]), dim=0)
|
||||
cur_image_idx += 1
|
||||
new_input_embeds.append(cur_new_input_embeds)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
inputs_embeds = torch.stack(new_input_embeds, dim=0)
|
||||
input_ids = None
|
||||
|
||||
return super(OmniLMMModel, self).forward(
|
||||
input_ids=input_ids, attention_mask=attention_mask, past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds, use_cache=use_cache,
|
||||
output_attentions=output_attentions, output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
|
||||
class OmniLMMForCausalLM(MistralForCausalLM):
|
||||
config_class = OmniLMMConfig
|
||||
|
||||
def __init__(self, config, mm_vision_tower=None, tune_clip=True):
|
||||
super(MistralForCausalLM, self).__init__(config)
|
||||
self.model = OmniLMMModel(
|
||||
config, mm_vision_tower=mm_vision_tower, tune_clip=tune_clip)
|
||||
|
||||
self.lm_head = nn.Linear(
|
||||
config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
images: Optional[torch.FloatTensor] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
**kwargs
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# print(f'@@@ At forward, labels: {labels.shape}-{labels}', flush=True)
|
||||
# print(f'@@@ At forward, input_ids: {input_ids.shape}-{input_ids}', flush=True)
|
||||
# print(f'@@@ At forward, input_ids: {attention_mask.shape}-{attention_mask}', flush=True)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
images=images,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model/pipeline parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
# TODO could be removed for generate_vllm()
|
||||
def prepare_inputs_for_generation(
|
||||
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
||||
):
|
||||
if past_key_values:
|
||||
input_ids = input_ids[:, -1:]
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and past_key_values is None:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids}
|
||||
|
||||
model_inputs.update(
|
||||
{
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": kwargs.get("use_cache"),
|
||||
"attention_mask": attention_mask,
|
||||
"images": kwargs.get("images", None),
|
||||
}
|
||||
)
|
||||
return model_inputs
|
||||
|
||||
def generate_vllm(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
images: Optional[torch.FloatTensor] = None,
|
||||
vision_hidden_states=None,
|
||||
return_vision_hidden_states=False,
|
||||
**kwargs
|
||||
):
|
||||
model_inputs = {'input_ids': input_ids}
|
||||
if vision_hidden_states is None:
|
||||
model_inputs['pixel_values'] = images
|
||||
else:
|
||||
model_inputs['vision_hidden_states'] = vision_hidden_states
|
||||
|
||||
with torch.inference_mode():
|
||||
inputs_embeds, vision_hidden_states = self.model.get_vllm_embedding(model_inputs)
|
||||
|
||||
result = self.generate(
|
||||
inputs_embeds=inputs_embeds,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
if return_vision_hidden_states:
|
||||
return result, vision_hidden_states
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def initialize_vision_tokenizer(self, mm_use_im_start_end, tokenizer, device,
|
||||
tune_mm_mlp_adapter=False):
|
||||
self.model.vision_config.use_im_start_end = mm_use_im_start_end
|
||||
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
||||
self.resize_token_embeddings(len(tokenizer))
|
||||
|
||||
if mm_use_im_start_end:
|
||||
num_new_tokens = tokenizer.add_tokens(
|
||||
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
|
||||
self.resize_token_embeddings(len(tokenizer))
|
||||
self.model.vision_config.im_start_token, self.model.vision_config.im_end_token = tokenizer.convert_tokens_to_ids(
|
||||
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
|
||||
|
||||
if num_new_tokens > 0:
|
||||
input_embeddings = self.get_input_embeddings().weight.data
|
||||
output_embeddings = self.get_output_embeddings().weight.data
|
||||
|
||||
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
|
||||
dim=0, keepdim=True)
|
||||
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
|
||||
dim=0, keepdim=True)
|
||||
|
||||
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
||||
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
||||
|
||||
# for new sft data
|
||||
num_new_tokens = tokenizer.add_tokens(
|
||||
['<box>', '</box>', '<ref>', '</ref>', '<quad>', '</quad>'], special_tokens=True)
|
||||
self.resize_token_embeddings(len(tokenizer))
|
||||
|
||||
if num_new_tokens > 0:
|
||||
input_embeddings = self.get_input_embeddings().weight.data
|
||||
output_embeddings = self.get_output_embeddings().weight.data
|
||||
|
||||
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
|
||||
dim=0, keepdim=True)
|
||||
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
|
||||
dim=0, keepdim=True)
|
||||
|
||||
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
||||
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
||||
|
||||
if tune_mm_mlp_adapter:
|
||||
self.model.orig_embeds_params = [
|
||||
self.get_input_embeddings().weight.data.clone().to(device=device)]
|
||||
for p in self.get_input_embeddings().parameters():
|
||||
p.requires_grad = True
|
||||
for p in self.get_output_embeddings().parameters():
|
||||
p.requires_grad = False
|
||||
|
||||
self.model.vision_config.im_patch_token = tokenizer.convert_tokens_to_ids(
|
||||
[DEFAULT_IMAGE_PATCH_TOKEN])[0]
|
||||
print(f'Tokenizer: {tokenizer}\n patch_token_id: {self.model.vision_config.im_patch_token}, visoin_config: {self.model.vision_config}', flush=True)
|
||||
# exit()
|
||||
|
||||
|
||||
AutoConfig.register("omnilmm", OmniLMMConfig)
|
||||
AutoModelForCausalLM.register(OmniLMMConfig, OmniLMMForCausalLM)
|
||||
171
omnilmm/model/resampler.py
Normal file
171
omnilmm/model/resampler.py
Normal file
@@ -0,0 +1,171 @@
|
||||
# Copyright (c) Alibaba Cloud.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from collections import OrderedDict
|
||||
import math
|
||||
import requests
|
||||
from io import BytesIO
|
||||
from functools import partial
|
||||
from PIL import Image
|
||||
from typing import Callable, Optional, Sequence, Tuple, List, Union
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.init import trunc_normal_
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms import InterpolationMode
|
||||
|
||||
|
||||
def get_abs_pos(abs_pos, tgt_size):
|
||||
# abs_pos: L, C
|
||||
# tgt_size: M
|
||||
# return: M, C
|
||||
src_size = int(math.sqrt(abs_pos.size(0)))
|
||||
tgt_size = int(math.sqrt(tgt_size))
|
||||
dtype = abs_pos.dtype
|
||||
|
||||
if src_size != tgt_size:
|
||||
return F.interpolate(
|
||||
abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2),
|
||||
size=(tgt_size, tgt_size),
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype)
|
||||
else:
|
||||
return abs_pos
|
||||
|
||||
|
||||
# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
|
||||
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
|
||||
"""
|
||||
grid_size: int of the grid height and width
|
||||
return:
|
||||
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
||||
"""
|
||||
grid_h = np.arange(grid_size, dtype=np.float32)
|
||||
grid_w = np.arange(grid_size, dtype=np.float32)
|
||||
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
||||
grid = np.stack(grid, axis=0)
|
||||
|
||||
grid = grid.reshape([2, 1, grid_size, grid_size])
|
||||
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
||||
if cls_token:
|
||||
pos_embed = np.concatenate(
|
||||
[np.zeros([1, embed_dim]), pos_embed], axis=0)
|
||||
return pos_embed
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
||||
assert embed_dim % 2 == 0
|
||||
|
||||
# use half of dimensions to encode grid_h
|
||||
emb_h = get_1d_sincos_pos_embed_from_grid(
|
||||
embed_dim // 2, grid[0]) # (H*W, D/2)
|
||||
emb_w = get_1d_sincos_pos_embed_from_grid(
|
||||
embed_dim // 2, grid[1]) # (H*W, D/2)
|
||||
|
||||
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
||||
return emb
|
||||
|
||||
|
||||
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
||||
"""
|
||||
embed_dim: output dimension for each position
|
||||
pos: a list of positions to be encoded: size (M,)
|
||||
out: (M, D)
|
||||
"""
|
||||
assert embed_dim % 2 == 0
|
||||
omega = np.arange(embed_dim // 2, dtype=np.float32)
|
||||
omega /= embed_dim / 2.
|
||||
omega = 1. / 10000 ** omega # (D/2,)
|
||||
|
||||
pos = pos.reshape(-1) # (M,)
|
||||
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
||||
|
||||
emb_sin = np.sin(out) # (M, D/2)
|
||||
emb_cos = np.cos(out) # (M, D/2)
|
||||
|
||||
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
||||
return emb
|
||||
|
||||
|
||||
class Resampler(nn.Module):
|
||||
"""
|
||||
A 2D perceiver-resampler network with one cross attention layers by
|
||||
(grid_size**2) learnable queries and 2d sincos pos_emb
|
||||
Outputs:
|
||||
A tensor with the shape of (grid_size**2, embed_dim)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
grid_size,
|
||||
embed_dim,
|
||||
num_heads,
|
||||
kv_dim=None,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6)
|
||||
):
|
||||
super().__init__()
|
||||
self.num_queries = grid_size ** 2
|
||||
self.embed_dim = embed_dim
|
||||
self.num_heads = num_heads
|
||||
|
||||
self.pos_embed = nn.Parameter(
|
||||
torch.from_numpy(get_2d_sincos_pos_embed(
|
||||
embed_dim, grid_size)).float()
|
||||
).requires_grad_(False)
|
||||
|
||||
self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
|
||||
trunc_normal_(self.query, std=.02)
|
||||
|
||||
if kv_dim is not None and kv_dim != embed_dim:
|
||||
self.kv_proj = nn.Linear(kv_dim, embed_dim, bias=False)
|
||||
else:
|
||||
self.kv_proj = nn.Identity()
|
||||
|
||||
self.attn = nn.MultiheadAttention(embed_dim, num_heads)
|
||||
self.ln_q = norm_layer(embed_dim)
|
||||
self.ln_kv = norm_layer(embed_dim)
|
||||
|
||||
self.ln_post = norm_layer(embed_dim)
|
||||
self.proj = nn.Parameter(
|
||||
(embed_dim ** -0.5) * torch.randn(embed_dim, embed_dim))
|
||||
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
|
||||
def forward(self, x, attn_mask=None):
|
||||
|
||||
pos_embed = get_abs_pos(self.pos_embed, x.size(1))
|
||||
|
||||
x = self.kv_proj(x)
|
||||
x = self.ln_kv(x).permute(1, 0, 2)
|
||||
|
||||
N = x.shape[1]
|
||||
q = self.ln_q(self.query)
|
||||
# print((self._repeat(q, N) + self.pos_embed.unsqueeze(1)).dtype, (x + pos_embed.unsqueeze(1)).dtype, x.dtype)
|
||||
out = self.attn(
|
||||
self._repeat(q, N) + self.pos_embed.unsqueeze(1),
|
||||
x + pos_embed.unsqueeze(1),
|
||||
x,
|
||||
attn_mask=attn_mask)[0]
|
||||
x = out.permute(1, 0, 2)
|
||||
|
||||
x = self.ln_post(x)
|
||||
x = x @ self.proj
|
||||
return x
|
||||
|
||||
def _repeat(self, query, N: int):
|
||||
return query.unsqueeze(1).repeat(1, N, 1)
|
||||
555
omnilmm/model/utils.py
Normal file
555
omnilmm/model/utils.py
Normal file
@@ -0,0 +1,555 @@
|
||||
from torchvision import transforms
|
||||
from timm.data.transforms import RandomResizedCropAndInterpolation
|
||||
from timm.data.constants import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
from transformers import AutoConfig
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
import torch.distributed as dist
|
||||
import numpy as np
|
||||
import pickle
|
||||
import base64
|
||||
import cv2
|
||||
import os
|
||||
import torch
|
||||
from transformers import AutoConfig, StoppingCriteria
|
||||
|
||||
try:
|
||||
from timm.data.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
||||
except ImportError:
|
||||
OPENAI_CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073)
|
||||
OPENAI_CLIP_STD = (0.26862954, 0.26130258, 0.27577711)
|
||||
|
||||
|
||||
def auto_upgrade(config):
|
||||
cfg = AutoConfig.from_pretrained(config)
|
||||
if 'llava' in config and cfg.model_type != 'llava':
|
||||
print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.")
|
||||
print("You must upgrade the checkpoint to the new code base (this can be done automatically).")
|
||||
confirm = input(
|
||||
"Please confirm that you want to upgrade the checkpoint. [Y/N]")
|
||||
if confirm.lower() in ["y", "yes"]:
|
||||
print("Upgrading checkpoint...")
|
||||
assert len(cfg.architectures) == 1
|
||||
setattr(cfg.__class__, "model_type", "llava")
|
||||
cfg.architectures[0] = 'LlavaLlamaForCausalLM'
|
||||
cfg.save_pretrained(config)
|
||||
print("Checkpoint upgraded.")
|
||||
else:
|
||||
print("Checkpoint upgrade aborted.")
|
||||
exit(1)
|
||||
|
||||
|
||||
class KeywordsStoppingCriteria(StoppingCriteria):
|
||||
def __init__(self, keywords, tokenizer, input_ids):
|
||||
self.keywords = keywords
|
||||
self.tokenizer = tokenizer
|
||||
self.start_len = None
|
||||
self.input_ids = input_ids
|
||||
|
||||
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
||||
if self.start_len is None:
|
||||
self.start_len = self.input_ids.shape[1]
|
||||
else:
|
||||
outputs = self.tokenizer.batch_decode(
|
||||
output_ids[:, self.start_len:], skip_special_tokens=True)[0]
|
||||
for keyword in self.keywords:
|
||||
if keyword in outputs:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def auto_upgrade(config):
|
||||
cfg = AutoConfig.from_pretrained(config)
|
||||
if 'llava' in config and cfg.model_type != 'llava':
|
||||
print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.")
|
||||
print("You must upgrade the checkpoint to the new code base (this can be done automatically).")
|
||||
confirm = input(
|
||||
"Please confirm that you want to upgrade the checkpoint. [Y/N]")
|
||||
if confirm.lower() in ["y", "yes"]:
|
||||
print("Upgrading checkpoint...")
|
||||
assert len(cfg.architectures) == 1
|
||||
setattr(cfg.__class__, "model_type", "llava")
|
||||
cfg.architectures[0] = 'LlavaLlamaForCausalLM'
|
||||
cfg.save_pretrained(config)
|
||||
print("Checkpoint upgraded.")
|
||||
else:
|
||||
print("Checkpoint upgrade aborted.")
|
||||
exit(1)
|
||||
|
||||
# aug functions
|
||||
|
||||
|
||||
def identity_func(img):
|
||||
return img
|
||||
|
||||
|
||||
def autocontrast_func(img, cutoff=0):
|
||||
'''
|
||||
same output as PIL.ImageOps.autocontrast
|
||||
'''
|
||||
n_bins = 256
|
||||
|
||||
def tune_channel(ch):
|
||||
n = ch.size
|
||||
cut = cutoff * n // 100
|
||||
if cut == 0:
|
||||
high, low = ch.max(), ch.min()
|
||||
else:
|
||||
hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
|
||||
low = np.argwhere(np.cumsum(hist) > cut)
|
||||
low = 0 if low.shape[0] == 0 else low[0]
|
||||
high = np.argwhere(np.cumsum(hist[::-1]) > cut)
|
||||
high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0]
|
||||
if high <= low:
|
||||
table = np.arange(n_bins)
|
||||
else:
|
||||
scale = (n_bins - 1) / (high - low)
|
||||
table = np.arange(n_bins) * scale - low * scale
|
||||
table[table < 0] = 0
|
||||
table[table > n_bins - 1] = n_bins - 1
|
||||
table = table.clip(0, 255).astype(np.uint8)
|
||||
return table[ch]
|
||||
|
||||
channels = [tune_channel(ch) for ch in cv2.split(img)]
|
||||
out = cv2.merge(channels)
|
||||
return out
|
||||
|
||||
|
||||
def equalize_func(img):
|
||||
'''
|
||||
same output as PIL.ImageOps.equalize
|
||||
PIL's implementation is different from cv2.equalize
|
||||
'''
|
||||
n_bins = 256
|
||||
|
||||
def tune_channel(ch):
|
||||
hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
|
||||
non_zero_hist = hist[hist != 0].reshape(-1)
|
||||
step = np.sum(non_zero_hist[:-1]) // (n_bins - 1)
|
||||
if step == 0:
|
||||
return ch
|
||||
n = np.empty_like(hist)
|
||||
n[0] = step // 2
|
||||
n[1:] = hist[:-1]
|
||||
table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8)
|
||||
return table[ch]
|
||||
|
||||
channels = [tune_channel(ch) for ch in cv2.split(img)]
|
||||
out = cv2.merge(channels)
|
||||
return out
|
||||
|
||||
|
||||
def rotate_func(img, degree, fill=(0, 0, 0)):
|
||||
'''
|
||||
like PIL, rotate by degree, not radians
|
||||
'''
|
||||
H, W = img.shape[0], img.shape[1]
|
||||
center = W / 2, H / 2
|
||||
M = cv2.getRotationMatrix2D(center, degree, 1)
|
||||
out = cv2.warpAffine(img, M, (W, H), borderValue=fill)
|
||||
return out
|
||||
|
||||
|
||||
def solarize_func(img, thresh=128):
|
||||
'''
|
||||
same output as PIL.ImageOps.posterize
|
||||
'''
|
||||
table = np.array([el if el < thresh else 255 - el for el in range(256)])
|
||||
table = table.clip(0, 255).astype(np.uint8)
|
||||
out = table[img]
|
||||
return out
|
||||
|
||||
|
||||
def color_func(img, factor):
|
||||
'''
|
||||
same output as PIL.ImageEnhance.Color
|
||||
'''
|
||||
# implementation according to PIL definition, quite slow
|
||||
# degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis]
|
||||
# out = blend(degenerate, img, factor)
|
||||
# M = (
|
||||
# np.eye(3) * factor
|
||||
# + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor)
|
||||
# )[np.newaxis, np.newaxis, :]
|
||||
M = (
|
||||
np.float32([
|
||||
[0.886, -0.114, -0.114],
|
||||
[-0.587, 0.413, -0.587],
|
||||
[-0.299, -0.299, 0.701]]) * factor
|
||||
+ np.float32([[0.114], [0.587], [0.299]])
|
||||
)
|
||||
out = np.matmul(img, M).clip(0, 255).astype(np.uint8)
|
||||
return out
|
||||
|
||||
|
||||
def contrast_func(img, factor):
|
||||
"""
|
||||
same output as PIL.ImageEnhance.Contrast
|
||||
"""
|
||||
mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299]))
|
||||
table = np.array([(
|
||||
el - mean) * factor + mean
|
||||
for el in range(256)
|
||||
]).clip(0, 255).astype(np.uint8)
|
||||
out = table[img]
|
||||
return out
|
||||
|
||||
|
||||
def brightness_func(img, factor):
|
||||
'''
|
||||
same output as PIL.ImageEnhance.Contrast
|
||||
'''
|
||||
table = (np.arange(256, dtype=np.float32) *
|
||||
factor).clip(0, 255).astype(np.uint8)
|
||||
out = table[img]
|
||||
return out
|
||||
|
||||
|
||||
def sharpness_func(img, factor):
|
||||
'''
|
||||
The differences the this result and PIL are all on the 4 boundaries, the center
|
||||
areas are same
|
||||
'''
|
||||
kernel = np.ones((3, 3), dtype=np.float32)
|
||||
kernel[1][1] = 5
|
||||
kernel /= 13
|
||||
degenerate = cv2.filter2D(img, -1, kernel)
|
||||
if factor == 0.0:
|
||||
out = degenerate
|
||||
elif factor == 1.0:
|
||||
out = img
|
||||
else:
|
||||
out = img.astype(np.float32)
|
||||
degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :]
|
||||
out[1:-1, 1:-1, :] = degenerate + factor * \
|
||||
(out[1:-1, 1:-1, :] - degenerate)
|
||||
out = out.astype(np.uint8)
|
||||
return out
|
||||
|
||||
|
||||
def shear_x_func(img, factor, fill=(0, 0, 0)):
|
||||
H, W = img.shape[0], img.shape[1]
|
||||
M = np.float32([[1, factor, 0], [0, 1, 0]])
|
||||
out = cv2.warpAffine(img, M, (W, H), borderValue=fill,
|
||||
flags=cv2.INTER_LINEAR).astype(np.uint8)
|
||||
return out
|
||||
|
||||
|
||||
def translate_x_func(img, offset, fill=(0, 0, 0)):
|
||||
'''
|
||||
same output as PIL.Image.transform
|
||||
'''
|
||||
H, W = img.shape[0], img.shape[1]
|
||||
M = np.float32([[1, 0, -offset], [0, 1, 0]])
|
||||
out = cv2.warpAffine(img, M, (W, H), borderValue=fill,
|
||||
flags=cv2.INTER_LINEAR).astype(np.uint8)
|
||||
return out
|
||||
|
||||
|
||||
def translate_y_func(img, offset, fill=(0, 0, 0)):
|
||||
'''
|
||||
same output as PIL.Image.transform
|
||||
'''
|
||||
H, W = img.shape[0], img.shape[1]
|
||||
M = np.float32([[1, 0, 0], [0, 1, -offset]])
|
||||
out = cv2.warpAffine(img, M, (W, H), borderValue=fill,
|
||||
flags=cv2.INTER_LINEAR).astype(np.uint8)
|
||||
return out
|
||||
|
||||
|
||||
def posterize_func(img, bits):
|
||||
'''
|
||||
same output as PIL.ImageOps.posterize
|
||||
'''
|
||||
out = np.bitwise_and(img, np.uint8(255 << (8 - bits)))
|
||||
return out
|
||||
|
||||
|
||||
def shear_y_func(img, factor, fill=(0, 0, 0)):
|
||||
H, W = img.shape[0], img.shape[1]
|
||||
M = np.float32([[1, 0, 0], [factor, 1, 0]])
|
||||
out = cv2.warpAffine(img, M, (W, H), borderValue=fill,
|
||||
flags=cv2.INTER_LINEAR).astype(np.uint8)
|
||||
return out
|
||||
|
||||
|
||||
def cutout_func(img, pad_size, replace=(0, 0, 0)):
|
||||
replace = np.array(replace, dtype=np.uint8)
|
||||
H, W = img.shape[0], img.shape[1]
|
||||
rh, rw = np.random.random(2)
|
||||
pad_size = pad_size // 2
|
||||
ch, cw = int(rh * H), int(rw * W)
|
||||
x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H)
|
||||
y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W)
|
||||
out = img.copy()
|
||||
out[x1:x2, y1:y2, :] = replace
|
||||
return out
|
||||
|
||||
|
||||
# level to args
|
||||
def enhance_level_to_args(MAX_LEVEL):
|
||||
def level_to_args(level):
|
||||
return ((level / MAX_LEVEL) * 1.8 + 0.1,)
|
||||
return level_to_args
|
||||
|
||||
|
||||
def shear_level_to_args(MAX_LEVEL, replace_value):
|
||||
def level_to_args(level):
|
||||
level = (level / MAX_LEVEL) * 0.3
|
||||
if np.random.random() > 0.5:
|
||||
level = -level
|
||||
return (level, replace_value)
|
||||
|
||||
return level_to_args
|
||||
|
||||
|
||||
def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):
|
||||
def level_to_args(level):
|
||||
level = (level / MAX_LEVEL) * float(translate_const)
|
||||
if np.random.random() > 0.5:
|
||||
level = -level
|
||||
return (level, replace_value)
|
||||
|
||||
return level_to_args
|
||||
|
||||
|
||||
def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):
|
||||
def level_to_args(level):
|
||||
level = int((level / MAX_LEVEL) * cutout_const)
|
||||
return (level, replace_value)
|
||||
|
||||
return level_to_args
|
||||
|
||||
|
||||
def solarize_level_to_args(MAX_LEVEL):
|
||||
def level_to_args(level):
|
||||
level = int((level / MAX_LEVEL) * 256)
|
||||
return (level, )
|
||||
return level_to_args
|
||||
|
||||
|
||||
def none_level_to_args(level):
|
||||
return ()
|
||||
|
||||
|
||||
def posterize_level_to_args(MAX_LEVEL):
|
||||
def level_to_args(level):
|
||||
level = int((level / MAX_LEVEL) * 4)
|
||||
return (level, )
|
||||
return level_to_args
|
||||
|
||||
|
||||
def rotate_level_to_args(MAX_LEVEL, replace_value):
|
||||
def level_to_args(level):
|
||||
level = (level / MAX_LEVEL) * 30
|
||||
if np.random.random() < 0.5:
|
||||
level = -level
|
||||
return (level, replace_value)
|
||||
|
||||
return level_to_args
|
||||
|
||||
|
||||
func_dict = {
|
||||
'Identity': identity_func,
|
||||
'AutoContrast': autocontrast_func,
|
||||
'Equalize': equalize_func,
|
||||
'Rotate': rotate_func,
|
||||
'Solarize': solarize_func,
|
||||
'Color': color_func,
|
||||
'Contrast': contrast_func,
|
||||
'Brightness': brightness_func,
|
||||
'Sharpness': sharpness_func,
|
||||
'ShearX': shear_x_func,
|
||||
'TranslateX': translate_x_func,
|
||||
'TranslateY': translate_y_func,
|
||||
'Posterize': posterize_func,
|
||||
'ShearY': shear_y_func,
|
||||
}
|
||||
|
||||
translate_const = 10
|
||||
MAX_LEVEL = 10
|
||||
replace_value = (128, 128, 128)
|
||||
arg_dict = {
|
||||
'Identity': none_level_to_args,
|
||||
'AutoContrast': none_level_to_args,
|
||||
'Equalize': none_level_to_args,
|
||||
'Rotate': rotate_level_to_args(MAX_LEVEL, replace_value),
|
||||
'Solarize': solarize_level_to_args(MAX_LEVEL),
|
||||
'Color': enhance_level_to_args(MAX_LEVEL),
|
||||
'Contrast': enhance_level_to_args(MAX_LEVEL),
|
||||
'Brightness': enhance_level_to_args(MAX_LEVEL),
|
||||
'Sharpness': enhance_level_to_args(MAX_LEVEL),
|
||||
'ShearX': shear_level_to_args(MAX_LEVEL, replace_value),
|
||||
'TranslateX': translate_level_to_args(
|
||||
translate_const, MAX_LEVEL, replace_value
|
||||
),
|
||||
'TranslateY': translate_level_to_args(
|
||||
translate_const, MAX_LEVEL, replace_value
|
||||
),
|
||||
'Posterize': posterize_level_to_args(MAX_LEVEL),
|
||||
'ShearY': shear_level_to_args(MAX_LEVEL, replace_value),
|
||||
}
|
||||
|
||||
|
||||
class RandomAugment(object):
|
||||
|
||||
def __init__(self, N=2, M=10, isPIL=False, augs=[]):
|
||||
self.N = N
|
||||
self.M = M
|
||||
self.isPIL = isPIL
|
||||
if augs:
|
||||
self.augs = augs
|
||||
else:
|
||||
self.augs = list(arg_dict.keys())
|
||||
|
||||
def get_random_ops(self):
|
||||
sampled_ops = np.random.choice(self.augs, self.N)
|
||||
return [(op, 0.5, self.M) for op in sampled_ops]
|
||||
|
||||
def __call__(self, img):
|
||||
if self.isPIL:
|
||||
img = np.array(img)
|
||||
ops = self.get_random_ops()
|
||||
for name, prob, level in ops:
|
||||
if np.random.random() > prob:
|
||||
continue
|
||||
args = arg_dict[name](level)
|
||||
img = func_dict[name](img, *args)
|
||||
return img
|
||||
|
||||
|
||||
def build_transform(is_train, randaug=True, input_size=224, interpolation='bicubic', std_mode='IMAGENET_INCEPTION'):
|
||||
if std_mode == 'IMAGENET_INCEPTION':
|
||||
mean = IMAGENET_INCEPTION_MEAN
|
||||
std = IMAGENET_INCEPTION_STD
|
||||
elif std_mode == 'OPENAI_CLIP':
|
||||
mean = OPENAI_CLIP_MEAN
|
||||
std = OPENAI_CLIP_STD
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
if is_train:
|
||||
crop_scale = float(os.environ.get('TRAIN_CROP_SCALE', 0.9999))
|
||||
t = [
|
||||
RandomResizedCropAndInterpolation(
|
||||
input_size, scale=(crop_scale, 1.0), interpolation='bicubic'),
|
||||
# transforms.RandomHorizontalFlip(),
|
||||
]
|
||||
if randaug and os.environ.get('TRAIN_DO_AUG', 'False') == 'True':
|
||||
print(f'@@@@@ Do random aug during training', flush=True)
|
||||
t.append(
|
||||
RandomAugment(
|
||||
2, 7, isPIL=True,
|
||||
augs=[
|
||||
'Identity', 'AutoContrast', 'Equalize', 'Brightness', 'Sharpness',
|
||||
'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate',
|
||||
]))
|
||||
else:
|
||||
print(f'@@@@@ Skip random aug during training', flush=True)
|
||||
t += [
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=mean, std=std),
|
||||
]
|
||||
t = transforms.Compose(t)
|
||||
else:
|
||||
t = transforms.Compose([
|
||||
transforms.Resize((input_size, input_size),
|
||||
interpolation=transforms.InterpolationMode.BICUBIC),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=mean, std=std)
|
||||
])
|
||||
|
||||
return t
|
||||
|
||||
|
||||
def img2b64(img_path):
|
||||
img = Image.open(img_path) # path to file
|
||||
img_buffer = BytesIO()
|
||||
img.save(img_buffer, format=img.format)
|
||||
byte_data = img_buffer.getvalue()
|
||||
base64_str = base64.b64encode(byte_data) # bytes
|
||||
base64_str = base64_str.decode("utf-8") # str
|
||||
return base64_str
|
||||
|
||||
|
||||
def str2b64(str):
|
||||
return base64.b64encode(str.encode('utf-8')).decode('utf-8')
|
||||
|
||||
|
||||
def b642str(b64):
|
||||
return base64.b64decode(b64).decode('utf-8')
|
||||
|
||||
|
||||
def is_dist_avail_and_initialized():
|
||||
if not dist.is_available():
|
||||
return False
|
||||
if not dist.is_initialized():
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_world_size():
|
||||
if not is_dist_avail_and_initialized():
|
||||
return 1
|
||||
return dist.get_world_size()
|
||||
|
||||
|
||||
def get_rank():
|
||||
if not is_dist_avail_and_initialized():
|
||||
return 0
|
||||
return dist.get_rank()
|
||||
|
||||
|
||||
def all_gather(data):
|
||||
"""
|
||||
Run all_gather on arbitrary picklable data (not necessarily tensors)
|
||||
Args:
|
||||
data: any picklable object
|
||||
Returns:
|
||||
list[data]: list of data gathered from each rank
|
||||
"""
|
||||
world_size = get_world_size()
|
||||
if world_size == 1:
|
||||
return [data]
|
||||
|
||||
# serialized to a Tensor
|
||||
buffer = pickle.dumps(data)
|
||||
storage = torch.ByteStorage.from_buffer(buffer)
|
||||
tensor = torch.ByteTensor(storage).to("cuda")
|
||||
|
||||
# obtain Tensor size of each rank
|
||||
local_size = torch.LongTensor([tensor.numel()]).to("cuda")
|
||||
size_list = [torch.LongTensor([0]).to("cuda") for _ in range(world_size)]
|
||||
dist.all_gather(size_list, local_size)
|
||||
size_list = [int(size.item()) for size in size_list]
|
||||
max_size = max(size_list)
|
||||
|
||||
# receiving Tensor from all ranks
|
||||
# we pad the tensor because torch all_gather does not support
|
||||
# gathering tensors of different shapes
|
||||
tensor_list = []
|
||||
for _ in size_list:
|
||||
tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda"))
|
||||
if local_size != max_size:
|
||||
padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda")
|
||||
tensor = torch.cat((tensor, padding), dim=0)
|
||||
dist.all_gather(tensor_list, tensor)
|
||||
|
||||
data_list = []
|
||||
for size, tensor in zip(size_list, tensor_list):
|
||||
buffer = tensor.cpu().numpy().tobytes()[:size]
|
||||
data_list.append(pickle.loads(buffer))
|
||||
|
||||
return data_list
|
||||
|
||||
|
||||
def mean(lst):
|
||||
return sum(lst) / len(lst)
|
||||
|
||||
|
||||
def stop_gradient_by_name(name: str):
|
||||
def apply_fn(module):
|
||||
if hasattr(module, name):
|
||||
getattr(module, name).requires_grad_(False)
|
||||
|
||||
return apply_fn
|
||||
Reference in New Issue
Block a user