|
| 1 | +externcrate blas; |
| 2 | +externcrate openblas_src; |
| 3 | + |
1 | 4 | use once_cell::sync::Lazy;// 1.3.1 |
2 | 5 | use pgx::*; |
3 | 6 | use std::collections::HashMap; |
@@ -370,6 +373,266 @@ mod pgml_rust { |
370 | 373 | None =>error!("Model with id = {} does not exist", model_id), |
371 | 374 | } |
372 | 375 | } |
| 376 | + |
| 377 | +#[pg_extern(immutable, parallel_safe, strict, name="add")] |
| 378 | +fnpgml_add_scalar_s(vector:Vec<f32>,addend:f32) ->Vec<f32>{ |
| 379 | + vector.as_slice().iter().map(|a| a + addend).collect() |
| 380 | +} |
| 381 | + |
| 382 | +#[pg_extern(immutable, parallel_safe, strict, name="add")] |
| 383 | +fnpgml_add_scalar_d(vector:Vec<f64>,addend:f64) ->Vec<f64>{ |
| 384 | + vector.as_slice().iter().map(|a| a + addend).collect() |
| 385 | +} |
| 386 | + |
| 387 | +#[pg_extern(immutable, parallel_safe, strict, name="subtract")] |
| 388 | +fnpgml_subtract_scalar_s(vector:Vec<f32>,subtahend:f32) ->Vec<f32>{ |
| 389 | + vector.as_slice().iter().map(|a| a - subtahend).collect() |
| 390 | +} |
| 391 | + |
| 392 | +#[pg_extern(immutable, parallel_safe, strict, name="subtract")] |
| 393 | +fnpgml_subtract_scalar_d(vector:Vec<f64>,subtahend:f64) ->Vec<f64>{ |
| 394 | + vector.as_slice().iter().map(|a| a - subtahend).collect() |
| 395 | +} |
| 396 | + |
| 397 | +#[pg_extern(immutable, parallel_safe, strict, name="multiply")] |
| 398 | +fnpgml_multiply_scalar_s(vector:Vec<f32>,multiplicand:f32) ->Vec<f32>{ |
| 399 | + vector.as_slice().iter().map(|a| a* multiplicand).collect() |
| 400 | +} |
| 401 | + |
| 402 | +#[pg_extern(immutable, parallel_safe, strict, name="multiply")] |
| 403 | +fnpgml_multiply_scalar_d(vector:Vec<f64>,multiplicand:f64) ->Vec<f64>{ |
| 404 | + vector.as_slice().iter().map(|a| a* multiplicand).collect() |
| 405 | +} |
| 406 | + |
| 407 | +#[pg_extern(immutable, parallel_safe, strict, name="divide")] |
| 408 | +fnpgml_divide_scalar_s(vector:Vec<f32>,dividend:f32) ->Vec<f32>{ |
| 409 | + vector.as_slice().iter().map(|a| a / dividend).collect() |
| 410 | +} |
| 411 | + |
| 412 | +#[pg_extern(immutable, parallel_safe, strict, name="divide")] |
| 413 | +fnpgml_divide_scalar_d(vector:Vec<f64>,dividend:f64) ->Vec<f64>{ |
| 414 | + vector.as_slice().iter().map(|a| a / dividend).collect() |
| 415 | +} |
| 416 | + |
| 417 | +#[pg_extern(immutable, parallel_safe, strict, name="add")] |
| 418 | +fnpgml_add_vector_s(vector:Vec<f32>,addend:Vec<f32>) ->Vec<f32>{ |
| 419 | + vector.as_slice().iter() |
| 420 | +.zip(addend.as_slice().iter()) |
| 421 | +.map(|(a, b)| a + b).collect() |
| 422 | +} |
| 423 | + |
| 424 | +#[pg_extern(immutable, parallel_safe, strict, name="add")] |
| 425 | +fnpgml_add_vector_d(vector:Vec<f64>,addend:Vec<f64>) ->Vec<f64>{ |
| 426 | + vector.as_slice().iter() |
| 427 | +.zip(addend.as_slice().iter()) |
| 428 | +.map(|(a, b)| a + b).collect() |
| 429 | +} |
| 430 | + |
| 431 | +#[pg_extern(immutable, parallel_safe, strict, name="subtract")] |
| 432 | +fnpgml_subtract_vector_s(vector:Vec<f32>,subtahend:Vec<f32>) ->Vec<f32>{ |
| 433 | + vector.as_slice().iter() |
| 434 | +.zip(subtahend.as_slice().iter()) |
| 435 | +.map(|(a, b)| a - b).collect() |
| 436 | +} |
| 437 | + |
| 438 | +#[pg_extern(immutable, parallel_safe, strict, name="subtract")] |
| 439 | +fnpgml_subtract_vector_d(vector:Vec<f64>,subtahend:Vec<f64>) ->Vec<f64>{ |
| 440 | + vector.as_slice().iter() |
| 441 | +.zip(subtahend.as_slice().iter()) |
| 442 | +.map(|(a, b)| a - b).collect() |
| 443 | +} |
| 444 | + |
| 445 | +#[pg_extern(immutable, parallel_safe, strict, name="multiply")] |
| 446 | +fnpgml_multiply_vector_s(vector:Vec<f32>,multiplicand:Vec<f32>) ->Vec<f32>{ |
| 447 | + vector.as_slice().iter() |
| 448 | +.zip(multiplicand.as_slice().iter()) |
| 449 | +.map(|(a, b)| a* b).collect() |
| 450 | +} |
| 451 | + |
| 452 | +#[pg_extern(immutable, parallel_safe, strict, name="multiply")] |
| 453 | +fnpgml_multiply_vector_d(vector:Vec<f64>,multiplicand:Vec<f64>) ->Vec<f64>{ |
| 454 | + vector.as_slice().iter() |
| 455 | +.zip(multiplicand.as_slice().iter()) |
| 456 | +.map(|(a, b)| a* b).collect() |
| 457 | +} |
| 458 | + |
| 459 | +#[pg_extern(immutable, parallel_safe, strict, name="divide")] |
| 460 | +fnpgml_divide_vector_s(vector:Vec<f32>,dividend:Vec<f32>) ->Vec<f32>{ |
| 461 | + vector.as_slice().iter() |
| 462 | +.zip(dividend.as_slice().iter()) |
| 463 | +.map(|(a, b)| a / b).collect() |
| 464 | +} |
| 465 | + |
| 466 | +#[pg_extern(immutable, parallel_safe, strict, name="divide")] |
| 467 | +fnpgml_divide_vector_d(vector:Vec<f64>,dividend:Vec<f64>) ->Vec<f64>{ |
| 468 | + vector.as_slice().iter() |
| 469 | +.zip(dividend.as_slice().iter()) |
| 470 | +.map(|(a, b)| a / b).collect() |
| 471 | +} |
| 472 | + |
| 473 | +#[pg_extern(immutable, parallel_safe, strict, name="norm_l0")] |
| 474 | +fnpgml_norm_l0_s(vector:Vec<f32>) ->f32{ |
| 475 | + vector.as_slice().iter().map(|a|if*a ==0.0{0.0}else{1.0}).sum() |
| 476 | +} |
| 477 | + |
| 478 | +#[pg_extern(immutable, parallel_safe, strict, name="norm_l0")] |
| 479 | +fnpgml_norm_l0_d(vector:Vec<f64>) ->f64{ |
| 480 | + vector.as_slice().iter().map(|a|if*a ==0.0{0.0}else{1.0}).sum() |
| 481 | +} |
| 482 | + |
| 483 | +#[pg_extern(immutable, parallel_safe, strict, name="norm_l1")] |
| 484 | +fnpgml_norm_l1_s(vector:Vec<f32>) ->f32{ |
| 485 | +unsafe{ |
| 486 | + blas::sasum(vector.len().try_into().unwrap(), vector.as_slice(),1) |
| 487 | +} |
| 488 | +} |
| 489 | + |
| 490 | +#[pg_extern(immutable, parallel_safe, strict, name="norm_l1")] |
| 491 | +fnpgml_norm_l1_d(vector:Vec<f64>) ->f64{ |
| 492 | +unsafe{ |
| 493 | + blas::dasum(vector.len().try_into().unwrap(), vector.as_slice(),1) |
| 494 | +} |
| 495 | +} |
| 496 | + |
| 497 | +#[pg_extern(immutable, parallel_safe, strict, name="norm_l2")] |
| 498 | +fnpgml_norm_l2_s(vector:Vec<f32>) ->f32{ |
| 499 | +unsafe{ |
| 500 | + blas::snrm2(vector.len().try_into().unwrap(), vector.as_slice(),1) |
| 501 | +} |
| 502 | +} |
| 503 | + |
| 504 | +#[pg_extern(immutable, parallel_safe, strict, name="norm_l2")] |
| 505 | +fnpgml_norm_l2_d(vector:Vec<f64>) ->f64{ |
| 506 | +unsafe{ |
| 507 | + blas::dnrm2(vector.len().try_into().unwrap(), vector.as_slice(),1) |
| 508 | +} |
| 509 | +} |
| 510 | + |
| 511 | +#[pg_extern(immutable, parallel_safe, strict, name="norm_max")] |
| 512 | +fnpgml_norm_max_s(vector:Vec<f32>) ->f32{ |
| 513 | +unsafe{ |
| 514 | +let index = blas::isamax(vector.len().try_into().unwrap(), vector.as_slice(),1); |
| 515 | + vector[index -1] |
| 516 | +} |
| 517 | +} |
| 518 | + |
| 519 | +#[pg_extern(immutable, parallel_safe, strict, name="norm_max")] |
| 520 | +fnpgml_norm_max_d(vector:Vec<f64>) ->f64{ |
| 521 | +unsafe{ |
| 522 | +let index = blas::idamax(vector.len().try_into().unwrap(), vector.as_slice(),1); |
| 523 | + vector[index -1] |
| 524 | +} |
| 525 | +} |
| 526 | + |
| 527 | +#[pg_extern(immutable, parallel_safe, strict, name="normalize_l1")] |
| 528 | +fnpgml_normalize_l1_s(vector:Vec<f32>) ->Vec<f32>{ |
| 529 | +let norm:f32; |
| 530 | +unsafe{ |
| 531 | + norm = blas::sasum(vector.len().try_into().unwrap(), vector.as_slice(),1); |
| 532 | +} |
| 533 | +pgml_divide_scalar_s(vector, norm) |
| 534 | +} |
| 535 | + |
| 536 | +#[pg_extern(immutable, parallel_safe, strict, name="normalize_l1")] |
| 537 | +fnpgml_normalize_l1_d(vector:Vec<f64>) ->Vec<f64>{ |
| 538 | +let norm:f64; |
| 539 | +unsafe{ |
| 540 | + norm = blas::dasum(vector.len().try_into().unwrap(), vector.as_slice(),1); |
| 541 | +} |
| 542 | +pgml_divide_scalar_d(vector, norm) |
| 543 | +} |
| 544 | + |
| 545 | +#[pg_extern(immutable, parallel_safe, strict, name="normalize_l2")] |
| 546 | +fnpgml_normalize_l2_s(vector:Vec<f32>) ->Vec<f32>{ |
| 547 | +let norm:f32; |
| 548 | +unsafe{ |
| 549 | + norm = blas::snrm2(vector.len().try_into().unwrap(), vector.as_slice(),1); |
| 550 | +} |
| 551 | +pgml_divide_scalar_s(vector, norm) |
| 552 | +} |
| 553 | + |
| 554 | +#[pg_extern(immutable, parallel_safe, strict, name="normalize_l2")] |
| 555 | +fnpgml_normalize_l2_d(vector:Vec<f64>) ->Vec<f64>{ |
| 556 | +let norm:f64; |
| 557 | +unsafe{ |
| 558 | + norm = blas::dnrm2(vector.len().try_into().unwrap(), vector.as_slice(),1); |
| 559 | +} |
| 560 | +pgml_divide_scalar_d(vector, norm) |
| 561 | +} |
| 562 | + |
| 563 | +#[pg_extern(immutable, parallel_safe, strict, name="normalize_max")] |
| 564 | +fnpgml_normalize_max_s(vector:Vec<f32>) ->Vec<f32>{ |
| 565 | +let norm = vector.as_slice().iter().map(|a| a.abs()).max_by(|a, b| a.partial_cmp(b).unwrap()).unwrap(); |
| 566 | +pgml_divide_scalar_s(vector, norm) |
| 567 | +} |
| 568 | + |
| 569 | +#[pg_extern(immutable, parallel_safe, strict, name="normalize_max")] |
| 570 | +fnpgml_normalize_max_d(vector:Vec<f64>) ->Vec<f64>{ |
| 571 | +let norm = vector.as_slice().iter().map(|a| a.abs()).max_by(|a, b| a.partial_cmp(b).unwrap()).unwrap(); |
| 572 | +pgml_divide_scalar_d(vector, norm) |
| 573 | +} |
| 574 | + |
| 575 | +#[pg_extern(immutable, parallel_safe, strict, name="distance_l1")] |
| 576 | +fnpgml_distance_l1_s(vector:Vec<f32>,other:Vec<f32>) ->f32{ |
| 577 | + vector.as_slice().iter() |
| 578 | +.zip(other.as_slice().iter()) |
| 579 | +.map(|(a, b)|(a - b).abs()).sum() |
| 580 | +} |
| 581 | + |
| 582 | +#[pg_extern(immutable, parallel_safe, strict, name="distance_l1")] |
| 583 | +fnpgml_distance_l1_d(vector:Vec<f64>,other:Vec<f64>) ->f64{ |
| 584 | + vector.as_slice().iter() |
| 585 | +.zip(other.as_slice().iter()) |
| 586 | +.map(|(a, b)|(a - b).abs()).sum() |
| 587 | +} |
| 588 | + |
| 589 | +#[pg_extern(immutable, parallel_safe, strict, name="distance_l2")] |
| 590 | +fnpgml_distance_l2_s(vector:Vec<f32>,other:Vec<f32>) ->f32{ |
| 591 | + vector.as_slice().iter() |
| 592 | +.zip(other.as_slice().iter()) |
| 593 | +.map(|(a, b)|(a - b).powf(2.0)).sum::<f32>().sqrt() |
| 594 | +} |
| 595 | + |
| 596 | +#[pg_extern(immutable, parallel_safe, strict, name="distance_l2")] |
| 597 | +fnpgml_distance_l2_d(vector:Vec<f64>,other:Vec<f64>) ->f64{ |
| 598 | + vector.as_slice().iter() |
| 599 | +.zip(other.as_slice().iter()) |
| 600 | +.map(|(a, b)|(a - b).powf(2.0)).sum::<f64>().sqrt() |
| 601 | +} |
| 602 | + |
| 603 | +#[pg_extern(immutable, parallel_safe, strict, name="dot_product")] |
| 604 | +fnpgml_dot_product_s(vector:Vec<f32>,other:Vec<f32>) ->f32{ |
| 605 | +unsafe{ |
| 606 | + blas::sdot(vector.len().try_into().unwrap(), vector.as_slice(),1, other.as_slice(),1) |
| 607 | +} |
| 608 | +} |
| 609 | + |
| 610 | +#[pg_extern(immutable, parallel_safe, strict, name="dot_product")] |
| 611 | +fnpgml_dot_product_d(vector:Vec<f64>,other:Vec<f64>) ->f64{ |
| 612 | +unsafe{ |
| 613 | + blas::ddot(vector.len().try_into().unwrap(), vector.as_slice(),1, other.as_slice(),1) |
| 614 | +} |
| 615 | +} |
| 616 | + |
| 617 | +#[pg_extern(immutable, parallel_safe, strict, name="cosine_similarity")] |
| 618 | +fnpgml_cosine_similarity_s(vector:Vec<f32>,other:Vec<f32>) ->f32{ |
| 619 | +unsafe{ |
| 620 | +let dot = blas::sdot(vector.len().try_into().unwrap(), vector.as_slice(),1, other.as_slice(),1); |
| 621 | +let a_norm = blas::snrm2(vector.len().try_into().unwrap(), vector.as_slice(),1); |
| 622 | +let b_norm = blas::snrm2(other.len().try_into().unwrap(), other.as_slice(),1); |
| 623 | + dot /(a_norm* b_norm) |
| 624 | +} |
| 625 | +} |
| 626 | + |
| 627 | +#[pg_extern(immutable, parallel_safe, strict, name="cosine_similarity")] |
| 628 | +fnpgml_cosine_similarity_d(vector:Vec<f64>,other:Vec<f64>) ->f64{ |
| 629 | +unsafe{ |
| 630 | +let dot = blas::ddot(vector.len().try_into().unwrap(), vector.as_slice(),1, other.as_slice(),1); |
| 631 | +let a_norm = blas::dnrm2(vector.len().try_into().unwrap(), vector.as_slice(),1); |
| 632 | +let b_norm = blas::dnrm2(other.len().try_into().unwrap(), other.as_slice(),1); |
| 633 | + dot /(a_norm* b_norm) |
| 634 | +} |
| 635 | +} |
373 | 636 | } |
374 | 637 |
|
375 | 638 | #[cfg(any(test, feature ="pg_test"))] |
|