Source code for atlas_gui.datasets.video

from atlas_gui.datasets.dataset import DatasetBase
import os
import json
import numpy as np
import cv2


[docs] class Video(DatasetBase): """ Dataset handler for video files (.mp4, .avi, .mkv). Supports three structures (auto-detected): - Single file: a direct file path or folder with one video (1 segment, 1 camera) - Folder of videos: each video file = one segment (N segments, 1 camera) - Multi-camera: folder of subfolders, each containing multiple video files (each subfolder = one segment, each video = one camera, basename = camera key) """ VALID_EXTENSIONS = {'.mp4', '.avi', '.mkv'} def __init__(self, config): super().__init__() self.config = config self.dataset_name = config['dataset_name'] self.annotation_dir = config['annotation_dir'] self.current_segment_idx = 0 self.structure = None # 'single_file', 'folder', or 'multicam' self.segments = {} # idx -> {'cameras': {cam_key: video_path}} os.makedirs(self.annotation_dir, exist_ok=True) def _is_video(self, filename): return os.path.splitext(filename)[1].lower() in self.VALID_EXTENSIONS def _decode_video(self, video_path): """Decode an entire video file into a numpy array of frames.""" cap = cv2.VideoCapture(video_path) frames = [] while True: ret, frame = cap.read() if not ret: break frames.append(frame) cap.release() if not frames: raise ValueError(f"Could not read any frames from {video_path}") return np.stack(frames) def _get_video_frame_count(self, video_path): """Get number of frames without decoding the entire video.""" cap = cv2.VideoCapture(video_path) count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) cap.release() return count
[docs] def load_data(self, file_path): self.file_path = file_path self._detect_structure(file_path) self.load_segments_info(file_path)
def _detect_structure(self, path): """Auto-detect the dataset structure and populate self.segments.""" # Direct file path if os.path.isfile(path): self.structure = 'single_file' cam_key = self.config['camera_keys'][0] if self.config.get('camera_keys') else 'camera' if not self.config.get('camera_keys'): self.config['camera_keys'] = [cam_key] self.segments = {0: {'cameras': {cam_key: path}}} return # Directory entries = sorted(os.listdir(path)) video_files = [e for e in entries if os.path.isfile(os.path.join(path, e)) and self._is_video(e)] subdirs = [e for e in entries if os.path.isdir(os.path.join(path, e))] if video_files and not subdirs: # Folder of videos: each video = one segment self.structure = 'folder' cam_key = self.config['camera_keys'][0] if self.config.get('camera_keys') else 'camera' if not self.config.get('camera_keys'): self.config['camera_keys'] = [cam_key] self.segments = { i: {'cameras': {cam_key: os.path.join(path, vf)}} for i, vf in enumerate(video_files) } elif subdirs: # Multi-camera: subfolders with video files self.structure = 'multicam' self.segments = {} # Auto-detect camera keys from first subfolder if not configured first_sub = os.path.join(path, subdirs[0]) sub_videos = sorted([ e for e in os.listdir(first_sub) if os.path.isfile(os.path.join(first_sub, e)) and self._is_video(e) ]) if not self.config.get('camera_keys'): self.config['camera_keys'] = [ os.path.splitext(v)[0] for v in sub_videos ] for i, d in enumerate(subdirs): sub_path = os.path.join(path, d) cameras = {} for vf in sorted(os.listdir(sub_path)): if os.path.isfile(os.path.join(sub_path, vf)) and self._is_video(vf): cam_name = os.path.splitext(vf)[0] cameras[cam_name] = os.path.join(sub_path, vf) self.segments[i] = {'cameras': cameras} else: raise ValueError( f"Could not detect dataset structure in '{path}'. " "Expected video files or subfolders containing videos." )
[docs] def load_segments_info(self, file_path): self.segments_info = {} for idx, seg_data in self.segments.items(): first_cam_key = list(seg_data['cameras'].keys())[0] video_path = seg_data['cameras'][first_cam_key] n_frames = self._get_video_frame_count(video_path) duration = (n_frames - 1) / self.config['fps'] if n_frames > 1 else 0.0 if self.structure == 'multicam': text = os.path.basename(os.path.dirname(video_path)) else: text = os.path.splitext(os.path.basename(video_path))[0] self.segments_info[str(idx)] = { 'index': idx, 'start': 0.0, 'end': duration, 'text': text, 'uid': f"segment_{idx}", }
[docs] def get_segment(self, segment_idx): self.current_segment_idx = segment_idx seg_data = self.segments[segment_idx] result = {} for cam_key in self.config['camera_keys']: if cam_key in seg_data['cameras']: result[cam_key] = self._decode_video(seg_data['cameras'][cam_key]) self.data = result return result
[docs] def get_max_timestamp(self): return self.segments_info[str(self.current_segment_idx)]['end']
[docs] def write_annot_data(self, segment_idx, annots): annotations_path = os.path.join( self.annotation_dir, f"{self.dataset_name}_annotations.json" ) if os.path.exists(annotations_path): with open(annotations_path, "r") as f: all_annotations = json.load(f) else: all_annotations = {} def convert(obj): if isinstance(obj, np.generic): return obj.item() if isinstance(obj, dict): return {k: convert(v) for k, v in obj.items()} return obj uid = self.segments_info[str(segment_idx)]['uid'] all_annotations[uid] = convert(annots) with open(annotations_path, "w") as f: json.dump(all_annotations, f, indent=2)
[docs] def load_annot_data(self, segment_idx): uid = self.segments_info[str(segment_idx)]['uid'] annotations_path = os.path.join( self.annotation_dir, f"{self.dataset_name}_annotations.json" ) if not os.path.exists(annotations_path): return {} with open(annotations_path, 'r') as f: all_annotations = json.load(f) return all_annotations.get(uid, {})