- Notifications
You must be signed in to change notification settings - Fork26k
Writing memory format aware operators
Memory format aware operators are the operators which satisfy two requirements:
- they generate output in same memory format as inputs
- they use the most efficient kernels for each different memory formats
Let say we want to add/modifyoperator to supporttorch.channels_last memory format.
in_tensor=x.contiguous(memory_format=torch.channels_last)out_tensor=torch.operator(in_tensor)print(out_tensor.is_contiguous(memory_format=torch.channels_last))# True
To do so, we need to modify the operator's CPP code. An old version of operator might look similar to this:
auto output_tensor = at::empty_like(input_tensor);// .... standard kernel for contiguous or strided tensorsreturn output_tensor;
The preferred way of writing memory format aware operators is to use theswitch operator. This approach allows us to expand memory formats support in the future.
// ...auto memory_format = input_tensor.suggest_memory_format();auto output_tensor = at::empty(output_shape, memory_format);switch (memory_format) {case MemoryFormat::ChannelsLast: {auto input_cl_contiguous = input_tensor.contiguous( MemoryFormat::ChannelsLast);// if kernel requires memory dense// tensor// .... kernel codebreak; }case MemoryFormat::Contiguous: {// .... standard kernel for contiguous or strided tensorsbreak; }default:TORCH_CHECK(false,"Unsupported memory format. Supports only ChannelsLast, Contiguous");}// ...
Important to learn thatsuggest_memory_format is not similar toinput_tensor.is_contiguous(...), seefunction comments.
More memory format handling required when you are writing_out operator implementation.
in_tensor=x.contiguous(memory_format=torch.channels_last)out_tensor=o.contiguous(memory_format=torch.contiguous_format)torch.operator(in_tensor,out=out_tensor)print(out_tensor.is_contiguous(memory_format=torch.contiguous_format))# True
Keeping the memory format of the output is essential. However, some performant algorithms require matching formats of inputs and outputs. In this case, it is possible to do acopy_ trick.
Tensorself_or_new_memory_format(Tensor& self, MemoryFormat memory_format) {if (self.is_contiguous(memory_format)) {return self; }returnat::empty_like(self, self.options(), memory_format);}
// ...auto memory_format = input_tensor.suggest_memory_format();assert_no_internal_overlap(output);if (output_shape != output.sizes()) { output.resize_(output_shape, memory_format);}auto temporary_output_tensor = self_or_new_memory_format(output, memory_format);switch (memory_format) {case MemoryFormat::ChannelsLast: {auto input_cl_contiguous = input_tensor.contiguous( MemoryFormat::ChannelsLast);// if kernel requires memory dense// tensor// .... kernel codebreak; }case MemoryFormat::Contiguous: {// .... standard kernelbreak; }default:TORCH_CHECK(false,"Unsupported memory format. Supports only ChannelsLast, Contiguous");}if (!output.is_same(temporary_output_tensor)) { output.copy_(temporary_output_tensor);}// ...
In some cases, there is no performant algorithm for contiguous or channels last inputs, so the same trick with temporary tensors andcopy_ can be applied.
// ...auto memory_format = input_tensor.suggest_memory_format();assert_no_internal_overlap(output);if (output_shape != output.sizes()) { output.resize_(output_shape, memory_format);}auto temporary_output_tensor = self_or_new_memory_format(output, MemoryFormat::ChannelsLast);auto input_cl_contiguous = input_tensor.contiguous(MemoryFormat::ChannelsLast);// .... channels last kernel codeif (!output.is_same(temporary_output_tensor)) { output.copy_(temporary_output_tensor);}// ...
Or you can do hard exit with unsupported memory format message (this is least preferred way, and we consider such operators incomplete).
// ...switch (memory_format) {case MemoryFormat::ChannelsLast: {auto input_cl_contiguous = input_tensor.contiguous( MemoryFormat::ChannelsLast);// if kernel requires memory dense// tensor// .... kernel codebreak; }case MemoryFormat::Contiguous:default:TORCH_CHECK(false,"Unsupported memory format. Supports only ChannelsLast");}// ...
Please do not forget to cover all scenarios with unit tests. We had seen countless cases when simple test saved hours of debugging.
I would love to contribute to PyTorch!