Source code for atlas_gui.utils.config

import yaml
from atlas_gui.datasets.dataset import DatasetBase
from atlas_gui.datasets.reassemble import Reassemble
from atlas_gui.datasets.rlds import RLDS
from atlas_gui.datasets.rosbag_ds import Rosbag
from atlas_gui.datasets.frames import Frames
from atlas_gui.datasets.video import Video

[docs] def load_config(path="config.yaml"): """ Load a YAML configuration file. Args: path (str): Path to the configuration file. Defaults to 'config.yaml'. Returns: dict: Parsed YAML configuration as a Python dictionary. """ with open(path, "r") as f: return yaml.safe_load(f)
[docs] def get_nested(data_dict, key_path): """ Access a nested dictionary using a slash-separated key path. Args: data_dict (dict): The dictionary to traverse. key_path (str): Slash-separated path to the nested key (e.g., 'robot_state/joint_efforts'). Returns: Any: The value found at the nested key path. """ keys = key_path.split('/') for key in keys: data_dict = data_dict[key] return data_dict
[docs] def has_nested_key(data, path): """ Check whether a nested key path exists in a dictionary. Args: data (dict): The dictionary to search. path (str): Slash-separated path representing the nested key. Returns: bool: True if the nested key exists, False otherwise. """ keys = path.split('/') for key in keys: if not isinstance(data, dict) or key not in data: return False data = data[key] return True
[docs] def get_nested_np(data_dict, key_path): """ Access a nested dictionary key and convert the result to a NumPy array. Specifically used for RLDS datasets, where TensorFlow tensors need to be converted. Args: data_dict (dict): The dictionary to access. key_path (str): Slash-separated path to the nested key. Returns: np.ndarray: The NumPy array corresponding to the nested value. """ keys = key_path.split('/') for key in keys: data_dict = data_dict[key] return data_dict.numpy()
[docs] def create_dataset(dataset_type: str, config) -> DatasetBase: """ Factory method to instantiate the appropriate dataset class. Args: dataset_type (str): Type of dataset to load, either 'reassemble' or 'rlds'. config (dict): Configuration dictionary to pass to the dataset. Returns: DatasetBase: An instance of either Reassemble or RLDS dataset. Raises: ValueError: If the dataset_type is unsupported. """ dataset_type = dataset_type.lower() if dataset_type == "reassemble": return Reassemble(config) elif dataset_type == "rlds": return RLDS(config) elif dataset_type == "frames": return Frames(config) elif dataset_type == "video": return Video(config) elif dataset_type == "rosbag": return Rosbag(config) else: raise ValueError(f"Unsupported dataset type: {dataset_type}")