Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitedb64d6

Browse files
authored
Reduce the scope of the unsafe block across the code base (#319)
1 parentc70413b commitedb64d6

File tree

19 files changed

+1649
-1831
lines changed

19 files changed

+1649
-1831
lines changed

‎src/algorithm/mod.rs

Lines changed: 135 additions & 137 deletions
Large diffs are not rendered by default.

‎src/blas/mod.rs

Lines changed: 27 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -130,20 +130,20 @@ pub fn gemm<T>(
130130
)where
131131
T:HasAfEnum +FloatingPoint,
132132
{
133-
unsafe{
134-
letmut out = output.get();
135-
let err_val =af_gemm(
133+
letmut out =unsafe{ output.get()};
134+
leterr_val =unsafe{
135+
af_gemm(
136136
&mut outas*mutaf_array,
137137
optlhsasc_uint,
138138
optrhsasc_uint,
139139
alpha.as_ptr()as*constc_void,
140140
lhs.get(),
141141
rhs.get(),
142142
beta.as_ptr()as*constc_void,
143-
);
144-
HANDLE_ERROR(AfError::from(err_val));
145-
output.set(out);
146-
}
143+
)
144+
};
145+
HANDLE_ERROR(AfError::from(err_val));
146+
output.set(out);
147147
}
148148

149149
/// Matrix multiple of two Arrays
@@ -162,18 +162,18 @@ pub fn matmul<T>(lhs: &Array<T>, rhs: &Array<T>, optlhs: MatProp, optrhs: MatPro
162162
where
163163
T:HasAfEnum +FloatingPoint,
164164
{
165-
unsafe{
166-
letmut temp:af_array = std::ptr::null_mut();
167-
let err_val =af_matmul(
165+
letmut temp:af_array = std::ptr::null_mut();
166+
leterr_val =unsafe{
167+
af_matmul(
168168
&mut tempas*mutaf_array,
169169
lhs.get(),
170170
rhs.get(),
171171
optlhsasc_uint,
172172
optrhsasc_uint,
173-
);
174-
HANDLE_ERROR(AfError::from(err_val));
175-
temp.into()
176-
}
173+
)
174+
};
175+
HANDLE_ERROR(AfError::from(err_val));
176+
temp.into()
177177
}
178178

179179
/// Calculate the dot product of vectors.
@@ -194,18 +194,18 @@ pub fn dot<T>(lhs: &Array<T>, rhs: &Array<T>, optlhs: MatProp, optrhs: MatProp)
194194
where
195195
T:HasAfEnum +FloatingPoint,
196196
{
197-
unsafe{
198-
letmut temp:af_array = std::ptr::null_mut();
199-
let err_val =af_dot(
197+
letmut temp:af_array = std::ptr::null_mut();
198+
leterr_val =unsafe{
199+
af_dot(
200200
&mut tempas*mutaf_array,
201201
lhs.get(),
202202
rhs.get(),
203203
optlhsasc_uint,
204204
optrhsasc_uint,
205-
);
206-
HANDLE_ERROR(AfError::from(err_val));
207-
temp.into()
208-
}
205+
)
206+
};
207+
HANDLE_ERROR(AfError::from(err_val));
208+
temp.into()
209209
}
210210

211211
/// Transpose of a matrix.
@@ -220,12 +220,10 @@ where
220220
///
221221
/// Transposed Array.
222222
pubfntranspose<T:HasAfEnum>(arr:&Array<T>,conjugate:bool) ->Array<T>{
223-
unsafe{
224-
letmut temp:af_array = std::ptr::null_mut();
225-
let err_val =af_transpose(&mut tempas*mutaf_array, arr.get(), conjugate);
226-
HANDLE_ERROR(AfError::from(err_val));
227-
temp.into()
228-
}
223+
letmut temp:af_array = std::ptr::null_mut();
224+
let err_val =unsafe{af_transpose(&mut tempas*mutaf_array, arr.get(), conjugate)};
225+
HANDLE_ERROR(AfError::from(err_val));
226+
temp.into()
229227
}
230228

231229
/// Inplace transpose of a matrix.
@@ -236,10 +234,8 @@ pub fn transpose<T: HasAfEnum>(arr: &Array<T>, conjugate: bool) -> Array<T> {
236234
/// - `conjugate` is a boolean that indicates if the transpose operation needs to be a conjugate
237235
/// transpose
238236
pubfntranspose_inplace<T:HasAfEnum>(arr:&mutArray<T>,conjugate:bool){
239-
unsafe{
240-
let err_val =af_transpose_inplace(arr.get(), conjugate);
241-
HANDLE_ERROR(AfError::from(err_val));
242-
}
237+
let err_val =unsafe{af_transpose_inplace(arr.get(), conjugate)};
238+
HANDLE_ERROR(AfError::from(err_val));
243239
}
244240

245241
/// Sets the cuBLAS math mode for the internal handle.

‎src/core/arith.rs

Lines changed: 36 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -108,12 +108,10 @@ where
108108
typeOutput =Array<T>;
109109

110110
fnnot(self) ->Self::Output{
111-
unsafe{
112-
letmut temp:af_array = std::ptr::null_mut();
113-
let err_val =af_not(&mut tempas*mutaf_array,self.get());
114-
HANDLE_ERROR(AfError::from(err_val));
115-
temp.into()
116-
}
111+
letmut temp:af_array = std::ptr::null_mut();
112+
let err_val =unsafe{af_not(&mut tempas*mutaf_array,self.get())};
113+
HANDLE_ERROR(AfError::from(err_val));
114+
temp.into()
117115
}
118116
}
119117

@@ -124,12 +122,12 @@ macro_rules! unary_func {
124122
/// This is an element wise unary operation.
125123
pubfn $fn_name<T:HasAfEnum>(input:&Array<T>) ->Array<T::$out_type >
126124
whereT::$out_type:HasAfEnum{
127-
unsafe{
125+
128126
letmut temp: af_array = std::ptr::null_mut();
129-
let err_val = $ffi_fn(&mut tempas*mut af_array, input.get());
127+
let err_val =unsafe{$ffi_fn(&mut tempas*mut af_array, input.get())};
130128
HANDLE_ERROR(AfError::from(err_val));
131129
temp.into()
132-
}
130+
133131
}
134132
)
135133
}
@@ -256,12 +254,12 @@ macro_rules! unary_boolean_func {
256254
///
257255
/// This is an element wise unary operation.
258256
pubfn $fn_name<T:HasAfEnum>(input:&Array<T>) ->Array<bool>{
259-
unsafe{
257+
260258
letmut temp: af_array = std::ptr::null_mut();
261-
let err_val = $ffi_fn(&mut tempas*mut af_array, input.get());
259+
let err_val =unsafe{$ffi_fn(&mut tempas*mut af_array, input.get())};
262260
HANDLE_ERROR(AfError::from(err_val));
263261
temp.into()
264-
}
262+
265263
}
266264
)
267265
}
@@ -291,12 +289,11 @@ macro_rules! binary_func {
291289
A:ImplicitPromote<B>,
292290
B:ImplicitPromote<A>,
293291
{
294-
unsafe{
295-
letmut temp: af_array = std::ptr::null_mut();
296-
let err_val = $ffi_fn(&mut tempas*mut af_array, lhs.get(), rhs.get(), batch);
297-
HANDLE_ERROR(AfError::from(err_val));
298-
Into::<Array<A::Output>>::into(temp)
299-
}
292+
letmut temp: af_array = std::ptr::null_mut();
293+
let err_val =
294+
unsafe{ $ffi_fn(&mut tempas*mut af_array, lhs.get(), rhs.get(), batch)};
295+
HANDLE_ERROR(AfError::from(err_val));
296+
Into::<Array<A::Output>>::into(temp)
300297
}
301298
};
302299
}
@@ -389,12 +386,11 @@ macro_rules! overloaded_binary_func {
389386
A:ImplicitPromote<B>,
390387
B:ImplicitPromote<A>,
391388
{
392-
unsafe{
393-
letmut temp: af_array = std::ptr::null_mut();
394-
let err_val = $ffi_name(&mut tempas*mut af_array, lhs.get(), rhs.get(), batch);
395-
HANDLE_ERROR(AfError::from(err_val));
396-
temp.into()
397-
}
389+
letmut temp: af_array = std::ptr::null_mut();
390+
let err_val =
391+
unsafe{ $ffi_name(&mut tempas*mut af_array, lhs.get(), rhs.get(), batch)};
392+
HANDLE_ERROR(AfError::from(err_val));
393+
temp.into()
398394
}
399395

400396
#[doc=$doc_str]
@@ -491,12 +487,11 @@ macro_rules! overloaded_logic_func {
491487
A:ImplicitPromote<B>,
492488
B:ImplicitPromote<A>,
493489
{
494-
unsafe{
495-
letmut temp: af_array = std::ptr::null_mut();
496-
let err_val = $ffi_name(&mut tempas*mut af_array, lhs.get(), rhs.get(), batch);
497-
HANDLE_ERROR(AfError::from(err_val));
498-
temp.into()
499-
}
490+
letmut temp: af_array = std::ptr::null_mut();
491+
let err_val =
492+
unsafe{ $ffi_name(&mut tempas*mut af_array, lhs.get(), rhs.get(), batch)};
493+
HANDLE_ERROR(AfError::from(err_val));
494+
temp.into()
500495
}
501496

502497
#[doc=$doc_str]
@@ -611,18 +606,18 @@ where
611606
X:ImplicitPromote<Y>,
612607
Y:ImplicitPromote<X>,
613608
{
614-
unsafe{
615-
letmut temp:af_array = std::ptr::null_mut();
616-
let err_val =af_clamp(
609+
letmut temp:af_array = std::ptr::null_mut();
610+
leterr_val =unsafe{
611+
af_clamp(
617612
&mut tempas*mutaf_array,
618613
inp.get(),
619614
lo.get(),
620615
hi.get(),
621616
batch,
622-
);
623-
HANDLE_ERROR(AfError::from(err_val));
624-
temp.into()
625-
}
617+
)
618+
};
619+
HANDLE_ERROR(AfError::from(err_val));
620+
temp.into()
626621
}
627622

628623
/// Clamp the values of Array
@@ -979,10 +974,8 @@ pub fn bitnot<T: HasAfEnum>(input: &Array<T>) -> Array<T>
979974
where
980975
T:HasAfEnum +IntegralType,
981976
{
982-
unsafe{
983-
letmut temp:af_array = std::ptr::null_mut();
984-
let err_val =af_bitnot(&mut tempas*mutaf_array, input.get());
985-
HANDLE_ERROR(AfError::from(err_val));
986-
temp.into()
987-
}
977+
letmut temp:af_array = std::ptr::null_mut();
978+
let err_val =unsafe{af_bitnot(&mut tempas*mutaf_array, input.get())};
979+
HANDLE_ERROR(AfError::from(err_val));
980+
temp.into()
988981
}

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp