Batching wrapper for PyTorch Tensors
Ever written a function to process PyTorch tensors that already does batch processing, but you didn’t want to write a whole loop to go over the full dataset?
Me too! One fix is to change your function to process a single sample and use torch.vmap
. But your function might not be supported by torch.vmap
(e.g. it has lots of conditionals or in-place operations).
Another way is to write a wrapper function that processes your input in batches. I found myself writing such a wrapper many times so I prompted ChatGPT for a generic one.
import torch
from torch.utils.data import DataLoader, TensorDataset
def chunked_function_wrapper(function, chunk_size, device='cuda', **kwargs):
"""
Wraps a function to process inputs in chunks using TensorDataset and supports multiple tensors,
allowing both positional and named arguments while handling None values.
Args:
function (callable): The function to apply to the inputs.
chunk_size (int): The size of each chunk to process.
device (str): The device to process tensors on (default: 'cuda').
**kwargs: Additional keyword arguments for the function.
Returns:
callable: A wrapped function that processes inputs in chunks.
"""
def wrapped_function(*tensor_args, **tensor_kwargs):
# Filter out None values from tensor_kwargs
tensor_kwargs = {key: val for key, val in tensor_kwargs.items() if val is not None}
all_tensors = list(tensor_args) + list(tensor_kwargs.values())
# Ensure all tensors have the same size along the first dimension
if not all_tensors:
raise ValueError("At least one tensor input is required.")
num_samples = all_tensors[0].size(0)
for tensor in all_tensors:
if tensor.size(0) != num_samples:
raise ValueError("All input tensors must have the same first dimension size.")
# Create a TensorDataset and DataLoader
dataset = TensorDataset(*all_tensors)
dataloader = DataLoader(dataset, batch_size=chunk_size)
results = []
for batch in dataloader:
batch_args = batch[:len(tensor_args)]
batch_kwargs = {key: batch[i + len(tensor_args)] for i, key in enumerate(tensor_kwargs.keys())}
# Transfer the batch to the specified device
batch_args = [tensor.to(device) for tensor in batch_args]
batch_kwargs = {key: tensor.to(device) for key, tensor in batch_kwargs.items()}
# Apply the function to the batch
with torch.no_grad():
result = function(*batch_args, **batch_kwargs, **kwargs)
# Collect the results (move to CPU if needed)
if isinstance(result, torch.Tensor):
results.append(result.cpu())
elif isinstance(result, (list, tuple)):
if not results:
# Initialize lists for multiple outputs
results = [[] for _ in range(len(result))]
for i, res in enumerate(result):
results[i].append(res.cpu())
else:
raise ValueError("The function must return a Tensor or a tuple/list of Tensors.")
# Concatenate results along the first dimension
if isinstance(results[0], list):
return [torch.cat(res, dim=0) for res in results]
else:
return torch.cat(results, dim=0)
return wrapped_function
# Example usage
def example_function(tensor1, tensor2, tensor3=None):
if tensor3 is not None:
return tensor1 + tensor2 + tensor3, tensor1 * tensor2 * tensor3 # Example: sum and product
return tensor1 + tensor2, tensor1 * tensor2
# Wrap the function
chunk_size = 1024
wrapped_example_function = chunked_function_wrapper(example_function, chunk_size, device='cuda')
# Input tensors
inputs1 = torch.randn(10000, 128)
inputs2 = torch.randn(10000, 128)
inputs3 = None # Optional input
# Use the wrapped function with both positional and named arguments
outputs = wrapped_example_function(inputs1, inputs2, tensor3=inputs3)
# Results
output1, output2 = outputs
print(output1.shape) # Should match inputs1/inputs2 shape
print(output2.shape) # Should match inputs1/inputs2 shape
Enjoy Reading This Article?
Here are some more articles you might like to read next: