Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit08068a8

Browse files
committed
FEAT: Add method squeeze_into
This method can squeeze into a particular dimensionality.Squeezing means removing axes of length 1. When squeezing to aparticular dimensionality, we may have to still pad out the shape withextra 1-shape axes to fill the dimensionality.
1 parent9cea3c7 commit08068a8

File tree

1 file changed

+132
-0
lines changed

1 file changed

+132
-0
lines changed

‎src/dimension/mod.rs

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -762,6 +762,53 @@ where
762762
*strides = new_strides;
763763
}
764764

765+
/// Remove axes with length one, except never removing the last axis.
766+
pub(crate)fnsqueeze_into<D,E>(dim:&D,strides:&D) ->Result<(E,E),ShapeError>
767+
where
768+
D:Dimension,
769+
E:Dimension,
770+
{
771+
debug_assert_eq!(dim.ndim(), strides.ndim());
772+
773+
// Count axes with dim == 1; we keep axes with d == 0 or d > 1
774+
letmut ndim_new =0;
775+
for&din dim.slice(){
776+
if d !=1{ ndim_new +=1;}
777+
}
778+
letmut fill_ones =0;
779+
ifletSome(e_ndim) =E::NDIM{
780+
if e_ndim < ndim_new{
781+
returnErr(ShapeError::from_kind(ErrorKind::IncompatibleShape));
782+
}
783+
fill_ones = e_ndim - ndim_new;
784+
ndim_new = e_ndim;
785+
}else{
786+
// dynamic-dimensional
787+
// use minimum one dimension unless input has less than one dim
788+
if dim.ndim() >0 && ndim_new ==0{
789+
ndim_new =1;
790+
fill_ones =1;
791+
}
792+
}
793+
794+
letmut new_dim =E::zeros(ndim_new);
795+
letmut new_strides =E::zeros(ndim_new);
796+
letmut i =0;
797+
while i < fill_ones{
798+
new_dim[i] =1;
799+
new_strides[i] =1;
800+
i +=1;
801+
}
802+
for(&d,&s)inizip!(dim.slice(), strides.slice()){
803+
if d !=1{
804+
new_dim[i] = d;
805+
new_strides[i] = s;
806+
i +=1;
807+
}
808+
}
809+
Ok((new_dim, new_strides))
810+
}
811+
765812

766813
/// Sort axes to standard/row major order, i.e Axis(0) has biggest stride and Axis(n - 1) least
767814
/// stride
@@ -1148,6 +1195,91 @@ mod test {
11481195
assert_eq!(s, sans);
11491196
}
11501197

1198+
#[test]
1199+
#[cfg(feature ="std")]
1200+
fntest_squeeze_into(){
1201+
usesuper::squeeze_into;
1202+
1203+
let dyndim =Dim::<&[usize]>;
1204+
1205+
// squeeze to ixdyn
1206+
let d =dyndim(&[1,2,1,1,3,1]);
1207+
let s =dyndim(&[!0, !0, !0,9,10, !0]);
1208+
let dans =dyndim(&[2,3]);
1209+
let sans =dyndim(&[!0,10]);
1210+
let(d2, s2) =squeeze_into::<_,IxDyn>(&d,&s).unwrap();
1211+
assert_eq!(d2, dans);
1212+
assert_eq!(s2, sans);
1213+
1214+
// squeeze to ixdyn does not go below 1D
1215+
let d =dyndim(&[1,1]);
1216+
let s =dyndim(&[3,4]);
1217+
let dans =dyndim(&[1]);
1218+
let sans =dyndim(&[1]);
1219+
let(d2, s2) =squeeze_into::<_,IxDyn>(&d,&s).unwrap();
1220+
assert_eq!(d2, dans);
1221+
assert_eq!(s2, sans);
1222+
1223+
let d =Dim([1,1]);
1224+
let s =Dim([3,4]);
1225+
let dans =Dim([1]);
1226+
let sans =Dim([1]);
1227+
let(d2, s2) =squeeze_into::<_,Ix1>(&d,&s).unwrap();
1228+
assert_eq!(d2, dans);
1229+
assert_eq!(s2, sans);
1230+
1231+
// squeeze to zero-dim
1232+
let(d2, s2) =squeeze_into::<_,Ix0>(&d,&s).unwrap();
1233+
assert_eq!(d2,Ix0());
1234+
assert_eq!(s2,Ix0());
1235+
1236+
let d =Dim([0,1,3,4]);
1237+
let s =Dim([2,3,4,5]);
1238+
let dans =Dim([0,3,4]);
1239+
let sans =Dim([2,4,5]);
1240+
let(d2, s2) =squeeze_into::<_,Ix3>(&d,&s).unwrap();
1241+
assert_eq!(d2, dans);
1242+
assert_eq!(s2, sans);
1243+
1244+
// Pad with ones
1245+
let d =Dim([0,1,3,1]);
1246+
let s =Dim([2,3,4,5]);
1247+
let dans =Dim([1,0,3]);
1248+
let sans =Dim([1,2,4]);
1249+
let(d2, s2) =squeeze_into::<_,Ix3>(&d,&s).unwrap();
1250+
assert_eq!(d2, dans);
1251+
assert_eq!(s2, sans);
1252+
1253+
// Try something that doesn't fit
1254+
let d =Dim([0,1,3,1]);
1255+
let s =Dim([2,3,4,5]);
1256+
let res =squeeze_into::<_,Ix1>(&d,&s);
1257+
assert!(res.is_err());
1258+
let res =squeeze_into::<_,Ix0>(&d,&s);
1259+
assert!(res.is_err());
1260+
1261+
// Squeeze 0d to 0d
1262+
let d =Dim([]);
1263+
let s =Dim([]);
1264+
let res =squeeze_into::<_,Ix0>(&d,&s);
1265+
assert!(res.is_ok());
1266+
// grow 0d to 2d
1267+
let dans =Dim([1,1]);
1268+
let sans =Dim([1,1]);
1269+
let(d2, s2) =squeeze_into::<_,Ix2>(&d,&s).unwrap();
1270+
assert_eq!(d2, dans);
1271+
assert_eq!(s2, sans);
1272+
1273+
// Squeeze 0d to 0d dynamic
1274+
let d =dyndim(&[]);
1275+
let s =dyndim(&[]);
1276+
let(d2, s2) =squeeze_into::<_,IxDyn>(&d,&s).unwrap();
1277+
let dans = d;
1278+
let sans = s;
1279+
assert_eq!(d2, dans);
1280+
assert_eq!(s2, sans);
1281+
}
1282+
11511283
#[test]
11521284
fntest_merge_axes_from_the_back(){
11531285
let dyndim =Dim::<&[usize]>;

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp