Source code for atlas_gui.datasets.rlds

from atlas_gui.datasets.dataset import DatasetBase
import tensorflow_datasets as tfds
import os
import json
import numpy as np
from itertools import islice

[docs] class RLDS(DatasetBase): """ Dataset handler for RLDS (Reinforcement Learning Datasets) format datasets. This class loads RLDS datasets via TensorFlow Datasets, and supports segment access, annotation I/O, and automatic structuring of episode information. Each segment corresponds to an entire episode. """ def __init__(self, config, split="train"): """ Initialize the RLDS dataset. Args: config (dict): Configuration dictionary. Must contain keys: - 'dataset_name': str, used to identify the dataset - 'annotation_dir': str, path to store annotation JSONs - 'fps': int or float, used to compute episode duration Optional keys: - 'download': bool, if True use tfds.load() to auto-download - 'data_dir': str, directory for downloaded datasets split (str): Dataset split to load, e.g., 'train'. Defaults to 'train'. """ super().__init__() self.config = config self.dataset_name = self.config['dataset_name'] self.split = split self.dataset = None self.download_mode = self.config.get('download', False) self.data_dir = os.path.expanduser(self.config.get('data_dir', '~/tensorflow_datasets')) self._iterator = None self._current_episode = None self.current_segment_idx = 0 self.annotation_dir = self.config['annotation_dir'] os.makedirs(self.annotation_dir, exist_ok=True)
[docs] def load_data(self, file_path=None): """ Load the RLDS dataset. Args: file_path (str, optional): Path to the TFDS-formatted dataset directory. Required if download=False, ignored if download=True. """ if self.download_mode: # Auto-download mode: use tfds.load() print(f"Loading dataset '{self.dataset_name}' from tfds (data_dir={self.data_dir})...") self.file_path = self.data_dir self.dataset = tfds.load( self.dataset_name, split=self.split, data_dir=self.data_dir, shuffle_files=False ) else: # Local mode: use builder_from_directory() if file_path is None: raise ValueError("file_path required when download=False") self.file_path = file_path builder = tfds.builder_from_directory(builder_dir=file_path) self.dataset = builder.as_dataset(split='train', shuffle_files=False) self._iterator = iter(self.dataset) self.load_segments_info(file_path=self.file_path) self._iterator = iter(self.dataset)
[docs] def get_segment(self, segment_idx): """ Return a specific segment (episode) from the dataset, stacked into NumPy arrays. Args: segment_idx (int): Index of the segment to load. Returns: dict: A dictionary containing stacked step-wise data (e.g., observations, actions). """ if self.dataset is None: raise ValueError("Dataset not loaded. Call load_data() first.") # Sequential forward step if (segment_idx == self.current_segment_idx + 1 and segment_idx != 0) or \ (segment_idx == 0 and self.current_segment_idx == -1): try: self._current_episode = next(self._iterator) self.current_segment_idx = segment_idx except StopIteration: raise IndexError(f"Segment index {segment_idx} out of range.") else: try: self._iterator = iter(self.dataset) self._current_episode = next(islice(self._iterator, segment_idx, segment_idx + 1)) self.current_segment_idx = segment_idx except StopIteration: raise IndexError(f"Segment index {segment_idx} out of range.") steps = list(self._current_episode['steps']) def recursively_stack(keys, depth=0): result = {} for key in keys: try: sample = steps[0][key] # Nested dict → recurse if isinstance(sample, dict): nested_keys = sample.keys() result[key] = {} for nk in nested_keys: try: result[key][nk] = np.stack([step[key][nk].numpy() for step in steps]) except Exception as e: print(f"{' ' * depth}Warning: could not stack '{key}/{nk}': {e}") else: result[key] = np.stack([step[key].numpy() for step in steps]) except Exception as e: print(f"{' ' * depth}Warning: could not stack '{key}': {e}") return result stacked = recursively_stack(steps[0].keys()) self._current_episode = {"steps": stacked} return self._current_episode
[docs] def load_segments_info(self, file_path=None): """ Load metadata for each episode in the dataset. Creates a dictionary where each segment entry includes: - index - start time - end time (based on FPS and number of steps) - language instruction (text) - unique ID (derived from episode metadata or index) """ if self.dataset is None: raise ValueError("Call load_data() before loading segments info.") self.segments_info = {} for idx, episode in enumerate(self.dataset): steps = episode["steps"] num_steps = len(steps) # Extract text from first step using configured text_keys text = "" if self.config.get('text_keys'): try: sample_step = next(iter(steps)) text_key = self.config['text_keys'][0] text_value = sample_step[text_key] if hasattr(text_value, 'numpy'): text_value = text_value.numpy() if isinstance(text_value, bytes): text = text_value.decode("utf-8") else: text = str(text_value) except Exception: text = "" # Extract a unique ID - try episode_metadata first, fall back to index unique_id = f"{self.dataset_name}_episode_{idx}" try: if "episode_metadata" in episode: metadata = episode["episode_metadata"] if "recording_folderpath" in metadata: folder_path = metadata["recording_folderpath"].numpy().decode("utf-8") unique_id = os.path.basename(os.path.dirname(os.path.dirname(folder_path))) elif "episode_id" in metadata: unique_id = str(metadata["episode_id"].numpy()) elif "file_path" in metadata: unique_id = os.path.basename(metadata["file_path"].numpy().decode("utf-8")) except Exception: pass # Keep default unique_id segment_info = { "index": idx, "start": 0.0, "end": (num_steps - 1) / self.config['fps'], "text": text, "uid": unique_id } self.segments_info[str(idx)] = segment_info
[docs] def get_max_timestamp(self): """ Return the end timestamp of the current episode. Returns: float: The maximum timestamp for the current segment. """ return self.segments_info[str(self.current_segment_idx)]['end']
[docs] def write_annot_data(self, segment_idx, annots): """ Write annotation data to a per-dataset JSON file. Annotations are stored using the UID of the segment as the key. Args: segment_idx (int): Index of the segment being annotated. annots (dict): Dictionary of annotations (must be JSON serializable). """ annotations_path = os.path.join(self.annotation_dir, f"{self.dataset_name}_annotations.json") # Load existing annotations if the file exists if os.path.exists(annotations_path): with open(annotations_path, "r") as f: all_annotations = json.load(f) else: all_annotations = {} # Convert numpy values to Python types (e.g., float64 → float) def convert(obj): if isinstance(obj, np.generic): return obj.item() if isinstance(obj, dict): return {k: convert(v) for k, v in obj.items()} return obj all_annotations[self.segments_info[str(segment_idx)]['uid']] = convert(annots) # Save back to JSON with open(annotations_path, "w") as f: json.dump(all_annotations, f, indent=2)
[docs] def load_annot_data(self, segment_idx): """ Load annotation data for the given segment index, using UID-based lookup. Args: segment_idx (int): Index of the segment to load annotations for. Returns: dict: Annotation data for the given segment, or an empty dict if none found. """ uid = self.segments_info[str(segment_idx)]['uid'] annotations_path = os.path.join(self.annotation_dir, f"{self.config['dataset_name']}_annotations.json") if not os.path.exists(annotations_path): return {} with open(annotations_path, 'r') as f: all_annotations = json.load(f) annot_data = all_annotations.get(uid, {}) return annot_data