Reputation: 31
I am working on a small algorithm that takes as input a 3D list of data points (Batch, num_points, (x, y)), and calculates a 3D tensor of given resolution (Batch, res, res) with the minimum distance from each pixel to nearest data point.
I have it working, but it is meant to be able to process a resolution of up to 65536 (with 4 data points and 3 batches). I run into RAM usage problems.
I have implemented chunking, where I split the large tensor into chunks of a given size, then calculate that tensor, and store it into a H5PY file to free up RAM. This does work, but is still very slow when trying with a resolution of 65536 (with 3 batches my RAM seems to crash if I try chunk size of much more than 4096, meaning my algorithm has to compute 256 individual chunks).
Here is my code so far:
import torch
import numpy as np
import matplotlib.pyplot as plt
import h5py
def create_points(batch_size, num_points):
coords = np.random.rand(batch_size, num_points, 2)
return coords
def min_dist(points, res, chunk_size=4096, filename='min_dist.h5'):
data_coords = torch.tensor(points, dtype=torch.float32) * (res - 1)
B, P, _ = data_coords.shape
# Calculate number of chunks
num_chunks = (res + chunk_size - 1) // chunk_size
with h5py.File(filename, 'w') as hf:
min_dist_tensor_full = \
hf.create_dataset(
'min_dist',
(B, res, res),
dtype='f',
chunks=(B, chunk_size, chunk_size)
)
# Initialize or load the final tensor
for i in range(num_chunks):
for j in range(num_chunks):
x_start = i * chunk_size
x_end = min((i + 1) * chunk_size, res)
y_start = j * chunk_size
y_end = min((j + 1) * chunk_size, res)
grid_x, grid_y = torch.meshgrid(
torch.arange(x_start, x_end, dtype=torch.float32),
torch.arange(y_start, y_end, dtype=torch.float32),
indexing='xy'
)
# grid_x = grid_x.flatten()
# grid_y = grid_y.flatten()
grid_coords = torch.stack([grid_x.flatten(), grid_y.flatten()], dim=-1)
chunk_size = x_end - x_start
grid_coords_batch = grid_coords.unsqueeze(0).expand(B, -1, -1)
dists = torch.cdist(grid_coords_batch, data_coords)
min_dists, _ = dists.min(dim=2)
min_dist_chunk = min_dists.view(B, chunk_size, chunk_size)
# min_dist_tensor_full[:, y_start:y_end, x_start:x_end] = min_dist_chunk
min_dist_tensor_full[:, y_start:y_end, x_start:x_end] = min_dist_chunk.numpy()
print(f"Processed chunk ({i}, {j}): ({x_start}, {x_end}), ({y_start}, {y_end})")
return filename
points = create_points(3, 4)
file_dist_data = min_dist(points, 65536)
batch_size = points.shape[0]
with h5py.File(file_dist_data, 'r') as f:
min_dist_tensor_full = f['min_dist']
for i in range(batch_size):
plt.figure()
plt.imshow(min_dist_tensor_full[i], cmap='viridis')
plt.colorbar()
plt.title(f'Batch {i + 1}')
plt.show()
Upvotes: 0
Views: 25