Skip to content

API Reference

BankSource

Bases: Module

External-bank source: paste pre-staged crops from an instance bank.

The bank tensor ([B, K_bank, 5, h, w] packed as RGB+alpha+class-id) is set per training step via :meth:set_bank_batch. sample then multinomial-picks one crop per batch row, places it at the origin of a target-sized canvas (forming a synthetic source view with K_source = 1), and draws per-target (scale, translate, hflip) affine parameters. placement.source_idx = arange(B) so each target gathers from its own row of the source view.

Pinning K_source = 1 matches v0.3.0 "one paste per target" semantics; multi-crop-per-target is a future extension. Configurable placement geometry shares :class:BatchedPlacementConfig knobs with :class:IntraBatchSource.

Source code in src/segpaste/augmentation/source.py
class BankSource(nn.Module):
    """External-bank source: paste pre-staged crops from an instance bank.

    The bank tensor (``[B, K_bank, 5, h, w]`` packed as RGB+alpha+class-id)
    is set per training step via :meth:`set_bank_batch`. ``sample`` then
    multinomial-picks one crop per batch row, places it at the origin of
    a target-sized canvas (forming a synthetic source view with
    ``K_source = 1``), and draws per-target ``(scale, translate, hflip)``
    affine parameters. ``placement.source_idx = arange(B)`` so each
    target gathers from its own row of the source view.

    Pinning ``K_source = 1`` matches v0.3.0 "one paste per target"
    semantics; multi-crop-per-target is a future extension. Configurable
    placement geometry shares :class:`BatchedPlacementConfig` knobs with
    :class:`IntraBatchSource`.
    """

    def __init__(self, placement_config: BatchedPlacementConfig | None = None) -> None:
        super().__init__()
        self.placement_config = placement_config or BatchedPlacementConfig()
        self._bank_batch: Tensor | None = None

    def set_bank_batch(self, bank_batch: Tensor) -> None:
        """Stage a per-step bank tensor of shape ``[B, K_bank, 5, h, w]``."""
        if bank_batch.ndim != 5 or bank_batch.shape[2] != 5:
            shape = tuple(bank_batch.shape)
            raise ValueError(f"bank_batch must be [B, K_bank, 5, h, w], got {shape}")
        self._bank_batch = bank_batch

    def sample(
        self,
        target: PaddedBatchedDenseSample,
        valid_extent: Tensor | None,
        source_eligible: Tensor | None,  # noqa: ARG002 — bank ignores panoptic gating
        generator: torch.Generator | None,
    ) -> tuple[PaddedBatchedDenseSample, BatchedPlacement]:
        bank = self._bank_batch
        if bank is None:
            raise RuntimeError("BankSource requires set_bank_batch(...) before forward")
        b = target.batch_size
        if bank.shape[0] != b:
            raise ValueError(f"bank_batch B={bank.shape[0]} != target.batch_size {b}")
        k_bank = bank.shape[1]
        h_crop = bank.shape[3]
        w_crop = bank.shape[4]
        canvas_h, canvas_w = target.images.shape[-2:]
        device = bank.device

        # Uniform draw via randint avoids materializing a [B, K_bank] weight
        # tensor and the multinomial cumsum kernel. Future weighted-bank
        # support will substitute multinomial here.
        selected = torch.randint(0, k_bank, (b,), device=device, generator=generator)
        batch_arange = torch.arange(b, device=device)
        chosen = bank[batch_arange, selected]  # [B, 5, h_crop, w_crop]

        target_dtype = target.images.dtype
        img_canvas = torch.zeros(
            (b, 3, canvas_h, canvas_w), dtype=target_dtype, device=device
        )
        img_canvas[:, :, :h_crop, :w_crop] = chosen[:, 0:3].to(target_dtype)
        mask_canvas = torch.zeros(
            (b, 1, canvas_h, canvas_w), dtype=torch.bool, device=device
        )
        mask_canvas[:, :, :h_crop, :w_crop] = chosen[:, 3:4] > 0.5
        labels = chosen[:, 4, 0, 0].to(torch.int64).unsqueeze(1)  # [B, 1]
        boxes_xyxy = torch.zeros((b, 1, 4), dtype=torch.float32, device=device)
        boxes_xyxy[:, 0, 2] = float(w_crop)
        boxes_xyxy[:, 0, 3] = float(h_crop)
        instance_ids = torch.zeros((b, 1), dtype=torch.int32, device=device)
        instance_valid_src = torch.ones((b, 1), dtype=torch.bool, device=device)
        source_view = PaddedBatchedDenseSample(
            images=tv_tensors.Image(img_canvas),
            boxes=boxes_xyxy,
            labels=labels,
            instance_valid=instance_valid_src,
            max_instances=1,
            instance_masks=mask_canvas,
            instance_ids=instance_ids,
        )

        placement = self._sample_placement(
            target,
            source_view,
            valid_extent=valid_extent,
            generator=generator,
            device=device,
            crop_h=h_crop,
            crop_w=w_crop,
        )
        return source_view, placement

    def _sample_placement(
        self,
        target: PaddedBatchedDenseSample,
        source_view: PaddedBatchedDenseSample,
        *,
        valid_extent: Tensor | None,
        generator: torch.Generator | None,
        device: torch.device,
        crop_h: int,
        crop_w: int,
    ) -> BatchedPlacement:
        """Per-target ``(scale, translate, hflip)`` draw for the bank crop.

        Mirrors the contiguous-canvas branch of :class:`BatchedPlacementSampler`
        but specialized to ``K_source = 1`` and ``source_idx = arange(B)``.
        Patch-aligned paste is deferred to a follow-up — bank geometry is
        already canvas-aligned by construction at the bank's ``(h_crop,
        w_crop)``, so there is no ``pad_to_multiple`` quantization issue.
        """
        config = self.placement_config
        b = target.batch_size
        canvas_h, canvas_w = target.images.shape[-2:]

        if valid_extent is None:
            ve = source_view.images.new_tensor(
                [float(canvas_h), float(canvas_w)]
            ).expand(b, 2)
        else:
            ve = valid_extent.to(device=device, dtype=torch.float32)
        tgt_h = ve[:, 0]
        tgt_w = ve[:, 1]

        smin, smax = config.scale_range
        scale = (
            torch.rand((b,), generator=generator, device=device) * (smax - smin) + smin
        )
        hflip = (
            torch.empty((b,), device=device, dtype=torch.float32)
            .bernoulli_(config.hflip_probability, generator=generator)
            .bool()
        )

        # Crop spans the whole [0, w_crop] x [0, h_crop] region; effective
        # right/bottom edges are the crop dimensions.
        max_scaled_x2 = float(crop_w) * scale
        max_scaled_y2 = float(crop_h) * scale
        max_ty = torch.clamp(tgt_h - max_scaled_y2, min=0.0)
        max_tx = torch.clamp(tgt_w - max_scaled_x2, min=0.0)
        ty = torch.rand((b,), generator=generator, device=device) * max_ty
        tx = torch.rand((b,), generator=generator, device=device) * max_tx
        translate = torch.stack([ty, tx], dim=-1)

        fits = (max_scaled_y2 <= tgt_h) & (max_scaled_x2 <= tgt_w)
        do_paste = (
            torch.empty((b,), device=device, dtype=torch.float32)
            .bernoulli_(config.paste_prob, generator=generator)
            .bool()
        )
        paste_valid = (fits & do_paste).unsqueeze(-1)  # [B, 1]

        source_idx = torch.arange(b, device=device, dtype=torch.int64)
        src_valid_extent = (
            source_view.images.new_tensor([float(crop_h), float(crop_w)])
            .expand(b, 2)
            .contiguous()
        )

        return BatchedPlacement(
            source_idx=source_idx,
            translate=translate,
            scale=scale,
            hflip=hflip,
            paste_valid=paste_valid,
            src_valid_extent=src_valid_extent,
        )

set_bank_batch(bank_batch)

Stage a per-step bank tensor of shape [B, K_bank, 5, h, w].

Source code in src/segpaste/augmentation/source.py
def set_bank_batch(self, bank_batch: Tensor) -> None:
    """Stage a per-step bank tensor of shape ``[B, K_bank, 5, h, w]``."""
    if bank_batch.ndim != 5 or bank_batch.shape[2] != 5:
        shape = tuple(bank_batch.shape)
        raise ValueError(f"bank_batch must be [B, K_bank, 5, h, w], got {shape}")
    self._bank_batch = bank_batch

BatchCopyPaste

Bases: Module

Graph-compilable batched copy-paste augmentation.

Source code in src/segpaste/augmentation/batch_copy_paste.py
class BatchCopyPaste(nn.Module):
    """Graph-compilable batched copy-paste augmentation."""

    config: BatchCopyPasteConfig
    thing_classes: Tensor
    stuff_classes: Tensor

    def __init__(
        self,
        config: BatchCopyPasteConfig | None = None,
        *,
        source_strategy: SourceStrategy | None = None,
    ) -> None:
        super().__init__()
        self.config = config or BatchCopyPasteConfig()
        self.source_strategy: SourceStrategy = source_strategy or build_source_strategy(
            self.config.source, self.config.placement
        )
        self.propagator = AffinePropagator()
        self.harmonizer = ImageHarmonizer(self.config.harmonize)
        self.compositor = TileCompositor(self.config.composite)

        if self.config.panoptic is not None:
            taxonomy = self.config.panoptic.taxonomy
            things = sorted(
                cls for cls, kind in taxonomy.classes.items() if kind == "thing"
            )
            stuffs = sorted(
                cls for cls, kind in taxonomy.classes.items() if kind == "stuff"
            )
            self.register_buffer(
                "thing_classes", torch.tensor(things, dtype=torch.int64)
            )
            self.register_buffer(
                "stuff_classes", torch.tensor(stuffs, dtype=torch.int64)
            )
            self._class_table_size = max([*things, *stuffs, taxonomy.ignore_index]) + 1
            self._pad_ignore_index = taxonomy.ignore_index
        else:
            self.register_buffer("thing_classes", torch.empty((0,), dtype=torch.int64))
            self.register_buffer("stuff_classes", torch.empty((0,), dtype=torch.int64))
            self._class_table_size = 0
            self._pad_ignore_index = 255

    def forward(
        self,
        padded: PaddedBatchedDenseSample,
        generator: torch.Generator | None = None,
    ) -> PaddedBatchedDenseSample:
        if padded.batch_size == 0:
            return padded

        if self.config.placement.pad_to_multiple is not None:
            padded = pad_canvas_to_multiple(
                padded,
                self.config.placement.pad_to_multiple,
                self._pad_ignore_index,
            )

        valid_extent = self._valid_extent(padded)
        source_eligible = self._source_eligible(padded)
        source_view, placement = self.source_strategy.sample(
            padded,
            valid_extent,
            source_eligible,
            generator,
        )
        warped = self.propagator(padded, source_view, placement)
        paste_mask = self._paste_mask(warped, placement)
        warped = self.harmonizer(padded, warped, paste_mask, generator)
        composited = self.compositor(padded, warped, paste_mask)
        if self.config.panoptic is not None:
            composited, warped = self._revert_stuff_collapse(
                padded, composited, warped, paste_mask
            )
        composited = drop_occluded_targets(
            padded, composited, self.config.min_residual_area_frac
        )
        return self._merge_slots(composited, warped, placement)

    def _source_eligible(self, padded: PaddedBatchedDenseSample) -> Tensor | None:
        if self.config.panoptic is None:
            return None
        return torch.isin(padded.labels, self.thing_classes, assume_unique=True)

    def _revert_stuff_collapse(
        self,
        padded: PaddedBatchedDenseSample,
        composited: PaddedBatchedDenseSample,
        warped: PaddedBatchedDenseSample,
        paste_mask: Tensor,
    ) -> tuple[PaddedBatchedDenseSample, PaddedBatchedDenseSample]:
        """Revert paste pixels where a pre-paste stuff class collapsed (ADR-0006 §3).

        Image, semantic, and panoptic modalities are reverted to ``padded`` on
        pixels that (a) carried a stuff class pre-paste whose post-paste area
        fell below ``tau_stuff_frac`` of its pre-paste area, and (b) were
        overwritten by paste. Target instance survivors are restored and
        warped paste masks cleared so the panoptic bijection on thing pixels
        still holds downstream.
        """
        panoptic = self.config.panoptic
        if (
            panoptic is None
            or padded.semantic_maps is None
            or composited.semantic_maps is None
            or self.stuff_classes.numel() == 0
        ):
            return composited, warped

        b, h, w = paste_mask.shape
        device = paste_mask.device
        n = self._class_table_size
        tau = panoptic.tau_stuff_frac

        pre = padded.semantic_maps.as_subclass(Tensor).flatten(1)
        post = composited.semantic_maps.as_subclass(Tensor).flatten(1)
        ones = torch.ones((), dtype=torch.int64, device=device).expand_as(pre)
        before_hist = torch.zeros((b, n), dtype=torch.int64, device=device)
        after_hist = torch.zeros((b, n), dtype=torch.int64, device=device)
        before_hist.scatter_add_(1, pre, ones)
        after_hist.scatter_add_(1, post, ones)

        before = before_hist.index_select(1, self.stuff_classes).to(torch.float32)
        after = after_hist.index_select(1, self.stuff_classes).to(torch.float32)
        collapse = (before > 0) & (after < tau * before)

        collapse_table = torch.zeros((b, n), dtype=torch.bool, device=device)
        collapse_table.scatter_(
            1, self.stuff_classes.unsqueeze(0).expand(b, -1), collapse
        )
        revert = collapse_table.gather(1, pre).view(b, h, w) & paste_mask

        rev3 = revert.unsqueeze(1)
        new_image = torch.where(rev3, padded.images, composited.images)
        new_sem = torch.where(
            revert,
            padded.semantic_maps.as_subclass(Tensor),
            composited.semantic_maps.as_subclass(Tensor),
        )
        if padded.panoptic_maps is not None and composited.panoptic_maps is not None:
            new_pano: Tensor | None = torch.where(
                revert,
                padded.panoptic_maps.as_subclass(Tensor),
                composited.panoptic_maps.as_subclass(Tensor),
            )
        else:
            new_pano = None

        if composited.instance_masks is not None and padded.instance_masks is not None:
            new_target_masks: Tensor | None = composited.instance_masks | (
                padded.instance_masks & rev3
            )
        else:
            new_target_masks = composited.instance_masks

        new_warped_masks = (
            warped.instance_masks & ~rev3 if warped.instance_masks is not None else None
        )

        new_composited = replace(
            composited,
            images=tv_tensors.Image(new_image),
            semantic_maps=SemanticMap(new_sem),
            panoptic_maps=PanopticMap(new_pano) if new_pano is not None else None,
            instance_masks=new_target_masks,
        )
        new_warped = replace(warped, instance_masks=new_warped_masks)
        return new_composited, new_warped

    @staticmethod
    def _valid_extent(padded: PaddedBatchedDenseSample) -> Tensor | None:
        """Per-sample ``[B, 2]`` (h_v, w_v) bound on the unpadded image rect.

        Assumes the LSJ convention of a top-left valid rect with bottom/right
        zero-pad (:class:`FixedSizeCrop` via ``augmentation/lsj.py``).
        """
        if padded.padding_mask is None:
            return None
        not_pad = (~padded.padding_mask.as_subclass(Tensor)).squeeze(1)
        h_v = not_pad.any(dim=-1).sum(dim=-1).to(torch.float32)
        w_v = not_pad.any(dim=-2).sum(dim=-1).to(torch.float32)
        return torch.stack([h_v, w_v], dim=-1)

    @staticmethod
    def _paste_mask(
        warped: PaddedBatchedDenseSample, placement: BatchedPlacement
    ) -> Tensor:
        """Union of warped instance masks gated by ``paste_valid``.

        Returns a ``[B, H, W]`` bool tensor identifying pixels that any
        valid pasted slot contributes to.
        """
        b = warped.batch_size
        h, w = warped.images.shape[-2:]
        if warped.instance_masks is None:
            return torch.zeros((b, h, w), dtype=torch.bool, device=warped.images.device)
        gate = placement.paste_valid.view(
            placement.paste_valid.shape[0], placement.paste_valid.shape[1], 1, 1
        )
        return (warped.instance_masks & gate).any(dim=1)

    @staticmethod
    def _merge_slots(
        composited: PaddedBatchedDenseSample,
        warped: PaddedBatchedDenseSample,
        placement: BatchedPlacement,
    ) -> PaddedBatchedDenseSample:
        """Compact paste rows into the target's free slots.

        Output slot ``t`` carries the survivor-updated target row when
        ``composited.instance_valid[t]`` is ``True``, otherwise it
        receives the next pasted source row in source-slot order. When
        the number of pastes exceeds the target's free slot count, the
        surplus is dropped. The remap is computed via a ``[B, K, K]``
        rank-equality match — graph-clean and inexpensive at COCO scale.
        """
        pv = placement.paste_valid  # [B, K_s]
        free = ~composited.instance_valid  # [B, K_t]

        # Pair the n-th free target slot with the n-th source paste slot via
        # rank-equality. Ranks are unique per row, so each match row has
        # at most one True; argmax then names the source slot to gather.
        free_rank = free.long().cumsum(-1) - 1
        paste_rank = pv.long().cumsum(-1) - 1
        match = (
            free.unsqueeze(-1)
            & pv.unsqueeze(-2)
            & (free_rank.unsqueeze(-1) == paste_rank.unsqueeze(-2))
        )  # [B, K_t, K_s]
        receives = match.any(dim=-1)  # [B, K_t]
        # argmax returns 0 where no match — guarded by `receives` in `where`.
        src_k = match.long().argmax(dim=-1)  # [B, K_t] (indices into K_s)

        b, k_t = composited.instance_valid.shape
        batch_idx = torch.arange(b, device=pv.device).unsqueeze(-1).expand(b, k_t)

        def gather(src: Tensor, dst: Tensor) -> Tensor:
            sel = receives.view(b, k_t, *([1] * (src.ndim - 2)))
            return torch.where(sel, src[batch_idx, src_k], dst)

        merged_boxes = gather(warped.boxes, composited.boxes)
        merged_labels = gather(warped.labels, composited.labels)
        merged_masks = (
            gather(warped.instance_masks, composited.instance_masks)
            if warped.instance_masks is not None
            and composited.instance_masks is not None
            else composited.instance_masks
        )
        merged_ids = (
            gather(warped.instance_ids, composited.instance_ids)
            if warped.instance_ids is not None and composited.instance_ids is not None
            else composited.instance_ids
        )
        merged_valid = composited.instance_valid | receives

        return PaddedBatchedDenseSample(
            images=composited.images,
            boxes=merged_boxes,
            labels=merged_labels,
            instance_valid=merged_valid,
            max_instances=composited.max_instances,
            instance_masks=merged_masks,
            instance_ids=merged_ids,
            semantic_maps=composited.semantic_maps,
            panoptic_maps=composited.panoptic_maps,
            depth=composited.depth,
            depth_valid=composited.depth_valid,
            normals=composited.normals,
            padding_mask=composited.padding_mask,
            camera_intrinsics=composited.camera_intrinsics,
        )

BatchedDenseSample dataclass

Canonical batched container for dense-label Copy-Paste.

Stacked fields share (H, W) across the batch — LSJ preprocessing is assumed to have homogenized sample shapes. Ragged fields keep one entry per sample. B == 0 is valid (empty batches produce zero-length lists and zero-batch-dim stacked tensors).

Source code in src/segpaste/types/batched_dense_sample.py
@dataclass(frozen=True, slots=True)
class BatchedDenseSample:
    """Canonical batched container for dense-label Copy-Paste.

    Stacked fields share ``(H, W)`` across the batch — LSJ preprocessing is
    assumed to have homogenized sample shapes. Ragged fields keep one entry
    per sample. ``B == 0`` is valid (empty batches produce zero-length lists
    and zero-batch-dim stacked tensors).
    """

    images: tv_tensors.Image
    boxes: list[tv_tensors.BoundingBoxes]
    labels: list[torch.Tensor]
    instance_masks: list[InstanceMask] | None = None
    instance_ids: list[torch.Tensor] | None = None
    semantic_maps: SemanticMap | None = None
    panoptic_maps: PanopticMap | None = None
    depth: torch.Tensor | None = None
    depth_valid: torch.Tensor | None = None
    normals: torch.Tensor | None = None
    padding_mask: PaddingMask | None = None
    camera_intrinsics: list[CameraIntrinsics] | None = None

    @skip_if_compiling
    def __post_init__(self) -> None:
        b = self.images.size(0)
        if len(self.boxes) != b or len(self.labels) != b:
            raise ValueError("boxes and labels must be length B")

        if (self.instance_masks is None) ^ (self.instance_ids is None):
            raise ValueError(
                "instance_masks and instance_ids must both be set or both None"
            )
        if (
            self.instance_masks is not None
            and self.instance_ids is not None
            and (len(self.instance_masks) != b or len(self.instance_ids) != b)
        ):
            raise ValueError("instance_masks and instance_ids must be length B")

        if (self.depth is None) ^ (self.depth_valid is None):
            raise ValueError("depth and depth_valid must both be set or both None")

        stacked_shape_checks = (
            ("semantic_maps", self.semantic_maps, 3),
            ("panoptic_maps", self.panoptic_maps, 3),
            ("depth", self.depth, 4),
            ("depth_valid", self.depth_valid, 4),
            ("normals", self.normals, 4),
            ("padding_mask", self.padding_mask, 4),
        )
        for name, tensor, expected_rank in stacked_shape_checks:
            if tensor is None:
                continue
            if tensor.ndim != expected_rank:
                raise ValueError(f"{name} must have rank {expected_rank}")
            if tensor.size(0) != b:
                raise ValueError(f"{name} must have batch dim {b}")

        if self.camera_intrinsics is not None and len(self.camera_intrinsics) != b:
            raise ValueError("camera_intrinsics must be length B")

    @property
    def batch_size(self) -> int:
        return self.images.size(0)

    @staticmethod
    def from_samples(samples: list[DenseSample]) -> "BatchedDenseSample":
        """Stack a list of :class:`DenseSample` into a :class:`BatchedDenseSample`.

        All samples must share the same active modality set and the same
        ``(H, W)``. ``B == 0`` yields an empty-but-valid batch.
        """
        if not samples:
            return _empty_batch()

        active = samples[0].active_modalities()
        h, w = samples[0].image.shape[-2:]
        for s in samples[1:]:
            if s.active_modalities() != active:
                raise ValueError("all samples must share the same active modality set")
            if s.image.shape[-2:] != (h, w):
                raise ValueError("all samples must share (H, W)")

        images = tv_tensors.Image(
            torch.stack([s.image.as_subclass(torch.Tensor) for s in samples])
        )
        boxes = [s.boxes for s in samples]
        labels = [s.labels for s in samples]

        instance_masks: list[InstanceMask] | None = None
        instance_ids: list[torch.Tensor] | None = None
        if Modality.INSTANCE in active:
            # Co-optionality on DenseSample makes None structurally impossible here.
            instance_masks = cast(
                list[InstanceMask], [s.instance_masks for s in samples]
            )
            instance_ids = cast(list[torch.Tensor], [s.instance_ids for s in samples])

        semantic_maps = _stack_optional(samples, "semantic_map", wrapper=SemanticMap)
        panoptic_maps = _stack_optional(samples, "panoptic_map", wrapper=PanopticMap)
        depth = _stack_optional(samples, "depth")
        depth_valid = _stack_optional(samples, "depth_valid")
        normals = _stack_optional(samples, "normals")
        padding_mask = _stack_optional(samples, "padding_mask", wrapper=PaddingMask)

        intrinsics = [s.camera_intrinsics for s in samples]
        camera_intrinsics = (
            cast(list[CameraIntrinsics], intrinsics)
            if all(i is not None for i in intrinsics)
            else None
        )

        return BatchedDenseSample(
            images=images,
            boxes=boxes,
            labels=labels,
            instance_masks=instance_masks,
            instance_ids=instance_ids,
            semantic_maps=semantic_maps,
            panoptic_maps=panoptic_maps,
            depth=depth,
            depth_valid=depth_valid,
            normals=normals,
            padding_mask=padding_mask,
            camera_intrinsics=camera_intrinsics,
        )

    def to_padded(self, max_instances: int) -> PaddedBatchedDenseSample:
        """Pack ragged per-sample instance fields into K-padded tensors.

        Valid rows are written at slots ``[0, n_i)`` for each sample ``i`` and
        marked ``True`` in ``instance_valid``. Padded rows are zero-valued.
        Raises if any sample has more than ``max_instances`` objects.
        """
        b = self.batch_size
        k = max_instances
        device = self.images.device

        if b > 0:
            box_dtype = self.boxes[0].as_subclass(torch.Tensor).dtype
            label_dtype = self.labels[0].dtype
        else:
            box_dtype = torch.float32
            label_dtype = torch.int64

        boxes_padded = torch.zeros((b, k, 4), dtype=box_dtype, device=device)
        labels_padded = torch.zeros((b, k), dtype=label_dtype, device=device)
        instance_valid = torch.zeros((b, k), dtype=torch.bool, device=device)

        for i in range(b):
            n = self.boxes[i].size(0)
            if n > k:
                raise ValueError(
                    f"sample {i} has {n} instances, exceeds max_instances={k}"
                )
            if n > 0:
                boxes_padded[i, :n] = self.boxes[i].as_subclass(torch.Tensor)
                labels_padded[i, :n] = self.labels[i]
                instance_valid[i, :n] = True

        instance_masks_padded: torch.Tensor | None = None
        instance_ids_padded: torch.Tensor | None = None
        if self.instance_masks is not None and self.instance_ids is not None:
            h, w = self.images.shape[-2:]
            instance_masks_padded = torch.zeros(
                (b, k, h, w), dtype=torch.bool, device=device
            )
            instance_ids_padded = torch.zeros((b, k), dtype=torch.int32, device=device)
            for i in range(b):
                n = self.instance_masks[i].size(0)
                if n > 0:
                    instance_masks_padded[i, :n] = self.instance_masks[i].as_subclass(
                        torch.Tensor
                    )
                    instance_ids_padded[i, :n] = self.instance_ids[i]

        camera_intrinsics_tensor: torch.Tensor | None = None
        if self.camera_intrinsics is not None:
            camera_intrinsics_tensor = torch.tensor(
                [[c.fx, c.fy, c.cx, c.cy] for c in self.camera_intrinsics],
                dtype=torch.float32,
                device=device,
            )

        return PaddedBatchedDenseSample(
            images=self.images,
            boxes=boxes_padded,
            labels=labels_padded,
            instance_valid=instance_valid,
            max_instances=k,
            instance_masks=instance_masks_padded,
            instance_ids=instance_ids_padded,
            semantic_maps=self.semantic_maps,
            panoptic_maps=self.panoptic_maps,
            depth=self.depth,
            depth_valid=self.depth_valid,
            normals=self.normals,
            padding_mask=self.padding_mask,
            camera_intrinsics=camera_intrinsics_tensor,
        )

    @staticmethod
    def from_padded(padded: PaddedBatchedDenseSample) -> "BatchedDenseSample":
        """Unpack a :class:`PaddedBatchedDenseSample` into a ragged batch.

        Uses ``instance_valid`` as the per-sample gather mask. Reconstructs
        ``tv_tensors.BoundingBoxes`` in XYXY format (the DenseSample canonical
        convention) and unpacks the ``[B, 4]`` intrinsics tensor back into
        :class:`CameraIntrinsics` instances.
        """
        b = padded.batch_size
        h, w = padded.images.shape[-2:]

        boxes: list[tv_tensors.BoundingBoxes] = []
        labels: list[torch.Tensor] = []
        for i in range(b):
            mask = padded.instance_valid[i]
            boxes.append(
                tv_tensors.BoundingBoxes(  # pyright: ignore[reportCallIssue]
                    padded.boxes[i][mask],
                    format=tv_tensors.BoundingBoxFormat.XYXY,
                    canvas_size=(h, w),
                )
            )
            labels.append(padded.labels[i][mask])

        instance_masks: list[InstanceMask] | None = None
        instance_ids: list[torch.Tensor] | None = None
        if padded.instance_masks is not None and padded.instance_ids is not None:
            instance_masks = []
            instance_ids = []
            for i in range(b):
                mask = padded.instance_valid[i]
                instance_masks.append(InstanceMask(padded.instance_masks[i][mask]))
                instance_ids.append(padded.instance_ids[i][mask])

        camera_intrinsics: list[CameraIntrinsics] | None = None
        if padded.camera_intrinsics is not None:
            rows = cast(list[list[float]], padded.camera_intrinsics.tolist())
            camera_intrinsics = [
                CameraIntrinsics(fx=row[0], fy=row[1], cx=row[2], cy=row[3])
                for row in rows
            ]

        return BatchedDenseSample(
            images=padded.images,
            boxes=boxes,
            labels=labels,
            instance_masks=instance_masks,
            instance_ids=instance_ids,
            semantic_maps=padded.semantic_maps,
            panoptic_maps=padded.panoptic_maps,
            depth=padded.depth,
            depth_valid=padded.depth_valid,
            normals=padded.normals,
            padding_mask=padded.padding_mask,
            camera_intrinsics=camera_intrinsics,
        )

    def to_samples(self) -> list[DenseSample]:
        """Unstack back into per-sample :class:`DenseSample` objects."""
        images = self.images.as_subclass(torch.Tensor)
        out: list[DenseSample] = []
        for i in range(self.batch_size):
            fields_dict: dict[str, Any] = {
                "image": tv_tensors.Image(images[i]),
                "boxes": self.boxes[i],
                "labels": self.labels[i],
            }
            if self.instance_masks is not None and self.instance_ids is not None:
                fields_dict["instance_masks"] = self.instance_masks[i]
                fields_dict["instance_ids"] = self.instance_ids[i]
            if self.semantic_maps is not None:
                fields_dict["semantic_map"] = SemanticMap(self.semantic_maps[i])
            if self.panoptic_maps is not None:
                fields_dict["panoptic_map"] = PanopticMap(self.panoptic_maps[i])
            if self.depth is not None:
                fields_dict["depth"] = self.depth[i]
            if self.depth_valid is not None:
                fields_dict["depth_valid"] = self.depth_valid[i]
            if self.normals is not None:
                fields_dict["normals"] = self.normals[i]
            if self.padding_mask is not None:
                fields_dict["padding_mask"] = PaddingMask(self.padding_mask[i])
            if self.camera_intrinsics is not None:
                fields_dict["camera_intrinsics"] = self.camera_intrinsics[i]
            out.append(DenseSample(**fields_dict))
        return out

from_padded(padded) staticmethod

Unpack a :class:PaddedBatchedDenseSample into a ragged batch.

Uses instance_valid as the per-sample gather mask. Reconstructs tv_tensors.BoundingBoxes in XYXY format (the DenseSample canonical convention) and unpacks the [B, 4] intrinsics tensor back into :class:CameraIntrinsics instances.

Source code in src/segpaste/types/batched_dense_sample.py
@staticmethod
def from_padded(padded: PaddedBatchedDenseSample) -> "BatchedDenseSample":
    """Unpack a :class:`PaddedBatchedDenseSample` into a ragged batch.

    Uses ``instance_valid`` as the per-sample gather mask. Reconstructs
    ``tv_tensors.BoundingBoxes`` in XYXY format (the DenseSample canonical
    convention) and unpacks the ``[B, 4]`` intrinsics tensor back into
    :class:`CameraIntrinsics` instances.
    """
    b = padded.batch_size
    h, w = padded.images.shape[-2:]

    boxes: list[tv_tensors.BoundingBoxes] = []
    labels: list[torch.Tensor] = []
    for i in range(b):
        mask = padded.instance_valid[i]
        boxes.append(
            tv_tensors.BoundingBoxes(  # pyright: ignore[reportCallIssue]
                padded.boxes[i][mask],
                format=tv_tensors.BoundingBoxFormat.XYXY,
                canvas_size=(h, w),
            )
        )
        labels.append(padded.labels[i][mask])

    instance_masks: list[InstanceMask] | None = None
    instance_ids: list[torch.Tensor] | None = None
    if padded.instance_masks is not None and padded.instance_ids is not None:
        instance_masks = []
        instance_ids = []
        for i in range(b):
            mask = padded.instance_valid[i]
            instance_masks.append(InstanceMask(padded.instance_masks[i][mask]))
            instance_ids.append(padded.instance_ids[i][mask])

    camera_intrinsics: list[CameraIntrinsics] | None = None
    if padded.camera_intrinsics is not None:
        rows = cast(list[list[float]], padded.camera_intrinsics.tolist())
        camera_intrinsics = [
            CameraIntrinsics(fx=row[0], fy=row[1], cx=row[2], cy=row[3])
            for row in rows
        ]

    return BatchedDenseSample(
        images=padded.images,
        boxes=boxes,
        labels=labels,
        instance_masks=instance_masks,
        instance_ids=instance_ids,
        semantic_maps=padded.semantic_maps,
        panoptic_maps=padded.panoptic_maps,
        depth=padded.depth,
        depth_valid=padded.depth_valid,
        normals=padded.normals,
        padding_mask=padded.padding_mask,
        camera_intrinsics=camera_intrinsics,
    )

from_samples(samples) staticmethod

Stack a list of :class:DenseSample into a :class:BatchedDenseSample.

All samples must share the same active modality set and the same (H, W). B == 0 yields an empty-but-valid batch.

Source code in src/segpaste/types/batched_dense_sample.py
@staticmethod
def from_samples(samples: list[DenseSample]) -> "BatchedDenseSample":
    """Stack a list of :class:`DenseSample` into a :class:`BatchedDenseSample`.

    All samples must share the same active modality set and the same
    ``(H, W)``. ``B == 0`` yields an empty-but-valid batch.
    """
    if not samples:
        return _empty_batch()

    active = samples[0].active_modalities()
    h, w = samples[0].image.shape[-2:]
    for s in samples[1:]:
        if s.active_modalities() != active:
            raise ValueError("all samples must share the same active modality set")
        if s.image.shape[-2:] != (h, w):
            raise ValueError("all samples must share (H, W)")

    images = tv_tensors.Image(
        torch.stack([s.image.as_subclass(torch.Tensor) for s in samples])
    )
    boxes = [s.boxes for s in samples]
    labels = [s.labels for s in samples]

    instance_masks: list[InstanceMask] | None = None
    instance_ids: list[torch.Tensor] | None = None
    if Modality.INSTANCE in active:
        # Co-optionality on DenseSample makes None structurally impossible here.
        instance_masks = cast(
            list[InstanceMask], [s.instance_masks for s in samples]
        )
        instance_ids = cast(list[torch.Tensor], [s.instance_ids for s in samples])

    semantic_maps = _stack_optional(samples, "semantic_map", wrapper=SemanticMap)
    panoptic_maps = _stack_optional(samples, "panoptic_map", wrapper=PanopticMap)
    depth = _stack_optional(samples, "depth")
    depth_valid = _stack_optional(samples, "depth_valid")
    normals = _stack_optional(samples, "normals")
    padding_mask = _stack_optional(samples, "padding_mask", wrapper=PaddingMask)

    intrinsics = [s.camera_intrinsics for s in samples]
    camera_intrinsics = (
        cast(list[CameraIntrinsics], intrinsics)
        if all(i is not None for i in intrinsics)
        else None
    )

    return BatchedDenseSample(
        images=images,
        boxes=boxes,
        labels=labels,
        instance_masks=instance_masks,
        instance_ids=instance_ids,
        semantic_maps=semantic_maps,
        panoptic_maps=panoptic_maps,
        depth=depth,
        depth_valid=depth_valid,
        normals=normals,
        padding_mask=padding_mask,
        camera_intrinsics=camera_intrinsics,
    )

to_padded(max_instances)

Pack ragged per-sample instance fields into K-padded tensors.

Valid rows are written at slots [0, n_i) for each sample i and marked True in instance_valid. Padded rows are zero-valued. Raises if any sample has more than max_instances objects.

Source code in src/segpaste/types/batched_dense_sample.py
def to_padded(self, max_instances: int) -> PaddedBatchedDenseSample:
    """Pack ragged per-sample instance fields into K-padded tensors.

    Valid rows are written at slots ``[0, n_i)`` for each sample ``i`` and
    marked ``True`` in ``instance_valid``. Padded rows are zero-valued.
    Raises if any sample has more than ``max_instances`` objects.
    """
    b = self.batch_size
    k = max_instances
    device = self.images.device

    if b > 0:
        box_dtype = self.boxes[0].as_subclass(torch.Tensor).dtype
        label_dtype = self.labels[0].dtype
    else:
        box_dtype = torch.float32
        label_dtype = torch.int64

    boxes_padded = torch.zeros((b, k, 4), dtype=box_dtype, device=device)
    labels_padded = torch.zeros((b, k), dtype=label_dtype, device=device)
    instance_valid = torch.zeros((b, k), dtype=torch.bool, device=device)

    for i in range(b):
        n = self.boxes[i].size(0)
        if n > k:
            raise ValueError(
                f"sample {i} has {n} instances, exceeds max_instances={k}"
            )
        if n > 0:
            boxes_padded[i, :n] = self.boxes[i].as_subclass(torch.Tensor)
            labels_padded[i, :n] = self.labels[i]
            instance_valid[i, :n] = True

    instance_masks_padded: torch.Tensor | None = None
    instance_ids_padded: torch.Tensor | None = None
    if self.instance_masks is not None and self.instance_ids is not None:
        h, w = self.images.shape[-2:]
        instance_masks_padded = torch.zeros(
            (b, k, h, w), dtype=torch.bool, device=device
        )
        instance_ids_padded = torch.zeros((b, k), dtype=torch.int32, device=device)
        for i in range(b):
            n = self.instance_masks[i].size(0)
            if n > 0:
                instance_masks_padded[i, :n] = self.instance_masks[i].as_subclass(
                    torch.Tensor
                )
                instance_ids_padded[i, :n] = self.instance_ids[i]

    camera_intrinsics_tensor: torch.Tensor | None = None
    if self.camera_intrinsics is not None:
        camera_intrinsics_tensor = torch.tensor(
            [[c.fx, c.fy, c.cx, c.cy] for c in self.camera_intrinsics],
            dtype=torch.float32,
            device=device,
        )

    return PaddedBatchedDenseSample(
        images=self.images,
        boxes=boxes_padded,
        labels=labels_padded,
        instance_valid=instance_valid,
        max_instances=k,
        instance_masks=instance_masks_padded,
        instance_ids=instance_ids_padded,
        semantic_maps=self.semantic_maps,
        panoptic_maps=self.panoptic_maps,
        depth=self.depth,
        depth_valid=self.depth_valid,
        normals=self.normals,
        padding_mask=self.padding_mask,
        camera_intrinsics=camera_intrinsics_tensor,
    )

to_samples()

Unstack back into per-sample :class:DenseSample objects.

Source code in src/segpaste/types/batched_dense_sample.py
def to_samples(self) -> list[DenseSample]:
    """Unstack back into per-sample :class:`DenseSample` objects."""
    images = self.images.as_subclass(torch.Tensor)
    out: list[DenseSample] = []
    for i in range(self.batch_size):
        fields_dict: dict[str, Any] = {
            "image": tv_tensors.Image(images[i]),
            "boxes": self.boxes[i],
            "labels": self.labels[i],
        }
        if self.instance_masks is not None and self.instance_ids is not None:
            fields_dict["instance_masks"] = self.instance_masks[i]
            fields_dict["instance_ids"] = self.instance_ids[i]
        if self.semantic_maps is not None:
            fields_dict["semantic_map"] = SemanticMap(self.semantic_maps[i])
        if self.panoptic_maps is not None:
            fields_dict["panoptic_map"] = PanopticMap(self.panoptic_maps[i])
        if self.depth is not None:
            fields_dict["depth"] = self.depth[i]
        if self.depth_valid is not None:
            fields_dict["depth_valid"] = self.depth_valid[i]
        if self.normals is not None:
            fields_dict["normals"] = self.normals[i]
        if self.padding_mask is not None:
            fields_dict["padding_mask"] = PaddingMask(self.padding_mask[i])
        if self.camera_intrinsics is not None:
            fields_dict["camera_intrinsics"] = self.camera_intrinsics[i]
        out.append(DenseSample(**fields_dict))
    return out

CameraIntrinsics dataclass

Pinhole camera intrinsics in pixel coordinates.

Required on a :class:DenseSample when any composite is constructed with metric_depth=True.

Source code in src/segpaste/types/dense_sample.py
@dataclass(frozen=True, slots=True)
class CameraIntrinsics:
    """Pinhole camera intrinsics in pixel coordinates.

    Required on a :class:`DenseSample` when any composite is constructed with
    ``metric_depth=True``.
    """

    fx: float
    fy: float
    cx: float
    cy: float

DenseSample dataclass

Canonical per-sample container for dense-label Copy-Paste.

Modality-specific fields are None when their modality is not active. Use :meth:active_modalities to derive the active set.

Source code in src/segpaste/types/dense_sample.py
@dataclass(frozen=True, slots=True)
class DenseSample:
    """Canonical per-sample container for dense-label Copy-Paste.

    Modality-specific fields are ``None`` when their modality is not active.
    Use :meth:`active_modalities` to derive the active set.
    """

    image: tv_tensors.Image  # [C, H, W]
    boxes: tv_tensors.BoundingBoxes  # [N, 4], xyxy
    labels: torch.Tensor  # [N], int64
    # [N] int32; co-optional with instance_masks
    instance_ids: torch.Tensor | None = None
    instance_masks: InstanceMask | None = None  # [N, H, W], bool
    semantic_map: SemanticMap | None = None  # [H, W], int64
    panoptic_map: PanopticMap | None = None  # [H, W], int64
    depth: torch.Tensor | None = None  # [1, H, W], float32
    depth_valid: torch.Tensor | None = None  # [1, H, W], bool
    normals: torch.Tensor | None = None  # [3, H, W], float32
    padding_mask: PaddingMask | None = None  # [1, H, W], bool
    camera_intrinsics: CameraIntrinsics | None = None
    metric_depth: bool = False

    @skip_if_compiling
    def __post_init__(self) -> None:
        h, w = self.image.shape[-2:]

        if self.boxes.size(0) != self.labels.size(0):
            raise ValueError("boxes and labels must have same number of objects")

        if self.instance_masks is not None:
            if self.instance_masks.size(0) != self.boxes.size(0):
                raise ValueError(
                    "instance_masks and boxes must have same number of objects"
                )
            if self.instance_masks.shape[-2:] != (h, w):
                raise ValueError("instance_masks must share H, W with image")
            if self.instance_ids is None:
                raise ValueError(
                    "instance_ids must be provided when instance_masks is set"
                )
            if self.instance_ids.dtype != torch.int32:
                raise ValueError("instance_ids dtype must be int32")
            if self.instance_ids.shape != (self.boxes.size(0),):
                raise ValueError("instance_ids must be shape [N] matching boxes count")
        elif self.instance_ids is not None:
            raise ValueError("instance_ids requires instance_masks (co-optional)")

        if self.semantic_map is not None and self.semantic_map.shape[-2:] != (h, w):
            raise ValueError("semantic_map must share H, W with image")

        if self.panoptic_map is not None and self.panoptic_map.shape[-2:] != (h, w):
            raise ValueError("panoptic_map must share H, W with image")

        if self.depth is not None and self.depth.shape[-2:] != (h, w):
            raise ValueError("depth must share H, W with image")

        if self.depth_valid is not None and self.depth_valid.shape[-2:] != (h, w):
            raise ValueError("depth_valid must share H, W with image")

        if self.normals is not None and self.normals.shape[-2:] != (h, w):
            raise ValueError("normals must share H, W with image")

        if self.padding_mask is not None and self.padding_mask.shape[1:] != (h, w):
            raise ValueError("padding_mask must share H, W with image")

        # Depth consistency: both fields together, or neither.
        if (self.depth is None) ^ (self.depth_valid is None):
            raise ValueError("depth and depth_valid must both be set or both be None")

        # Metric depth requires calibrated intrinsics (ADR-0007 §1).
        if (
            self.metric_depth
            and self.depth is not None
            and self.camera_intrinsics is None
        ):
            raise ValueError(
                "metric_depth=True requires camera_intrinsics when depth is set"
            )

    def active_modalities(self) -> set[Modality]:
        """Return the set of active modalities for this sample."""
        active: set[Modality] = {Modality.IMAGE}
        if self.instance_masks is not None:
            active.add(Modality.INSTANCE)
        if self.semantic_map is not None:
            active.add(Modality.SEMANTIC)
        if self.panoptic_map is not None:
            active.add(Modality.PANOPTIC)
        if self.depth is not None:
            active.add(Modality.DEPTH)
        if self.normals is not None:
            active.add(Modality.NORMALS)
        return active

    def to_dict(self) -> dict[str, Any]:
        """Round-trippable dict representation. Omits ``None`` fields."""
        return {
            f.name: value
            for f in fields(self)
            if (value := getattr(self, f.name)) is not None
        }

    @staticmethod
    def from_dict(data: Mapping[str, Any]) -> "DenseSample":
        names = {f.name for f in fields(DenseSample)}
        return DenseSample(**{k: v for k, v in data.items() if k in names})

active_modalities()

Return the set of active modalities for this sample.

Source code in src/segpaste/types/dense_sample.py
def active_modalities(self) -> set[Modality]:
    """Return the set of active modalities for this sample."""
    active: set[Modality] = {Modality.IMAGE}
    if self.instance_masks is not None:
        active.add(Modality.INSTANCE)
    if self.semantic_map is not None:
        active.add(Modality.SEMANTIC)
    if self.panoptic_map is not None:
        active.add(Modality.PANOPTIC)
    if self.depth is not None:
        active.add(Modality.DEPTH)
    if self.normals is not None:
        active.add(Modality.NORMALS)
    return active

to_dict()

Round-trippable dict representation. Omits None fields.

Source code in src/segpaste/types/dense_sample.py
def to_dict(self) -> dict[str, Any]:
    """Round-trippable dict representation. Omits ``None`` fields."""
    return {
        f.name: value
        for f in fields(self)
        if (value := getattr(self, f.name)) is not None
    }

FixedSizeCrop

Bases: Transform

Source code in src/segpaste/augmentation/lsj.py
class FixedSizeCrop(Transform):
    def __init__(
        self,
        output_height: int,
        output_width: int,
        img_pad_value: float | int = 0,
        seg_pad_value: int = 255,
    ) -> None:
        """Crops the given image to a fixed size.

        Args:
            output_height (int): Desired output height.
            output_width (int): Desired output width.
        """
        super().__init__()
        self.output_height = output_height
        self.output_width = output_width

        self.img_pad_value = img_pad_value
        self.seg_pad_value = seg_pad_value

    def make_params(
        self, flat_inputs: list[tv_tensors.TVTensor | torch.Tensor]
    ) -> dict[str, int]:
        inpt_h, inpt_w = query_size(flat_inputs)

        offset_top = round(random.randint(0, max(0, inpt_h - self.output_height)))
        offset_left = round(random.randint(0, max(0, inpt_w - self.output_width)))

        return {"offset_top": offset_top, "offset_left": offset_left}

    def transform(
        self, inpt: tv_tensors.TVTensor | torch.Tensor, params: dict[str, int]
    ) -> Any:
        h, w = F.get_size(inpt)
        cropped = self._call_kernel(
            F.crop,
            inpt,
            top=params["offset_top"],
            left=params["offset_left"],
            height=self.output_height,
            width=self.output_width,
        )
        pad_value = self._pad_value_for(cropped)
        if pad_value is not None:
            if h < self.output_height:
                cropped[..., h:, :] = pad_value
            if w < self.output_width:
                cropped[..., :, w:] = pad_value

        return cropped

    def _pad_value_for(
        self, cropped: tv_tensors.TVTensor | torch.Tensor
    ) -> bool | int | float | None:
        if isinstance(cropped, tv_tensors.Image | tv_tensors.Video):
            return self.img_pad_value
        if isinstance(cropped, PaddingMask):
            return self.seg_pad_value
        if isinstance(cropped, tv_tensors.Mask):
            # Bool per-instance masks must pad False; 255 coerces to True and
            # would absorb the whole pad band into every instance.
            return False if cropped.dtype == torch.bool else self.seg_pad_value
        return None

__init__(output_height, output_width, img_pad_value=0, seg_pad_value=255)

Crops the given image to a fixed size.

Parameters:

Name Type Description Default
output_height int

Desired output height.

required
output_width int

Desired output width.

required
Source code in src/segpaste/augmentation/lsj.py
def __init__(
    self,
    output_height: int,
    output_width: int,
    img_pad_value: float | int = 0,
    seg_pad_value: int = 255,
) -> None:
    """Crops the given image to a fixed size.

    Args:
        output_height (int): Desired output height.
        output_width (int): Desired output width.
    """
    super().__init__()
    self.output_height = output_height
    self.output_width = output_width

    self.img_pad_value = img_pad_value
    self.seg_pad_value = seg_pad_value

InstanceBank

Bases: Protocol

Read-only sequence of class-labeled instance crops.

Concrete backends (MemmapBank, LMDBBank, WebDatasetBank) live under :mod:segpaste._internal.bank until promotion. The Protocol is the only public name; users construct backends via the scripts/build_instance_bank.py CLI (PR5) or import the backend class directly from segpaste._internal.bank.

Implementations must be safe to call from DataLoader workers — i.e. re-entrant after __init__ and free of un-pickle-able state — so num_workers > 0 is supported.

Source code in src/segpaste/_internal/bank/protocol.py
@runtime_checkable
class InstanceBank(Protocol):
    """Read-only sequence of class-labeled instance crops.

    Concrete backends (``MemmapBank``, ``LMDBBank``, ``WebDatasetBank``)
    live under :mod:`segpaste._internal.bank` until promotion. The
    Protocol is the only public name; users construct backends via the
    ``scripts/build_instance_bank.py`` CLI (PR5) or import the backend
    class directly from ``segpaste._internal.bank``.

    Implementations must be safe to call from DataLoader workers — i.e.
    re-entrant after ``__init__`` and free of un-pickle-able state — so
    ``num_workers > 0`` is supported.
    """

    def __len__(self) -> int: ...
    def __getitem__(self, idx: int) -> BankCrop: ...

    @property
    def class_frequencies(self) -> torch.Tensor:
        """``int64 [num_classes]`` count of crops per class (zero-indexed)."""
        ...

    @property
    def crop_class_ids(self) -> torch.Tensor:
        """``int64 [N]`` class id per crop. Zero-copy on memmap backends;
        loaded once at open. Lets :class:`BankSampler` build per-crop
        weights without an O(N) pass through ``__getitem__``."""
        ...

    @property
    def crop_size(self) -> tuple[int, int]:
        """``(h, w)`` after preprocessing — the same for every crop."""
        ...

    @property
    def has_embeddings(self) -> bool:
        """Whether ``BankCrop.embedding`` is populated for every crop."""
        ...

    @property
    def version(self) -> str:
        """Stable ``{format}@{sha256[:12]}`` identifier for cache keys."""
        ...

class_frequencies property

int64 [num_classes] count of crops per class (zero-indexed).

crop_class_ids property

int64 [N] class id per crop. Zero-copy on memmap backends; loaded once at open. Lets :class:BankSampler build per-crop weights without an O(N) pass through __getitem__.

crop_size property

(h, w) after preprocessing — the same for every crop.

has_embeddings property

Whether BankCrop.embedding is populated for every crop.

version property

Stable {format}@{sha256[:12]} identifier for cache keys.

InstanceMask

Bases: Mask

Per-instance binary masks. Shape [N, H, W], dtype bool.

Source code in src/segpaste/types/dense_sample.py
class InstanceMask(Mask):
    """Per-instance binary masks. Shape [N, H, W], dtype bool."""

    if TYPE_CHECKING:

        def __new__(cls, data: Any, **kwargs: Any) -> "InstanceMask": ...

IntraBatchSource

Bases: Module

v0.3.0-equivalent source: sample sources from the same batch.

Wraps :class:BatchedPlacementSampler and returns target itself as the source view. The diagonal-masked multinomial inside the sampler guarantees source_idx[i] != i for B > 1. Default constructor matches v0.3.0 defaults; pass a :class:BatchedPlacementConfig for non-default placement parameters.

Source code in src/segpaste/augmentation/source.py
class IntraBatchSource(nn.Module):
    """v0.3.0-equivalent source: sample sources from the same batch.

    Wraps :class:`BatchedPlacementSampler` and returns ``target`` itself as
    the source view. The diagonal-masked multinomial inside the sampler
    guarantees ``source_idx[i] != i`` for ``B > 1``. Default constructor
    matches v0.3.0 defaults; pass a :class:`BatchedPlacementConfig` for
    non-default placement parameters.
    """

    def __init__(self, config: BatchedPlacementConfig | None = None) -> None:
        super().__init__()
        self.placement_sampler = BatchedPlacementSampler(config)

    def sample(
        self,
        target: PaddedBatchedDenseSample,
        valid_extent: Tensor | None,
        source_eligible: Tensor | None,
        generator: torch.Generator | None,
    ) -> tuple[PaddedBatchedDenseSample, BatchedPlacement]:
        placement = self.placement_sampler(
            target,
            generator,
            valid_extent=valid_extent,
            source_eligible=source_eligible,
        )
        return target, placement

Modality

Bases: Enum

Dense-sample modalities. IMAGE is always active; the others gate fields.

Source code in src/segpaste/types/dense_sample.py
class Modality(Enum):
    """Dense-sample modalities. IMAGE is always active; the others gate fields."""

    IMAGE = "image"
    INSTANCE = "instance"
    PANOPTIC = "panoptic"
    SEMANTIC = "semantic"
    DEPTH = "depth"
    NORMALS = "normals"

PaddedBatchedDenseSample dataclass

Fully-rectangular batched container for graph-compilable augmentation.

Source code in src/segpaste/types/padded_batched_dense_sample.py
@dataclass(frozen=True, slots=True)
class PaddedBatchedDenseSample:
    """Fully-rectangular batched container for graph-compilable augmentation."""

    images: tv_tensors.Image  # [B, C, H, W]
    boxes: torch.Tensor  # [B, K, 4] float, xyxy; invalid rows zeroed.
    labels: torch.Tensor  # [B, K] int64; invalid rows zeroed.
    instance_valid: torch.Tensor  # [B, K] bool
    max_instances: int
    instance_masks: torch.Tensor | None = None  # [B, K, H, W] bool
    instance_ids: torch.Tensor | None = None  # [B, K] int32
    semantic_maps: SemanticMap | None = None  # [B, H, W] int64
    panoptic_maps: PanopticMap | None = None  # [B, H, W] int64
    depth: torch.Tensor | None = None  # [B, 1, H, W] float32
    depth_valid: torch.Tensor | None = None  # [B, 1, H, W] bool
    normals: torch.Tensor | None = None  # [B, 3, H, W] float32
    padding_mask: PaddingMask | None = None  # [B, 1, H, W] bool
    camera_intrinsics: torch.Tensor | None = None  # [B, 4] float32 (fx, fy, cx, cy)

    @skip_if_compiling
    def __post_init__(self) -> None:
        b = self.images.size(0)
        k = self.max_instances
        if self.images.ndim != 4:
            raise ValueError("images must have rank 4 [B, C, H, W]")
        h, w = self.images.shape[-2:]

        if self.boxes.shape != (b, k, 4):
            raise ValueError("boxes must have shape [B, K, 4]")
        if self.labels.shape != (b, k):
            raise ValueError("labels must have shape [B, K]")
        if self.instance_valid.shape != (b, k):
            raise ValueError("instance_valid must have shape [B, K]")
        if self.instance_valid.dtype != torch.bool:
            raise ValueError("instance_valid dtype must be bool")

        if (self.instance_masks is None) ^ (self.instance_ids is None):
            raise ValueError(
                "instance_masks and instance_ids must both be set or both None"
            )
        if self.instance_masks is not None and self.instance_ids is not None:
            if self.instance_masks.shape != (b, k, h, w):
                raise ValueError("instance_masks must have shape [B, K, H, W]")
            if self.instance_masks.dtype != torch.bool:
                raise ValueError("instance_masks dtype must be bool")
            if self.instance_ids.shape != (b, k):
                raise ValueError("instance_ids must have shape [B, K]")
            if self.instance_ids.dtype != torch.int32:
                raise ValueError("instance_ids dtype must be int32")

        if (self.depth is None) ^ (self.depth_valid is None):
            raise ValueError("depth and depth_valid must both be set or both None")

        stacked_shape_checks = (
            ("semantic_maps", self.semantic_maps, 3),
            ("panoptic_maps", self.panoptic_maps, 3),
            ("depth", self.depth, 4),
            ("depth_valid", self.depth_valid, 4),
            ("normals", self.normals, 4),
            ("padding_mask", self.padding_mask, 4),
        )
        for name, tensor, expected_rank in stacked_shape_checks:
            if tensor is None:
                continue
            if tensor.ndim != expected_rank:
                raise ValueError(f"{name} must have rank {expected_rank}")
            if tensor.size(0) != b:
                raise ValueError(f"{name} must have batch dim {b}")
            if tensor.shape[-2:] != (h, w):
                raise ValueError(f"{name} must share (H, W) with images")

        if self.camera_intrinsics is not None and self.camera_intrinsics.shape != (
            b,
            4,
        ):
            raise ValueError("camera_intrinsics must have shape [B, 4]")

    @property
    def batch_size(self) -> int:
        return self.images.size(0)

PaddingMask

Bases: Mask

Unlike tv_tensor.Mask, PaddingMask is not associated with any object.

It is used to indicate padded parts of an Image. Unlike tv_tensor.Mask, it is is forwarded unchanged by this package reimplementation SanitizeBoundingBoxes.

Source code in src/segpaste/types/data_structures.py
class PaddingMask(Mask):
    """Unlike tv_tensor.Mask, PaddingMask is not associated with any object.

    It is used to indicate padded parts of an Image. Unlike tv_tensor.Mask, it is
    is forwarded unchanged by this package reimplementation SanitizeBoundingBoxes.
    """

    @classmethod
    def from_tensor(cls, data: torch.Tensor) -> "PaddingMask":
        """Wrap a bool tensor as :class:`PaddingMask` with the static type preserved.

        ``Mask.__new__`` is annotated to return ``Mask``; this factory exists
        purely to recover ``PaddingMask`` typing without smearing ``cast`` at
        every call site.
        """
        return cast(PaddingMask, cls(data))

from_tensor(data) classmethod

Wrap a bool tensor as :class:PaddingMask with the static type preserved.

Mask.__new__ is annotated to return Mask; this factory exists purely to recover PaddingMask typing without smearing cast at every call site.

Source code in src/segpaste/types/data_structures.py
@classmethod
def from_tensor(cls, data: torch.Tensor) -> "PaddingMask":
    """Wrap a bool tensor as :class:`PaddingMask` with the static type preserved.

    ``Mask.__new__`` is annotated to return ``Mask``; this factory exists
    purely to recover ``PaddingMask`` typing without smearing ``cast`` at
    every call site.
    """
    return cast(PaddingMask, cls(data))

PanopticMap

Bases: Mask

Per-pixel panoptic id encoding. Shape [H, W], dtype int64.

Source code in src/segpaste/types/dense_sample.py
class PanopticMap(Mask):
    """Per-pixel panoptic id encoding. Shape [H, W], dtype int64."""

    if TYPE_CHECKING:

        def __new__(cls, data: Any, **kwargs: Any) -> "PanopticMap": ...

PanopticSchema

Bases: Protocol

Panoptic class taxonomy, passed explicitly at composite construction.

Source code in src/segpaste/types/dense_sample.py
@runtime_checkable
class PanopticSchema(Protocol):
    """Panoptic class taxonomy, passed explicitly at composite construction."""

    classes: Mapping[int, Literal["thing", "stuff"]]
    ignore_index: int
    max_instances_per_image: int

PresetConfig

Bases: _FrozenStrict

A registered dataset preset (ADR-0009 §3).

Field additions are allowed (additive-only per ADR-0001 Part (i)); renames or removals are breaking.

Source code in src/segpaste/presets/_base.py
class PresetConfig(_FrozenStrict):
    """A registered dataset preset (ADR-0009 §3).

    Field additions are allowed (additive-only per ADR-0001 Part (i));
    renames or removals are breaking.
    """

    name: str
    """Stable identifier; matches the registry key."""

    description: str
    """One-paragraph human-readable rationale."""

    batch_copy_paste: BatchCopyPasteConfig = Field(default_factory=BatchCopyPasteConfig)
    """The augmentation hyperparameters this preset pins."""

    target_modalities: tuple[Modality, ...]
    """Dense-sample modalities this preset expects to see."""

    sign_off: SignOff | None = None
    """Audit trail for the local sign-off ritual (ADR-0009 §5)."""

batch_copy_paste = Field(default_factory=BatchCopyPasteConfig) class-attribute instance-attribute

The augmentation hyperparameters this preset pins.

description instance-attribute

One-paragraph human-readable rationale.

name instance-attribute

Stable identifier; matches the registry key.

sign_off = None class-attribute instance-attribute

Audit trail for the local sign-off ritual (ADR-0009 §5).

target_modalities instance-attribute

Dense-sample modalities this preset expects to see.

RandomResize

Bases: Transform

Source code in src/segpaste/augmentation/lsj.py
class RandomResize(Transform):
    def __init__(
        self,
        min_scale: float,
        max_scale: float,
        target_height: int,
        target_width: int,
    ) -> None:
        """Randomly resize the input image while preserving aspect ratio.

        The final size is obtained by scaling the target height and width with a random
        factor.
        """

        super().__init__()
        self.min_scale = min_scale
        self.max_scale = max_scale
        self.target_height = target_height
        self.target_width = target_width

    def make_params(
        self,
        flat_inputs: list[tv_tensors.TVTensor | torch.Tensor],  # noqa: ARG002
    ) -> dict[str, float]:
        scale = random.uniform(self.min_scale, self.max_scale)
        return {"scale": scale}

    def transform(
        self, inpt: tv_tensors.TVTensor | torch.Tensor, params: dict[str, float]
    ) -> Any:
        h, w = F.get_size(inpt)
        scale: float = params["scale"]

        target_scale_h, target_scale_w = (
            self.target_height * scale,
            self.target_width * scale,
        )
        output_scale = min(target_scale_h / h, target_scale_w / w)

        new_h, new_w = round(h * output_scale), round(w * output_scale)
        return self._call_kernel(F.resize, inpt, [new_h, new_w])

__init__(min_scale, max_scale, target_height, target_width)

Randomly resize the input image while preserving aspect ratio.

The final size is obtained by scaling the target height and width with a random factor.

Source code in src/segpaste/augmentation/lsj.py
def __init__(
    self,
    min_scale: float,
    max_scale: float,
    target_height: int,
    target_width: int,
) -> None:
    """Randomly resize the input image while preserving aspect ratio.

    The final size is obtained by scaling the target height and width with a random
    factor.
    """

    super().__init__()
    self.min_scale = min_scale
    self.max_scale = max_scale
    self.target_height = target_height
    self.target_width = target_width

SanitizeBoundingBoxes

Bases: SanitizeBoundingBoxes

Source code in src/segpaste/augmentation/lsj.py
class SanitizeBoundingBoxes(tv_SanitizeBoundingBoxes):
    def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
        """Unlike the original SanitizeBoundingBoxes, this transform can also handle
        PaddingMask and will forward them unchanged."""

        is_label = params["labels"] is not None and any(
            inpt is label for label in params["labels"]
        )
        if is_label:
            return inpt[params["valid"]]

        is_bounding_boxes_or_mask = isinstance(
            inpt, tv_tensors.BoundingBoxes | tv_tensors.Mask
        ) and not isinstance(inpt, PaddingMask)
        if not is_bounding_boxes_or_mask:
            return inpt

        return tv_tensors.wrap(inpt[params["valid"]], like=inpt)

transform(inpt, params)

Unlike the original SanitizeBoundingBoxes, this transform can also handle PaddingMask and will forward them unchanged.

Source code in src/segpaste/augmentation/lsj.py
def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
    """Unlike the original SanitizeBoundingBoxes, this transform can also handle
    PaddingMask and will forward them unchanged."""

    is_label = params["labels"] is not None and any(
        inpt is label for label in params["labels"]
    )
    if is_label:
        return inpt[params["valid"]]

    is_bounding_boxes_or_mask = isinstance(
        inpt, tv_tensors.BoundingBoxes | tv_tensors.Mask
    ) and not isinstance(inpt, PaddingMask)
    if not is_bounding_boxes_or_mask:
        return inpt

    return tv_tensors.wrap(inpt[params["valid"]], like=inpt)

SemanticMap

Bases: Mask

Per-pixel semantic class ids. Shape [H, W], dtype int64, ignore = 255.

Source code in src/segpaste/types/dense_sample.py
class SemanticMap(Mask):
    """Per-pixel semantic class ids. Shape [H, W], dtype int64, ignore = 255."""

    if TYPE_CHECKING:

        def __new__(cls, data: Any, **kwargs: Any) -> "SemanticMap": ...

SourceStrategy

Bases: Protocol

Picks the source view and per-target placement for one forward step.

Implementations may be nn.Module subclasses (so child modules and buffers register correctly) or plain callables — :func:runtime_checkable structural typing only requires sample. The return contract is fixed: source_view row-aligned with target along the batch dim, and placement.source_idx indexing into source_view.

Source code in src/segpaste/augmentation/source.py
@runtime_checkable
class SourceStrategy(Protocol):
    """Picks the source view and per-target placement for one forward step.

    Implementations may be ``nn.Module`` subclasses (so child modules and
    buffers register correctly) or plain callables — :func:`runtime_checkable`
    structural typing only requires ``sample``. The return contract is fixed:
    ``source_view`` row-aligned with ``target`` along the batch dim, and
    ``placement.source_idx`` indexing into ``source_view``.
    """

    def sample(
        self,
        target: PaddedBatchedDenseSample,
        valid_extent: Tensor | None,
        source_eligible: Tensor | None,
        generator: torch.Generator | None,
    ) -> tuple[PaddedBatchedDenseSample, BatchedPlacement]: ...

create_coco_dataloader(image_folder, label_path, transforms, batch_size=4, collate_fn=_identity_collate)

Create a COCO DataLoader preconfigured for segpaste pipelines.

Parameters:

Name Type Description Default
image_folder str

Directory containing the COCO image files.

required
label_path str

Path to the COCO JSON annotations file.

required
transforms Transform

Transform applied to each sample.

required
batch_size int

Batch size for the returned DataLoader.

4
collate_fn Any

Collate function; defaults to an identity collate that yields list[DenseSample] — wrap the result through :meth:BatchedDenseSample.from_samples.to_padded(K) → :class:BatchCopyPaste to apply augmentation.

_identity_collate

Returns:

Type Description
DataLoader[DenseSample]

A DataLoader yielding :class:DenseSample instances.

Source code in src/segpaste/integrations/coco.py
def create_coco_dataloader(
    image_folder: str,
    label_path: str,
    transforms: v2.Transform,
    batch_size: int = 4,
    collate_fn: Any = _identity_collate,
) -> torch.utils.data.DataLoader[DenseSample]:
    """Create a COCO DataLoader preconfigured for segpaste pipelines.

    Args:
        image_folder (str): Directory containing the COCO image files.
        label_path (str): Path to the COCO JSON annotations file.
        transforms (v2.Transform): Transform applied to each sample.
        batch_size (int): Batch size for the returned DataLoader.
        collate_fn: Collate function; defaults to an identity collate that
            yields ``list[DenseSample]`` — wrap the result through
            :meth:`BatchedDenseSample.from_samples` → ``.to_padded(K)`` →
            :class:`BatchCopyPaste` to apply augmentation.

    Returns:
        A DataLoader yielding :class:`DenseSample` instances.
    """

    dataset = CocoDetectionV2(
        image_folder=image_folder,
        label_path=label_path,
        transforms=transforms,
    )

    return torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=collate_fn,
    )

get_preset(name)

Return the registered preset for name.

Raises:

Type Description
KeyError

if name is not registered.

Source code in src/segpaste/presets/__init__.py
def get_preset(name: str) -> PresetConfig:
    """Return the registered preset for *name*.

    Raises:
        KeyError: if *name* is not registered.
    """
    if name not in _REGISTRY:
        raise KeyError(f"unknown preset {name!r}; registered: {list_presets()}")
    return _REGISTRY[name]

list_presets()

Sorted tuple of registered preset names. The result is a value snapshot.

Source code in src/segpaste/presets/__init__.py
def list_presets() -> tuple[str, ...]:
    """Sorted tuple of registered preset names. The result is a value snapshot."""
    return tuple(sorted(_REGISTRY))

make_large_scale_jittering(output_size, min_scale=0.1, max_scale=2.0, img_pad_value=0, seg_pad_value=255)

Factory function to create a LargeScaleJittering transform.

Parameters:

Name Type Description Default
output_size int or tuple

The desired output size (height, width) of the crop.

required
min_scale float

The minimum scale factor for resizing.

0.1
max_scale float

The maximum scale factor for resizing.

2.0
img_pad_value float or int

Fill value for image padding.

0
seg_pad_value int

Fill value for segmentation mask padding.

255

Returns:

Type Description
Transform

A Compose transform implementing Large Scale Jittering.

Source code in src/segpaste/augmentation/lsj.py
def make_large_scale_jittering(
    output_size: int | tuple[int, int],
    min_scale: float = 0.1,
    max_scale: float = 2.0,
    img_pad_value: float | int = 0,
    seg_pad_value: int = 255,
) -> Transform:
    """
    Factory function to create a LargeScaleJittering transform.

    Args:
        output_size (int or tuple): The desired output size (height, width) of the crop.
        min_scale (float): The minimum scale factor for resizing.
        max_scale (float): The maximum scale factor for resizing.
        img_pad_value (float or int): Fill value for image padding.
        seg_pad_value (int): Fill value for segmentation mask padding.

    Returns:
        A Compose transform implementing Large Scale Jittering.
    """
    if isinstance(output_size, int):
        output_size = (output_size, output_size)

    output_height, output_width = output_size

    return Compose(
        [
            RandomResize(
                min_scale=min_scale,
                max_scale=max_scale,
                target_height=output_height,
                target_width=output_width,
            ),
            FixedSizeCrop(
                output_height=output_height,
                output_width=output_width,
                img_pad_value=img_pad_value,
                seg_pad_value=seg_pad_value,
            ),
        ]
    )

register_preset(config)

Register config under config.name.

Raises:

Type Description
ValueError

if config.name is already registered.

Source code in src/segpaste/presets/__init__.py
def register_preset(config: PresetConfig) -> None:
    """Register *config* under ``config.name``.

    Raises:
        ValueError: if ``config.name`` is already registered.
    """
    if config.name in _REGISTRY:
        raise ValueError(f"preset {config.name!r} is already registered")
    _REGISTRY[config.name] = config