【YOLO改进】换遍IoU损失函数之Focal IoU Loss(基于MMYOLO)

Focal IoU损失函数

在目标检测任务中,评估预测边界框的质量是一个重要环节。传统的 IoU(Intersection over Union)损失函数虽然能够评估预测边界框与真实边界框的重叠程度,但在某些情况下存在一些问题。例如,当预测边界框与真实边界框的重叠度较低时,IoU 损失函数的梯度会非常小,导致模型难以进行优化。此外,正负样本分布不均衡也是目标检测任务中的一个常见问题,这会影响模型的训练效果。

为了解决这些问题,研究者提出了 Focal IoU Loss。该损失函数旨在平衡高质量样本和低质量样本对损失的贡献,同时强调困难样本在训练过程中的重要性。

设计原理

Focal IoU Loss 的设计原理主要基于以下两个方面:

  1. 平衡高质量样本和低质量样本对损失的贡献:通过引入一个调节因子,使得高质量样本(即 IoU 值较大的样本)在损失计算中占据更大的权重,而低质量样本(即 IoU 值较小的样本)的权重则相对较小。这样可以在一定程度上避免低质量样本对模型训练的负面影响。
  2. 强调困难样本在训练过程中的重要性:类似于 Focal Loss,Focal IoU Loss 也为具有挑战性的样本(即 IoU 值较低的样本)分配更高的权重。这样可以使模型在训练过程中更加关注这些困难样本,从而提高模型的泛化能力。

计算步骤

Focal IoU Loss 的计算步骤大致如下:

  1. 计算预测边界框与真实边界框的 IoU 值。
  2. 根据 IoU 值的大小,将样本分为高质量样本和低质量样本。具体划分标准可以根据实际情况进行调整。
  3. 引入一个调节因子 α,用于平衡高质量样本和低质量样本对损失的贡献。α 的取值可以根据实际情况进行调整,通常取值在 0 到 1 之间。
  4. 对于高质量样本,直接使用原始的 IoU 损失函数进行计算;对于低质量样本,则根据 IoU 值的大小为其分配一个较小的权重。
  5. 将所有样本的损失进行加权求和,得到最终的 Focal IoU Loss。

添加 Focal IoU损失函数(基于MMYOLO)

由于MMYOLO中没有实现Focal IoU损失函数,所以需要在mmyolo/models/iou_loss.py中添加Focal IoU的计算和对应的iou_mode,修改完以后在终端运行

python setup.py install

再在配置文件中进行修改即可。

包含前面所有IoU损失函数替换的mmyolo/models/iou_loss.py如下(Focal IoU只实现了Focal CIoU,其他同理):

# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import Optional, Tuple, Union

import torch
import torch.nn as nn
from mmdet.models.losses.utils import weight_reduce_loss
from mmdet.structures.bbox import HorizontalBoxes

from mmyolo.registry import MODELS
class WIoU_Scale:
    ''' monotonous: {
            None: origin v1
            True: monotonic FM v2
            False: non-monotonic FM v3
        }
        momentum: The momentum of running mean'''

    iou_mean = 1.
    monotonous = False
    _momentum = 1 - 0.5 ** (1 / 7000)
    _is_train = True

    def __init__(self, iou):
        self.iou = iou
        self._update(self)

    @classmethod
    def _update(cls, self):
        if cls._is_train: cls.iou_mean = (1 - cls._momentum) * cls.iou_mean + \
                                         cls._momentum * self.iou.detach().mean().item()

    @classmethod
    def _scaled_loss(cls, self, gamma=1.9, delta=3):
        if isinstance(self.monotonous, bool):
            if self.monotonous:
                return (self.iou.detach() / self.iou_mean).sqrt()
            else:
                beta = self.iou.detach() / self.iou_mean
                alpha = delta * torch.pow(gamma, beta - delta)
                return beta / alpha
        return 1

def bbox_overlaps(pred: torch.Tensor,
                  target: torch.Tensor,
                  iou_mode: str = 'ciou',
                  bbox_format: str = 'xywh',
                  siou_theta: float = 4.0,
                  gamma: float = 0.5,
                  eps: float = 1e-7,
                  focal: bool = False,) -> torch.Tensor:
    r"""Calculate overlap between two set of bboxes.
    `Implementation of paper `Enhancing Geometric Factors into
    Model Learning and Inference for Object Detection and Instance
    Segmentation <https://arxiv.org/abs/2005.03572>`_.

    In the CIoU implementation of YOLOv5 and MMDetection, there is a slight
    difference in the way the alpha parameter is computed.

    mmdet version:
        alpha = (ious > 0.5).float() * v / (1 - ious + v)
    YOLOv5 version:
        alpha = v / (v - ious + (1 + eps)

    Args:
        pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2)
            or (x, y, w, h),shape (n, 4).
        target (Tensor): Corresponding gt bboxes, shape (n, 4).
        iou_mode (str): Options are ('iou', 'ciou', 'giou', 'siou').
            Defaults to "ciou".
        bbox_format (str): Options are "xywh" and "xyxy".
            Defaults to "xywh".
        siou_theta (float): siou_theta for SIoU when calculate shape cost.
            Defaults to 4.0.
        eps (float): Eps to avoid log(0).

    Returns:
        Tensor: shape (n, ).
    """
    assert iou_mode in ('iou', 'ciou', 'giou', 'siou','diou','eiou','innersiou','innerciou','shapeiou','wiou')
    assert bbox_format in ('xyxy', 'xywh')
    if bbox_format == 'xywh':
        pred = HorizontalBoxes.cxcywh_to_xyxy(pred)
        target = HorizontalBoxes.cxcywh_to_xyxy(target)

    bbox1_x1, bbox1_y1 = pred[..., 0], pred[..., 1]
    bbox1_x2, bbox1_y2 = pred[..., 2], pred[..., 3]
    bbox2_x1, bbox2_y1 = target[..., 0], target[..., 1]
    bbox2_x2, bbox2_y2 = target[..., 2], target[..., 3]

    # Overlap
    overlap = (torch.min(bbox1_x2, bbox2_x2) -
               torch.max(bbox1_x1, bbox2_x1)).clamp(0) * \
              (torch.min(bbox1_y2, bbox2_y2) -
               torch.max(bbox1_y1, bbox2_y1)).clamp(0)

    # Union
    w1, h1 = bbox1_x2 - bbox1_x1, bbox1_y2 - bbox1_y1
    w2, h2 = bbox2_x2 - bbox2_x1, bbox2_y2 - bbox2_y1
    union = (w1 * h1) + (w2 * h2) - overlap + eps

    h1 = bbox1_y2 - bbox1_y1 + eps
    h2 = bbox2_y2 - bbox2_y1 + eps

    # IoU
    ious = overlap / union

    # enclose area
    enclose_x1y1 = torch.min(pred[..., :2], target[..., :2])
    enclose_x2y2 = torch.max(pred[..., 2:], target[..., 2:])
    enclose_wh = (enclose_x2y2 - enclose_x1y1).clamp(min=0)

    enclose_w = enclose_wh[..., 0]  # cw
    enclose_h = enclose_wh[..., 1]  # ch

    if iou_mode == 'ciou':
        # CIoU = IoU - ( (ρ^2(b_pred,b_gt) / c^2) + (alpha x v) )

        # calculate enclose area (c^2)
        enclose_area = enclose_w**2 + enclose_h**2 + eps

        # calculate ρ^2(b_pred,b_gt):
        # euclidean distance between b_pred(bbox2) and b_gt(bbox1)
        # center point, because bbox format is xyxy -> left-top xy and
        # right-bottom xy, so need to / 4 to get center point.
        rho2_left_item = ((bbox2_x1 + bbox2_x2) - (bbox1_x1 + bbox1_x2))**2 / 4
        rho2_right_item = ((bbox2_y1 + bbox2_y2) -
                           (bbox1_y1 + bbox1_y2))**2 / 4
        rho2 = rho2_left_item + rho2_right_item  # rho^2 (ρ^2)

        # Width and height ratio (v)
        wh_ratio = (4 / (math.pi**2)) * torch.pow(
            torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)

        with torch.no_grad():
            alpha = wh_ratio / (wh_ratio - ious + (1 + eps))

        if focal:
            # focal-CIOU
            ious = torch.pow(ious, gamma)*(ious - ((rho2 / enclose_area) + (alpha * wh_ratio)))
            # cious = ious - ((rho2 / enclose_area) + (alpha * wh_ratio))

            # return cious, ious, gamma
        else:
            #CIOU
            ious = ious - ((rho2 / enclose_area) + (alpha * wh_ratio))

    elif iou_mode == 'giou':
        # GIoU = IoU - ( (A_c - union) / A_c )
        convex_area = enclose_w * enclose_h + eps  # convex area (A_c)
        ious = ious - (convex_area - union) / convex_area

    elif iou_mode == 'siou':
        # SIoU: https://arxiv.org/pdf/2205.12740.pdf
        # SIoU = IoU - ( (Distance Cost + Shape Cost) / 2 )

        # calculate sigma (σ):
        # euclidean distance between bbox2(pred) and bbox1(gt) center point,
        # sigma_cw = b_cx_gt - b_cx
        sigma_cw = (bbox2_x1 + bbox2_x2) / 2 - (bbox1_x1 + bbox1_x2) / 2 + eps
        # sigma_ch = b_cy_gt - b_cy
        sigma_ch = (bbox2_y1 + bbox2_y2) / 2 - (bbox1_y1 + bbox1_y2) / 2 + eps
        # sigma = √( (sigma_cw ** 2) - (sigma_ch ** 2) )
        sigma = torch.pow(sigma_cw**2 + sigma_ch**2, 0.5)

        # choose minimize alpha, sin(alpha)
        sin_alpha = torch.abs(sigma_ch) / sigma
        sin_beta = torch.abs(sigma_cw) / sigma
        sin_alpha = torch.where(sin_alpha <= math.sin(math.pi / 4), sin_alpha,
                                sin_beta)

        # Angle cost = 1 - 2 * ( sin^2 ( arcsin(x) - (pi / 4) ) )
        angle_cost = torch.cos(torch.arcsin(sin_alpha) * 2 - math.pi / 2)

        # Distance cost = Σ_(t=x,y) (1 - e ^ (- γ ρ_t))
        rho_x = (sigma_cw / enclose_w)**2  # ρ_x
        rho_y = (sigma_ch / enclose_h)**2  # ρ_y
        gamma = 2 - angle_cost  # γ
        distance_cost = (1 - torch.exp(-1 * gamma * rho_x)) + (
            1 - torch.exp(-1 * gamma * rho_y))

        # Shape cost = Ω = Σ_(t=w,h) ( ( 1 - ( e ^ (-ω_t) ) ) ^ θ )
        omiga_w = torch.abs(w1 - w2) / torch.max(w1, w2)  # ω_w
        omiga_h = torch.abs(h1 - h2) / torch.max(h1, h2)  # ω_h
        shape_cost = torch.pow(1 - torch.exp(-1 * omiga_w),
                               siou_theta) + torch.pow(
                                   1 - torch.exp(-1 * omiga_h), siou_theta)

        ious = ious - ((distance_cost + shape_cost) * 0.5)
    elif iou_mode == 'diou':
        # CIoU = IoU - ( (ρ^2(b_pred,b_gt) / c^2) + (alpha x v) )

        # calculate enclose area (c^2)
        enclose_area = enclose_w**2 + enclose_h**2 + eps

        # calculate ρ^2(b_pred,b_gt):
        # euclidean distance between b_pred(bbox2) and b_gt(bbox1)
        # center point, because bbox format is xyxy -> left-top xy and
        # right-bottom xy, so need to / 4 to get center point.
        rho2_left_item = ((bbox2_x1 + bbox2_x2) - (bbox1_x1 + bbox1_x2))**2 / 4
        rho2_right_item = ((bbox2_y1 + bbox2_y2) -
                           (bbox1_y1 + bbox1_y2))**2 / 4
        rho2 = rho2_left_item + rho2_right_item  # rho^2 (ρ^2)
        ious = ious - ((rho2 / enclose_area))
    elif iou_mode == "eiou":

        # CIoU = IoU - ( (ρ^2(b_pred,b_gt) / c^2) + (alpha x v) )

        # calculate enclose area (c^2)
        enclose_area = enclose_w**2 + enclose_h**2 + eps

        # calculate ρ^2(b_pred,b_gt):
        # euclidean distance between b_pred(bbox2) and b_gt(bbox1)
        # center point, because bbox format is xyxy -> left-top xy and
        # right-bottom xy, so need to / 4 to get center point.
        rho2_left_item = ((bbox2_x1 + bbox2_x2) - (bbox1_x1 + bbox1_x2))**2 / 4
        rho2_right_item = ((bbox2_y1 + bbox2_y2) -
                           (bbox1_y1 + bbox1_y2))**2 / 4
        rho2 = rho2_left_item + rho2_right_item  # rho^2 (ρ^2)
        rho_w2 = ((bbox2_x2 - bbox2_x1) - (bbox1_x2 - bbox1_x1)) ** 2
        rho_h2 = ((bbox2_y2 - bbox2_y1) - (bbox1_y2 - bbox1_y1)) ** 2
        cw2 = enclose_w ** 2 + eps
        ch2 = enclose_h ** 2 + eps
        ious = ious - (rho2 / enclose_area + rho_w2 / cw2 + rho_h2 / ch2)
    elif iou_mode == "innersiou":
        ratio=1.0
        w1_, h1_, w2_, h2_ = w1 / 2, h1 / 2, w2 / 2, h2 / 2
        x1 = bbox1_x1 + w1_
        y1 = bbox1_y1 + h1_
        x2 = bbox2_x1 + w2_
        y2 = bbox2_y1 + h2_

        inner_b1_x1, inner_b1_x2, inner_b1_y1, inner_b1_y2 = x1 - w1_ * ratio, x1 + w1_ * ratio, \
                                                             y1 - h1_ * ratio, y1 + h1_ * ratio
        inner_b2_x1, inner_b2_x2, inner_b2_y1, inner_b2_y2 = x2 - w2_ * ratio, x2 + w2_ * ratio, \
                                                             y2 - h2_ * ratio, y2 + h2_ * ratio
        inner_inter = (torch.min(inner_b1_x2, inner_b2_x2) - torch.max(inner_b1_x1, inner_b2_x1)).clamp(0) * \
                      (torch.min(inner_b1_y2, inner_b2_y2) - torch.max(inner_b1_y1, inner_b2_y1)).clamp(0)
        inner_union = w1 * ratio * h1 * ratio + w2 * ratio * h2 * ratio - inner_inter + eps
        inner_iou = inner_inter / inner_union

        cw = torch.max(bbox1_x2, bbox2_x2) - torch.min(bbox1_x1, bbox2_x1)  # convex width
        ch = torch.max(bbox1_y2, bbox2_y2) - torch.min(bbox1_y1, bbox2_y1)  # convex height
        s_cw = (bbox2_x1 + bbox2_x2 - bbox1_x1 - bbox1_x2) * 0.5 + eps
        s_ch = (bbox2_y1 + bbox2_y2 - bbox1_y1 - bbox1_y2) * 0.5 + eps
        sigma = torch.pow(s_cw ** 2 + s_ch ** 2, 0.5)
        sin_alpha_1 = torch.abs(s_cw) / sigma
        sin_alpha_2 = torch.abs(s_ch) / sigma
        threshold = pow(2, 0.5) / 2
        sin_alpha = torch.where(sin_alpha_1 > threshold, sin_alpha_2, sin_alpha_1)
        angle_cost = torch.cos(torch.arcsin(sin_alpha) * 2 - math.pi / 2)
        rho_x = (s_cw / cw) ** 2
        rho_y = (s_ch / ch) ** 2
        gamma = angle_cost - 2
        distance_cost = 2 - torch.exp(gamma * rho_x) - torch.exp(gamma * rho_y)
        omiga_w = torch.abs(w1 - w2) / torch.max(w1, w2)
        omiga_h = torch.abs(h1 - h2) / torch.max(h1, h2)
        shape_cost = torch.pow(1 - torch.exp(-1 * omiga_w), 4) + torch.pow(1 - torch.exp(-1 * omiga_h), 4)
        ious = inner_iou - 0.5 * (distance_cost + shape_cost)
    elif iou_mode == "innerciou":
        ratio=1.0
        w1_, h1_, w2_, h2_ = w1 / 2, h1 / 2, w2 / 2, h2 / 2
        x1 = bbox1_x1 + w1_
        y1 = bbox1_y1 + h1_
        x2 = bbox2_x1 + w2_
        y2 = bbox2_y1 + h2_

        inner_b1_x1, inner_b1_x2, inner_b1_y1, inner_b1_y2 = x1 - w1_ * ratio, x1 + w1_ * ratio, \
                                                             y1 - h1_ * ratio, y1 + h1_ * ratio
        inner_b2_x1, inner_b2_x2, inner_b2_y1, inner_b2_y2 = x2 - w2_ * ratio, x2 + w2_ * ratio, \
                                                             y2 - h2_ * ratio, y2 + h2_ * ratio
        inner_inter = (torch.min(inner_b1_x2, inner_b2_x2) - torch.max(inner_b1_x1, inner_b2_x1)).clamp(0) * \
                      (torch.min(inner_b1_y2, inner_b2_y2) - torch.max(inner_b1_y1, inner_b2_y1)).clamp(0)
        inner_union = w1 * ratio * h1 * ratio + w2 * ratio * h2 * ratio - inner_inter + eps
        inner_iou = inner_inter / inner_union

        # CIoU = IoU - ( (ρ^2(b_pred,b_gt) / c^2) + (alpha x v) )

        # calculate enclose area (c^2)
        enclose_area = enclose_w**2 + enclose_h**2 + eps

        # calculate ρ^2(b_pred,b_gt):
        # euclidean distance between b_pred(bbox2) and b_gt(bbox1)
        # center point, because bbox format is xyxy -> left-top xy and
        # right-bottom xy, so need to / 4 to get center point.
        rho2_left_item = ((bbox2_x1 + bbox2_x2) - (bbox1_x1 + bbox1_x2))**2 / 4
        rho2_right_item = ((bbox2_y1 + bbox2_y2) -
                           (bbox1_y1 + bbox1_y2))**2 / 4
        rho2 = rho2_left_item + rho2_right_item  # rho^2 (ρ^2)

        # Width and height ratio (v)
        wh_ratio = (4 / (math.pi**2)) * torch.pow(
            torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)

        with torch.no_grad():
            alpha = wh_ratio / (wh_ratio - ious + (1 + eps))

        # innerCIoU
        ious = inner_iou - ((rho2 / enclose_area) + (alpha * wh_ratio))
    elif iou_mode == "shapeiou":
        scale = 0
        # Shape-Distance    #Shape-Distance    #Shape-Distance    #Shape-Distance    #Shape-Distance    #Shape-Distance    #Shape-Distance
        ww = 2 * torch.pow(w2, scale) / (torch.pow(w2, scale) + torch.pow(h2, scale))
        hh = 2 * torch.pow(h2, scale) / (torch.pow(w2, scale) + torch.pow(h2, scale))
        cw = torch.max(bbox1_x2, bbox2_x2) - torch.min(bbox1_x1, bbox2_x1)  # convex width
        ch = torch.max(bbox1_y2, bbox2_y2) - torch.min(bbox1_y1, bbox2_y1)  # convex height
        c2 = cw ** 2 + ch ** 2 + eps  # convex diagonal squared
        center_distance_x = ((bbox2_x1 + bbox2_x2 - bbox1_x1 - bbox1_x2) ** 2) / 4
        center_distance_y = ((bbox2_y1 + bbox2_y2 - bbox1_y1 - bbox1_y2) ** 2) / 4
        center_distance = hh * center_distance_x + ww * center_distance_y
        distance = center_distance / c2

        # Shape-Shape    #Shape-Shape    #Shape-Shape    #Shape-Shape    #Shape-Shape    #Shape-Shape    #Shape-Shape    #Shape-Shape
        omiga_w = hh * torch.abs(w1 - w2) / torch.max(w1, w2)
        omiga_h = ww * torch.abs(h1 - h2) / torch.max(h1, h2)
        shape_cost = torch.pow(1 - torch.exp(-1 * omiga_w), 4) + torch.pow(1 - torch.exp(-1 * omiga_h), 4)

        # Shape-IoU    #Shape-IoU    #Shape-IoU    #Shape-IoU    #Shape-IoU    #Shape-IoU    #Shape-IoU    #Shape-IoU    #Shape-IoU



        ious = ious - distance - 0.5 * (shape_cost)
    elif iou_mode == "wiou":
        # CIoU = IoU - ( (ρ^2(b_pred,b_gt) / c^2) + (alpha x v) )
            enclose_area = enclose_w**2 + enclose_h**2 + eps

            # calculate ρ^2(b_pred,b_gt):
            # euclidean distance between b_pred(bbox2) and b_gt(bbox1)
            # center point, because bbox format is xyxy -> left-top xy and
            # right-bottom xy, so need to / 4 to get center point.
            rho2_left_item = ((bbox2_x1 + bbox2_x2) - (bbox1_x1 + bbox1_x2))**2 / 4
            rho2_right_item = ((bbox2_y1 + bbox2_y2) -
                               (bbox1_y1 + bbox1_y2))**2 / 4
            rho2 = rho2_left_item + rho2_right_item  # rho^2 (ρ^2)
            obj = WIoU_Scale(ious)
            wise_iou_loss1 = getattr(obj,'_scaled_loss')(obj)
            wise_iou_loss2 = (1-ious)* torch.exp((rho2 / enclose_area))

            return wise_iou_loss1,wise_iou_loss2,ious.clamp(min=-1.0, max=1.0)

    return ious.clamp(min=-1.0, max=1.0)


@MODELS.register_module()
class IoULoss(nn.Module):
    """IoULoss.

    Computing the IoU loss between a set of predicted bboxes and target bboxes.
    Args:
        iou_mode (str): Options are "ciou".
            Defaults to "ciou".
        bbox_format (str): Options are "xywh" and "xyxy".
            Defaults to "xywh".
        eps (float): Eps to avoid log(0).
        reduction (str): Options are "none", "mean" and "sum".
        loss_weight (float): Weight of loss.
        return_iou (bool): If True, return loss and iou.
    """

    def __init__(self,
                 iou_mode: str = 'ciou',
                 bbox_format: str = 'xywh',
                 eps: float = 1e-7,
                 reduction: str = 'mean',
                 loss_weight: float = 1.0,
                 focal: bool = False,
                 return_iou: bool = True):
        super().__init__()
        assert bbox_format in ('xywh', 'xyxy')
        assert iou_mode in ('ciou', 'siou', 'giou','diou','eiou','innersiou','innerciou','shapeiou','wiou')
        self.iou_mode = iou_mode
        self.bbox_format = bbox_format
        self.eps = eps
        self.reduction = reduction
        self.loss_weight = loss_weight
        self.return_iou = return_iou
        self.focal = focal
    def forward(
        self,
        pred: torch.Tensor,
        target: torch.Tensor,
        weight: Optional[torch.Tensor] = None,
        avg_factor: Optional[float] = None,
        reduction_override: Optional[Union[str, bool]] = None
    ) -> Tuple[Union[torch.Tensor, torch.Tensor], torch.Tensor]:
        """Forward function.

        Args:
            pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2)
                or (x, y, w, h),shape (n, 4).
            target (Tensor): Corresponding gt bboxes, shape (n, 4).
            weight (Tensor, optional): Element-wise weights.
            avg_factor (float, optional): Average factor when computing the
                mean of losses.
            reduction_override (str, bool, optional): Same as built-in losses
                of PyTorch. Defaults to None.
        Returns:
            loss or tuple(loss, iou):
        """
        if weight is not None and not torch.any(weight > 0):
            if pred.dim() == weight.dim() + 1:
                weight = weight.unsqueeze(1)
            return (pred * weight).sum()  # 0
        assert reduction_override in (None, 'none', 'mean', 'sum')
        reduction = (
            reduction_override if reduction_override else self.reduction)

        if weight is not None and weight.dim() > 1:
            weight = weight.mean(-1)

        iou = bbox_overlaps(
            pred,
            target,
            iou_mode=self.iou_mode,
            bbox_format=self.bbox_format,
            eps=self.eps)
        if type(iou) is tuple:
            loss = self.loss_weight * weight_reduce_loss(1.0 - iou[2], weight,
                                                         reduction, avg_factor)
            loss += weight_reduce_loss((iou[0]*iou[1]).mean(),weight,
                                       reduction,avg_factor)
            iou = iou[2]
        else:
            loss = self.loss_weight * weight_reduce_loss(1.0 - iou, weight,
                                                     reduction, avg_factor)

        if self.return_iou:
            return loss, iou
        else:
            return loss

修改后的配置文件(以configs/yolov5/yolov5_s-v61_syncbn_8xb16-300e_coco.py为例)

_base_ = ['../_base_/default_runtime.py', '../_base_/det_p5_tta.py']

# ========================Frequently modified parameters======================
# -----data related-----
data_root = 'data/coco/'  # Root path of data
# Path of train annotation file
train_ann_file = 'annotations/instances_train2017.json'
train_data_prefix = 'train2017/'  # Prefix of train image path
# Path of val annotation file
val_ann_file = 'annotations/instances_val2017.json'
val_data_prefix = 'val2017/'  # Prefix of val image path

num_classes = 80  # Number of classes for classification
# Batch size of a single GPU during training
train_batch_size_per_gpu = 16
# Worker to pre-fetch data for each single GPU during training
train_num_workers = 8
# persistent_workers must be False if num_workers is 0
persistent_workers = True

# -----model related-----
# Basic size of multi-scale prior box
anchors = [
    [(10, 13), (16, 30), (33, 23)],  # P3/8
    [(30, 61), (62, 45), (59, 119)],  # P4/16
    [(116, 90), (156, 198), (373, 326)]  # P5/32
]

# -----train val related-----
# Base learning rate for optim_wrapper. Corresponding to 8xb16=128 bs
base_lr = 0.01
max_epochs = 300  # Maximum training epochs

