Source code for shmpipeline.kernels.gpu.custom_operation

"""GPU custom operation kernel backed by a restricted Python expression."""

from __future__ import annotations

from typing import Any, Mapping

import torch

from shmpipeline.config import KernelConfig, SharedMemoryConfig
from shmpipeline.errors import ConfigValidationError
from shmpipeline.kernel import KernelContext
from shmpipeline.kernels.cpu._expression import compile_custom_operation
from shmpipeline.kernels.gpu.base import (
    GpuKernel,
    as_gpu_tensor,
    torch_dtype_from_numpy,
)


[docs] class CustomOperationGpuKernel(GpuKernel): """Fuse a restricted arithmetic expression into one GPU kernel.""" kind = "gpu.custom_operation" auxiliary_arity = None
[docs] @classmethod def validate_config( cls, config: KernelConfig, shared_memory: Mapping[str, SharedMemoryConfig], ) -> None: super().validate_config(config, shared_memory) if config.operation is None: raise ConfigValidationError( f"kernel {config.name!r} of kind {cls.kind!r} requires an 'operation' field" ) if "input" in config.auxiliary_by_alias: raise ConfigValidationError( f"kernel {config.name!r} cannot bind auxiliary alias 'input'" ) input_spec = shared_memory[config.input] output_spec = shared_memory[config.output] compile_custom_operation( expression=config.operation, input_shape=input_spec.shape, input_dtype=input_spec.dtype, auxiliary_specs={ binding.alias: ( shared_memory[binding.name].shape, shared_memory[binding.name].dtype, ) for binding in config.auxiliary }, output_shape=output_spec.shape, output_dtype=output_spec.dtype, kernel_name=config.name, )
def __init__(self, context: KernelContext) -> None: super().__init__(context) input_spec = context.trigger_input_spec output_spec = context.output_spec self.plan = compile_custom_operation( expression=context.config.operation or "", input_shape=input_spec.shape, input_dtype=input_spec.dtype, auxiliary_specs={ binding.alias: ( context.shared_memory[binding.name].shape, context.shared_memory[binding.name].dtype, ) for binding in context.config.auxiliary }, output_shape=output_spec.shape, output_dtype=output_spec.dtype, kernel_name=context.config.name, ) self.temporaries = tuple( torch.empty( shape, dtype=torch_dtype_from_numpy(dtype), device=self.device ) for shape, dtype in zip( self.plan.temp_shapes, self.plan.temp_dtypes ) )
[docs] def compute_into( self, trigger_input: Any, output: Any, auxiliary_inputs: Mapping[str, Any], ) -> None: values = {"input": as_gpu_tensor(trigger_input, device=self.device)} values.update( { name: as_gpu_tensor(value, device=self.device) for name, value in auxiliary_inputs.items() } ) for instruction in self.plan.instructions: destination = self._resolve_destination( instruction.destination, output ) operands = [ self._resolve_operand(operand, values, output) for operand in instruction.operands ] if instruction.operation == "copy": destination.copy_(operands[0]) elif instruction.operation == "neg": torch.neg(operands[0], out=destination) elif instruction.operation == "pos": destination.copy_(operands[0]) elif instruction.operation == "add": torch.add(operands[0], operands[1], out=destination) elif instruction.operation == "sub": torch.sub(operands[0], operands[1], out=destination) elif instruction.operation == "mul": torch.mul(operands[0], operands[1], out=destination) elif instruction.operation == "div": torch.div(operands[0], operands[1], out=destination) elif instruction.operation == "matmul": torch.matmul(operands[0], operands[1], out=destination) elif instruction.operation == "abs": torch.abs(operands[0], out=destination) elif instruction.operation == "minimum": torch.minimum(operands[0], operands[1], out=destination) elif instruction.operation == "maximum": torch.maximum(operands[0], operands[1], out=destination) elif instruction.operation == "clip": torch.clamp( operands[0], min=operands[1], max=operands[2], out=destination, ) else: raise RuntimeError( f"unknown operation {instruction.operation!r}" ) torch.cuda.synchronize(output.device)
def _resolve_destination( self, operand, output: torch.Tensor ) -> torch.Tensor: if operand.kind == "output": return output if operand.kind == "temp": return self.temporaries[operand.value] raise RuntimeError( f"invalid destination operand kind: {operand.kind!r}" ) def _resolve_operand(self, operand, values, output: torch.Tensor): if operand.kind == "value": return values[operand.value] if operand.kind == "temp": return self.temporaries[operand.value] if operand.kind == "output": return output if operand.kind == "const": return operand.value raise RuntimeError(f"invalid operand kind: {operand.kind!r}")