torch.segment_reduce#
- torch.segment_reduce(data:Tensor,reduce:str,*,lengths:Tensor|None=None,indices:Tensor|None=None,offsets:Tensor|None=None,axis:_int=0,unsafe:_bool=False,initial:Number|_complex|None=None)→Tensor#
Perform a segment reduction operation on the input tensor along the specified axis.
- Parameters
- Keyword Arguments
lengths (Tensor,optional) – Length of each segment. Default:
None.offsets (Tensor,optional) – Offset of each segment. Default:
None.axis (int,optional) – The axis perform reduction. Default:
0.unsafe (bool,optional) – Skip validation IfTrue. Default:
False.initial (Number,optional) – The initial value for the reduction operation. Default:
None.
Example:
>>>data=torch.tensor([[1,2,3,4],[5,6,7,8],[9,10,11,12]],dtype=torch.float32,device='cuda')>>>lengths=torch.tensor([2,1],device='cuda')>>>torch.segment_reduce(data,'max',lengths=lengths)tensor([[ 5., 6., 7., 8.], [ 9., 10., 11., 12.]], device='cuda:0')