J. Sav
J. Sav

Reputation: 31

Reducing RAM usage when dealing with very large Tensor

I am working on a small algorithm that takes as input a 3D list of data points (Batch, num_points, (x, y)), and calculates a 3D tensor of given resolution (Batch, res, res) with the minimum distance from each pixel to nearest data point.

I have it working, but it is meant to be able to process a resolution of up to 65536 (with 4 data points and 3 batches). I run into RAM usage problems.

I have implemented chunking, where I split the large tensor into chunks of a given size, then calculate that tensor, and store it into a H5PY file to free up RAM. This does work, but is still very slow when trying with a resolution of 65536 (with 3 batches my RAM seems to crash if I try chunk size of much more than 4096, meaning my algorithm has to compute 256 individual chunks).

Here is my code so far:

import torch
import numpy as np
import matplotlib.pyplot as plt
import h5py

def create_points(batch_size, num_points):
  coords = np.random.rand(batch_size, num_points, 2)
  return coords


def min_dist(points, res, chunk_size=4096, filename='min_dist.h5'):

  data_coords = torch.tensor(points, dtype=torch.float32) * (res - 1)

  B, P, _ = data_coords.shape

  # Calculate number of chunks
  num_chunks = (res + chunk_size - 1) // chunk_size  

  with h5py.File(filename, 'w') as hf:
    min_dist_tensor_full = \
    hf.create_dataset(
        'min_dist', 
        (B, res, res), 
        dtype='f',
        chunks=(B, chunk_size, chunk_size)
    )

  

  # Initialize or load the final tensor

    for i in range(num_chunks):
      for j in range(num_chunks):
        x_start = i * chunk_size
        x_end = min((i + 1) * chunk_size, res)
        y_start = j * chunk_size
        y_end = min((j + 1) * chunk_size, res)

        grid_x, grid_y = torch.meshgrid(
            torch.arange(x_start, x_end, dtype=torch.float32),
            torch.arange(y_start, y_end, dtype=torch.float32),
            indexing='xy'
        )

        # grid_x = grid_x.flatten()
        # grid_y = grid_y.flatten()
        grid_coords = torch.stack([grid_x.flatten(), grid_y.flatten()], dim=-1)



        chunk_size = x_end - x_start

        grid_coords_batch = grid_coords.unsqueeze(0).expand(B, -1, -1)

        dists = torch.cdist(grid_coords_batch, data_coords)

        min_dists, _ = dists.min(dim=2)

        min_dist_chunk = min_dists.view(B, chunk_size, chunk_size)

        # min_dist_tensor_full[:, y_start:y_end, x_start:x_end] = min_dist_chunk
        min_dist_tensor_full[:, y_start:y_end, x_start:x_end] = min_dist_chunk.numpy()

        print(f"Processed chunk ({i}, {j}): ({x_start}, {x_end}), ({y_start}, {y_end})")


  return filename

points = create_points(3, 4)
file_dist_data = min_dist(points, 65536)

batch_size = points.shape[0]

with h5py.File(file_dist_data, 'r') as f:
    min_dist_tensor_full = f['min_dist']
    for i in range(batch_size):
        plt.figure()
        plt.imshow(min_dist_tensor_full[i], cmap='viridis')
        plt.colorbar()
        plt.title(f'Batch {i + 1}')
        plt.show()

Upvotes: 0

Views: 25

Answers (0)

Related Questions