Rate this Page

torch.segment_reduce#

torch.segment_reduce(data:Tensor,reduce:str,*,lengths:Tensor|None=None,indices:Tensor|None=None,offsets:Tensor|None=None,axis:_int=0,unsafe:_bool=False,initial:Number|_complex|None=None)Tensor#

Perform a segment reduction operation on the input tensor along the specified axis.

Parameters
  • data (Tensor) – The input tensor on which the segment reduction operation will be performed.

  • reduce (str) – The type of reduction operation. Supported values aresum,mean,max,min,prod.

Keyword Arguments
  • lengths (Tensor,optional) – Length of each segment. Default:None.

  • offsets (Tensor,optional) – Offset of each segment. Default:None.

  • axis (int,optional) – The axis perform reduction. Default:0.

  • unsafe (bool,optional) – Skip validation IfTrue. Default:False.

  • initial (Number,optional) – The initial value for the reduction operation. Default:None.

Example:

>>>data=torch.tensor([[1,2,3,4],[5,6,7,8],[9,10,11,12]],dtype=torch.float32,device='cuda')>>>lengths=torch.tensor([2,1],device='cuda')>>>torch.segment_reduce(data,'max',lengths=lengths)tensor([[ 5.,  6.,  7.,  8.],        [ 9., 10., 11., 12.]], device='cuda:0')