Skip to content

API

Data Classes

The Well provides two main class WellDataset and WellDataModule to handle the raw data that are stored in .hdf5 files. The WellDataset implements a map-style PyTorch Dataset. The WellDataModule provides dataloaders for training, validation, and test. The tutorial provides a guide on how to use these classes in a training pipeline.

Dataset

The WellDataset is a map-style dataset. It converts the .hdf5 file structure expected by the Well into torch.Tensor data. It first processes metadata from the .hdf5 attributes to allow for retrieval of individual samples.

the_well.data.WellDataset

Bases: Dataset

Generic dataset for any Well data. Returns data in B x T x H [x W [x D]] x C format.

Train/Test/Valid is assumed to occur on a folder level.

Takes in path to directory of HDF5 files to construct dset.

Parameters:

Name Type Description Default
path Optional[str]

Path to directory of HDF5 files, one of path or well_base_path+well_dataset_name must be specified

None
normalization_path str

Path to normalization constants - assumed to be in same format as constructed data.

'../stats.yaml'
well_base_path Optional[str]

Path to well dataset directory, only used with dataset_name

None
well_dataset_name Optional[str]

Name of well dataset to load - overrides path if specified

None
well_split_name str

Name of split to load - options are 'train', 'valid', 'test'

'train'
include_filters List[str]

Only include files whose name contains at least one of these strings

[]
exclude_filters List[str]

Exclude any files whose name contains at least one of these strings

[]
use_normalization bool

Whether to normalize data in the dataset

False
n_steps_input int

Number of steps to include in each sample

1
n_steps_output int

Number of steps to include in y

1
min_dt_stride int

Minimum stride between samples

1
max_dt_stride int

Maximum stride between samples

1
flatten_tensors bool

Whether to flatten tensor valued field into channels

True
cache_small bool

Whether to cache small tensors in memory for faster access

True
max_cache_size float

Maximum numel of constant tensor to cache

1000000000.0
return_grid bool

Whether to return grid coordinates

True
boundary_return_type str

options=['padding', 'mask', 'exact', 'none'] How to return boundary conditions. Currently only padding supported.

'padding'
full_trajectory_mode bool

Overrides to return full trajectory starting from t0 instead of samples for long run validation.

False
name_override Optional[str]

Override name of dataset (used for more precise logging)

None
transform Optional[Augmentation]

Transform to apply to data. In the form f(data: TrajectoryData, metadata: TrajectoryMetadata) -> TrajectoryData, where data contains a piece of trajectory (fields, scalars, BCs, ...) and metadata contains additional informations, including the dataset itself.

None
min_std float

Minimum standard deviation for field normalization. If a field standard deviation is lower than this value, it is replaced by this value.

0.0001
storage_options

Option for the ffspec storage.

None
Source code in the_well/data/datasets.py
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
class WellDataset(Dataset):
    """
    Generic dataset for any Well data. Returns data in B x T x H [x W [x D]] x C format.

    Train/Test/Valid is assumed to occur on a folder level.

    Takes in path to directory of HDF5 files to construct dset.

    Args:
        path:
            Path to directory of HDF5 files, one of path or well_base_path+well_dataset_name
            must be specified
        normalization_path:
            Path to normalization constants - assumed to be in same format as constructed data.
        well_base_path:
            Path to well dataset directory, only used with dataset_name
        well_dataset_name:
            Name of well dataset to load - overrides path if specified
        well_split_name:
            Name of split to load - options are 'train', 'valid', 'test'
        include_filters:
            Only include files whose name contains at least one of these strings
        exclude_filters:
            Exclude any files whose name contains at least one of these strings
        use_normalization:
            Whether to normalize data in the dataset
        n_steps_input:
            Number of steps to include in each sample
        n_steps_output:
            Number of steps to include in y
        min_dt_stride:
            Minimum stride between samples
        max_dt_stride:
            Maximum stride between samples
        flatten_tensors:
            Whether to flatten tensor valued field into channels
        cache_small:
            Whether to cache small tensors in memory for faster access
        max_cache_size:
            Maximum numel of constant tensor to cache
        return_grid:
            Whether to return grid coordinates
        boundary_return_type: options=['padding', 'mask', 'exact', 'none']
            How to return boundary conditions. Currently only padding supported.
        full_trajectory_mode:
            Overrides to return full trajectory starting from t0 instead of samples
                for long run validation.
        name_override:
            Override name of dataset (used for more precise logging)
        transform:
            Transform to apply to data. In the form `f(data: TrajectoryData, metadata:
            TrajectoryMetadata) -> TrajectoryData`, where `data` contains a piece of
            trajectory (fields, scalars, BCs, ...) and `metadata` contains additional
            informations, including the dataset itself.
        min_std:
            Minimum standard deviation for field normalization. If a field standard
            deviation is lower than this value, it is replaced by this value.
        storage_options :
            Option for the ffspec storage.
    """

    def __init__(
        self,
        path: Optional[str] = None,
        normalization_path: str = "../stats.yaml",
        well_base_path: Optional[str] = None,
        well_dataset_name: Optional[str] = None,
        well_split_name: str = "train",
        include_filters: List[str] = [],
        exclude_filters: List[str] = [],
        use_normalization: bool = False,
        max_rollout_steps=100,
        n_steps_input: int = 1,
        n_steps_output: int = 1,
        min_dt_stride: int = 1,
        max_dt_stride: int = 1,
        flatten_tensors: bool = True,
        cache_small: bool = True,
        max_cache_size: float = 1e9,
        return_grid: bool = True,
        boundary_return_type: str = "padding",
        full_trajectory_mode: bool = False,
        name_override: Optional[str] = None,
        transform: Optional["Augmentation"] = None,
        min_std: float = 1e-4,
        storage_options: Optional[Dict] = None,
    ):
        super().__init__()
        assert path is not None or (
            well_base_path is not None and well_dataset_name is not None
        ), "Must specify path or well_base_path and well_dataset_name"
        if path is not None:
            self.data_path = path
            self.normalization_path = os.path.join(path, normalization_path)

        else:
            assert (
                well_dataset_name in WELL_DATASETS
            ), f"Dataset name {well_dataset_name} not in the expected list {WELL_DATASETS}."
            self.data_path = os.path.join(
                well_base_path, well_dataset_name, "data", well_split_name
            )
            self.normalization_path = os.path.join(
                well_base_path, well_dataset_name, "stats.yaml"
            )

        self.fs, _ = fsspec.url_to_fs(self.data_path, **(storage_options or {}))

        if use_normalization:
            with self.fs.open(self.normalization_path, mode="r") as f:
                stats = yaml.safe_load(f)

            self.means = {
                field: torch.as_tensor(val) for field, val in stats["mean"].items()
            }
            self.stds = {
                field: torch.clip(torch.as_tensor(val), min=min_std)
                for field, val in stats["std"].items()
            }

        # Input checks
        if boundary_return_type is not None and boundary_return_type not in ["padding"]:
            raise NotImplementedError("Only padding boundary conditions supported")
        if not flatten_tensors:
            raise NotImplementedError("Only flattened tensors supported right now")

        # Copy params
        self.well_dataset_name = well_dataset_name
        self.use_normalization = use_normalization
        self.include_filters = include_filters
        self.exclude_filters = exclude_filters
        self.max_rollout_steps = max_rollout_steps
        self.n_steps_input = n_steps_input
        self.n_steps_output = n_steps_output  # Gets overridden by full trajectory mode
        self.min_dt_stride = min_dt_stride
        self.max_dt_stride = max_dt_stride
        self.flatten_tensors = flatten_tensors
        self.return_grid = return_grid
        self.boundary_return_type = boundary_return_type
        self.full_trajectory_mode = full_trajectory_mode
        self.cache_small = cache_small
        self.max_cache_size = max_cache_size
        self.transform = transform
        if self.min_dt_stride < self.max_dt_stride and self.full_trajectory_mode:
            raise ValueError(
                "Full trajectory mode not supported with variable stride lengths"
            )
        # Check the directory has hdf5 that meet our exclusion criteria
        sub_files = self.fs.glob(self.data_path + "/*.h5") + self.fs.glob(
            self.data_path + "/*.hdf5"
        )
        # Check filters - only use file if include_filters are present and exclude_filters are not
        if len(self.include_filters) > 0:
            retain_files = []
            for include_string in self.include_filters:
                retain_files += [f for f in sub_files if include_string in f]
            sub_files = retain_files
        if len(self.exclude_filters) > 0:
            for exclude_string in self.exclude_filters:
                sub_files = [f for f in sub_files if exclude_string not in f]
        assert len(sub_files) > 0, "No HDF5 files found in path {}".format(
            self.data_path
        )
        self.files_paths = sub_files
        self.files_paths.sort()
        self.caches = [{} for _ in self.files_paths]
        # Build multi-index
        self.metadata = self._build_metadata()
        # Override name if necessary for logging
        if name_override is not None:
            self.dataset_name = name_override

    def _build_metadata(self):
        """Builds multi-file indices and checks that folder contains consistent dataset"""
        self.n_files = len(self.files_paths)
        self.n_trajectories_per_file = []
        self.n_steps_per_trajectory = []
        self.n_windows_per_trajectory = []
        self.file_index_offsets = [0]  # Used to track where each file starts
        # Things where we just care every file has same value
        size_tuples = set()
        names = set()
        ndims = set()
        bcs = set()
        lowest_steps = 1e9  # Note - we should never have 1e9 steps
        for index, file in enumerate(self.files_paths):
            with (
                self.fs.open(file, "rb", **IO_PARAMS["fsspec_params"]) as f,
                h5.File(f, "r", **IO_PARAMS["h5py_params"]) as _f,
            ):
                grid_type = _f.attrs["grid_type"]
                # Run sanity checks - all files should have same ndims, size_tuple, and names
                trajectories = int(_f.attrs["n_trajectories"])
                # Number of steps is always last dim of time
                steps = _f["dimensions"]["time"].shape[-1]
                size_tuple = [
                    _f["dimensions"][d].shape[-1]
                    for d in _f["dimensions"].attrs["spatial_dims"]
                ]
                ndims.add(_f.attrs["n_spatial_dims"])
                names.add(_f.attrs["dataset_name"])
                size_tuples.add(tuple(size_tuple))
                # Fast enough that I'd rather check each file rather than processing extra files before checking
                assert len(names) == 1, "Multiple dataset names found in specified path"
                assert len(ndims) == 1, "Multiple ndims found in specified path"
                assert (
                    len(size_tuples) == 1
                ), "Multiple resolutions found in specified path"

                # Track lowest amount of steps in case we need to use full_trajectory_mode
                lowest_steps = min(lowest_steps, steps)

                windows_per_trajectory = raw_steps_to_possible_sample_t0s(
                    steps, self.n_steps_input, self.n_steps_output, self.min_dt_stride
                )
                assert windows_per_trajectory > 0, (
                    f"{steps} steps is not enough steps for file {file}"
                    f" to allow {self.n_steps_input} input and {self.n_steps_output} output steps"
                    f" with a minimum stride of {self.min_dt_stride}"
                )
                self.n_trajectories_per_file.append(trajectories)
                self.n_steps_per_trajectory.append(steps)
                self.n_windows_per_trajectory.append(windows_per_trajectory)
                self.file_index_offsets.append(
                    self.file_index_offsets[-1] + trajectories * windows_per_trajectory
                )
                # Check BCs
                for bc in _f["boundary_conditions"].keys():
                    bcs.add(_f["boundary_conditions"][bc].attrs["bc_type"])

                if index == 0:
                    # Populate scalar names
                    self.scalar_names = []
                    self.constant_scalar_names = []

                    for scalar in _f["scalars"].attrs["field_names"]:
                        if _f["scalars"][scalar].attrs["time_varying"]:
                            self.scalar_names.append(scalar)
                        else:
                            self.constant_scalar_names.append(scalar)

                    # Populate field names
                    self.field_names = {i: [] for i in range(3)}
                    self.constant_field_names = {i: [] for i in range(3)}

                    for i in range(3):
                        ti = f"t{i}_fields"
                        # if _f[ti][field].attrs["symmetric"]:
                        # itertools.combinations_with_replacement
                        ti_field_dims = [
                            "".join(xyz)
                            for xyz in itertools.product(
                                _f["dimensions"].attrs["spatial_dims"],
                                repeat=i,
                            )
                        ]

                        for field in _f[ti].attrs["field_names"]:
                            for dims in ti_field_dims:
                                field_name = f"{field}_{dims}" if dims else field

                                if _f[ti][field].attrs["time_varying"]:
                                    self.field_names[i].append(field_name)
                                else:
                                    self.constant_field_names[i].append(field_name)
        # Full trajectory mode overrides the above and just sets each sample to "full"
        # trajectory where full = min(lowest_steps_per_file, max_rollout_steps)
        if self.full_trajectory_mode:
            self.n_steps_output = (
                lowest_steps // self.min_dt_stride
            ) - self.n_steps_input
            assert self.n_steps_output > 0, (
                f"Full trajectory mode not supported for dataset {names[0]} with {lowest_steps} minimum steps"
                f" and a minimum stride of {self.min_dt_stride} and {self.n_steps_input} input steps"
            )
            self.n_windows_per_trajectory = [1] * self.n_files
            self.n_steps_per_trajectory = [lowest_steps] * self.n_files
            self.file_index_offsets = np.cumsum([0] + self.n_trajectories_per_file)

        # Just to make sure it doesn't put us in file -1
        self.file_index_offsets[0] = -1
        self.files: List[h5.File | None] = [
            None for _ in self.files_paths
        ]  # We open file references as they come
        # Dataset length is last number of samples
        self.len = self.file_index_offsets[-1]
        self.n_spatial_dims = int(ndims.pop())  # Number of spatial dims
        self.size_tuple = tuple(map(int, size_tuples.pop()))  # Size of spatial dims
        self.dataset_name = names.pop()  # Name of dataset
        # BCs
        self.num_bcs = len(bcs)  # Number of boundary condition type included in data
        self.bc_types = list(bcs)  # List of boundary condition types

        return WellMetadata(
            dataset_name=self.dataset_name,
            n_spatial_dims=self.n_spatial_dims,
            grid_type=grid_type,
            spatial_resolution=self.size_tuple,
            scalar_names=self.scalar_names,
            constant_scalar_names=self.constant_scalar_names,
            field_names=self.field_names,
            constant_field_names=self.constant_field_names,
            boundary_condition_types=self.bc_types,
            n_files=self.n_files,
            n_trajectories_per_file=self.n_trajectories_per_file,
            n_steps_per_trajectory=self.n_steps_per_trajectory,
        )

    def _open_file(self, file_ind: int):
        _file = h5.File(
            self.fs.open(
                self.files_paths[file_ind], "rb", **IO_PARAMS["fsspec_params"]
            ),
            "r",
            **IO_PARAMS["h5py_params"],
        )
        self.files[file_ind] = _file

    def _check_cache(self, cache: Dict[str, Any], name: str, data: Any):
        if self.cache_small and data.numel() < self.max_cache_size:
            cache[name] = data

    def _pad_axes(
        self,
        field_data: Any,
        use_dims,
        time_varying: bool = False,
        tensor_order: int = 0,
    ):
        """Repeats data over axes not used in storage"""
        # Look at which dimensions currently are not used and tile based on their sizes
        expand_dims = (1,) if time_varying else ()
        expand_dims = expand_dims + tuple(
            [
                self.size_tuple[i] if not use_dim else 1
                for i, use_dim in enumerate(use_dims)
            ]
        )
        expand_dims = expand_dims + (1,) * tensor_order
        return torch.tile(field_data, expand_dims)

    def _reconstruct_fields(self, file, cache, sample_idx, time_idx, n_steps, dt):
        """Reconstruct space fields starting at index sample_idx, time_idx, with
        n_steps and dt stride."""
        variable_fields = {0: {}, 1: {}, 2: {}}
        constant_fields = {0: {}, 1: {}, 2: {}}
        # Iterate through field types and apply appropriate transforms to stack them
        for i, order_fields in enumerate(["t0_fields", "t1_fields", "t2_fields"]):
            field_names = file[order_fields].attrs["field_names"]
            for field_name in field_names:
                field = file[order_fields][field_name]
                use_dims = field.attrs["dim_varying"]
                # If the field is in the cache, use it, otherwise go through read/pad
                if field_name in cache:
                    field_data = cache[field_name]
                else:
                    field_data = field
                    # Index is built gradually since there can be different numbers of leading fields
                    multi_index = ()
                    if field.attrs["sample_varying"]:
                        multi_index = multi_index + (sample_idx,)
                    if field.attrs["time_varying"]:
                        multi_index = multi_index + (
                            slice(time_idx, time_idx + n_steps * dt, dt),
                        )
                    field_data = field_data[multi_index]
                    field_data = torch.as_tensor(field_data)
                    # Normalize
                    if self.use_normalization:
                        if field_name in self.means:
                            field_data = field_data - self.means[field_name]
                        if field_name in self.stds:
                            field_data = field_data / self.stds[field_name]
                    # If constant, try to cache
                    if (
                        not field.attrs["time_varying"]
                        and not field.attrs["sample_varying"]
                    ):
                        self._check_cache(cache, field_name, field_data)

                # Expand dims
                field_data = self._pad_axes(
                    field_data,
                    use_dims,
                    time_varying=field.attrs["time_varying"],
                    tensor_order=i,
                )

                if field.attrs["time_varying"]:
                    variable_fields[i][field_name] = field_data
                else:
                    constant_fields[i][field_name] = field_data

        return (variable_fields, constant_fields)

    def _reconstruct_scalars(self, file, cache, sample_idx, time_idx, n_steps, dt):
        """Reconstruct scalar values (not fields) starting at index sample_idx, time_idx, with
        n_steps and dt stride."""
        variable_scalars = {}
        constant_scalars = {}
        for scalar_name in file["scalars"].attrs["field_names"]:
            scalar = file["scalars"][scalar_name]

            if scalar_name in cache:
                scalar_data = cache[scalar_name]
            else:
                scalar_data = scalar
                # Build index gradually to account for different leading dims
                multi_index = ()
                if scalar.attrs["sample_varying"]:
                    multi_index = multi_index + (sample_idx,)
                if scalar.attrs["time_varying"]:
                    multi_index = multi_index + (
                        slice(time_idx, time_idx + n_steps * dt, dt),
                    )
                scalar_data = scalar_data[multi_index]
                scalar_data = torch.as_tensor(scalar_data)
                # If constant, try to cache
                if (
                    not scalar.attrs["time_varying"]
                    and not scalar.attrs["sample_varying"]
                ):
                    self._check_cache(cache, scalar_name, scalar_data)

            if scalar.attrs["time_varying"]:
                variable_scalars[scalar_name] = scalar_data
            else:
                constant_scalars[scalar_name] = scalar_data

        return (variable_scalars, constant_scalars)

    def _reconstruct_grids(self, file, cache, sample_idx, time_idx, n_steps, dt):
        """Reconstruct grid values starting at index sample_idx, time_idx, with
        n_steps and dt stride."""
        # Time
        if "time_grid" in cache:
            time_grid = cache["time_grid"]
        elif file["dimensions"]["time"].attrs["sample_varying"]:
            time_grid = torch.tensor(file["dimensions"]["time"][sample_idx, :])
        else:
            time_grid = torch.tensor(file["dimensions"]["time"][:])
            self._check_cache(cache, "time_grid", time_grid)
        # We have already sampled leading index if it existed so timegrid should be 1D
        time_grid = time_grid[time_idx : time_idx + n_steps * dt : dt]
        # Nothing should depend on absolute time - might change if we add weather
        time_grid = time_grid - time_grid.min()

        # Space - TODO - support time-varying grids or non-tensor product grids
        if "space_grid" in cache:
            space_grid = cache["space_grid"]
        else:
            space_grid = []
            sample_invariant = True
            for dim in file["dimensions"].attrs["spatial_dims"]:
                if file["dimensions"][dim].attrs["sample_varying"]:
                    sample_invariant = False
                    coords = torch.tensor(file["dimensions"][dim][sample_idx])
                else:
                    coords = torch.tensor(file["dimensions"][dim][:])
                space_grid.append(coords)
            space_grid = torch.stack(torch.meshgrid(*space_grid, indexing="ij"), -1)
            if sample_invariant:
                self._check_cache(cache, "space_grid", space_grid)
        return space_grid, time_grid

    def _padding_bcs(self, file, cache, sample_idx, time_idx, n_steps, dt):
        """Handles BC case where BC corresponds to a specific padding type

        Note/TODO - currently assumes boundaries to be axis-aligned and cover the entire
        domain. This is a simplification that will need to be addressed in the future.
        """
        if "boundary_output" in cache:
            boundary_output = cache["boundary_output"]
        else:
            bcs = file["boundary_conditions"]
            dim_indices = {
                dim: i for i, dim in enumerate(file["dimensions"].attrs["spatial_dims"])
            }
            boundary_output = torch.zeros(self.n_spatial_dims, 2)
            for bc_name in bcs.keys():
                bc = bcs[bc_name]
                bc_type = bc.attrs["bc_type"].upper()  # Enum is in upper case
                if len(bc.attrs["associated_dims"]) > 1:
                    raise NotImplementedError(
                        "Only axis-aligned boundaries supported for now. If your code is not using BCs, consider setting `boundary_return_type` to None."
                    )
                dim = bc.attrs["associated_dims"][0]
                mask = bc["mask"]
                if mask[0]:
                    boundary_output[dim_indices[dim]][0] = BoundaryCondition[
                        bc_type
                    ].value
                if mask[-1]:
                    boundary_output[dim_indices[dim]][1] = BoundaryCondition[
                        bc_type
                    ].value
            self._check_cache(cache, "boundary_output", boundary_output)
        return boundary_output

    def _reconstruct_bcs(self, file, cache, sample_idx, time_idx, n_steps, dt):
        """Needs work to support arbitrary BCs.

        Currently supports finite set of boundary condition types that describe
        the geometry of the domain. Implements these as mask channels. The total
        number of channels is determined by the number of BC types in the
        data.

        #TODO generalize boundary types
        """
        if self.boundary_return_type == "padding":
            return self._padding_bcs(file, cache, sample_idx, time_idx, n_steps, dt)
        else:
            raise NotImplementedError()

    def __getitem__(self, index):
        # Find specific file and local index
        file_idx = int(
            np.searchsorted(self.file_index_offsets, index, side="right") - 1
        )  # which file we are on
        windows_per_trajectory = self.n_windows_per_trajectory[file_idx]
        local_idx = index - max(
            self.file_index_offsets[file_idx], 0
        )  # First offset is -1
        sample_idx = local_idx // windows_per_trajectory
        time_idx = local_idx % windows_per_trajectory
        # open hdf5 file (and cache the open object)
        if self.files[file_idx] is None:
            self._open_file(file_idx)

        # If we gave a stride range, decide the largest size we can use given the sample location
        dt = self.min_dt_stride
        if self.max_dt_stride > self.min_dt_stride:
            effective_max_dt = maximum_stride_for_initial_index(
                time_idx,
                self.n_steps_per_trajectory[file_idx],
                self.n_steps_input,
                self.n_steps_output,
            )
            effective_max_dt = min(effective_max_dt, self.max_dt_stride)
            if effective_max_dt > self.min_dt_stride:
                # Randint is non-inclusive on the upper bound
                dt = np.random.randint(self.min_dt_stride, effective_max_dt + 1)
        # Fetch the data
        data = {}

        output_steps = min(self.n_steps_output, self.max_rollout_steps)
        data["variable_fields"], data["constant_fields"] = self._reconstruct_fields(
            self.files[file_idx],
            self.caches[file_idx],
            sample_idx,
            time_idx,
            self.n_steps_input + output_steps,
            dt,
        )
        data["variable_scalars"], data["constant_scalars"] = self._reconstruct_scalars(
            self.files[file_idx],
            self.caches[file_idx],
            sample_idx,
            time_idx,
            self.n_steps_input + output_steps,
            dt,
        )

        if self.boundary_return_type is not None:
            data["boundary_conditions"] = self._reconstruct_bcs(
                self.files[file_idx],
                self.caches[file_idx],
                sample_idx,
                time_idx,
                self.n_steps_input + output_steps,
                dt,
            )

        if self.return_grid:
            data["space_grid"], data["time_grid"] = self._reconstruct_grids(
                self.files[file_idx],
                self.caches[file_idx],
                sample_idx,
                time_idx,
                self.n_steps_input + output_steps,
                dt,
            )

        # Data transformation/augmentation
        if self.transform is not None:
            data = self.transform(
                cast(TrajectoryData, data),
                TrajectoryMetadata(
                    dataset=self,
                    file_idx=file_idx,
                    sample_idx=sample_idx,
                    time_idx=time_idx,
                    time_stride=dt,
                ),
            )

        # Concatenate fields and scalars
        for key in ("variable_fields", "constant_fields"):
            data[key] = [
                field.unsqueeze(-1).flatten(-order - 1)
                for order, fields in data[key].items()
                for _, field in fields.items()
            ]

            if data[key]:
                data[key] = torch.concatenate(data[key], dim=-1)
            else:
                data[key] = torch.tensor([])

        for key in ("variable_scalars", "constant_scalars"):
            data[key] = [scalar.unsqueeze(-1) for _, scalar in data[key].items()]

            if data[key]:
                data[key] = torch.concatenate(data[key], dim=-1)
            else:
                data[key] = torch.tensor([])

        # Input/Output split
        sample = {
            "input_fields": data["variable_fields"][
                : self.n_steps_input
            ],  # Ti x H x W x C
            "output_fields": data["variable_fields"][
                self.n_steps_input :
            ],  # To x H x W x C
            "constant_fields": data["constant_fields"],  # H x W x C
            "input_scalars": data["variable_scalars"][: self.n_steps_input],  # Ti x C
            "output_scalars": data["variable_scalars"][self.n_steps_input :],  # To x C
            "constant_scalars": data["constant_scalars"],  # C
        }

        if self.boundary_return_type is not None:
            sample["boundary_conditions"] = data["boundary_conditions"]  # N x 2

        if self.return_grid:
            sample["space_grid"] = data["space_grid"]  # H x W x D
            sample["input_time_grid"] = data["time_grid"][: self.n_steps_input]  # Ti
            sample["output_time_grid"] = data["time_grid"][self.n_steps_input :]  # To

        # Return only non-empty keys - maybe change this later
        return {k: v for k, v in sample.items() if v.numel() > 0}

    def __len__(self):
        return self.len

    def to_xarray(self, backend: Literal["numpy", "dask"] = "dask"):
        """Export the dataset to an Xarray Dataset by stacking all HDF5 files as Xarray datasets
        along the existing 'sample' dimension.

        Args:
            backend: 'numpy' for eager loading, 'dask' for lazy loading.

        Returns:
            xarray.Dataset:
                The stacked Xarray Dataset.

        Examples:
            To convert a dataset and plot the pressure for 5 different times for a single trajectory:
            >>> ds = dataset.to_xarray()
            >>> ds.pressure.isel(sample=0, time=[0, 10, 20, 30, 40]).plot(col='time', col_wrap=5)
        """

        import xarray as xr

        datasets = []
        total_samples = 0
        for file_idx in range(len(self.files_paths)):
            if self.files[file_idx] is None:
                self._open_file(file_idx)
            ds = hdf5_to_xarray(self.files[file_idx], backend=backend)
            # Ensure 'sample' dimension is always present
            if "sample" not in ds.sizes:
                ds = ds.expand_dims("sample")
            # Adjust the 'sample' coordinate
            if "sample" in ds.coords:
                n_samples = ds.sizes["sample"]
                ds = ds.assign_coords(sample=ds.coords["sample"] + total_samples)
                total_samples += n_samples
            datasets.append(ds)

        combined_ds = xr.concat(datasets, dim="sample")
        return combined_ds

    def __repr__(self) -> str:
        return f"<{self.__class__.__name__}: {self.data_path}>"
to_xarray(backend='dask')

Export the dataset to an Xarray Dataset by stacking all HDF5 files as Xarray datasets along the existing 'sample' dimension.

Parameters:

Name Type Description Default
backend Literal['numpy', 'dask']

'numpy' for eager loading, 'dask' for lazy loading.

'dask'

Returns:

Type Description

xarray.Dataset: The stacked Xarray Dataset.

Examples:

To convert a dataset and plot the pressure for 5 different times for a single trajectory:

>>> ds = dataset.to_xarray()
>>> ds.pressure.isel(sample=0, time=[0, 10, 20, 30, 40]).plot(col='time', col_wrap=5)
Source code in the_well/data/datasets.py
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
def to_xarray(self, backend: Literal["numpy", "dask"] = "dask"):
    """Export the dataset to an Xarray Dataset by stacking all HDF5 files as Xarray datasets
    along the existing 'sample' dimension.

    Args:
        backend: 'numpy' for eager loading, 'dask' for lazy loading.

    Returns:
        xarray.Dataset:
            The stacked Xarray Dataset.

    Examples:
        To convert a dataset and plot the pressure for 5 different times for a single trajectory:
        >>> ds = dataset.to_xarray()
        >>> ds.pressure.isel(sample=0, time=[0, 10, 20, 30, 40]).plot(col='time', col_wrap=5)
    """

    import xarray as xr

    datasets = []
    total_samples = 0
    for file_idx in range(len(self.files_paths)):
        if self.files[file_idx] is None:
            self._open_file(file_idx)
        ds = hdf5_to_xarray(self.files[file_idx], backend=backend)
        # Ensure 'sample' dimension is always present
        if "sample" not in ds.sizes:
            ds = ds.expand_dims("sample")
        # Adjust the 'sample' coordinate
        if "sample" in ds.coords:
            n_samples = ds.sizes["sample"]
            ds = ds.assign_coords(sample=ds.coords["sample"] + total_samples)
            total_samples += n_samples
        datasets.append(ds)

    combined_ds = xr.concat(datasets, dim="sample")
    return combined_ds

DataModule

The WellDataModule provides the different dataloaders required for training, validation, and testing. It has two kinds of dataloaders: the default one that yields batches of a fixed time horizon, and rollout ones that yields batches to evaluate rollout performances.

the_well.data.WellDataModule

Bases: AbstractDataModule

Data module class to yield batches of samples.

Parameters:

Name Type Description Default
well_base_path str

Path to the data folder containing the splits (train, validation, and test).

required
well_dataset_name str

Name of the well dataset to use.

required
batch_size int

Size of the batches yielded by the dataloaders

required
include_filters List[str]

Only file names containing any of these strings will be included.

[]
exclude_filters List[str]

File names containing any of these strings will be excluded.

[]
use_normalization bool

Whether to use normalization on the data. Currently only supports mean/std.

False
max_rollout_steps int

Maximum number of steps to use for the rollout dataset. Mostly for memory reasons.

100
n_steps_input int

Number of steps to use as input.

1
n_steps_output int

Number of steps to use as output.

1
min_dt_stride int

Minimum stride in time to use for the dataset.

1
max_dt_stride int

Maximum stride in time to use for the dataset. If this is greater than min, randomly choose between them. Note that this is unused for validation/test which uses "min_dt_stride" for both the min and max.

1
world_size int

Number of GPUs in use for distributed training.

1
data_workers int

Number of workers to use for data loading.

4
rank int

Rank of the current process in distributed training.

1
transform Optional[Augmentation]

Augmentation to apply to the data. If None, no augmentation is applied.

None
dataset_kws Optional[Dict[Literal['train', 'val', 'rollout_val', 'test', 'rollout_test'], Dict[str, Any]]]

Additional keyword arguments to pass to each dataset, as a dict of dicts.

None
storage_kwargs Optional[Dict]

Storage options passed to fsspec for accessing the raw data.

None
Source code in the_well/data/datamodule.py
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
class WellDataModule(AbstractDataModule):
    """Data module class to yield batches of samples.

    Args:
        well_base_path:
            Path to the data folder containing the splits (train, validation, and test).
        well_dataset_name:
            Name of the well dataset to use.
        batch_size:
            Size of the batches yielded by the dataloaders
        ---
        include_filters:
            Only file names containing any of these strings will be included.
        exclude_filters:
            File names containing any of these strings will be excluded.
        use_normalization:
            Whether to use normalization on the data. Currently only supports mean/std.
        max_rollout_steps:
            Maximum number of steps to use for the rollout dataset. Mostly for memory reasons.
        n_steps_input:
            Number of steps to use as input.
        n_steps_output:
            Number of steps to use as output.
        min_dt_stride:
            Minimum stride in time to use for the dataset.
        max_dt_stride:
            Maximum stride in time to use for the dataset. If this is greater than min, randomly choose between them.
                Note that this is unused for validation/test which uses "min_dt_stride" for both the min and max.
        world_size:
            Number of GPUs in use for distributed training.
        data_workers:
            Number of workers to use for data loading.
        rank:
            Rank of the current process in distributed training.
        transform:
            Augmentation to apply to the data. If None, no augmentation is applied.
        dataset_kws:
            Additional keyword arguments to pass to each dataset, as a dict of dicts.
        storage_kwargs:
            Storage options passed to fsspec for accessing the raw data.
    """

    def __init__(
        self,
        well_base_path: str,
        well_dataset_name: str,
        batch_size: int,
        include_filters: List[str] = [],
        exclude_filters: List[str] = [],
        use_normalization: bool = False,
        max_rollout_steps: int = 100,
        n_steps_input: int = 1,
        n_steps_output: int = 1,
        min_dt_stride: int = 1,
        max_dt_stride: int = 1,
        world_size: int = 1,
        data_workers: int = 4,
        rank: int = 1,
        transform: Optional[Augmentation] = None,
        dataset_kws: Optional[
            Dict[
                Literal["train", "val", "rollout_val", "test", "rollout_test"],
                Dict[str, Any],
            ]
        ] = None,
        storage_kwargs: Optional[Dict] = None,
    ):
        self.train_dataset = WellDataset(
            well_base_path=well_base_path,
            well_dataset_name=well_dataset_name,
            well_split_name="train",
            include_filters=include_filters,
            exclude_filters=exclude_filters,
            use_normalization=use_normalization,
            n_steps_input=n_steps_input,
            n_steps_output=n_steps_output,
            storage_options=storage_kwargs,
            min_dt_stride=min_dt_stride,
            max_dt_stride=max_dt_stride,
            transform=transform,
            **(
                dataset_kws["train"]
                if dataset_kws is not None and "train" in dataset_kws
                else {}
            ),
        )
        self.val_dataset = WellDataset(
            well_base_path=well_base_path,
            well_dataset_name=well_dataset_name,
            well_split_name="valid",
            include_filters=include_filters,
            exclude_filters=exclude_filters,
            use_normalization=use_normalization,
            n_steps_input=n_steps_input,
            n_steps_output=n_steps_output,
            storage_options=storage_kwargs,
            min_dt_stride=min_dt_stride,
            max_dt_stride=min_dt_stride,
            **(
                dataset_kws["val"]
                if dataset_kws is not None and "val" in dataset_kws
                else {}
            ),
        )
        self.rollout_val_dataset = WellDataset(
            well_base_path=well_base_path,
            well_dataset_name=well_dataset_name,
            well_split_name="valid",
            include_filters=include_filters,
            exclude_filters=exclude_filters,
            use_normalization=use_normalization,
            max_rollout_steps=max_rollout_steps,
            n_steps_input=n_steps_input,
            n_steps_output=n_steps_output,
            full_trajectory_mode=True,
            storage_options=storage_kwargs,
            min_dt_stride=min_dt_stride,
            max_dt_stride=min_dt_stride,
            **(
                dataset_kws["rollout_val"]
                if dataset_kws is not None and "rollout_val" in dataset_kws
                else {}
            ),
        )
        self.test_dataset = WellDataset(
            well_base_path=well_base_path,
            well_dataset_name=well_dataset_name,
            well_split_name="test",
            include_filters=include_filters,
            exclude_filters=exclude_filters,
            n_steps_input=n_steps_input,
            n_steps_output=n_steps_output,
            storage_options=storage_kwargs,
            min_dt_stride=min_dt_stride,
            max_dt_stride=min_dt_stride,
            **(
                dataset_kws["test"]
                if dataset_kws is not None and "test" in dataset_kws
                else {}
            ),
        )
        self.rollout_test_dataset = WellDataset(
            well_base_path=well_base_path,
            well_dataset_name=well_dataset_name,
            well_split_name="test",
            include_filters=include_filters,
            exclude_filters=exclude_filters,
            max_rollout_steps=max_rollout_steps,
            n_steps_input=n_steps_input,
            n_steps_output=n_steps_output,
            full_trajectory_mode=True,
            storage_options=storage_kwargs,
            min_dt_stride=min_dt_stride,
            max_dt_stride=min_dt_stride,
            **(
                dataset_kws["rollout_test"]
                if dataset_kws is not None and "rollout_test" in dataset_kws
                else {}
            ),
        )
        self.well_base_path = well_base_path
        self.well_dataset_name = well_dataset_name
        self.batch_size = batch_size
        self.world_size = world_size
        self.data_workers = data_workers
        self.rank = rank

    @property
    def is_distributed(self) -> bool:
        return self.world_size > 1

    def train_dataloader(self) -> DataLoader:
        """Generate a dataloader for training data.

        Returns:
            A dataloader
        """
        sampler = None
        if self.is_distributed:
            sampler = DistributedSampler(
                self.train_dataset,
                num_replicas=self.world_size,
                rank=self.rank,
                shuffle=True,
            )
            logger.debug(
                f"Use {sampler.__class__.__name__} "
                f"({self.rank}/{self.world_size}) for training data"
            )
        shuffle = sampler is None

        return DataLoader(
            self.train_dataset,
            num_workers=self.data_workers,
            pin_memory=True,
            batch_size=self.batch_size,
            shuffle=shuffle,
            drop_last=True,
            sampler=sampler,
        )

    def val_dataloader(self) -> DataLoader:
        """Generate a dataloader for validation data.

        Returns:
            A dataloader
        """
        sampler = None
        if self.is_distributed:
            sampler = DistributedSampler(
                self.val_dataset,
                num_replicas=self.world_size,
                rank=self.rank,
                shuffle=True,
            )
            logger.debug(
                f"Use {sampler.__class__.__name__} "
                f"({self.rank}/{self.world_size}) for validation data"
            )
        shuffle = sampler is None  # Most valid epochs are short
        return DataLoader(
            self.val_dataset,
            num_workers=self.data_workers,
            pin_memory=True,
            batch_size=self.batch_size,
            shuffle=shuffle,
            drop_last=True,
            sampler=sampler,
        )

    def rollout_val_dataloader(self) -> DataLoader:
        """Generate a dataloader for rollout validation data.

        Returns:
            A dataloader
        """
        sampler = None
        if self.is_distributed:
            sampler = DistributedSampler(
                self.rollout_val_dataset,
                num_replicas=self.world_size,
                rank=self.rank,
                shuffle=True,  # Since we're subsampling, don't want continuous
            )
            logger.debug(
                f"Use {sampler.__class__.__name__} "
                f"({self.rank}/{self.world_size}) for rollout validation data"
            )
        shuffle = sampler is None  # Most valid epochs are short
        return DataLoader(
            self.rollout_val_dataset,
            num_workers=self.data_workers,
            pin_memory=True,
            batch_size=1,
            shuffle=shuffle,  # Shuffling because most batches we take a small subsample
            drop_last=True,
            sampler=sampler,
        )

    def test_dataloader(self) -> DataLoader:
        """Generate a dataloader for test data.

        Returns:
            A dataloader
        """
        sampler = None
        if self.is_distributed:
            sampler = DistributedSampler(
                self.test_dataset,
                num_replicas=self.world_size,
                rank=self.rank,
                shuffle=False,
            )
            logger.debug(
                f"Use {sampler.__class__.__name__} "
                f"({self.rank}/{self.world_size}) for test data"
            )
        return DataLoader(
            self.test_dataset,
            num_workers=self.data_workers,
            pin_memory=True,
            batch_size=self.batch_size,
            shuffle=False,
            drop_last=True,
            sampler=sampler,
        )

    def rollout_test_dataloader(self) -> DataLoader:
        """Generate a dataloader for rollout test data.

        Returns:
            A dataloader
        """
        sampler = None
        if self.is_distributed:
            sampler = DistributedSampler(
                self.rollout_test_dataset,
                num_replicas=self.world_size,
                rank=self.rank,
                shuffle=False,
            )
            logger.debug(
                f"Use {sampler.__class__.__name__} "
                f"({self.rank}/{self.world_size}) for rollout test data"
            )
        return DataLoader(
            self.rollout_test_dataset,
            num_workers=self.data_workers,
            pin_memory=True,
            batch_size=1,  # min(self.batch_size, len(self.rollout_test_dataset)),
            shuffle=False,
            drop_last=True,
            sampler=sampler,
        )

    def __repr__(self) -> str:
        return f"<{self.__class__.__name__}: {self.well_dataset_name} on {self.well_base_path}>"
rollout_test_dataloader()

Generate a dataloader for rollout test data.

Returns:

Type Description
DataLoader

A dataloader

Source code in the_well/data/datamodule.py
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
def rollout_test_dataloader(self) -> DataLoader:
    """Generate a dataloader for rollout test data.

    Returns:
        A dataloader
    """
    sampler = None
    if self.is_distributed:
        sampler = DistributedSampler(
            self.rollout_test_dataset,
            num_replicas=self.world_size,
            rank=self.rank,
            shuffle=False,
        )
        logger.debug(
            f"Use {sampler.__class__.__name__} "
            f"({self.rank}/{self.world_size}) for rollout test data"
        )
    return DataLoader(
        self.rollout_test_dataset,
        num_workers=self.data_workers,
        pin_memory=True,
        batch_size=1,  # min(self.batch_size, len(self.rollout_test_dataset)),
        shuffle=False,
        drop_last=True,
        sampler=sampler,
    )
rollout_val_dataloader()

Generate a dataloader for rollout validation data.

Returns:

Type Description
DataLoader

A dataloader

Source code in the_well/data/datamodule.py
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
def rollout_val_dataloader(self) -> DataLoader:
    """Generate a dataloader for rollout validation data.

    Returns:
        A dataloader
    """
    sampler = None
    if self.is_distributed:
        sampler = DistributedSampler(
            self.rollout_val_dataset,
            num_replicas=self.world_size,
            rank=self.rank,
            shuffle=True,  # Since we're subsampling, don't want continuous
        )
        logger.debug(
            f"Use {sampler.__class__.__name__} "
            f"({self.rank}/{self.world_size}) for rollout validation data"
        )
    shuffle = sampler is None  # Most valid epochs are short
    return DataLoader(
        self.rollout_val_dataset,
        num_workers=self.data_workers,
        pin_memory=True,
        batch_size=1,
        shuffle=shuffle,  # Shuffling because most batches we take a small subsample
        drop_last=True,
        sampler=sampler,
    )
test_dataloader()

Generate a dataloader for test data.

Returns:

Type Description
DataLoader

A dataloader

Source code in the_well/data/datamodule.py
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
def test_dataloader(self) -> DataLoader:
    """Generate a dataloader for test data.

    Returns:
        A dataloader
    """
    sampler = None
    if self.is_distributed:
        sampler = DistributedSampler(
            self.test_dataset,
            num_replicas=self.world_size,
            rank=self.rank,
            shuffle=False,
        )
        logger.debug(
            f"Use {sampler.__class__.__name__} "
            f"({self.rank}/{self.world_size}) for test data"
        )
    return DataLoader(
        self.test_dataset,
        num_workers=self.data_workers,
        pin_memory=True,
        batch_size=self.batch_size,
        shuffle=False,
        drop_last=True,
        sampler=sampler,
    )
train_dataloader()

Generate a dataloader for training data.

Returns:

Type Description
DataLoader

A dataloader

Source code in the_well/data/datamodule.py
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
def train_dataloader(self) -> DataLoader:
    """Generate a dataloader for training data.

    Returns:
        A dataloader
    """
    sampler = None
    if self.is_distributed:
        sampler = DistributedSampler(
            self.train_dataset,
            num_replicas=self.world_size,
            rank=self.rank,
            shuffle=True,
        )
        logger.debug(
            f"Use {sampler.__class__.__name__} "
            f"({self.rank}/{self.world_size}) for training data"
        )
    shuffle = sampler is None

    return DataLoader(
        self.train_dataset,
        num_workers=self.data_workers,
        pin_memory=True,
        batch_size=self.batch_size,
        shuffle=shuffle,
        drop_last=True,
        sampler=sampler,
    )
val_dataloader()

Generate a dataloader for validation data.

Returns:

Type Description
DataLoader

A dataloader

Source code in the_well/data/datamodule.py
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
def val_dataloader(self) -> DataLoader:
    """Generate a dataloader for validation data.

    Returns:
        A dataloader
    """
    sampler = None
    if self.is_distributed:
        sampler = DistributedSampler(
            self.val_dataset,
            num_replicas=self.world_size,
            rank=self.rank,
            shuffle=True,
        )
        logger.debug(
            f"Use {sampler.__class__.__name__} "
            f"({self.rank}/{self.world_size}) for validation data"
        )
    shuffle = sampler is None  # Most valid epochs are short
    return DataLoader(
        self.val_dataset,
        num_workers=self.data_workers,
        pin_memory=True,
        batch_size=self.batch_size,
        shuffle=shuffle,
        drop_last=True,
        sampler=sampler,
    )

Metrics

The Well package implements a series of metrics to assess the performances of a trained model.

the_well.benchmark.metrics

LInfinity

Bases: Metric

Source code in the_well/benchmark/metrics/spatial.py
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
class LInfinity(Metric):
    @staticmethod
    def eval(
        x: torch.Tensor | np.ndarray,
        y: torch.Tensor | np.ndarray,
        meta: WellMetadata,
    ) -> torch.Tensor:
        """
        L-Infinity Norm

        Args:
            x: Input tensor.
            y: Target tensor.
            meta: Metadata for the dataset.

        Returns:
            L-Infinity norm between x and y.
        """
        spatial_dims = tuple(range(-meta.n_spatial_dims - 1, -1))
        return torch.max(
            torch.abs(x - y).flatten(start_dim=spatial_dims[0], end_dim=-2), dim=-2
        ).values
eval(x, y, meta) staticmethod

L-Infinity Norm

Parameters:

Name Type Description Default
x Tensor | ndarray

Input tensor.

required
y Tensor | ndarray

Target tensor.

required
meta WellMetadata

Metadata for the dataset.

required

Returns:

Type Description
Tensor

L-Infinity norm between x and y.

Source code in the_well/benchmark/metrics/spatial.py
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
@staticmethod
def eval(
    x: torch.Tensor | np.ndarray,
    y: torch.Tensor | np.ndarray,
    meta: WellMetadata,
) -> torch.Tensor:
    """
    L-Infinity Norm

    Args:
        x: Input tensor.
        y: Target tensor.
        meta: Metadata for the dataset.

    Returns:
        L-Infinity norm between x and y.
    """
    spatial_dims = tuple(range(-meta.n_spatial_dims - 1, -1))
    return torch.max(
        torch.abs(x - y).flatten(start_dim=spatial_dims[0], end_dim=-2), dim=-2
    ).values

MSE

Bases: Metric

Source code in the_well/benchmark/metrics/spatial.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class MSE(Metric):
    @staticmethod
    def eval(
        x: torch.Tensor | np.ndarray,
        y: torch.Tensor | np.ndarray,
        meta: WellMetadata,
    ) -> torch.Tensor:
        """
        Mean Squared Error

        Args:
            x: Input tensor.
            y: Target tensor.
            meta: Metadata for the dataset.

        Returns:
            Mean squared error between x and y.
        """
        n_spatial_dims = tuple(range(-meta.n_spatial_dims - 1, -1))
        return torch.mean((x - y) ** 2, dim=n_spatial_dims)
eval(x, y, meta) staticmethod

Mean Squared Error

Parameters:

Name Type Description Default
x Tensor | ndarray

Input tensor.

required
y Tensor | ndarray

Target tensor.

required
meta WellMetadata

Metadata for the dataset.

required

Returns:

Type Description
Tensor

Mean squared error between x and y.

Source code in the_well/benchmark/metrics/spatial.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
@staticmethod
def eval(
    x: torch.Tensor | np.ndarray,
    y: torch.Tensor | np.ndarray,
    meta: WellMetadata,
) -> torch.Tensor:
    """
    Mean Squared Error

    Args:
        x: Input tensor.
        y: Target tensor.
        meta: Metadata for the dataset.

    Returns:
        Mean squared error between x and y.
    """
    n_spatial_dims = tuple(range(-meta.n_spatial_dims - 1, -1))
    return torch.mean((x - y) ** 2, dim=n_spatial_dims)

NMSE

Bases: Metric

Source code in the_well/benchmark/metrics/spatial.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
class NMSE(Metric):
    @staticmethod
    def eval(
        x: torch.Tensor | np.ndarray,
        y: torch.Tensor | np.ndarray,
        meta: WellMetadata,
        eps: float = 1e-7,
        norm_mode: str = "norm",
    ) -> torch.Tensor:
        """
        Normalized Mean Squared Error

        Args:
            x: Input tensor.
            y: Target tensor.
            meta: Metadata for the dataset.
            eps: Small value to avoid division by zero. Default is 1e-7.
            norm_mode:
                Mode for computing the normalization factor. Can be 'norm' or 'std'. Default is 'norm'.

        Returns:
            Normalized mean squared error between x and y.
        """
        n_spatial_dims = tuple(range(-meta.n_spatial_dims - 1, -1))
        if norm_mode == "norm":
            norm = torch.mean(y**2, dim=n_spatial_dims)
        elif norm_mode == "std":
            norm = torch.std(y, dim=n_spatial_dims) ** 2
        else:
            raise ValueError(f"Invalid norm_mode: {norm_mode}")
        return MSE.eval(x, y, meta) / (norm + eps)
eval(x, y, meta, eps=1e-07, norm_mode='norm') staticmethod

Normalized Mean Squared Error

Parameters:

Name Type Description Default
x Tensor | ndarray

Input tensor.

required
y Tensor | ndarray

Target tensor.

required
meta WellMetadata

Metadata for the dataset.

required
eps float

Small value to avoid division by zero. Default is 1e-7.

1e-07
norm_mode str

Mode for computing the normalization factor. Can be 'norm' or 'std'. Default is 'norm'.

'norm'

Returns:

Type Description
Tensor

Normalized mean squared error between x and y.

Source code in the_well/benchmark/metrics/spatial.py
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
@staticmethod
def eval(
    x: torch.Tensor | np.ndarray,
    y: torch.Tensor | np.ndarray,
    meta: WellMetadata,
    eps: float = 1e-7,
    norm_mode: str = "norm",
) -> torch.Tensor:
    """
    Normalized Mean Squared Error

    Args:
        x: Input tensor.
        y: Target tensor.
        meta: Metadata for the dataset.
        eps: Small value to avoid division by zero. Default is 1e-7.
        norm_mode:
            Mode for computing the normalization factor. Can be 'norm' or 'std'. Default is 'norm'.

    Returns:
        Normalized mean squared error between x and y.
    """
    n_spatial_dims = tuple(range(-meta.n_spatial_dims - 1, -1))
    if norm_mode == "norm":
        norm = torch.mean(y**2, dim=n_spatial_dims)
    elif norm_mode == "std":
        norm = torch.std(y, dim=n_spatial_dims) ** 2
    else:
        raise ValueError(f"Invalid norm_mode: {norm_mode}")
    return MSE.eval(x, y, meta) / (norm + eps)

NRMSE

Bases: Metric

Source code in the_well/benchmark/metrics/spatial.py
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
class NRMSE(Metric):
    @staticmethod
    def eval(
        x: torch.Tensor | np.ndarray,
        y: torch.Tensor | np.ndarray,
        meta: WellMetadata,
        eps: float = 1e-7,
        norm_mode: str = "norm",
    ) -> torch.Tensor:
        """
        Normalized Root Mean Squared Error

        Args:
            x: Input tensor.
            y: Target tensor.
            meta: Metadata for the dataset.
            eps: Small value to avoid division by zero. Default is 1e-7.
            norm_mode : Mode for computing the normalization factor. Can be 'norm' or 'std'. Default is 'norm'.

        Returns:
            Normalized root mean squared error between x and y.

        """
        return torch.sqrt(NMSE.eval(x, y, meta, eps=eps, norm_mode=norm_mode))
eval(x, y, meta, eps=1e-07, norm_mode='norm') staticmethod

Normalized Root Mean Squared Error

Parameters:

Name Type Description Default
x Tensor | ndarray

Input tensor.

required
y Tensor | ndarray

Target tensor.

required
meta WellMetadata

Metadata for the dataset.

required
eps float

Small value to avoid division by zero. Default is 1e-7.

1e-07
norm_mode

Mode for computing the normalization factor. Can be 'norm' or 'std'. Default is 'norm'.

'norm'

Returns:

Type Description
Tensor

Normalized root mean squared error between x and y.

Source code in the_well/benchmark/metrics/spatial.py
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
@staticmethod
def eval(
    x: torch.Tensor | np.ndarray,
    y: torch.Tensor | np.ndarray,
    meta: WellMetadata,
    eps: float = 1e-7,
    norm_mode: str = "norm",
) -> torch.Tensor:
    """
    Normalized Root Mean Squared Error

    Args:
        x: Input tensor.
        y: Target tensor.
        meta: Metadata for the dataset.
        eps: Small value to avoid division by zero. Default is 1e-7.
        norm_mode : Mode for computing the normalization factor. Can be 'norm' or 'std'. Default is 'norm'.

    Returns:
        Normalized root mean squared error between x and y.

    """
    return torch.sqrt(NMSE.eval(x, y, meta, eps=eps, norm_mode=norm_mode))

RMSE

Bases: Metric

Source code in the_well/benchmark/metrics/spatial.py
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
class RMSE(Metric):
    @staticmethod
    def eval(
        x: torch.Tensor | np.ndarray,
        y: torch.Tensor | np.ndarray,
        meta: WellMetadata,
    ) -> torch.Tensor:
        """
        Root Mean Squared Error

        Args:
            x: torch.Tensor | np.ndarray
                Input tensor.
            y: torch.Tensor | np.ndarray
                Target tensor.
            meta: WellMetadata
                Metadata for the dataset.

        Returns:
            Root mean squared error between x and y.
        """
        return torch.sqrt(MSE.eval(x, y, meta))
eval(x, y, meta) staticmethod

Root Mean Squared Error

Parameters:

Name Type Description Default
x Tensor | ndarray

torch.Tensor | np.ndarray Input tensor.

required
y Tensor | ndarray

torch.Tensor | np.ndarray Target tensor.

required
meta WellMetadata

WellMetadata Metadata for the dataset.

required

Returns:

Type Description
Tensor

Root mean squared error between x and y.

Source code in the_well/benchmark/metrics/spatial.py
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
@staticmethod
def eval(
    x: torch.Tensor | np.ndarray,
    y: torch.Tensor | np.ndarray,
    meta: WellMetadata,
) -> torch.Tensor:
    """
    Root Mean Squared Error

    Args:
        x: torch.Tensor | np.ndarray
            Input tensor.
        y: torch.Tensor | np.ndarray
            Target tensor.
        meta: WellMetadata
            Metadata for the dataset.

    Returns:
        Root mean squared error between x and y.
    """
    return torch.sqrt(MSE.eval(x, y, meta))

VMSE

Bases: Metric

Source code in the_well/benchmark/metrics/spatial.py
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
class VMSE(Metric):
    @staticmethod
    def eval(
        x: torch.Tensor | np.ndarray,
        y: torch.Tensor | np.ndarray,
        meta: WellMetadata,
    ) -> torch.Tensor:
        """
        Variance Scaled Mean Squared Error

        Args:
            x: Input tensor.
            y: Target tensor.
            meta: Metadata for the dataset.

        Returns:
            Variance mean squared error between x and y.
        """
        return NMSE.eval(x, y, meta, norm_mode="std")
eval(x, y, meta) staticmethod

Variance Scaled Mean Squared Error

Parameters:

Name Type Description Default
x Tensor | ndarray

Input tensor.

required
y Tensor | ndarray

Target tensor.

required
meta WellMetadata

Metadata for the dataset.

required

Returns:

Type Description
Tensor

Variance mean squared error between x and y.

Source code in the_well/benchmark/metrics/spatial.py
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
@staticmethod
def eval(
    x: torch.Tensor | np.ndarray,
    y: torch.Tensor | np.ndarray,
    meta: WellMetadata,
) -> torch.Tensor:
    """
    Variance Scaled Mean Squared Error

    Args:
        x: Input tensor.
        y: Target tensor.
        meta: Metadata for the dataset.

    Returns:
        Variance mean squared error between x and y.
    """
    return NMSE.eval(x, y, meta, norm_mode="std")

VRMSE

Bases: Metric

Source code in the_well/benchmark/metrics/spatial.py
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
class VRMSE(Metric):
    @staticmethod
    def eval(
        x: torch.Tensor | np.ndarray,
        y: torch.Tensor | np.ndarray,
        meta: WellMetadata,
    ) -> torch.Tensor:
        """
        Root Variance Scaled Mean Squared Error

        Args:
            x: Input tensor.
            y: Target tensor.
            meta: Metadata for the dataset.

        Returns:
            Root variance mean squared error between x and y.
        """
        return NRMSE.eval(x, y, meta, norm_mode="std")
eval(x, y, meta) staticmethod

Root Variance Scaled Mean Squared Error

Parameters:

Name Type Description Default
x Tensor | ndarray

Input tensor.

required
y Tensor | ndarray

Target tensor.

required
meta WellMetadata

Metadata for the dataset.

required

Returns:

Type Description
Tensor

Root variance mean squared error between x and y.

Source code in the_well/benchmark/metrics/spatial.py
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
@staticmethod
def eval(
    x: torch.Tensor | np.ndarray,
    y: torch.Tensor | np.ndarray,
    meta: WellMetadata,
) -> torch.Tensor:
    """
    Root Variance Scaled Mean Squared Error

    Args:
        x: Input tensor.
        y: Target tensor.
        meta: Metadata for the dataset.

    Returns:
        Root variance mean squared error between x and y.
    """
    return NRMSE.eval(x, y, meta, norm_mode="std")

binned_spectral_mse

Bases: Metric

Source code in the_well/benchmark/metrics/spectral.py
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
class binned_spectral_mse(Metric):
    @staticmethod
    def eval(
        x: torch.Tensor,
        y: torch.Tensor,
        meta: WellMetadata,
        bins: torch.Tensor = None,
        fourier_input: bool = False,
    ) -> torch.Tensor:
        """
        Binned Spectral Mean Squared Error.
        Corresponds to MSE computed after filtering over wavenumber bins in the Fourier domain.

        Default binning is a set of three (approximately) logspaced from 0 to pi.

        Note that, MSE(x, y) should match the sum over frequency bins of the spectral MSE.

        Args:
            x: Input tensor.
            y: Target tensor.
            meta: Metadata for the dataset.
            bins:
                Tensor of bin edges. If None, we use a default binning that is a set of three (approximately) logspaced from 0 to pi. The default is None.
            fourier_input:
                If True, x and y are assumed to be the Fourier transform of the input data. The default is False.

        Returns:
            The power spectrum mean squared error between x and y.
        """
        spatial_dims = tuple(range(-meta.n_spatial_dims - 1, -1))
        spatial_shape = tuple(x.shape[dim] for dim in spatial_dims)
        prod_spatial_shape = np.prod(np.array(spatial_shape))
        ndims = meta.n_spatial_dims

        if bins is None:  # Default binning
            bins = torch.logspace(
                np.log10(2 * np.pi / max(spatial_shape)),
                np.log10(np.pi * np.sqrt(ndims) + 1e-6),
                4,
            ).to(x.device)  # Low, medium, and high frequency bins
            bins[0] = 0.0  # We start from zero
        _, ps_res_mean, _, counts = power_spectrum(
            x - y, meta, bins=bins, fourier_input=fourier_input, return_counts=True
        )

        # TODO - MAJOR DESIGN VIOLATION - BUT ITS FASTER TO IMPLEMENT THIS WAY TODAY...
        _, ps_true_mean, _, true_counts = power_spectrum(
            y, meta, bins=bins, fourier_input=fourier_input, return_counts=True
        )

        # Compute the mean squared error per bin (stems from Plancherel's formula)
        mse_per_bin = ps_res_mean * counts[:-1].unsqueeze(-1) / prod_spatial_shape**2
        true_energy_per_min = (
            ps_true_mean * true_counts[:-1].unsqueeze(-1) / prod_spatial_shape**2
        )
        nmse_per_bin = mse_per_bin / (true_energy_per_min + 1e-7)

        mse_dict = {
            f"spectral_error_mse_per_bin_{i}": mse_per_bin[..., i, :]
            for i in range(mse_per_bin.shape[-2])
        }
        nmse_dict = {
            f"spectral_error_nmse_per_bin_{i}": nmse_per_bin[..., i, :]
            for i in range(nmse_per_bin.shape[-2])
        }
        out_dict = mse_dict
        # Hacked to add this here for now - should be split with taking PS as input
        out_dict |= nmse_dict
        # TODO Figure out better way to handle multi-output losses
        return out_dict
eval(x, y, meta, bins=None, fourier_input=False) staticmethod

Binned Spectral Mean Squared Error. Corresponds to MSE computed after filtering over wavenumber bins in the Fourier domain.

Default binning is a set of three (approximately) logspaced from 0 to pi.

Note that, MSE(x, y) should match the sum over frequency bins of the spectral MSE.

Parameters:

Name Type Description Default
x Tensor

Input tensor.

required
y Tensor

Target tensor.

required
meta WellMetadata

Metadata for the dataset.

required
bins Tensor

Tensor of bin edges. If None, we use a default binning that is a set of three (approximately) logspaced from 0 to pi. The default is None.

None
fourier_input bool

If True, x and y are assumed to be the Fourier transform of the input data. The default is False.

False

Returns:

Type Description
Tensor

The power spectrum mean squared error between x and y.

Source code in the_well/benchmark/metrics/spectral.py
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
@staticmethod
def eval(
    x: torch.Tensor,
    y: torch.Tensor,
    meta: WellMetadata,
    bins: torch.Tensor = None,
    fourier_input: bool = False,
) -> torch.Tensor:
    """
    Binned Spectral Mean Squared Error.
    Corresponds to MSE computed after filtering over wavenumber bins in the Fourier domain.

    Default binning is a set of three (approximately) logspaced from 0 to pi.

    Note that, MSE(x, y) should match the sum over frequency bins of the spectral MSE.

    Args:
        x: Input tensor.
        y: Target tensor.
        meta: Metadata for the dataset.
        bins:
            Tensor of bin edges. If None, we use a default binning that is a set of three (approximately) logspaced from 0 to pi. The default is None.
        fourier_input:
            If True, x and y are assumed to be the Fourier transform of the input data. The default is False.

    Returns:
        The power spectrum mean squared error between x and y.
    """
    spatial_dims = tuple(range(-meta.n_spatial_dims - 1, -1))
    spatial_shape = tuple(x.shape[dim] for dim in spatial_dims)
    prod_spatial_shape = np.prod(np.array(spatial_shape))
    ndims = meta.n_spatial_dims

    if bins is None:  # Default binning
        bins = torch.logspace(
            np.log10(2 * np.pi / max(spatial_shape)),
            np.log10(np.pi * np.sqrt(ndims) + 1e-6),
            4,
        ).to(x.device)  # Low, medium, and high frequency bins
        bins[0] = 0.0  # We start from zero
    _, ps_res_mean, _, counts = power_spectrum(
        x - y, meta, bins=bins, fourier_input=fourier_input, return_counts=True
    )

    # TODO - MAJOR DESIGN VIOLATION - BUT ITS FASTER TO IMPLEMENT THIS WAY TODAY...
    _, ps_true_mean, _, true_counts = power_spectrum(
        y, meta, bins=bins, fourier_input=fourier_input, return_counts=True
    )

    # Compute the mean squared error per bin (stems from Plancherel's formula)
    mse_per_bin = ps_res_mean * counts[:-1].unsqueeze(-1) / prod_spatial_shape**2
    true_energy_per_min = (
        ps_true_mean * true_counts[:-1].unsqueeze(-1) / prod_spatial_shape**2
    )
    nmse_per_bin = mse_per_bin / (true_energy_per_min + 1e-7)

    mse_dict = {
        f"spectral_error_mse_per_bin_{i}": mse_per_bin[..., i, :]
        for i in range(mse_per_bin.shape[-2])
    }
    nmse_dict = {
        f"spectral_error_nmse_per_bin_{i}": nmse_per_bin[..., i, :]
        for i in range(nmse_per_bin.shape[-2])
    }
    out_dict = mse_dict
    # Hacked to add this here for now - should be split with taking PS as input
    out_dict |= nmse_dict
    # TODO Figure out better way to handle multi-output losses
    return out_dict