model_test_cfg = dict(
    # The config of multi-label for multi-class prediction.
    multi_label=True,
    # The number of boxes before NMS
    nms_pre=30000,
    score_thr=0.001,  # Threshold to filter out boxes.
    nms=dict(type='nms', iou_threshold=0.65),  # NMS type and threshold
    max_per_img=300)  # Max number of detections of each image

# ========================Possible modified parameters========================
# -----data related-----
img_scale = (640, 640)  # width, height
# Dataset type, this will be used to define the dataset
dataset_type = 'YOLOv5CocoDataset'
# Batch size of a single GPU during validation
val_batch_size_per_gpu = 1
# Worker to pre-fetch data for each single GPU during validation
val_num_workers = 2

# Config of batch shapes. Only on val.
# It means not used if batch_shapes_cfg is None.
batch_shapes_cfg = dict(
    type='BatchShapePolicy',
    batch_size=val_batch_size_per_gpu,
    img_size=img_scale[0],
    # The image scale of padding should be divided by pad_size_divisor
    size_divisor=32,
    # Additional paddings for pixel scale
    extra_pad_ratio=0.5)

# -----model related-----
# The scaling factor that controls the depth of the network structure
deepen_factor = 0.33
# The scaling factor that controls the width of the network structure
widen_factor = 0.5
# Strides of multi-scale prior box
strides = [8, 16, 32]
num_det_layers = 3  # The number of model output scales
norm_cfg = dict(type='BN', momentum=0.03, eps=0.001)  # Normalization config

# -----train val related-----
affine_scale = 0.5  # YOLOv5RandomAffine scaling ratio
loss_cls_weight = 0.5
loss_bbox_weight = 0.05
loss_obj_weight = 1.0
prior_match_thr = 4.  # Priori box matching threshold
# The obj loss weights of the three output layers
obj_level_weights = [4., 1., 0.4]
lr_factor = 0.01  # Learning rate scaling factor
weight_decay = 0.0005
# Save model checkpoint and validation intervals
save_checkpoint_intervals = 10
# The maximum checkpoints to keep.
max_keep_ckpts = 3
# Single-scale training is recommended to
# be turned on, which can speed up training.
env_cfg = dict(cudnn_benchmark=True)

# ===============================Unmodified in most cases====================
model = dict(
    type='YOLODetector',
    data_preprocessor=dict(
        type='mmdet.DetDataPreprocessor',
        mean=[0., 0., 0.],
        std=[255., 255., 255.],
        bgr_to_rgb=True),
    backbone=dict(
        ##使用YOLOv8的主干网络

        type='YOLOv8CSPDarknet',
        deepen_factor=deepen_factor,
        widen_factor=widen_factor,
        norm_cfg=norm_cfg,
        act_cfg=dict(type='SiLU', inplace=True)

    ),
    neck=dict(
        type='YOLOv5PAFPN',
        deepen_factor=deepen_factor,
        widen_factor=widen_factor,
        in_channels=[256, 512, 1024],
        out_channels=[256, 512, 1024],
        num_csp_blocks=3,
        norm_cfg=norm_cfg,
        act_cfg=dict(type='SiLU', inplace=True)),
    bbox_head=dict(
        type='YOLOv5Head',
        head_module=dict(
            type='YOLOv5HeadModule',
            num_classes=num_classes,
            in_channels=[256, 512, 1024],
            widen_factor=widen_factor,
            featmap_strides=strides,
            num_base_priors=3),
        prior_generator=dict(
            type='mmdet.YOLOAnchorGenerator',
            base_sizes=anchors,
            strides=strides),
        # scaled based on number of detection layers
        loss_cls=dict(
            type='mmdet.CrossEntropyLoss',
            use_sigmoid=True,
            reduction='mean',
            loss_weight=loss_cls_weight *
            (num_classes / 80 * 3 / num_det_layers)),
        # 修改此处实现IoU损失函数的替换
        loss_bbox=dict(
            type='IoULoss',
            focal=True,
            iou_mode='ciou',
            bbox_format='xywh',
            eps=1e-7,
            reduction='mean',
            loss_weight=loss_bbox_weight * (3 / num_det_layers),
            return_iou=True),
        loss_obj=dict(
            type='mmdet.CrossEntropyLoss',
            use_sigmoid=True,
            reduction='mean',
            loss_weight=loss_obj_weight *
            ((img_scale[0] / 640)**2 * 3 / num_det_layers)),
        prior_match_thr=prior_match_thr,
        obj_level_weights=obj_level_weights),
    test_cfg=model_test_cfg)

albu_train_transforms = [
    dict(type='Blur', p=0.01),
    dict(type='MedianBlur', p=0.01),
    dict(type='ToGray', p=0.01),
    dict(type='CLAHE', p=0.01)
]

pre_transform = [
    dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args),
    dict(type='LoadAnnotations', with_bbox=True)
]

train_pipeline = [
    *pre_transform,
    dict(
        type='Mosaic',
        img_scale=img_scale,
        pad_val=114.0,
        pre_transform=pre_transform),
    dict(
        type='YOLOv5RandomAffine',
        max_rotate_degree=0.0,
        max_shear_degree=0.0,
        scaling_ratio_range=(1 - affine_scale, 1 + affine_scale),
        # img_scale is (width, height)
        border=(-img_scale[0] // 2, -img_scale[1] // 2),
        border_val=(114, 114, 114)),
    dict(
        type='mmdet.Albu',
        transforms=albu_train_transforms,
        bbox_params=dict(
            type='BboxParams',
            format='pascal_voc',
            label_fields=['gt_bboxes_labels', 'gt_ignore_flags']),
        keymap={
            'img': 'image',
            'gt_bboxes': 'bboxes'
        }),
    dict(type='YOLOv5HSVRandomAug'),
    dict(type='mmdet.RandomFlip', prob=0.5),
    dict(
        type='mmdet.PackDetInputs',
        meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'flip',
                   'flip_direction'))
]

train_dataloader = dict(
    batch_size=train_batch_size_per_gpu,
    num_workers=train_num_workers,
    persistent_workers=persistent_workers,
    pin_memory=True,
    sampler=dict(type='DefaultSampler', shuffle=True),
    dataset=dict(
        type=dataset_type,
        data_root=data_root,
        ann_file=train_ann_file,
        data_prefix=dict(img=train_data_prefix),
        filter_cfg=dict(filter_empty_gt=False, min_size=32),
        pipeline=train_pipeline))

test_pipeline = [
    dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args),
    dict(type='YOLOv5KeepRatioResize', scale=img_scale),
    dict(
        type='LetterResize',
        scale=img_scale,
        allow_scale_up=False,
        pad_val=dict(img=114)),
    dict(type='LoadAnnotations', with_bbox=True, _scope_='mmdet'),
    dict(
        type='mmdet.PackDetInputs',
        meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
                   'scale_factor', 'pad_param'))
]

val_dataloader = dict(
    batch_size=val_batch_size_per_gpu,
    num_workers=val_num_workers,
    persistent_workers=persistent_workers,
    pin_memory=True,
    drop_last=False,
    sampler=dict(type='DefaultSampler', shuffle=False),
    dataset=dict(
        type=dataset_type,
        data_root=data_root,
        test_mode=True,
        data_prefix=dict(img=val_data_prefix),
        ann_file=val_ann_file,
        pipeline=test_pipeline,
        batch_shapes_cfg=batch_shapes_cfg))

test_dataloader = val_dataloader

param_scheduler = None
optim_wrapper = dict(
    type='OptimWrapper',
    optimizer=dict(
        type='SGD',
        lr=base_lr,
        momentum=0.937,
        weight_decay=weight_decay,
        nesterov=True,
        batch_size_per_gpu=train_batch_size_per_gpu),
    constructor='YOLOv5OptimizerConstructor')

default_hooks = dict(
    param_scheduler=dict(
        type='YOLOv5ParamSchedulerHook',
        scheduler_type='linear',
        lr_factor=lr_factor,
        max_epochs=max_epochs),
    checkpoint=dict(
        type='CheckpointHook',
        interval=save_checkpoint_intervals,
        save_best='auto',
        max_keep_ckpts=max_keep_ckpts))

custom_hooks = [
    dict(
        type='EMAHook',
        ema_type='ExpMomentumEMA',
        momentum=0.0001,
        update_buffers=True,
        strict_load=False,
        priority=49)
]

val_evaluator = dict(
    type='mmdet.CocoMetric',
    proposal_nums=(100, 1, 10),
    ann_file=data_root + val_ann_file,
    metric='bbox')
test_evaluator = val_evaluator

train_cfg = dict(
    type='EpochBasedTrainLoop',
    max_epochs=max_epochs,
    val_interval=save_checkpoint_intervals)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/594492.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Vue踩坑,less与less-loader安装,版本不一致

无脑通过npm i less -D安装less之后&#xff0c;继续无脑通过npm i less-loader -D安装less-loader出现如下错误&#xff1a; 解决方法&#xff1a; 1) npm uninstall less与 npm uninstall less-loader 2) 直接对其版本&#xff1a; npm i less3.0.4 -D npm i less-loader…

Jackson 中使用 Optional

介绍 在本文中&#xff0c;我们会对 Optional 类进行一些说明&#xff0c;并且会解释下如果在使用 Optional 类的时候可能在 Jackson 中进行序列化和反序列化的过程中出现的问题。 针对上面的问题&#xff0c;本文会将会介绍在 Jackson 中如何处理 Optional 对象&#xff0c;…

vue2项目升级到vue3经历分享3

当初花了1个半月将十几个微服务从mybatis升级为mybatis-plus&#xff0c;就很自信的认为前端vue2升级到vue3也不过so so&#xff0c;世界过程中却是折腾&#xff0c;写法乱七八糟&#xff0c;不兼容的问题一大堆。 1 操作dom元素 this.$refs.subject[index].$children[0].$chil…

「 网络安全常用术语解读 」通用配置枚举CCE详解

1. 背景介绍 NIST提供了安全内容自动化协议&#xff08;Security Content Automation Protocol&#xff0c;SCAP&#xff09;为漏洞描述和评估提供一种通用语言。SCAP组件包括&#xff1a; 通用漏洞披露(Common Vulnerabilities and Exposures, CVE)&#xff1a;提供一个描述…

wordpress子比主题给文章内容加上密码查看

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言一、pandas是什么?二、使用步骤1.引入库2.读入数据第三步:文章内添加代码前言 提示:这里可以添加本文要记录的大概内容: 例如:随着人工智能的不断发展,机器学习这门技术也越来越重要,…

乳制品粘性测试:操作流程详解

乳制品粘性测试&#xff1a;操作流程详解 在乳制品生产中&#xff0c;粘性测试是确保产品质量和口感的重要步骤。通过详细的操作流程&#xff0c;我们可以准确评估乳制品的粘性&#xff0c;为产品优化和品质控制提供有力支持。本文将详细介绍乳制品粘性测试的操作流程。 一、准…

应用层协议——http协议和https协议

目录 一、HTTP协议 1、概念 2、URL 二、HTTP协议格式 1、请求协议格式 2、响应协议格式 三、HTTP的请求方法 四、HTTP的状态码 五、HTTP常见的报头 六、HTTP和HTTPS 1、HTTPS协议 2、HTTPS的加密原理 1、基本概念 2、加密原理 通常我们程序员在网络编程的时候&am…

CSDN如何转载他人文章(超简单)

经常遇到一篇很好的文章&#xff0c;但是CSDN没有转载功能&#xff0c;所以需要我们自己处理一下。以谷歌浏览器打开某篇文章为例 一、按F12打开开发者工具 二、复制转载文章内容 按Ctrlf查找content_views&#xff0c;找到这一行<div id"content_views" class&…

2024项目拆解日均引流100+精准创业粉,全程干货

详情介绍 2024年项目拆解引流高质量创业粉&#xff0c;全程干货单日轻松引流100&#xff0c;项目拆解类视频本身引流效果就非常不错&#xff0c;视频内容针对性强直击创业粉内心&#xff0c;获客非常有效。做好视频钩子&#xff0c;和后端承接&#xff0c;赚钱没那么难。 课程…

没有精益管理内容的数字化,会是成功的数字化吗?

精益管理与数字化的结合被认为是制造业和企业管理中的一种重要转型策略。精益管理注重消除浪费、提升效率和质量&#xff0c;而数字化则通过技术手段来实现这一目标&#xff0c;两者的结合能够带来更高效、更智能的生产和管理方式。 首先&#xff0c;精益管理的核心理念是价值…

什么是HTTPS证书?怎么免费申请?——值得收藏

SSL证书的核心功能在于保障互联网数据传输的安全性和网站身份的可靠性。它通过加密通信防止信息被窃取或篡改&#xff0c;同时验证网站的真实身份&#xff0c;有效抵御钓鱼攻击&#xff0c;增强用户信任。此外&#xff0c;使用SSL证书还有助于提升网站在搜索引擎中的排名&#…

Docker 入门篇(六)-- idea 打包 docker 镜像流程

环境准备&#xff1a; idea 环境&#xff1a;IntelliJ IDEA 2021.3.1 (Ultimate Edition)docker 版本&#xff1a;v. 26.1.0准备 springboot jar 文件 &#xff1a;target/DockerDemo-0.0.1-SNAPSHOT.jardocker 可视化管理工具 portainer &#xff1a;v2.6.0 一. 配置docker远…

前端面试题大合集3----网络篇

一、Http协议详解&#xff0c;http请求方式&#xff0c;http状态码 Http协议详解&#xff1a; 全称Hyper Text Transfer Protocol&#xff0c;即超文本传输协议&#xff0c;是互联网上应用最为广泛的一种网络传输协议。 是一个无状态的应用层协议&#xff0c;即不会保存客户…

【管理咨询宝藏92】国际咨询公司为大型药企数字化转型项目规划方案

本报告首发于公号“管理咨询宝藏”&#xff0c;如需阅读完整版报告内容&#xff0c;请查阅公号“管理咨询宝藏”。 【管理咨询宝藏92】国际咨询公司为大型药企数字化转型项目规划方案 【格式】PDF版本 【关键词】国际咨询公司、药企转型、数字化转型 【核心观点】 - 企业业务…

51-48 CVPR 2024 | Vlogger: make your dream a vlog 自编剧制作视频博客

24年1月&#xff0c;上海交大、上海人工智能实验室、中科院联合发布Vlogger&#xff1a;make your dream a vlog。该论文主要工作是生成超过5分钟的视频博客vlog。鉴于现有文本到视频T2V生成方法很难处理复杂的故事情节和多样化的场景&#xff0c;本文提出了一个名为Vlogger的通…

百度文库最新AI旋转验证码

上个月发现百度文库最新出了一个验证码&#xff0c;是AI生成的。内容每次可能都不一样&#xff0c;所以给识别造成 了很大困难。传统的比对放松完全失效。 一、介绍 这个是最近才出的最新验证码&#xff0c;内容主要以工厂、建筑、山峰、机器人、汽车、盆栽植物等为主。如下图…

Elasticsearch:如何使用 Java 对索引进行 ES|QL 的查询

在我之前的文章 “Elasticsearch&#xff1a;对 Java 对象的 ES|QL 查询”&#xff0c;我详细介绍了如何使用 Java 来对 ES|QL 进行查询。对于不是很熟悉 Elasticsearch 的开发者来说&#xff0c;那篇文章里的例子还是不能单独来进行运行。在今天的这篇文章中&#xff0c;我来详…

【DPU系列之】Bluefield 2 DPU卡的功能图,ConnectX网卡、ARM OS、Host OS的关系?(通过PCIe Switch连接)

核心要点&#xff1a; CX系列网卡与ARM中间有一个PCIe Swtich的硬件单元链接。 简要记录。 可以看到图中两个灰色框&#xff0c;上端是Host主机&#xff0c;下端是BlueField DPU卡。图中是BF2的图&#xff0c;是BF2用的是DDR4。DPU上的Connect系列网卡以及ARM系统之间有一个…

第一课为SimaPro的基本特征

问题&#xff1a; 咖啡机的设计中的环境影响指标。 step 1 点击Wizards&#xff0c;看到“Guided tour (with coffee)”。 在这个例子里&#xff0c; 定义了两种咖啡机&#xff1a; Sima型咖啡机 和 Pro型咖啡机&#xff0c; 具有以下规格&#xff1a; Sima型咖啡机 Pro型咖啡…
最新文章