@@ -87,53 +87,75 @@ def test_basic_unet_gradient_flow(sample_input_3d):
8787
8888
8989def test_volshape_to_ndgrid_ij_indexing_2d ():
90- """Test volshape_to_ndgrid produces grids with ij indexing."""
90+ """
91+ Test volshape_to_ndgrid produces grids with ij indexing.
92+
93+ With (ndim, *spatial) format, grid[0] is the first coordinate channel (row index),
94+ and grid[1] is the second coordinate channel (col index).
95+ """
9196shape = (64 ,64 )
92- grid = ne .volshape_to_ndgrid (shape ,indexing = "ij" ,stack = True ). unsqueeze ( 0 )
97+ grid = ne .volshape_to_ndgrid (shape ,indexing = "ij" ,stack = True )
9398
94- assert grid .shape == (1 ,* shape , 2 )
99+ assert grid .shape == (2 ,* shape )
95100
96- # ij indexing: first coord varies with first index
97- assert grid [0 ,0 ,0 , 0 ]< grid [0 ,1 ,0 , 0 ]# first coord increases with i
98- assert grid [0 ,0 ,0 , 1 ]== grid [0 ,1 , 0 ,1 ]#second coord constant alongi
99- assert grid [0 ,0 ,0 , 0 ]== grid [0 , 0 , 1 ,0 ]#first coord constant alongj
100- assert grid [0 ,0 ,0 , 1 ]< grid [0 ,0 ,1 , 1 ]# second coord increases with j
101+ # ij indexing: first coord varies with firstspatial index (row)
102+ assert grid [0 ,0 ,0 ]< grid [0 ,1 ,0 ]# first coord increases with i (row)
103+ assert grid [0 ,0 ,0 ]== grid [0 ,0 ,1 ]#first coord constant alongj (col)
104+ assert grid [1 ,0 ,0 ]== grid [1 , 1 ,0 ]#second coord constant alongi (row)
105+ assert grid [1 ,0 ,0 ]< grid [1 ,0 ,1 ]# second coord increases with j (col)
101106
102107
103108def test_volshape_to_ndgrid_xy_indexing_2d ():
104- """Test volshape_to_ndgrid produces grids with xy indexing."""
109+ """
110+ Test volshape_to_ndgrid produces grids with xy indexing.
111+
112+ With (ndim, *spatial) format and xy indexing, grid[0] is x (varies with col),
113+ and grid[1] is y (varies with row).
114+ """
105115shape = (64 ,64 )
106- grid = ne .volshape_to_ndgrid (shape ,indexing = "xy" ,stack = True ).unsqueeze (0 )
116+ grid = ne .volshape_to_ndgrid (shape ,indexing = "xy" ,stack = True )
117+
118+ assert grid .shape == (2 ,* shape )
107119
108- # xy indexing: first coord (x) varies with second index (columns)
109- assert grid [0 ,0 ,0 , 0 ]== grid [0 ,1 , 0 ,0 ]# x constant along rows (i)
110- assert grid [0 ,0 ,0 , 0 ]< grid [0 ,0 ,1 , 0 ]# x increases along columns (j)
111- assert grid [0 ,0 ,0 , 1 ]< grid [0 ,1 ,0 , 1 ]# y increases along rows (i)
112- assert grid [0 ,0 ,0 , 1 ]== grid [0 ,0 , 1 ,1 ]# y constant along columns (j)
120+ # xy indexing: first coord (x) varies with secondspatial index (columns)
121+ assert grid [0 ,0 ,0 ]== grid [0 ,1 ,0 ]# x constant along rows (i)
122+ assert grid [0 ,0 ,0 ]< grid [0 ,0 ,1 ]# x increases along columns (j)
123+ assert grid [1 ,0 ,0 ]< grid [1 ,1 ,0 ]# y increases along rows (i)
124+ assert grid [1 ,0 ,0 ]== grid [1 ,0 ,1 ]# y constant along columns (j)
113125
114126
115127def test_volshape_to_ndgrid_ij_indexing_3d ():
116- """Test volshape_to_ndgrid produces grids with ij indexing for 3D."""
128+ """
129+ Test volshape_to_ndgrid produces grids with ij indexing for 3D.
130+
131+ With (ndim, *spatial) format, grid[d] is the coordinate channel for dimension d.
132+ """
117133shape = (32 ,32 ,32 )
118- grid = ne .volshape_to_ndgrid (shape ,indexing = "ij" ,stack = True ). unsqueeze ( 0 )
134+ grid = ne .volshape_to_ndgrid (shape ,indexing = "ij" ,stack = True )
119135
120- assert grid .shape == (1 ,* shape , 3 )
136+ assert grid .shape == (3 ,* shape )
121137
122138# ij indexing: coords align with indices
123- assert grid [0 ,0 ,0 ,0 , 0 ]< grid [0 ,1 ,0 ,0 , 0 ]
124- assert grid [0 ,0 ,0 ,0 , 1 ]< grid [0 ,0 ,1 ,0 , 1 ]
125- assert grid [0 ,0 ,0 ,0 , 2 ]< grid [0 ,0 ,0 ,1 , 2 ]
139+ assert grid [0 ,0 ,0 ,0 ]< grid [0 ,1 ,0 ,0 ] # coord 0 increases with dim 0
140+ assert grid [1 ,0 ,0 ,0 ]< grid [1 ,0 ,1 ,0 ] # coord 1 increases with dim 1
141+ assert grid [2 ,0 ,0 ,0 ]< grid [2 ,0 ,0 ,1 ] # coord 2 increases with dim 2
126142
127143
128144def test_volshape_to_ndgrid_xy_indexing_3d ():
129- """Test volshape_to_ndgrid produces grids with xy indexing for 3D."""
145+ """
146+ Test volshape_to_ndgrid produces grids with xy indexing for 3D.
147+
148+ With (ndim, *spatial) format and xy indexing, coordinates are reordered.
149+ """
130150shape = (32 ,32 ,32 )
131- grid = ne .volshape_to_ndgrid (shape ,indexing = "xy" ,stack = True ).unsqueeze (0 )
151+ grid = ne .volshape_to_ndgrid (shape ,indexing = "xy" ,stack = True )
152+
153+ assert grid .shape == (3 ,* shape )
132154
133- # xy indexing:x varies with j, y varies with i
134- assert grid [0 ,0 ,0 ,0 , 0 ]== grid [0 ,1 ,0 ,0 , 0 ]# x constant alongi
135- assert grid [0 ,0 ,0 ,0 , 0 ]< grid [0 ,0 ,1 ,0 , 0 ]# x increases withj
136- assert grid [0 ,0 ,0 ,0 , 1 ]< grid [0 ,1 ,0 ,0 , 1 ]# y increases withi
155+ # xy indexing:coords are reordered relative to spatial dims
156+ assert grid [0 ,0 ,0 ,0 ]== grid [0 ,1 ,0 ,0 ]# x constant alongdim 0
157+ assert grid [0 ,0 ,0 ,0 ]< grid [0 ,0 ,1 ,0 ]# x increases withdim 1
158+ assert grid [1 ,0 ,0 ,0 ]< grid [1 ,1 ,0 ,0 ]# y increases withdim 0
137159
138160
139161def test_conv_block_forward_2d (shape_2d ):