Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit2a9ba60

Browse files
committed
FEAT: Add method squeeze_into
This method can squeeze into a particular dimensionality.Squeezing means removing axes of length 1. When squeezing to aparticular dimensionality, we may have to still pad out the shape withextra 1-shape axes to fill the dimensionality.
1 parentf31add8 commit2a9ba60

File tree

1 file changed

+134
-9
lines changed

1 file changed

+134
-9
lines changed

‎src/dimension/mod.rs

Lines changed: 134 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -785,37 +785,77 @@ where
785785
}
786786

787787
/// Remove axes with length one, except never removing the last axis.
788+
///
789+
/// This function is a no-op for const dim.
788790
pub(crate)fnsqueeze<D>(dim:&mutD,strides:&mutD)
789791
where
790792
D:Dimension,
791793
{
792794
ifletSome(_) =D::NDIM{
793795
return;
794796
}
797+
798+
// infallible for dyn dim
799+
let(d, s) =squeeze_into(dim, strides).unwrap();
800+
*dim = d;
801+
*strides = s;
802+
}
803+
804+
/// Remove axes with length one, except never removing the last axis.
805+
///
806+
/// Return an error if there are more non-unitary dimensions than can be stored
807+
/// in `E`. Infallible for dyn dim.
808+
///
809+
/// Squeeze does not shrink dyn dim down to smaller than 1D, but if the input is
810+
/// dynamic 0D, the output can be too.
811+
///
812+
/// For const dim, this may instead pad the dimensionality with ones if it needs
813+
/// to grow to fill the target dimensionality; the dimension is padded in the
814+
/// start.
815+
pub(crate)fnsqueeze_into<D,E>(dim:&D,strides:&D) ->Result<(E,E),ShapeError>
816+
where
817+
D:Dimension,
818+
E:Dimension,
819+
{
795820
debug_assert_eq!(dim.ndim(), strides.ndim());
796821

797822
// Count axes with dim == 1; we keep axes with d == 0 or d > 1
798823
letmut ndim_new =0;
799824
for&din dim.slice(){
800825
if d !=1{ ndim_new +=1;}
801826
}
802-
ndim_new =Ord::max(1, ndim_new);
803-
letmut new_dim =D::zeros(ndim_new);
804-
letmut new_strides =D::zeros(ndim_new);
827+
letmut fill_ones =0;
828+
ifletSome(e_ndim) =E::NDIM{
829+
if e_ndim < ndim_new{
830+
returnErr(ShapeError::from_kind(ErrorKind::IncompatibleShape));
831+
}
832+
fill_ones = e_ndim - ndim_new;
833+
ndim_new = e_ndim;
834+
}else{
835+
// dynamic-dimensional
836+
// use minimum one dimension unless input has less than one dim
837+
if dim.ndim() >0 && ndim_new ==0{
838+
ndim_new =1;
839+
fill_ones =1;
840+
}
841+
}
842+
843+
letmut new_dim =E::zeros(ndim_new);
844+
letmut new_strides =E::zeros(ndim_new);
805845
letmut i =0;
846+
while i < fill_ones{
847+
new_dim[i] =1;
848+
new_strides[i] =1;
849+
i +=1;
850+
}
806851
for(&d,&s)inizip!(dim.slice(), strides.slice()){
807852
if d !=1{
808853
new_dim[i] = d;
809854
new_strides[i] = s;
810855
i +=1;
811856
}
812857
}
813-
if i ==0{
814-
new_dim[i] =1;
815-
new_strides[i] =1;
816-
}
817-
*dim = new_dim;
818-
*strides = new_strides;
858+
Ok((new_dim, new_strides))
819859
}
820860

821861

@@ -1220,6 +1260,91 @@ mod test {
12201260
assert_eq!(s, sans);
12211261
}
12221262

1263+
#[test]
1264+
#[cfg(feature ="std")]
1265+
fntest_squeeze_into(){
1266+
usesuper::squeeze_into;
1267+
1268+
let dyndim =Dim::<&[usize]>;
1269+
1270+
// squeeze to ixdyn
1271+
let d =dyndim(&[1,2,1,1,3,1]);
1272+
let s =dyndim(&[!0, !0, !0,9,10, !0]);
1273+
let dans =dyndim(&[2,3]);
1274+
let sans =dyndim(&[!0,10]);
1275+
let(d2, s2) =squeeze_into::<_,IxDyn>(&d,&s).unwrap();
1276+
assert_eq!(d2, dans);
1277+
assert_eq!(s2, sans);
1278+
1279+
// squeeze to ixdyn does not go below 1D
1280+
let d =dyndim(&[1,1]);
1281+
let s =dyndim(&[3,4]);
1282+
let dans =dyndim(&[1]);
1283+
let sans =dyndim(&[1]);
1284+
let(d2, s2) =squeeze_into::<_,IxDyn>(&d,&s).unwrap();
1285+
assert_eq!(d2, dans);
1286+
assert_eq!(s2, sans);
1287+
1288+
let d =Dim([1,1]);
1289+
let s =Dim([3,4]);
1290+
let dans =Dim([1]);
1291+
let sans =Dim([1]);
1292+
let(d2, s2) =squeeze_into::<_,Ix1>(&d,&s).unwrap();
1293+
assert_eq!(d2, dans);
1294+
assert_eq!(s2, sans);
1295+
1296+
// squeeze to zero-dim
1297+
let(d2, s2) =squeeze_into::<_,Ix0>(&d,&s).unwrap();
1298+
assert_eq!(d2,Ix0());
1299+
assert_eq!(s2,Ix0());
1300+
1301+
let d =Dim([0,1,3,4]);
1302+
let s =Dim([2,3,4,5]);
1303+
let dans =Dim([0,3,4]);
1304+
let sans =Dim([2,4,5]);
1305+
let(d2, s2) =squeeze_into::<_,Ix3>(&d,&s).unwrap();
1306+
assert_eq!(d2, dans);
1307+
assert_eq!(s2, sans);
1308+
1309+
// Pad with ones
1310+
let d =Dim([0,1,3,1]);
1311+
let s =Dim([2,3,4,5]);
1312+
let dans =Dim([1,0,3]);
1313+
let sans =Dim([1,2,4]);
1314+
let(d2, s2) =squeeze_into::<_,Ix3>(&d,&s).unwrap();
1315+
assert_eq!(d2, dans);
1316+
assert_eq!(s2, sans);
1317+
1318+
// Try something that doesn't fit
1319+
let d =Dim([0,1,3,1]);
1320+
let s =Dim([2,3,4,5]);
1321+
let res =squeeze_into::<_,Ix1>(&d,&s);
1322+
assert!(res.is_err());
1323+
let res =squeeze_into::<_,Ix0>(&d,&s);
1324+
assert!(res.is_err());
1325+
1326+
// Squeeze 0d to 0d
1327+
let d =Dim([]);
1328+
let s =Dim([]);
1329+
let res =squeeze_into::<_,Ix0>(&d,&s);
1330+
assert!(res.is_ok());
1331+
// grow 0d to 2d
1332+
let dans =Dim([1,1]);
1333+
let sans =Dim([1,1]);
1334+
let(d2, s2) =squeeze_into::<_,Ix2>(&d,&s).unwrap();
1335+
assert_eq!(d2, dans);
1336+
assert_eq!(s2, sans);
1337+
1338+
// Squeeze 0d to 0d dynamic
1339+
let d =dyndim(&[]);
1340+
let s =dyndim(&[]);
1341+
let(d2, s2) =squeeze_into::<_,IxDyn>(&d,&s).unwrap();
1342+
let dans = d;
1343+
let sans = s;
1344+
assert_eq!(d2, dans);
1345+
assert_eq!(s2, sans);
1346+
}
1347+
12231348
#[test]
12241349
fntest_merge_axes_from_the_back(){
12251350
let dyndim =Dim::<&[usize]>;

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp