Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit5509f13

Browse files
authored
implement vector operations in rust/blas (#297)
1 parent7306dde commit5509f13

File tree

2 files changed

+266
-0
lines changed

2 files changed

+266
-0
lines changed

‎pgml-extension/pgml_rust/Cargo.toml‎

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ xgboost = { path = "rust-xgboost" }
2121
rustlearn ="0.5"
2222
once_cell ="1"
2323
rand ="0.8"
24+
blas = {version ="0.22.0" }
25+
blas-src = {version ="0.8",features = ["openblas"] }
26+
openblas-src = {version ="0.10",features = ["cblas","system"] }
2427

2528
[dev-dependencies]
2629
pgx-tests ="=0.4.5"

‎pgml-extension/pgml_rust/src/lib.rs‎

Lines changed: 263 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
externcrate blas;
2+
externcrate openblas_src;
3+
14
use once_cell::sync::Lazy;// 1.3.1
25
use pgx::*;
36
use std::collections::HashMap;
@@ -370,6 +373,266 @@ mod pgml_rust {
370373
None =>error!("Model with id = {} does not exist", model_id),
371374
}
372375
}
376+
377+
#[pg_extern(immutable, parallel_safe, strict, name="add")]
378+
fnpgml_add_scalar_s(vector:Vec<f32>,addend:f32) ->Vec<f32>{
379+
vector.as_slice().iter().map(|a| a + addend).collect()
380+
}
381+
382+
#[pg_extern(immutable, parallel_safe, strict, name="add")]
383+
fnpgml_add_scalar_d(vector:Vec<f64>,addend:f64) ->Vec<f64>{
384+
vector.as_slice().iter().map(|a| a + addend).collect()
385+
}
386+
387+
#[pg_extern(immutable, parallel_safe, strict, name="subtract")]
388+
fnpgml_subtract_scalar_s(vector:Vec<f32>,subtahend:f32) ->Vec<f32>{
389+
vector.as_slice().iter().map(|a| a - subtahend).collect()
390+
}
391+
392+
#[pg_extern(immutable, parallel_safe, strict, name="subtract")]
393+
fnpgml_subtract_scalar_d(vector:Vec<f64>,subtahend:f64) ->Vec<f64>{
394+
vector.as_slice().iter().map(|a| a - subtahend).collect()
395+
}
396+
397+
#[pg_extern(immutable, parallel_safe, strict, name="multiply")]
398+
fnpgml_multiply_scalar_s(vector:Vec<f32>,multiplicand:f32) ->Vec<f32>{
399+
vector.as_slice().iter().map(|a| a* multiplicand).collect()
400+
}
401+
402+
#[pg_extern(immutable, parallel_safe, strict, name="multiply")]
403+
fnpgml_multiply_scalar_d(vector:Vec<f64>,multiplicand:f64) ->Vec<f64>{
404+
vector.as_slice().iter().map(|a| a* multiplicand).collect()
405+
}
406+
407+
#[pg_extern(immutable, parallel_safe, strict, name="divide")]
408+
fnpgml_divide_scalar_s(vector:Vec<f32>,dividend:f32) ->Vec<f32>{
409+
vector.as_slice().iter().map(|a| a / dividend).collect()
410+
}
411+
412+
#[pg_extern(immutable, parallel_safe, strict, name="divide")]
413+
fnpgml_divide_scalar_d(vector:Vec<f64>,dividend:f64) ->Vec<f64>{
414+
vector.as_slice().iter().map(|a| a / dividend).collect()
415+
}
416+
417+
#[pg_extern(immutable, parallel_safe, strict, name="add")]
418+
fnpgml_add_vector_s(vector:Vec<f32>,addend:Vec<f32>) ->Vec<f32>{
419+
vector.as_slice().iter()
420+
.zip(addend.as_slice().iter())
421+
.map(|(a, b)| a + b).collect()
422+
}
423+
424+
#[pg_extern(immutable, parallel_safe, strict, name="add")]
425+
fnpgml_add_vector_d(vector:Vec<f64>,addend:Vec<f64>) ->Vec<f64>{
426+
vector.as_slice().iter()
427+
.zip(addend.as_slice().iter())
428+
.map(|(a, b)| a + b).collect()
429+
}
430+
431+
#[pg_extern(immutable, parallel_safe, strict, name="subtract")]
432+
fnpgml_subtract_vector_s(vector:Vec<f32>,subtahend:Vec<f32>) ->Vec<f32>{
433+
vector.as_slice().iter()
434+
.zip(subtahend.as_slice().iter())
435+
.map(|(a, b)| a - b).collect()
436+
}
437+
438+
#[pg_extern(immutable, parallel_safe, strict, name="subtract")]
439+
fnpgml_subtract_vector_d(vector:Vec<f64>,subtahend:Vec<f64>) ->Vec<f64>{
440+
vector.as_slice().iter()
441+
.zip(subtahend.as_slice().iter())
442+
.map(|(a, b)| a - b).collect()
443+
}
444+
445+
#[pg_extern(immutable, parallel_safe, strict, name="multiply")]
446+
fnpgml_multiply_vector_s(vector:Vec<f32>,multiplicand:Vec<f32>) ->Vec<f32>{
447+
vector.as_slice().iter()
448+
.zip(multiplicand.as_slice().iter())
449+
.map(|(a, b)| a* b).collect()
450+
}
451+
452+
#[pg_extern(immutable, parallel_safe, strict, name="multiply")]
453+
fnpgml_multiply_vector_d(vector:Vec<f64>,multiplicand:Vec<f64>) ->Vec<f64>{
454+
vector.as_slice().iter()
455+
.zip(multiplicand.as_slice().iter())
456+
.map(|(a, b)| a* b).collect()
457+
}
458+
459+
#[pg_extern(immutable, parallel_safe, strict, name="divide")]
460+
fnpgml_divide_vector_s(vector:Vec<f32>,dividend:Vec<f32>) ->Vec<f32>{
461+
vector.as_slice().iter()
462+
.zip(dividend.as_slice().iter())
463+
.map(|(a, b)| a / b).collect()
464+
}
465+
466+
#[pg_extern(immutable, parallel_safe, strict, name="divide")]
467+
fnpgml_divide_vector_d(vector:Vec<f64>,dividend:Vec<f64>) ->Vec<f64>{
468+
vector.as_slice().iter()
469+
.zip(dividend.as_slice().iter())
470+
.map(|(a, b)| a / b).collect()
471+
}
472+
473+
#[pg_extern(immutable, parallel_safe, strict, name="norm_l0")]
474+
fnpgml_norm_l0_s(vector:Vec<f32>) ->f32{
475+
vector.as_slice().iter().map(|a|if*a ==0.0{0.0}else{1.0}).sum()
476+
}
477+
478+
#[pg_extern(immutable, parallel_safe, strict, name="norm_l0")]
479+
fnpgml_norm_l0_d(vector:Vec<f64>) ->f64{
480+
vector.as_slice().iter().map(|a|if*a ==0.0{0.0}else{1.0}).sum()
481+
}
482+
483+
#[pg_extern(immutable, parallel_safe, strict, name="norm_l1")]
484+
fnpgml_norm_l1_s(vector:Vec<f32>) ->f32{
485+
unsafe{
486+
blas::sasum(vector.len().try_into().unwrap(), vector.as_slice(),1)
487+
}
488+
}
489+
490+
#[pg_extern(immutable, parallel_safe, strict, name="norm_l1")]
491+
fnpgml_norm_l1_d(vector:Vec<f64>) ->f64{
492+
unsafe{
493+
blas::dasum(vector.len().try_into().unwrap(), vector.as_slice(),1)
494+
}
495+
}
496+
497+
#[pg_extern(immutable, parallel_safe, strict, name="norm_l2")]
498+
fnpgml_norm_l2_s(vector:Vec<f32>) ->f32{
499+
unsafe{
500+
blas::snrm2(vector.len().try_into().unwrap(), vector.as_slice(),1)
501+
}
502+
}
503+
504+
#[pg_extern(immutable, parallel_safe, strict, name="norm_l2")]
505+
fnpgml_norm_l2_d(vector:Vec<f64>) ->f64{
506+
unsafe{
507+
blas::dnrm2(vector.len().try_into().unwrap(), vector.as_slice(),1)
508+
}
509+
}
510+
511+
#[pg_extern(immutable, parallel_safe, strict, name="norm_max")]
512+
fnpgml_norm_max_s(vector:Vec<f32>) ->f32{
513+
unsafe{
514+
let index = blas::isamax(vector.len().try_into().unwrap(), vector.as_slice(),1);
515+
vector[index -1]
516+
}
517+
}
518+
519+
#[pg_extern(immutable, parallel_safe, strict, name="norm_max")]
520+
fnpgml_norm_max_d(vector:Vec<f64>) ->f64{
521+
unsafe{
522+
let index = blas::idamax(vector.len().try_into().unwrap(), vector.as_slice(),1);
523+
vector[index -1]
524+
}
525+
}
526+
527+
#[pg_extern(immutable, parallel_safe, strict, name="normalize_l1")]
528+
fnpgml_normalize_l1_s(vector:Vec<f32>) ->Vec<f32>{
529+
let norm:f32;
530+
unsafe{
531+
norm = blas::sasum(vector.len().try_into().unwrap(), vector.as_slice(),1);
532+
}
533+
pgml_divide_scalar_s(vector, norm)
534+
}
535+
536+
#[pg_extern(immutable, parallel_safe, strict, name="normalize_l1")]
537+
fnpgml_normalize_l1_d(vector:Vec<f64>) ->Vec<f64>{
538+
let norm:f64;
539+
unsafe{
540+
norm = blas::dasum(vector.len().try_into().unwrap(), vector.as_slice(),1);
541+
}
542+
pgml_divide_scalar_d(vector, norm)
543+
}
544+
545+
#[pg_extern(immutable, parallel_safe, strict, name="normalize_l2")]
546+
fnpgml_normalize_l2_s(vector:Vec<f32>) ->Vec<f32>{
547+
let norm:f32;
548+
unsafe{
549+
norm = blas::snrm2(vector.len().try_into().unwrap(), vector.as_slice(),1);
550+
}
551+
pgml_divide_scalar_s(vector, norm)
552+
}
553+
554+
#[pg_extern(immutable, parallel_safe, strict, name="normalize_l2")]
555+
fnpgml_normalize_l2_d(vector:Vec<f64>) ->Vec<f64>{
556+
let norm:f64;
557+
unsafe{
558+
norm = blas::dnrm2(vector.len().try_into().unwrap(), vector.as_slice(),1);
559+
}
560+
pgml_divide_scalar_d(vector, norm)
561+
}
562+
563+
#[pg_extern(immutable, parallel_safe, strict, name="normalize_max")]
564+
fnpgml_normalize_max_s(vector:Vec<f32>) ->Vec<f32>{
565+
let norm = vector.as_slice().iter().map(|a| a.abs()).max_by(|a, b| a.partial_cmp(b).unwrap()).unwrap();
566+
pgml_divide_scalar_s(vector, norm)
567+
}
568+
569+
#[pg_extern(immutable, parallel_safe, strict, name="normalize_max")]
570+
fnpgml_normalize_max_d(vector:Vec<f64>) ->Vec<f64>{
571+
let norm = vector.as_slice().iter().map(|a| a.abs()).max_by(|a, b| a.partial_cmp(b).unwrap()).unwrap();
572+
pgml_divide_scalar_d(vector, norm)
573+
}
574+
575+
#[pg_extern(immutable, parallel_safe, strict, name="distance_l1")]
576+
fnpgml_distance_l1_s(vector:Vec<f32>,other:Vec<f32>) ->f32{
577+
vector.as_slice().iter()
578+
.zip(other.as_slice().iter())
579+
.map(|(a, b)|(a - b).abs()).sum()
580+
}
581+
582+
#[pg_extern(immutable, parallel_safe, strict, name="distance_l1")]
583+
fnpgml_distance_l1_d(vector:Vec<f64>,other:Vec<f64>) ->f64{
584+
vector.as_slice().iter()
585+
.zip(other.as_slice().iter())
586+
.map(|(a, b)|(a - b).abs()).sum()
587+
}
588+
589+
#[pg_extern(immutable, parallel_safe, strict, name="distance_l2")]
590+
fnpgml_distance_l2_s(vector:Vec<f32>,other:Vec<f32>) ->f32{
591+
vector.as_slice().iter()
592+
.zip(other.as_slice().iter())
593+
.map(|(a, b)|(a - b).powf(2.0)).sum::<f32>().sqrt()
594+
}
595+
596+
#[pg_extern(immutable, parallel_safe, strict, name="distance_l2")]
597+
fnpgml_distance_l2_d(vector:Vec<f64>,other:Vec<f64>) ->f64{
598+
vector.as_slice().iter()
599+
.zip(other.as_slice().iter())
600+
.map(|(a, b)|(a - b).powf(2.0)).sum::<f64>().sqrt()
601+
}
602+
603+
#[pg_extern(immutable, parallel_safe, strict, name="dot_product")]
604+
fnpgml_dot_product_s(vector:Vec<f32>,other:Vec<f32>) ->f32{
605+
unsafe{
606+
blas::sdot(vector.len().try_into().unwrap(), vector.as_slice(),1, other.as_slice(),1)
607+
}
608+
}
609+
610+
#[pg_extern(immutable, parallel_safe, strict, name="dot_product")]
611+
fnpgml_dot_product_d(vector:Vec<f64>,other:Vec<f64>) ->f64{
612+
unsafe{
613+
blas::ddot(vector.len().try_into().unwrap(), vector.as_slice(),1, other.as_slice(),1)
614+
}
615+
}
616+
617+
#[pg_extern(immutable, parallel_safe, strict, name="cosine_similarity")]
618+
fnpgml_cosine_similarity_s(vector:Vec<f32>,other:Vec<f32>) ->f32{
619+
unsafe{
620+
let dot = blas::sdot(vector.len().try_into().unwrap(), vector.as_slice(),1, other.as_slice(),1);
621+
let a_norm = blas::snrm2(vector.len().try_into().unwrap(), vector.as_slice(),1);
622+
let b_norm = blas::snrm2(other.len().try_into().unwrap(), other.as_slice(),1);
623+
dot /(a_norm* b_norm)
624+
}
625+
}
626+
627+
#[pg_extern(immutable, parallel_safe, strict, name="cosine_similarity")]
628+
fnpgml_cosine_similarity_d(vector:Vec<f64>,other:Vec<f64>) ->f64{
629+
unsafe{
630+
let dot = blas::ddot(vector.len().try_into().unwrap(), vector.as_slice(),1, other.as_slice(),1);
631+
let a_norm = blas::dnrm2(vector.len().try_into().unwrap(), vector.as_slice(),1);
632+
let b_norm = blas::dnrm2(other.len().try_into().unwrap(), other.as_slice(),1);
633+
dot /(a_norm* b_norm)
634+
}
635+
}
373636
}
374637

375638
#[cfg(any(test, feature ="pg_test"))]

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp