hugpv's picture
initial commit via hf
8e5930e verified
raw
history blame contribute delete
No virus
39.7 kB
import timm
import os
from typing import Any
from pytorch_lightning.utilities.types import LRSchedulerTypeUnion
import torch as t
from torch import nn
import numpy as np
import transformers
import pytorch_lightning as plight
import torchmetrics
import einops as eo
from loss_functions import coral_loss, corn_loss, corn_label_from_logits, macro_soft_f1
t.set_float32_matmul_precision("medium")
global_settings = dict(try_using_torch_compile=False)
class EnsembleModel(plight.LightningModule):
def __init__(self, models_without_norm_df, models_with_norm_df, learning_rate=0.0002, use_simple_average=False):
super().__init__()
self.models_without_norm = nn.ModuleList(list(models_without_norm_df))
self.models_with_norm = nn.ModuleList(list(models_with_norm_df))
self.learning_rate = learning_rate
self.use_simple_average = use_simple_average
if not self.use_simple_average:
self.combiner = nn.Linear(
self.models_with_norm[0].num_classes * (len(self.models_with_norm) + len(self.models_without_norm)),
self.models_with_norm[0].num_classes,
)
def forward(self, x):
x_unnormed, x_normed = x
if not self.use_simple_average:
out_unnormed = t.cat([model.model_step(x_unnormed, 0)[0] for model in self.models_without_norm], dim=-1)
out_normed = t.cat([model.model_step(x_normed, 0)[0] for model in self.models_with_norm], dim=-1)
out_avg = self.combiner(t.cat((out_unnormed, out_normed), dim=-1))
else:
out_unnormed = [model.model_step(x_unnormed, 0)[0] for model in self.models_without_norm]
out_normed = [model.model_step(x_normed, 0)[0] for model in self.models_with_norm]
out_avg = (t.stack(out_unnormed + out_normed, dim=-1) / 2).mean(-1)
return {"out_avg": out_avg, "out_unnormed": out_unnormed, "out_normed": out_normed}, x_unnormed[-1]
def training_step(self, batch, batch_idx):
out, y = self(batch)
loss = self.models_with_norm[0]._get_loss(out["out_avg"], y, batch[0])
self.log("train_loss", loss, on_epoch=True, on_step=True, sync_dist=True)
return loss
def validation_step(self, batch, batch_idx):
out, y = self(batch)
preds, y_onecold, ignore_index_val = self.models_with_norm[0]._get_preds_reals(out["out_avg"], y)
acc = torchmetrics.functional.accuracy(
preds,
y_onecold.to(t.long),
ignore_index=ignore_index_val,
num_classes=self.models_with_norm[0].num_classes,
task="multiclass",
)
self.log("acc", acc * 100, prog_bar=True, sync_dist=True)
loss = self.models_with_norm[0]._get_loss(out["out_avg"], y, batch[0])
self.log("val_loss", loss, prog_bar=True, sync_dist=True)
return loss
def predict_step(self, batch, batch_idx: int, dataloader_idx: int = 0):
out, y = self(batch)
preds, y_onecold, ignore_index_val = self.models_with_norm[0]._get_preds_reals(out["out_avg"], y)
return preds, out, y_onecold
def configure_optimizers(self):
return t.optim.Adam(self.parameters(), lr=self.learning_rate)
class TimmHeadReplace(nn.Module):
def __init__(self, pooling=None, in_channels=512, pooling_output_dimension=1, all_identity=False) -> None:
super().__init__()
if all_identity:
self.head = nn.Identity()
self.pooling = None
else:
self.pooling = pooling
if pooling is not None:
self.pooling_output_dimension = pooling_output_dimension
if self.pooling == "AdaptiveAvgPool2d":
self.pooling_layer = nn.AdaptiveAvgPool2d(pooling_output_dimension)
elif self.pooling == "AdaptiveMaxPool2d":
self.pooling_layer = nn.AdaptiveMaxPool2d(pooling_output_dimension)
self.head = nn.Flatten()
def forward(self, x, pre_logits=False):
if self.pooling is not None:
if self.pooling == "stack_avg_max_attn":
x = t.cat([layer(x) for layer in self.pooling_layer], dim=-1)
else:
x = self.pooling_layer(x)
return self.head(x)
class CVModel(nn.Module):
def __init__(
self,
modelname,
in_shape,
num_classes,
loss_func,
last_activation: str,
input_padding_val=10,
char_dims=2,
max_seq_length=1000,
) -> None:
super().__init__()
self.modelname = modelname
self.loss_func = loss_func
self.in_shape = in_shape
self.char_dims = char_dims
self.x_shape = in_shape
self.last_activation = last_activation
self.max_seq_length = max_seq_length
self.num_classes = num_classes
if self.loss_func == "OrdinalRegLoss":
self.out_shape = 1
else:
self.out_shape = num_classes
self.cv_model = timm.create_model(modelname, pretrained=True, num_classes=0)
self.cv_model.classifier = nn.Identity()
with t.inference_mode():
test_out = self.cv_model(t.ones(self.in_shape, dtype=t.float32))
self.cv_model_out_dim = test_out.shape[1]
self.cv_model.classifier = nn.Sequential(nn.Flatten(), nn.Linear(self.cv_model_out_dim, self.max_seq_length))
if self.out_shape == 1:
self.logit_norm = nn.Identity()
self.out_project = nn.Identity()
else:
self.logit_norm = nn.LayerNorm(self.max_seq_length)
self.out_project = nn.Linear(1, self.out_shape)
if last_activation == "Softmax":
self.final_activation = nn.Softmax(dim=-1)
elif last_activation == "Sigmoid":
self.final_activation = nn.Sigmoid()
elif last_activation == "LogSigmoid":
self.final_activation = nn.LogSigmoid()
elif last_activation == "Identity":
self.final_activation = nn.Identity()
else:
raise NotImplementedError(f"{last_activation} not implemented")
def forward(self, x):
if isinstance(x, list):
x = x[0]
x = self.cv_model(x)
x = self.cv_model.classifier(x).unsqueeze(-1)
x = self.out_project(x)
return self.final_activation(x)
class LitModel(plight.LightningModule):
def __init__(
self,
in_shape: tuple,
hidden_dim: int,
num_attention_heads: int,
num_layers: int,
loss_func: str,
learning_rate: float,
weight_decay: float,
cfg: dict,
use_lr_warmup: bool,
use_reduce_on_plateau: bool,
track_gradient_histogram=False,
register_forw_hook=False,
char_dims=2,
) -> None:
super().__init__()
if "only_use_2nd_input_stream" not in cfg:
cfg["only_use_2nd_input_stream"] = False
if "gamma_step_size" not in cfg:
cfg["gamma_step_size"] = 5
if "gamma_step_factor" not in cfg:
cfg["gamma_step_factor"] = 0.5
self.save_hyperparameters(
dict(
in_shape=in_shape,
hidden_dim=hidden_dim,
num_attention_heads=num_attention_heads,
num_layers=num_layers,
loss_func=loss_func,
learning_rate=learning_rate,
cfg=cfg,
x_shape=in_shape,
num_classes=cfg["num_classes"],
use_lr_warmup=use_lr_warmup,
num_warmup_steps=cfg["num_warmup_steps"],
use_reduce_on_plateau=use_reduce_on_plateau,
weight_decay=weight_decay,
track_gradient_histogram=track_gradient_histogram,
register_forw_hook=register_forw_hook,
char_dims=char_dims,
remove_timm_classifier_head_pooling=cfg["remove_timm_classifier_head_pooling"],
change_pooling_for_timm_head_to=cfg["change_pooling_for_timm_head_to"],
chars_conv_pooling_out_dim=cfg["chars_conv_pooling_out_dim"],
)
)
self.model_to_use = cfg["model_to_use"]
self.num_classes = cfg["num_classes"]
self.x_shape = in_shape
self.in_shape = in_shape
self.hidden_dim = hidden_dim
self.num_attention_heads = num_attention_heads
self.num_layers = num_layers
self.use_lr_warmup = use_lr_warmup
self.num_warmup_steps = cfg["num_warmup_steps"]
self.warmup_exponent = cfg["warmup_exponent"]
self.use_reduce_on_plateau = use_reduce_on_plateau
self.loss_func = loss_func
self.learning_rate = learning_rate
self.weight_decay = weight_decay
self.using_one_hot_targets = cfg["one_hot_y"]
self.track_gradient_histogram = track_gradient_histogram
self.register_forw_hook = register_forw_hook
if self.loss_func == "OrdinalRegLoss":
self.ord_reg_loss_max = cfg["ord_reg_loss_max"]
self.ord_reg_loss_min = cfg["ord_reg_loss_min"]
self.num_lin_layers = cfg["num_lin_layers"]
self.linear_activation = cfg["linear_activation"]
self.last_activation = cfg["last_activation"]
self.max_seq_length = cfg["manual_max_sequence_for_model"]
self.use_char_embed_info = cfg["use_embedded_char_pos_info"]
self.method_chars_into_model = cfg["method_chars_into_model"]
self.source_for_pretrained_cv_model = cfg["source_for_pretrained_cv_model"]
self.method_to_include_char_positions = cfg["method_to_include_char_positions"]
self.char_dims = char_dims
self.char_sequence_length = cfg["max_len_chars_list"] if self.use_char_embed_info else 0
self.chars_conv_lr_reduction_factor = cfg["chars_conv_lr_reduction_factor"]
if self.use_char_embed_info:
self.chars_bert_reduction_factor = cfg["chars_bert_reduction_factor"]
self.use_in_projection_bias = cfg["use_in_projection_bias"]
self.add_layer_norm_to_in_projection = cfg["add_layer_norm_to_in_projection"]
self.hidden_dropout_prob = cfg["hidden_dropout_prob"]
self.layer_norm_after_in_projection = cfg["layer_norm_after_in_projection"]
self.method_chars_into_model = cfg["method_chars_into_model"]
self.input_padding_val = cfg["input_padding_val"]
self.cv_char_modelname = cfg["cv_char_modelname"]
self.char_plot_shape = cfg["char_plot_shape"]
self.remove_timm_classifier_head_pooling = cfg["remove_timm_classifier_head_pooling"]
self.change_pooling_for_timm_head_to = cfg["change_pooling_for_timm_head_to"]
self.chars_conv_pooling_out_dim = cfg["chars_conv_pooling_out_dim"]
self.add_layer_norm_to_char_mlp = cfg["add_layer_norm_to_char_mlp"]
if "profile_torch_run" in cfg:
self.profile_torch_run = cfg["profile_torch_run"]
else:
self.profile_torch_run = False
if self.loss_func == "OrdinalRegLoss":
self.out_shape = 1
else:
self.out_shape = cfg["num_classes"]
if not self.hparams.cfg["only_use_2nd_input_stream"]:
if (
self.method_chars_into_model == "dense"
and self.use_char_embed_info
and self.method_to_include_char_positions == "concat"
):
self.project = nn.Linear(self.x_shape[-1], self.hidden_dim // 2, bias=self.use_in_projection_bias)
elif (
self.method_chars_into_model == "bert"
and self.use_char_embed_info
and self.method_to_include_char_positions == "concat"
):
self.hidden_dim_chars = self.hidden_dim // 2
self.project = nn.Linear(self.x_shape[-1], self.hidden_dim_chars, bias=self.use_in_projection_bias)
elif (
self.method_chars_into_model == "resnet"
and self.method_to_include_char_positions == "concat"
and self.use_char_embed_info
):
self.project = nn.Linear(self.x_shape[-1], self.hidden_dim // 2, bias=self.use_in_projection_bias)
elif self.model_to_use == "cv_only_model":
self.project = nn.Identity()
else:
self.project = nn.Linear(self.x_shape[-1], self.hidden_dim, bias=self.use_in_projection_bias)
if self.add_layer_norm_to_in_projection:
self.project = nn.Sequential(
nn.Linear(self.project.in_features, self.project.out_features, bias=self.use_in_projection_bias),
nn.LayerNorm(self.project.out_features),
)
if hasattr(self, "project") and "posix" in os.name and global_settings["try_using_torch_compile"]:
self.project = t.compile(self.project)
if self.use_char_embed_info:
self._create_char_model()
if self.layer_norm_after_in_projection:
if self.hparams.cfg["only_use_2nd_input_stream"]:
self.layer_norm_in = nn.LayerNorm(self.hidden_dim // 2)
else:
self.layer_norm_in = nn.LayerNorm(self.hidden_dim)
if "posix" in os.name and global_settings["try_using_torch_compile"]:
self.layer_norm_in = t.compile(self.layer_norm_in)
self._create_main_seq_model(cfg)
if register_forw_hook:
self.register_hooks()
if self.hparams.cfg["only_use_2nd_input_stream"]:
linear_in_dim = self.hidden_dim // 2
else:
linear_in_dim = self.hidden_dim
if self.num_lin_layers == 1:
self.linear = nn.Linear(linear_in_dim, self.out_shape)
else:
lin_layers = []
for _ in range(self.num_lin_layers - 1):
lin_layers.extend(
[
nn.Linear(linear_in_dim, linear_in_dim),
getattr(nn, self.linear_activation)(),
]
)
self.linear = nn.Sequential(*lin_layers, nn.Linear(linear_in_dim, self.out_shape))
if "posix" in os.name and global_settings["try_using_torch_compile"]:
self.linear = t.compile(self.linear)
if self.last_activation == "Softmax":
self.final_activation = nn.Softmax(dim=-1)
elif self.last_activation == "Sigmoid":
self.final_activation = nn.Sigmoid()
elif self.last_activation == "Identity":
self.final_activation = nn.Identity()
else:
raise NotImplementedError(f"{self.last_activation} not implemented")
if self.profile_torch_run:
self.profilerr = t.profiler.profile(
schedule=t.profiler.schedule(wait=1, warmup=10, active=10, repeat=1),
on_trace_ready=t.profiler.tensorboard_trace_handler("tblogs"),
with_stack=True,
record_shapes=True,
profile_memory=False,
)
def _create_main_seq_model(self, cfg):
if self.hparams.cfg["only_use_2nd_input_stream"]:
hidden_dim = self.hidden_dim // 2
else:
hidden_dim = self.hidden_dim
if self.model_to_use == "BERT":
self.bert_config = transformers.BertConfig(
vocab_size=self.x_shape[-1],
hidden_size=hidden_dim,
num_hidden_layers=self.num_layers,
intermediate_size=hidden_dim,
num_attention_heads=self.num_attention_heads,
max_position_embeddings=self.max_seq_length,
)
self.bert_model = transformers.BertModel(self.bert_config)
elif self.model_to_use == "cv_only_model":
self.bert_model = CVModel(
modelname=cfg["cv_modelname"],
in_shape=self.in_shape,
num_classes=cfg["num_classes"],
loss_func=cfg["loss_function"],
last_activation=cfg["last_activation"],
input_padding_val=cfg["input_padding_val"],
char_dims=self.char_dims,
max_seq_length=cfg["manual_max_sequence_for_model"],
)
else:
raise NotImplementedError(f"{self.model_to_use} not implemented")
if "posix" in os.name and global_settings["try_using_torch_compile"]:
self.bert_model = t.compile(self.bert_model)
return 0
def _create_char_model(self):
if self.method_chars_into_model == "dense":
self.chars_project_0 = nn.Linear(self.char_dims, 1, bias=self.use_in_projection_bias)
if "posix" in os.name and global_settings["try_using_torch_compile"]:
self.chars_project_0 = t.compile(self.chars_project_0)
if self.method_to_include_char_positions == "concat":
self.chars_project_1 = nn.Linear(
self.char_sequence_length, self.hidden_dim // 2, bias=self.use_in_projection_bias
)
else:
self.chars_project_1 = nn.Linear(
self.char_sequence_length, self.hidden_dim, bias=self.use_in_projection_bias
)
if "posix" in os.name and global_settings["try_using_torch_compile"]:
self.chars_project_1 = t.compile(self.chars_project_1)
elif not self.method_chars_into_model == "resnet":
self.chars_project = nn.Linear(self.char_dims, self.hidden_dim_chars, bias=self.use_in_projection_bias)
if "posix" in os.name and global_settings["try_using_torch_compile"]:
self.chars_project = t.compile(self.chars_project)
if self.method_chars_into_model == "bert":
if not hasattr(self, "hidden_dim_chars"):
if self.hidden_dim // self.chars_bert_reduction_factor > 1:
self.hidden_dim_chars = self.hidden_dim // self.chars_bert_reduction_factor
else:
self.hidden_dim_chars = self.hidden_dim
self.num_attention_heads_chars = self.hidden_dim_chars // (self.hidden_dim // self.num_attention_heads)
self.chars_bert_config = transformers.BertConfig(
vocab_size=self.x_shape[-1],
hidden_size=self.hidden_dim_chars,
num_hidden_layers=self.num_layers,
intermediate_size=self.hidden_dim_chars,
num_attention_heads=self.num_attention_heads_chars,
max_position_embeddings=self.char_sequence_length + 1,
num_labels=1,
)
self.chars_bert = transformers.BertForSequenceClassification(self.chars_bert_config)
if "posix" in os.name and global_settings["try_using_torch_compile"]:
self.chars_bert = t.compile(self.chars_bert)
self.chars_project_class_output = nn.Linear(1, self.hidden_dim_chars, bias=self.use_in_projection_bias)
if "posix" in os.name and global_settings["try_using_torch_compile"]:
self.chars_project_class_output = t.compile(self.chars_project_class_output)
elif self.method_chars_into_model == "resnet":
if self.source_for_pretrained_cv_model == "timm":
self.chars_conv = timm.create_model(
self.cv_char_modelname,
pretrained=True,
num_classes=0, # remove classifier nn.Linear
)
if self.remove_timm_classifier_head_pooling:
self.chars_conv.head = TimmHeadReplace(all_identity=True)
with t.inference_mode():
test_out = self.chars_conv(
t.ones((1, 3, self.char_plot_shape[0], self.char_plot_shape[1]), dtype=t.float32)
)
if test_out.ndim > 3:
self.chars_conv.head = TimmHeadReplace(
self.change_pooling_for_timm_head_to,
test_out.shape[1],
)
elif self.source_for_pretrained_cv_model == "huggingface":
self.chars_conv = transformers.AutoModelForImageClassification.from_pretrained(self.cv_char_modelname)
elif self.source_for_pretrained_cv_model == "torch_hub":
self.chars_conv = t.hub.load(*self.cv_char_modelname.split(","))
if hasattr(self.chars_conv, "classifier"):
self.chars_conv.classifier = nn.Identity()
elif hasattr(self.chars_conv, "cls_classifier"):
self.chars_conv.cls_classifier = nn.Identity()
elif hasattr(self.chars_conv, "fc"):
self.chars_conv.fc = nn.Identity()
if hasattr(self.chars_conv, "distillation_classifier"):
self.chars_conv.distillation_classifier = nn.Identity()
with t.inference_mode():
test_out = self.chars_conv(
t.ones((1, 3, self.char_plot_shape[0], self.char_plot_shape[1]), dtype=t.float32)
)
if hasattr(test_out, "last_hidden_state"):
self.chars_conv_out_dim = test_out.last_hidden_state.shape[1]
elif hasattr(test_out, "logits"):
self.chars_conv_out_dim = test_out.logits.shape[1]
elif isinstance(test_out, list):
self.chars_conv_out_dim = test_out[0].shape[1]
else:
self.chars_conv_out_dim = test_out.shape[1]
char_lin_layers = [nn.Flatten(), nn.Linear(self.chars_conv_out_dim, self.hidden_dim // 2)]
if self.add_layer_norm_to_char_mlp:
char_lin_layers.append(nn.LayerNorm(self.hidden_dim // 2))
self.chars_classifier = nn.Sequential(*char_lin_layers)
if hasattr(self.chars_conv, "distillation_classifier"):
self.chars_conv.distillation_classifier = nn.Sequential(
nn.Flatten(), nn.Linear(self.chars_conv_out_dim, self.hidden_dim // 2)
)
if "posix" in os.name and global_settings["try_using_torch_compile"]:
self.chars_classifier = t.compile(self.chars_classifier)
if "posix" in os.name and global_settings["try_using_torch_compile"]:
self.chars_conv = t.compile(self.chars_conv)
return 0
def register_hooks(self):
def add_to_tb(layer):
def hook(model, input, output):
if hasattr(output, "detach"):
for logger in self.loggers:
if hasattr(logger.experiment, "add_histogram"):
logger.experiment.add_histogram(
tag=f"{layer}_{str(list(output.shape))}",
values=output.detach(),
global_step=self.trainer.global_step,
)
return hook
for layer_id, layer in dict([*self.named_modules()]).items():
layer.register_forward_hook(add_to_tb(f"act_{layer_id}"))
def on_after_backward(self) -> None:
if self.track_gradient_histogram:
if self.trainer.global_step % 200 == 0:
for logger in self.loggers:
if hasattr(logger.experiment, "add_histogram"):
for layer_id, layer in dict([*self.named_modules()]).items():
parameters = layer.parameters()
for idx2, p in enumerate(parameters):
grad_val = p.grad
if grad_val is not None:
grad_name = f"grad_{idx2}_{layer_id}_{str(list(p.grad.shape))}"
logger.experiment.add_histogram(
tag=grad_name, values=grad_val, global_step=self.trainer.global_step
)
return super().on_after_backward()
def _fold_in_seq_dim(self, out, y):
batch_size, seq_len, num_classes = out.shape
out = eo.rearrange(out, "b s c -> (b s) c", s=seq_len)
if y is None:
return out, None
if len(y.shape) > 2:
y = eo.rearrange(y, "b s c -> (b s) c", s=seq_len)
else:
y = eo.rearrange(y, "b s -> (b s)", s=seq_len)
return out, y
def _get_loss(self, out, y, batch):
attention_mask = batch[-2]
if self.loss_func == "BCELoss":
if self.last_activation == "Identity":
loss = t.nn.functional.binary_cross_entropy_with_logits(out, y, reduction="none")
else:
loss = t.nn.functional.binary_cross_entropy(out, y, reduction="none")
replace_tensor = t.zeros(loss[1, 1, :].shape, device=loss.device, dtype=loss.dtype, requires_grad=False)
loss[~attention_mask.bool()] = replace_tensor
loss = loss.mean()
elif self.loss_func == "CrossEntropyLoss":
if len(out.shape) > 2:
out, y = self._fold_in_seq_dim(out, y)
loss = t.nn.functional.cross_entropy(out, y, reduction="mean", ignore_index=-100)
else:
loss = t.nn.functional.cross_entropy(out, y, reduction="mean", ignore_index=-100)
elif self.loss_func == "OrdinalRegLoss":
loss = t.nn.functional.mse_loss(out, y, reduction="none")
loss = loss[attention_mask.bool()].sum() * 10.0 / attention_mask.sum()
elif self.loss_func == "macro_soft_f1":
loss = macro_soft_f1(y, out, reduction="mean")
elif self.loss_func == "coral_loss":
loss = coral_loss(out, y)
elif self.loss_func == "corn_loss":
out, y = self._fold_in_seq_dim(out, y)
loss = corn_loss(out, y.squeeze(), self.out_shape)
else:
raise ValueError("Loss Function not reckognized")
return loss
def training_step(self, batch, batch_idx):
if self.profile_torch_run:
self.profilerr.step()
out, y = self.model_step(batch, batch_idx)
loss = self._get_loss(out, y, batch)
self.log("train_loss", loss, on_epoch=True, on_step=True, sync_dist=True)
return loss
def forward(*args):
return forward(args[0], args[1:])
def model_step(self, batch, batch_idx):
out = self.forward(batch)
return out, batch[-1]
def optimizer_step(
self,
epoch,
batch_idx,
optimizer,
optimizer_closure,
):
optimizer.step(closure=optimizer_closure)
if self.use_lr_warmup and self.hparams["cfg"]["lr_scheduling"] != "OneCycleLR":
if self.trainer.global_step < self.num_warmup_steps:
lr_scale = min(1.0, float(self.trainer.global_step + 1) / self.num_warmup_steps) ** self.warmup_exponent
for pg in optimizer.param_groups:
pg["lr"] = lr_scale * self.hparams.learning_rate
if self.trainer.global_step % 10 == 0 or self.trainer.global_step == 0:
for idx, pg in enumerate(optimizer.param_groups):
self.log(f"lr_{idx}", pg["lr"], prog_bar=True, sync_dist=True)
def lr_scheduler_step(self, scheduler: LRSchedulerTypeUnion, metric: Any | None) -> None:
if self.use_lr_warmup and self.hparams["cfg"]["lr_scheduling"] != "OneCycleLR":
if self.trainer.global_step > self.num_warmup_steps:
if metric is None:
scheduler.step()
else:
scheduler.step(metric)
else:
if metric is None:
scheduler.step()
else:
scheduler.step(metric)
def _get_preds_reals(self, out, y):
if self.loss_func == "corn_loss":
seq_len = out.shape[1]
out, y = self._fold_in_seq_dim(out, y)
preds = corn_label_from_logits(out)
preds = eo.rearrange(preds, "(b s) -> b s", s=seq_len)
if y is not None:
y = eo.rearrange(y.squeeze(), "(b s) -> b s", s=seq_len)
elif self.loss_func == "OrdinalRegLoss":
preds = out * (self.ord_reg_loss_max - self.ord_reg_loss_min)
preds = (preds + self.ord_reg_loss_min).round().to(t.long)
else:
preds = t.argmax(out, dim=-1)
if y is None:
return preds, y, -100
else:
if self.using_one_hot_targets:
y_onecold = t.argmax(y, dim=-1)
ignore_index_val = 0
elif self.loss_func == "OrdinalRegLoss":
y_onecold = (y * self.num_classes).round().to(t.long)
y_onecold = y * (self.ord_reg_loss_max - self.ord_reg_loss_min)
y_onecold = (y_onecold + self.ord_reg_loss_min).round().to(t.long)
ignore_index_val = t.min(y_onecold).to(t.long)
else:
y_onecold = y
ignore_index_val = -100
if len(preds.shape) > len(y_onecold.shape):
preds = preds.squeeze()
return preds, y_onecold, ignore_index_val
def validation_step(self, batch, batch_idx):
out, y = self.model_step(batch, batch_idx)
preds, y_onecold, ignore_index_val = self._get_preds_reals(out, y)
if self.loss_func == "OrdinalRegLoss":
y_onecold = y_onecold.flatten()
preds = preds.flatten()[y_onecold != ignore_index_val]
y_onecold = y_onecold[y_onecold != ignore_index_val]
acc = (preds == y_onecold).sum() / len(y_onecold)
else:
acc = torchmetrics.functional.accuracy(
preds,
y_onecold.to(t.long),
ignore_index=ignore_index_val,
num_classes=self.num_classes,
task="multiclass",
)
self.log("acc", acc * 100, prog_bar=True, sync_dist=True)
loss = self._get_loss(out, y, batch)
self.log("val_loss", loss, prog_bar=True, sync_dist=True)
return loss
def predict_step(self, batch, batch_idx):
out, y = self.model_step(batch, batch_idx)
preds, y_onecold, ignore_index_val = self._get_preds_reals(out, y)
return preds, y_onecold
def configure_optimizers(self):
params = list(self.named_parameters())
def is_chars_conv(n):
if "chars_conv" not in n:
return False
if "chars_conv" in n and "classifier" in n:
return False
else:
return True
grouped_parameters = [
{
"params": [p for n, p in params if is_chars_conv(n)],
"lr": self.learning_rate / self.chars_conv_lr_reduction_factor,
"weight_decay": self.weight_decay,
},
{
"params": [p for n, p in params if not is_chars_conv(n)],
"lr": self.learning_rate,
"weight_decay": self.weight_decay,
},
]
opti = t.optim.AdamW(grouped_parameters, lr=self.learning_rate, weight_decay=self.weight_decay)
if self.use_reduce_on_plateau:
opti_dict = {
"optimizer": opti,
"lr_scheduler": {
"scheduler": t.optim.lr_scheduler.ReduceLROnPlateau(opti, mode="min", patience=2, factor=0.5),
"monitor": "val_loss",
"frequency": 1,
"interval": "epoch",
},
}
return opti_dict
else:
cfg = self.hparams["cfg"]
if cfg["use_reduce_on_plateau"]:
scheduler = None
elif cfg["lr_scheduling"] == "multistep":
scheduler = t.optim.lr_scheduler.MultiStepLR(
opti, milestones=cfg["multistep_milestones"], gamma=cfg["gamma_multistep"], verbose=False
)
interval = "step" if cfg["use_training_steps_for_end_and_lr_decay"] else "epoch"
elif cfg["lr_scheduling"] == "StepLR":
scheduler = t.optim.lr_scheduler.StepLR(
opti, step_size=cfg["gamma_step_size"], gamma=cfg["gamma_step_factor"]
)
interval = "step" if cfg["use_training_steps_for_end_and_lr_decay"] else "epoch"
elif cfg["lr_scheduling"] == "anneal":
scheduler = t.optim.lr_scheduler.CosineAnnealingLR(
opti, 250, eta_min=cfg["min_lr_anneal"], last_epoch=-1, verbose=False
)
interval = "step"
elif cfg["lr_scheduling"] == "ExponentialLR":
scheduler = t.optim.lr_scheduler.ExponentialLR(opti, gamma=cfg["lr_sched_exp_fac"])
interval = "step"
else:
scheduler = None
if scheduler is None:
return [opti]
else:
opti_dict = {
"optimizer": opti,
"lr_scheduler": {
"scheduler": scheduler,
"monitor": "global_step",
"frequency": 1,
"interval": interval,
},
}
return opti_dict
def on_fit_start(self) -> None:
if self.profile_torch_run:
self.profilerr.start()
return super().on_fit_start()
def on_fit_end(self) -> None:
if self.profile_torch_run:
self.profilerr.stop()
return super().on_fit_end()
def prep_model_input(self, batch):
if len(batch) == 1:
batch = batch[0]
if self.use_char_embed_info:
if len(batch) == 5:
x, chars_coords, ims, attention_mask, _ = batch
elif batch[1].ndim == 4:
x, ims, attention_mask, _ = batch
else:
x, chars_coords, attention_mask, _ = batch
padding_list = None
else:
if len(batch) > 3:
x = batch[0]
y = batch[-1]
attention_mask = batch[1]
else:
x, attention_mask, y = batch
if self.model_to_use != "cv_only_model" and not self.hparams.cfg["only_use_2nd_input_stream"]:
x_embedded = self.project(x)
else:
x_embedded = x
if self.use_char_embed_info:
if self.method_chars_into_model == "dense":
bool_mask = chars_coords == self.input_padding_val
bool_mask = bool_mask[:, :, 0]
chars_coords_projected = self.chars_project_0(chars_coords).squeeze(-1)
chars_coords_projected = chars_coords_projected * bool_mask
if self.chars_project_1.in_features == chars_coords_projected.shape[-1]:
chars_coords_projected = self.chars_project_1(chars_coords_projected)
else:
chars_coords_projected = chars_coords_projected.mean(dim=-1)
chars_coords_projected = chars_coords_projected.unsqueeze(1).repeat(1, x_embedded.shape[2])
elif self.method_chars_into_model == "bert":
chars_mask = chars_coords != self.input_padding_val
chars_mask = t.cat(
(
t.ones(chars_mask[:, :1, 0].shape, dtype=t.long, device=chars_coords.device),
chars_mask[:, :, 0].to(t.long),
),
dim=1,
)
chars_coords_projected = self.chars_project(chars_coords)
position_ids = t.arange(
0, chars_coords_projected.shape[1] + 1, dtype=t.long, device=chars_coords_projected.device
)
token_type_ids = t.zeros(
(chars_coords_projected.size()[0], chars_coords_projected.size()[1] + 1),
dtype=t.long,
device=chars_coords_projected.device,
) # +1 for CLS
chars_coords_projected = t.cat(
(t.ones_like(chars_coords_projected[:, :1, :]), chars_coords_projected), dim=1
) # to add CLS token
chars_coords_projected = self.chars_bert(
position_ids=position_ids,
inputs_embeds=chars_coords_projected,
token_type_ids=token_type_ids,
attention_mask=chars_mask,
)
if hasattr(chars_coords_projected, "last_hidden_state"):
chars_coords_projected = chars_coords_projected.last_hidden_state[:, 0, :]
elif hasattr(chars_coords_projected, "logits"):
chars_coords_projected = chars_coords_projected.logits
else:
chars_coords_projected = chars_coords_projected.hidden_states[-1][:, 0, :]
elif self.method_chars_into_model == "resnet":
chars_conv_out = self.chars_conv(ims)
if isinstance(chars_conv_out, list):
chars_conv_out = chars_conv_out[0]
if hasattr(chars_conv_out, "logits"):
chars_conv_out = chars_conv_out.logits
chars_coords_projected = self.chars_classifier(chars_conv_out)
chars_coords_projected = chars_coords_projected.unsqueeze(1).repeat(1, x_embedded.shape[1], 1)
if hasattr(self, "chars_project_class_output"):
chars_coords_projected = self.chars_project_class_output(chars_coords_projected)
if self.hparams.cfg["only_use_2nd_input_stream"]:
x_embedded = chars_coords_projected
elif self.method_to_include_char_positions == "concat":
x_embedded = t.cat((x_embedded, chars_coords_projected), dim=-1)
else:
x_embedded = x_embedded + chars_coords_projected
return x_embedded, attention_mask
def forward(self, batch):
prepped_input = prep_model_input(self, batch)
if len(batch) > 5:
x_embedded, padding_list, attention_mask, attention_mask_for_prediction = prepped_input
elif len(batch) > 2:
x_embedded, attention_mask = prepped_input
else:
x_embedded = prepped_input[0]
attention_mask = prepped_input[-1]
position_ids = t.arange(0, x_embedded.shape[1], dtype=t.long, device=x_embedded.device)
token_type_ids = t.zeros(x_embedded.size()[:-1], dtype=t.long, device=x_embedded.device)
if self.layer_norm_after_in_projection:
x_embedded = self.layer_norm_in(x_embedded)
if self.model_to_use == "LSTM":
bert_out = self.bert_model(x_embedded)
elif self.model_to_use in ["ProphetNet", "T5", "FunnelModel"]:
bert_out = self.bert_model(inputs_embeds=x_embedded, attention_mask=attention_mask)
elif self.model_to_use == "xBERT":
bert_out = self.bert_model(x_embedded, mask=attention_mask.to(bool))
elif self.model_to_use == "cv_only_model":
bert_out = self.bert_model(x_embedded)
else:
bert_out = self.bert_model(
position_ids=position_ids,
inputs_embeds=x_embedded,
token_type_ids=token_type_ids,
attention_mask=attention_mask,
)
if hasattr(bert_out, "last_hidden_state"):
last_hidden_state = bert_out.last_hidden_state
out = self.linear(last_hidden_state)
elif hasattr(bert_out, "logits"):
out = bert_out.logits
else:
out = bert_out
out = self.final_activation(out)
return out