Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commita843a93

Browse files
committed
Fix logic to sequence based indexing in row/col/slice functions
Prior to this change, the following functions were not checking for additionaldimensional data beyond the dimension concerned with the particularfunction.- row- col- slice- rows- cols- slices- set_row- set_col- slice- set_rows- set_cols- set_slicesSimilar logic was missing in one particular matching pattern of viewmacro which is also fixed in this change.Few additional unit tests are added in macro and index module checkingfor the pitfalls this change addresses
1 parent6be6c54 commita843a93

File tree

2 files changed

+148
-32
lines changed

2 files changed

+148
-32
lines changed

‎src/core/index.rs

Lines changed: 93 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -293,13 +293,11 @@ pub fn row<T>(input: &Array<T>, row_num: i64) -> Array<T>
293293
where
294294
T:HasAfEnum,
295295
{
296-
index(
297-
input,
298-
&[
299-
Seq::new(row_numasf64, row_numasf64,1.0),
300-
Seq::default(),
301-
],
302-
)
296+
letmut seqs =vec![Seq::new(row_numasf64, row_numasf64,1.0)];
297+
for _din1..input.dims().ndims(){
298+
seqs.push(Seq::default());
299+
}
300+
index(input,&seqs)
303301
}
304302

305303
/// Set `row_num`^th row in `inout` Array to a new Array `new_row`
@@ -308,7 +306,7 @@ where
308306
T:HasAfEnum,
309307
{
310308
letmut seqs =vec![Seq::new(row_numasf64, row_numasf64,1.0)];
311-
ifinout.dims().ndims() >1{
309+
for _din1..inout.dims().ndims(){
312310
seqs.push(Seq::default());
313311
}
314312
assign_seq(inout,&seqs, new_row)
@@ -320,10 +318,11 @@ where
320318
T:HasAfEnum,
321319
{
322320
let step:f64 =if first > last && last <0{ -1.0}else{1.0};
323-
index(
324-
input,
325-
&[Seq::new(firstasf64, lastasf64, step),Seq::default()],
326-
)
321+
letmut seqs =vec![Seq::new(firstasf64, lastasf64, step)];
322+
for _din1..input.dims().ndims(){
323+
seqs.push(Seq::default());
324+
}
325+
index(input,&seqs)
327326
}
328327

329328
/// Set rows from `first` to `last` in `inout` Array with rows from Array `new_rows`
@@ -332,7 +331,10 @@ where
332331
T:HasAfEnum,
333332
{
334333
let step:f64 =if first > last && last <0{ -1.0}else{1.0};
335-
let seqs =[Seq::new(firstasf64, lastasf64, step),Seq::default()];
334+
letmut seqs =vec![Seq::new(firstasf64, lastasf64, step)];
335+
for _din1..inout.dims().ndims(){
336+
seqs.push(Seq::default());
337+
}
336338
assign_seq(inout,&seqs, new_rows)
337339
}
338340

@@ -352,24 +354,28 @@ pub fn col<T>(input: &Array<T>, col_num: i64) -> Array<T>
352354
where
353355
T:HasAfEnum,
354356
{
355-
index(
356-
input,
357-
&[
358-
Seq::default(),
359-
Seq::new(col_numasf64, col_numasf64,1.0),
360-
],
361-
)
357+
letmut seqs =vec![
358+
Seq::default(),
359+
Seq::new(col_numasf64, col_numasf64,1.0),
360+
];
361+
for _din2..input.dims().ndims(){
362+
seqs.push(Seq::default());
363+
}
364+
index(input,&seqs)
362365
}
363366

364367
/// Set `col_num`^th col in `inout` Array to a new Array `new_col`
365368
pubfnset_col<T>(inout:&mutArray<T>,new_col:&Array<T>,col_num:i64)
366369
where
367370
T:HasAfEnum,
368371
{
369-
let seqs =[
372+
letmutseqs =vec![
370373
Seq::default(),
371374
Seq::new(col_numasf64, col_numasf64,1.0),
372375
];
376+
for _din2..inout.dims().ndims(){
377+
seqs.push(Seq::default());
378+
}
373379
assign_seq(inout,&seqs, new_col)
374380
}
375381

@@ -379,10 +385,11 @@ where
379385
T:HasAfEnum,
380386
{
381387
let step:f64 =if first > last && last <0{ -1.0}else{1.0};
382-
index(
383-
input,
384-
&[Seq::default(),Seq::new(firstasf64, lastasf64, step)],
385-
)
388+
letmut seqs =vec![Seq::default(),Seq::new(firstasf64, lastasf64, step)];
389+
for _din2..input.dims().ndims(){
390+
seqs.push(Seq::default());
391+
}
392+
index(input,&seqs)
386393
}
387394

388395
/// Set cols from `first` to `last` in `inout` Array with cols from Array `new_cols`
@@ -391,7 +398,10 @@ where
391398
T:HasAfEnum,
392399
{
393400
let step:f64 =if first > last && last <0{ -1.0}else{1.0};
394-
let seqs =[Seq::default(),Seq::new(firstasf64, lastasf64, step)];
401+
letmut seqs =vec![Seq::default(),Seq::new(firstasf64, lastasf64, step)];
402+
for _din2..inout.dims().ndims(){
403+
seqs.push(Seq::default());
404+
}
395405
assign_seq(inout,&seqs, new_cols)
396406
}
397407

@@ -402,11 +412,14 @@ pub fn slice<T>(input: &Array<T>, slice_num: i64) -> Array<T>
402412
where
403413
T:HasAfEnum,
404414
{
405-
let seqs =[
415+
letmutseqs =vec![
406416
Seq::default(),
407417
Seq::default(),
408418
Seq::new(slice_numasf64, slice_numasf64,1.0),
409419
];
420+
for _din3..input.dims().ndims(){
421+
seqs.push(Seq::default());
422+
}
410423
index(input,&seqs)
411424
}
412425

@@ -417,11 +430,14 @@ pub fn set_slice<T>(inout: &mut Array<T>, new_slice: &Array<T>, slice_num: i64)
417430
where
418431
T:HasAfEnum,
419432
{
420-
let seqs =[
433+
letmutseqs =vec![
421434
Seq::default(),
422435
Seq::default(),
423436
Seq::new(slice_numasf64, slice_numasf64,1.0),
424437
];
438+
for _din3..inout.dims().ndims(){
439+
seqs.push(Seq::default());
440+
}
425441
assign_seq(inout,&seqs, new_slice)
426442
}
427443

@@ -433,11 +449,14 @@ where
433449
T:HasAfEnum,
434450
{
435451
let step:f64 =if first > last && last <0{ -1.0}else{1.0};
436-
let seqs =[
452+
letmutseqs =vec![
437453
Seq::default(),
438454
Seq::default(),
439455
Seq::new(firstasf64, lastasf64, step),
440456
];
457+
for _din3..input.dims().ndims(){
458+
seqs.push(Seq::default());
459+
}
441460
index(input,&seqs)
442461
}
443462

@@ -449,11 +468,14 @@ where
449468
T:HasAfEnum,
450469
{
451470
let step:f64 =if first > last && last <0{ -1.0}else{1.0};
452-
let seqs =[
471+
letmutseqs =vec![
453472
Seq::default(),
454473
Seq::default(),
455474
Seq::new(firstasf64, lastasf64, step),
456475
];
476+
for _din3..inout.dims().ndims(){
477+
seqs.push(Seq::default());
478+
}
457479
assign_seq(inout,&seqs, new_slices)
458480
}
459481

@@ -655,7 +677,7 @@ mod tests {
655677
usesuper::super::device::set_device;
656678
usesuper::super::dim4::Dim4;
657679
usesuper::super::index::{assign_gen, assign_seq, col, index, index_gen, row,Indexer};
658-
usesuper::super::index::{cols, rows};
680+
usesuper::super::index::{cols, rows, set_row, set_rows};
659681
usesuper::super::random::randu;
660682
usesuper::super::seq::Seq;
661683

@@ -868,4 +890,44 @@ mod tests {
868890
// 0.9675 0.3712 0.7896
869891
// ANCHOR_END: get_rows
870892
}
893+
894+
#[test]
895+
fnchange_row(){
896+
set_device(0);
897+
898+
let v0:Vec<bool> =vec![true,true,true,true,true,true];
899+
letmut a0 =Array::new(&v0,dim4!(v0.len()asu64));
900+
901+
let v1:Vec<bool> =vec![false];
902+
let a1 =Array::new(&v1,dim4!(v1.len()asu64));
903+
904+
set_row(&mut a0,&a1,2);
905+
906+
letmut res =vec![true; a0.elements()];
907+
a0.host(&mut res);
908+
909+
let gold =vec![true,true,false,true,true,true];
910+
911+
assert_eq!(gold, res);
912+
}
913+
914+
#[test]
915+
fnchange_rows(){
916+
set_device(0);
917+
918+
let v0:Vec<bool> =vec![true,true,true,true,true,true];
919+
letmut a0 =Array::new(&v0,dim4!(v0.len()asu64));
920+
921+
let v1:Vec<bool> =vec![false,false];
922+
let a1 =Array::new(&v1,dim4!(v1.len()asu64));
923+
924+
set_rows(&mut a0,&a1,2,3);
925+
926+
letmut res =vec![true; a0.elements()];
927+
a0.host(&mut res);
928+
929+
let gold =vec![true,true,false,false,true,true];
930+
931+
assert_eq!(gold, res);
932+
}
871933
}

‎src/core/macros.rs

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,9 @@ macro_rules! view {
190190
$(
191191
seq_vec.push($crate::seq!($start:$end:$step));
192192
)*
193+
for _d in seq_vec.len()..$array_ident.dims().ndims(){
194+
seq_vec.push($crate::seq!());
195+
}
193196
$crate::index(&$array_ident,&seq_vec)
194197
}
195198
};
@@ -354,7 +357,7 @@ mod tests {
354357
usesuper::super::array::Array;
355358
usesuper::super::data::constant;
356359
usesuper::super::device::set_device;
357-
usesuper::super::index::index;
360+
usesuper::super::index::{index, rows, set_rows};
358361
usesuper::super::random::randu;
359362

360363
#[test]
@@ -505,4 +508,55 @@ mod tests {
505508
let _ruu32_5x5 =randu!(u32;5,5);
506509
let _ruu8_5x5 =randu!(u8;5,5);
507510
}
511+
512+
#[test]
513+
fnmatch_eval_macro_with_set_rows(){
514+
set_device(0);
515+
516+
let inpt =vec![true,true,true,true,true,true,true,true,true,true];
517+
let gold =vec![
518+
true,true,false,false,true,true,true,false,false,true,
519+
];
520+
521+
letmut orig_arr =Array::new(&inpt,dim4!(5,2));
522+
letmut orig_cln = orig_arr.clone();
523+
524+
let new_vals =vec![false,false,false,false];
525+
let new_arr =Array::new(&new_vals,dim4!(2,2));
526+
527+
eval!( orig_arr[2:3:1,1:1:0] = new_arr);
528+
letmut res1 =vec![true; orig_arr.elements()];
529+
orig_arr.host(&mut res1);
530+
531+
set_rows(&mut orig_cln,&new_arr,2,3);
532+
letmut res2 =vec![true; orig_cln.elements()];
533+
orig_cln.host(&mut res2);
534+
535+
assert_eq!(gold, res1);
536+
assert_eq!(res1, res2);
537+
}
538+
539+
#[test]
540+
fnmatch_view_macro_with_get_rows(){
541+
set_device(0);
542+
543+
let inpt:Vec<i32> =(0..10).collect();
544+
let gold:Vec<i32> =vec![2,3,7,8];
545+
546+
println!("input {:?}", inpt);
547+
println!("gold {:?}", gold);
548+
549+
let orig_arr =Array::new(&inpt,dim4!(5,2));
550+
551+
let view_out =view!( orig_arr[2:3:1]);
552+
letmut res1 =vec![0i32; view_out.elements()];
553+
view_out.host(&mut res1);
554+
555+
let rows_out =rows(&orig_arr,2,3);
556+
letmut res2 =vec![0i32; rows_out.elements()];
557+
rows_out.host(&mut res2);
558+
559+
assert_eq!(gold, res1);
560+
assert_eq!(res1, res2);
561+
}
508562
}

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp