donkey
donkey

Reputation: 1448

Efficiently generate cross validation dataset with Xarray

I'd like to use a statistical tool which requires an Xarray as input. My dataset is quite large so I typically use polars for memory and compute efficiency, thus I am not familiar with the optimal way to do things in Xarray.

Essentially I've designed a function which takes in an Xarray Dataset. My dataset has a replicate dimension which is labelled (e.g. A, B, C). I'd like my function to generate a permutation of the dataset using different replicates for cross validation purposes.

To clarify, the idea here is that each treatment exists in each replicate, and the cross validation is to take all dataset combinations of treatments from different replicates, so if treatment is the index, and replicate is the letter the resulting datasets could look like (A, B, C); (B, A, C); (C, A, B)...

Minimal example:

import pandas as pd
import numpy as np
import xarray as xr

# Coordinates
treatments = ['bottom', 'middle', 'top']
replicates = ['A', 'B', 'C']
positions = ['12011', '182060', '24920', '32590', '42600', '51650', '60540', '68420', '69999', '70210']
variable_types = ['type1', 'type2', 'type3', 'type4']

samples = {
    ('bottom', 'A'): 'b03', ('bottom', 'B'): 'b04', ('bottom', 'C'): 'b05',
    ('middle', 'A'): 'b06', ('middle', 'B'): 'b07', ('middle', 'C'): 'b08',
    ('top', 'A'): 'b09', ('top', 'B'): 'b10', ('top', 'C'): 'b11'
}

# Create a DataFrame with all combinations of treatment, replicate, position, and variable_type
data_df = pd.DataFrame(
    [(treatment, replicate, position, variable_type)
     for treatment in treatments
     for replicate in replicates
     for position in positions
     for variable_type in variable_types],
    columns=['treatment', 'replicate', 'position', 'variable_type']
)

# Add a column for the value (Abundance) with random values
np.random.seed(42)
data_df['value'] = np.random.rand(len(data_df))

# Add a column for the sample based on treatment and replicate
data_df['sample'] = data_df.apply(lambda row: samples[(row['treatment'], row['replicate'])], axis=1)

# Create the xarray Dataset using the provided structure
xr_abs = xr.Dataset(dict(
    Abundance=xr.DataArray.from_series(
        data_df[["treatment", "replicate", "position", "variable_type", "value"]]
        .set_index(["treatment", "replicate", "position", "variable_type"])["value"]),
    Sample=xr.DataArray.from_series(
        data_df[["treatment", "replicate", "sample"]]
        .drop_duplicates()
        .set_index(["treatment", "replicate"])["sample"])
))

My implementation:

def shuffle_dataset(dataset, shuffle_seed, replicate_col, group_cols):
    # pull out DataArrays as Series
    abun_df = dataset.Abundance.to_series().reset_index().rename(columns={replicate_col: 'OldReplicate'})
    sample_df = dataset.Sample.to_series().reset_index().rename(columns={replicate_col: 'OldReplicate'})

    # Get dataset and data array structures to recreate it later
    data_arrays = [(n, d.dims) for n, d in dataset.data_vars.items()]
    data_coords = dataset.coords
    del dataset

    # Make new "GroupName" field in Series 
    sample_df['GroupName'] = sample_df[group_cols].agg('_'.join, axis=1)
    sample_df = sample_df.sort_values('GroupName')

    # Automatically generate the replicate map: {0: 'A', 1: 'B', ...}
    unique_replicates = sorted(sample_df["OldReplicate"].unique())

    # use generate_replicate_labels() to get new replicate labels
    new_labels = generate_replicate_labels(
        sample_names=sample_df['GroupName'].to_numpy(),
        random_state=shuffle_seed,
        replicate_map=dict(enumerate(unique_replicates))
    )

    sample_df[replicate_col] = new_labels

    # map new replicate labels onto abundance data
    shuffle_abun_df = pd.merge(left=abun_df, right=sample_df, how='left', on=group_cols + ['OldReplicate'])
    del abun_df, sample_df

    # Rebuild the dataset with shuffled replicate labels
    restored_data_vars = {}

    for var_name, original_dims in data_arrays:
        restored_data_vars[var_name] = xr.DataArray().from_series(
            shuffle_abun_df[original_dims + [var_name]].drop_duplicates().set_index(original_dims)[var_name])

    del shuffle_abun_df
    
    # 4. Create the new dataset with restored DataArrays
    shuffle_ds = xr.Dataset(restored_data_vars, coords=data_coords)

    # Save random seed used for shuffling as dataset attribute
    shuffle_ds.attrs['shuffle_seed'] = shuffle_seed

    return shuffle_ds 


def generate_replicate_labels(sample_names, random_state=None, replicate_map=None):
    """Generates random replicate labels to align with an input vector of sample names.

    Parameters
    ----------
    sample_names : np.ndarray
        Input array of sample names. Must be sorted in ascending order.
    random_state : {None, int, numpy.random.RandomState}
        Default is None.
    replicate_map : {None, dict}
        Map of integer labels to preferred replicate labels.
        Example:
            {0:'A', 1:'B', 2:'C'}
        Default is None.

    Returns
    -------
    replicate_labels : np.ndarray
        Array of randomly generated replicate lables, to be aligned with input `sample_names`.

    """

    # check that input is a numpy array, and sample_names are sorted
    assert type(sample_names) is np.ndarray, "`sample_names` must be a numpy.ndarray"
    assert np.all(sample_names[:-1] <= sample_names[1:]), "`sample_names` must be sorted in ascending order"

    # Get counts of each sample name
    names, counts = np.unique(sample_names, return_counts=True)
    rns = check_random_state(random_state)

    # Generate replicate labels
    replicate_labels = [rns.choice(np.arange(counts.max()), size=c, replace=False) for c in counts]
    replicate_labels = np.concatenate(replicate_labels)

    # Map preferred replicate labels
    if replicate_map is not None:
        mapped_replicate_labels = [replicate_map[i] for i in replicate_labels]
        replicate_labels = np.array(mapped_replicate_labels)

    return replicate_labels

My issue is that this is quite slow and memory expensive. The latter is the my primary issue as the software runs into an OOM error despite executing on a server with large RAM available.

Perhaps there is an efficient way to do the while in polars and then make a dataset? While here I directly made a pandas dataframe, in my code I turn a polars dataframe into a pandas dataframe into an Xarray Dataset.

Upvotes: 0

Views: 44

Answers (0)

Related Questions