from collections.abc import Mapping, Sequence
import numpy as np
import ipywidgets
import brambox as bb
import pandas as pd
from brambox.util._visual import setup_boxes
from ._viewer import Viewer
from .._util import cast_alpha, box_to_coords, mask_to_coords
try:
import torch
except ImportError:
torch = None
__all__ = ['TorchViewer']
[docs]class TorchViewer(Viewer):
"""
This widget can visualize a PyTorch dataset as bounding boxes drawn on top of the images. |br|
Its arguments work a lot like :class:`ibb.BramboxViewer`.
Args:
data (torch.utils.data.Dataset):
PyTorch dataset that should return images and bounding boxes (see Note).
extract_data (calable):
Extract image and dataframe from dataset output; Default (First Tensor and DataFrame in Sequence/Mapping)
label (pandas.Series or callable):
Label to write above the boxes; Default **class_label (confidence)**
color (pandas.Series or callable):
Color to use for drawing; Default **every class_label will get its own color, up to 10 labels**
size (pandas.Series or callable):
Thickness of the border of the bounding boxes; Default **3**
alpha (pandas.Series or callable):
Alpha fill value of the bounding boxes; Default **00**
**kwargs (dict):
Extra keyword arguments that will be passed to :class:`~ibb.widgets.Viewer`
Note:
The `label`, `color`, `size` and `alpha` arguments can also be tacked on to the `boxes` dataframe as columns.
They can also be a single value, which will then be used for each bounding box. |br|
Basically, as long as you can assign the value as a new column to the dataframe, it will work.
Finally, if these values are callable, they get called with the boxes dataframe and should return a valid pandas series.
"""
def __init__(self, data, extract_data=None, label=True, color=None, size=3, alpha=0, **kwargs):
assert torch is not None, 'PyTorch is required for this widget'
self.data = data
self.extract_data = extract_data if callable(extract_data) else default_extract_data
self.columns = (label, color, size, alpha)
# Metadata
self.info = False
self.clicked = None
self.draw_box_max = 2
self.draw_box_text = ['none', 'box', 'mask']
self.draw_box = self.draw_box_max - 1
# Example dataframe for setup
_, _, boxes = self.get_data(0)
kwargs['_example_boxes'] = boxes
self.draw_box_max = 3 if 'segmentation' in boxes.columns else 2
# ImageCanvas arguments
if 'hover_style' not in kwargs:
kwargs['hover_style'] = {'alpha': .5}
if 'click_style' not in kwargs:
kwargs['click_style'] = {'size': boxes['size'].max() + 2}
# Widget init
if 'total' not in kwargs:
kwargs['total'] = len(self.data)
super().__init__(**kwargs)
# Setup handlers
self.main[0].observe(self.on_click, 'clicked')
self.main[0].observe(self.on_poly, 'polygons')
def __init_header__(self, kwargs):
w_btn_save = ipywidgets.Button(
icon='picture-o',
tooltip='save image',
)
w_btn_save.add_class('ibb-square-button')
w_btn_save.on_click(self.on_save)
w_btn_box = ipywidgets.Button(
icon='square-o',
tooltip=f'toggle none/box/mask [{self.draw_box_text[self.draw_box]}]',
)
w_btn_box.add_class('ibb-square-button')
w_btn_box.on_click(self.on_box)
w_btn_info = ipywidgets.Button(
icon='bars',
tooltip='toggle info pane',
)
w_btn_info.add_class('ibb-square-button')
w_btn_info.on_click(self.on_info)
return [*super().__init_header__(kwargs), ipywidgets.HBox([w_btn_save, w_btn_box, w_btn_info])]
def __init_main__(self, kwargs):
w_info_bar = ipywidgets.HTML(placeholder='info')
w_info_bar.add_class('ibb-infobar')
if not self.info:
w_info_bar.add_class('ibb-hide')
return [*super().__init_main__(kwargs), w_info_bar]
def __init_side__(self, kwargs):
self.conf_enabled = 'confidence' in kwargs['_example_boxes']
if not self.conf_enabled:
return []
w_conf_slider = ipywidgets.FloatSlider(
value=0, min=0, max=1, step=0.01,
orientation='vertical',
continuous_update=False,
readout=True,
readout_format='.0%',
tooltip='confidence threshold to filter objects',
)
w_conf_slider.add_class('ibb-conf-slider')
w_conf_slider.observe(self.on_threshold, 'value')
return [w_conf_slider]
def get_data(self, index):
img, boxes = self.extract_data(self.data[index])
# Image setup
if isinstance(img, torch.Tensor):
img = img.cpu().permute(1, 2, 0).contiguous()
img = np.asarray(img)
# Dataframe setup
label, color, size, alpha = self.columns
if callable(label):
label = label(boxes)
if callable(color):
color = color(boxes)
if callable(size):
size = size(boxes)
if callable(alpha):
alpha = alpha(boxes)
boxes = setup_boxes(boxes, label=label, color=color, size=size, alpha=alpha)
boxes.color = 'rgb' + boxes.color.astype(str)
boxes['alpha'] = boxes['alpha'].apply(cast_alpha)
boxes['boxcoords'] = boxes.apply(box_to_coords, axis=1)
if self.draw_box_max == 3:
boxes['maskcoords'] = boxes.apply(mask_to_coords, axis=1)
# Label
try:
img_names = boxes['image'].unique()
assert len(img_names) == 1
lbl = img_names[0]
except BaseException:
lbl = ''
return lbl, img, boxes
def draw_boxes(self, boxes):
if self.draw_box:
bboxes = boxes[['color', 'size', 'alpha']].copy()
coord_col = 'maskcoords' if self.draw_box == 2 else 'boxcoords'
bboxes['coords'] = boxes[coord_col].apply(lambda c: c.tolist())
bboxes['label'] = boxes['class_label']
if 'confidence' in boxes:
bboxes['label'] += boxes['confidence'].apply(lambda num: f' ({num:.2%})')
self.main[0].polygons = bboxes.to_dict('records')
else:
self.main[0].polygons = None
def on_index(self, change):
""" """
self.header[0].value, self.main[0].image, self.current_all_boxes = self.get_data(change['new'])
if self.conf_enabled:
self.current_boxes = self.current_all_boxes[self.current_all_boxes['confidence'] >= self.side[0].value]
else:
self.current_boxes = self.current_all_boxes
self.draw_boxes(self.current_boxes.copy())
def on_save(self, btn):
self.main[0].save = True
def on_box(self, btn):
self.draw_box = (self.draw_box + 1) % self.draw_box_max
btn.tooltip = f'toggle none/box/mask [{self.draw_box_text[self.draw_box]}]'
self.redraw()
def on_info(self, btn):
self.info = not self.info
if self.info:
self.main[-1].remove_class('ibb-hide')
else:
self.main[-1].add_class('ibb-hide')
def on_click(self, change):
clicked = change['new']
if clicked is None:
self.main[-1].value = ''
return
self.clicked = self.current_boxes.iloc[clicked]
columns = (
sorted(self.clicked.index.difference([
'image',
'color',
'size',
'label',
'alpha',
'fill',
'points',
'x_top_left', 'y_top_left', 'width', 'height',
'boxcoords', 'maskcoords',
])) +
['x_top_left', 'y_top_left', 'width', 'height']
)
s = '<table>'
for col in columns:
if col == 'segmentation':
numcoords = len(self.clicked[col].exterior.coords) if hasattr(self.clicked[col], 'exterior') else len(self.clicked[col].coords)
s += f'<tr><td>{col}</td><td>{type(self.clicked[col]).__name__} ({numcoords - 1})</td></tr>'
else:
s += f'<tr><td>{col}</td><td>{self.clicked[col]}</td></tr>'
s += '</table>'
self.main[-1].value = s
def on_poly(self, change):
if self.clicked is None or change['new'] is None:
return
# Option 1 : Same object (keep clicked when toggling box/mask)
index = self.clicked.name
if index in self.current_boxes.index:
self.main[0].clicked = self.current_boxes.index.get_loc(index)
self.clicked = self.current_boxes.loc[index]
return
# Option 2 : Object with same class and id (useful in tracking context)
label = self.clicked['class_label']
id = self.clicked['id']
new_clicked = (self.current_boxes['class_label'] == label) & (self.current_boxes['id'] == id)
if new_clicked.any():
index = int(new_clicked.values.argmax())
self.main[0].clicked = index
self.clicked = self.current_boxes.iloc[index]
return
# Default: Reset clicked
self.clicked = None
def on_threshold(self, change):
self.redraw()
def default_extract_data(output):
if isinstance(output, torch.Tensor):
return output, bb.util.new('anno')
if not isinstance(output, (Sequence, Mapping)):
raise TypeError(f'Unkown Dataset output: {type(output)}')
if isinstance(output, Mapping):
output = output.values()
img, anno = None, None
for o in output:
if img is None and isinstance(o, torch.Tensor):
img = o.clone()
if anno is not None:
break
elif anno is None and isinstance(o, pd.DataFrame):
anno = o.copy()
if img is not None:
break
if img is None:
raise TypeError(f'Could not find Tensor in output: {type(output)}')
if anno is None:
anno = bb.util.new('anno')
return img, anno