Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit7bffa33

Browse files
authored
Make seed independent of num.threads and add legacy option (#1447)
1 parent8b08d83 commit7bffa33

File tree

42 files changed

+2245
-2104
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+2245
-2104
lines changed

‎REFERENCE.md‎

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -403,13 +403,11 @@ While the algorithm in `regression_forest` is very similar to that of classic ra
403403

404404
Overall, GRF is designed to produce the same estimates across platforms when using a consistent value for the random seed through the training option seed. However, there are still some cases where GRF can produce different estimates across platforms. When it comes to cross-platform predictions, the output of GRF will depend on a few factors beyond the forest seed.
405405

406-
One such factor is the compiler that was used to build GRF. Different compilers may have different default behavior around floating-pointrounding, and these could lead to slightly different forest splits if the data requires numerical precision.Another factor is how theforest construction is distributed across different threads. Right now, our forest splitting algorithm can give different results depending on the number of threads that were used to build the forest.
406+
One such factor is the compiler that was used to build GRF. Different compilers may have different default behavior around floating-pointbehavior and instruction optimizations, and these could lead to slightly different forest splits if the data requires numerical precision.In addition to setting theseed argument, rounding all input data to at most 8 significant digits may help.
407407

408-
Therefore, in order to ensure consistent results, we provide the following recommendations.
409-
- Make sure arguments`seed` and`num.threads` are the same across platforms
410-
- Round data to 8 significant digits
408+
Even though the compiler is the same, different CPU architectures may produce slightly different output. One such example is GRF compiled with clang and run on x86 (Intel) vs. ARM (Apple Silicon).
411409

412-
Also, please note that we have not done extensive testing on Windows platforms, although we do not expect randomnumbergeneration issues there tobe different from Linux/Mac. Regardless ofthe platform, if results are still not consistent please help us by submitting a Github issue.
410+
Prior to GRF version 2.4.0, another factor was how the forest construction was distributed across different threads. In these versions, our forest splitting algorithm can give different results depending on thenumberof threads used tobuild the forest, meaning that the num.threads argument had to bethesame for cross-platform reproducibility. To restore this behavior in current versions of GRF, you can set the global R option`options(grf.legacy.seed=TRUE)` and exactly recover results produced with past versions of the package.
413411

414412

415413
##References

‎core/src/forest/ForestOptions.cpp‎

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,15 @@ ForestOptions::ForestOptions(uint num_trees,
3737
double imbalance_penalty,
3838
uint num_threads,
3939
uint random_seed,
40+
bool legacy_seed,
4041
const std::vector<size_t>& sample_clusters,
4142
uint samples_per_cluster):
4243
ci_group_size(ci_group_size),
4344
sample_fraction(sample_fraction),
4445
tree_options(mtry, min_node_size, honesty, honesty_fraction, honesty_prune_leaves, alpha, imbalance_penalty),
4546
sampling_options(samples_per_cluster, sample_clusters),
46-
random_seed(random_seed) {
47+
random_seed(random_seed),
48+
legacy_seed(legacy_seed) {
4749

4850
this->num_threads =validate_num_threads(num_threads);
4951

@@ -85,6 +87,10 @@ uint ForestOptions::get_random_seed() const {
8587
return random_seed;
8688
}
8789

90+
boolForestOptions::get_legacy_seed()const {
91+
return legacy_seed;
92+
}
93+
8894
uintForestOptions::validate_num_threads(uint num_threads) {
8995
if (num_threads == DEFAULT_NUM_THREADS) {
9096
returnstd::thread::hardware_concurrency();

‎core/src/forest/ForestOptions.h‎

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ class ForestOptions {
4141
double imbalance_penalty,
4242
uint num_threads,
4343
uint random_seed,
44+
bool legacy_seed,
4445
const std::vector<size_t>& sample_clusters,
4546
uint samples_per_cluster);
4647

@@ -55,6 +56,8 @@ class ForestOptions {
5556

5657
uintget_num_threads()const;
5758
uintget_random_seed()const;
59+
// Toggle between seed and num_threads dependence to reproduce behavior prior to grf 2.4.0.
60+
boolget_legacy_seed()const;
5861

5962
private:
6063
uint num_trees;
@@ -66,6 +69,7 @@ class ForestOptions {
6669

6770
uint num_threads;
6871
uint random_seed;
72+
bool legacy_seed;
6973
};
7074

7175
}// namespace grf

‎core/src/forest/ForestTrainer.cpp‎

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,12 @@ std::vector<std::unique_ptr<Tree>> ForestTrainer::train_batch(
107107
trees.reserve(num_trees * ci_group_size);
108108

109109
for (size_t i =0; i < num_trees; i++) {
110-
uint tree_seed =udist(random_number_generator);
110+
uint tree_seed;
111+
if (options.get_legacy_seed()) {
112+
tree_seed =udist(random_number_generator);
113+
}else {
114+
tree_seed =static_cast<uint>(options.get_random_seed() + start + i);
115+
}
111116
RandomSamplersampler(tree_seed, options.get_sampling_options());
112117

113118
if (ci_group_size ==1) {

‎core/test/forest/ForestSmokeTest.cpp‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ TEST_CASE("forests don't crash when there are fewer trees than threads", "[fores
5252
uint samples_per_cluster =0;
5353

5454
ForestOptionsoptions(num_trees, ci_group_size, sample_fraction, mtry, min_node_size, honesty, honesty_fraction,
55-
prune, alpha, imbalance_penalty, num_threads, seed, empty_clusters, samples_per_cluster);
55+
prune, alpha, imbalance_penalty, num_threads, seed,true,empty_clusters, samples_per_cluster);
5656

5757
Forest forest = trainer.train(data, options);
5858
ForestPredictor predictor =regression_predictor(4);

‎core/test/forest/LocalLinearForestTest.cpp‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ TEST_CASE("LLF gives reasonable prediction on friedman data", "[local linear], [
4949
ForestOptionsoptions (
5050
num_trees, ci_group_size, sample_fraction,
5151
mtry, min_node_size, honesty, honesty_fraction, prune,
52-
alpha, imbalance_penalty, num_threads, seed, empty_clusters, samples_per_cluster);
52+
alpha, imbalance_penalty, num_threads, seed,true,empty_clusters, samples_per_cluster);
5353
ForestTrainer trainer =regression_trainer();
5454
Forest forest = trainer.train(data, options);
5555

@@ -136,7 +136,7 @@ TEST_CASE("local linear forests give reasonable variance estimates", "[regressio
136136
ForestOptionsoptions (
137137
num_trees, ci_group_size, sample_fraction,
138138
mtry, min_node_size, honesty, honesty_fraction, prune,
139-
alpha, imbalance_penalty, num_threads, seed, empty_clusters, samples_per_cluster);
139+
alpha, imbalance_penalty, num_threads, seed,true,empty_clusters, samples_per_cluster);
140140
ForestTrainer trainer =regression_trainer();
141141
Forest forest = trainer.train(data, options);
142142

‎core/test/utilities/ForestTestUtilities.cpp‎

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,9 @@ ForestOptions ForestTestUtilities::default_options(bool honesty,
4242
uint samples_per_cluster =0;
4343
uint num_threads =4;
4444
uint seed =42;
45+
bool legacy_seed =true;
4546

4647
returnForestOptions(num_trees,
4748
ci_group_size, sample_fraction, mtry, min_node_size, honesty, honesty_fraction,
48-
prune, alpha, imbalance_penalty, num_threads, seed, empty_clusters, samples_per_cluster);
49+
prune, alpha, imbalance_penalty, num_threads, seed,legacy_seed,empty_clusters, samples_per_cluster);
4950
}

‎r-package/grf/DESCRIPTION‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ Imports:
2929
methods,
3030
Rcpp (>= 0.12.15),
3131
sandwich (>= 2.4-0)
32-
RoxygenNote: 7.2.3
32+
RoxygenNote: 7.3.2
3333
Suggests:
3434
DiagrammeR,
3535
MASS,

‎r-package/grf/NAMESPACE‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ export(get_leaf_node)
3838
export(get_sample_weights)
3939
export(get_scores)
4040
export(get_tree)
41+
export(grf_options)
4142
export(instrumental_forest)
4243
export(ll_regression_forest)
4344
export(lm_forest)

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp