Ever written a function to process PyTorch tensors that already does batch processing, but you didn’t want to write a whole loop to go over the full dataset?

Me too! One fix is to change your function to process a single sample and use torch.vmap. But your function might not be supported by torch.vmap (e.g. it has lots of conditionals or in-place operations).

Another way is to write a wrapper function that processes your input in batches. I found myself writing such a wrapper many times so I prompted ChatGPT for a generic one.

import torch
from torch.utils.data import DataLoader, TensorDataset

def chunked_function_wrapper(function, chunk_size, device='cuda', **kwargs):
    """
    Wraps a function to process inputs in chunks using TensorDataset and supports multiple tensors,
    allowing both positional and named arguments while handling None values.

    Args:
        function (callable): The function to apply to the inputs.
        chunk_size (int): The size of each chunk to process.
        device (str): The device to process tensors on (default: 'cuda').
        **kwargs: Additional keyword arguments for the function.

    Returns:
        callable: A wrapped function that processes inputs in chunks.
    """
    def wrapped_function(*tensor_args, **tensor_kwargs):
        # Filter out None values from tensor_kwargs
        tensor_kwargs = {key: val for key, val in tensor_kwargs.items() if val is not None}
        
        all_tensors = list(tensor_args) + list(tensor_kwargs.values())

        # Ensure all tensors have the same size along the first dimension
        if not all_tensors:
            raise ValueError("At least one tensor input is required.")
        
        num_samples = all_tensors[0].size(0)
        for tensor in all_tensors:
            if tensor.size(0) != num_samples:
                raise ValueError("All input tensors must have the same first dimension size.")

        # Create a TensorDataset and DataLoader
        dataset = TensorDataset(*all_tensors)
        dataloader = DataLoader(dataset, batch_size=chunk_size)

        results = []

        for batch in dataloader:
            batch_args = batch[:len(tensor_args)]
            batch_kwargs = {key: batch[i + len(tensor_args)] for i, key in enumerate(tensor_kwargs.keys())}
            
            # Transfer the batch to the specified device
            batch_args = [tensor.to(device) for tensor in batch_args]
            batch_kwargs = {key: tensor.to(device) for key, tensor in batch_kwargs.items()}
            
            # Apply the function to the batch
            with torch.no_grad():
                result = function(*batch_args, **batch_kwargs, **kwargs)

            # Collect the results (move to CPU if needed)
            if isinstance(result, torch.Tensor):
                results.append(result.cpu())
            elif isinstance(result, (list, tuple)):
                if not results:
                    # Initialize lists for multiple outputs
                    results = [[] for _ in range(len(result))]
                for i, res in enumerate(result):
                    results[i].append(res.cpu())
            else:
                raise ValueError("The function must return a Tensor or a tuple/list of Tensors.")

        # Concatenate results along the first dimension
        if isinstance(results[0], list):
            return [torch.cat(res, dim=0) for res in results]
        else:
            return torch.cat(results, dim=0)

    return wrapped_function

# Example usage
def example_function(tensor1, tensor2, tensor3=None):
    if tensor3 is not None:
        return tensor1 + tensor2 + tensor3, tensor1 * tensor2 * tensor3  # Example: sum and product
    return tensor1 + tensor2, tensor1 * tensor2

# Wrap the function
chunk_size = 1024
wrapped_example_function = chunked_function_wrapper(example_function, chunk_size, device='cuda')

# Input tensors
inputs1 = torch.randn(10000, 128)
inputs2 = torch.randn(10000, 128)
inputs3 = None  # Optional input

# Use the wrapped function with both positional and named arguments
outputs = wrapped_example_function(inputs1, inputs2, tensor3=inputs3)

# Results
output1, output2 = outputs
print(output1.shape)  # Should match inputs1/inputs2 shape
print(output2.shape)  # Should match inputs1/inputs2 shape