@@ -265,12 +265,12 @@ namespace Rope {
265265int bs,
266266float theta,
267267const std::vector<int >& axes_dim,
268- bool yarn =false ,
269- std::vector<int > max_pe_len = {},
270- int ori_max_pe_len = 64 ,
271- bool dype =false ,
272- float current_timestep =1 .0f ,
273- std::vector<float > ntk_factors = {}) {
268+ bool yarn =false ,
269+ std::vector<int > max_pe_len = {},
270+ std::vector< int > ori_max_pe_len= { 64 , 64 , 64 } ,
271+ bool dype =false ,
272+ float current_timestep =1 .0f ,
273+ std::vector<float > ntk_factors = {}) {
274274 std::vector<std::vector<float >> trans_ids =transpose (ids);
275275size_t pos_len = ids.size () / bs;
276276int num_axes = axes_dim.size ();
@@ -292,7 +292,7 @@ namespace Rope {
292292
293293for (int i =0 ; i < num_axes; ++i) {
294294 std::vector<std::vector<float >> rope_emb =rope_ext (
295- trans_ids[i], axes_dim[i], theta,false ,1 .0f , ntk_factors[i],true , yarn, max_pe_len[i], ori_max_pe_len, dype, current_timestep);
295+ trans_ids[i], axes_dim[i], theta,false ,1 .0f , ntk_factors[i],true , yarn, max_pe_len[i], ori_max_pe_len[i] , dype, current_timestep);
296296
297297for (int b =0 ; b < bs; ++b) {
298298for (size_t j =0 ; j < pos_len; ++j) {
@@ -372,12 +372,31 @@ namespace Rope {
372372bool use_ntk =false ,
373373float current_timestep =1 .0f ) {
374374int base_resolution =1024 ;
375+ int base_patches_H = -1 ;
376+ int base_patches_W = -1 ;
377+
375378// set it via environment variable for now (TODO: arg)
379+ // could be either a single integer, or WxH
376380const char * env_base_resolution =getenv (" FLUX_DYPE_BASE_RESOLUTION" );
377381if (env_base_resolution !=nullptr ) {
378- base_resolution =atoi (env_base_resolution);
382+ if (strchr (env_base_resolution,' x' ) !=nullptr ) {
383+ const char * x_pos =strchr (env_base_resolution,' x' );
384+ base_patches_H =atoi (x_pos +1 ) /16 ;
385+ base_patches_W =atoi (env_base_resolution) /16 ;
386+ }else {
387+ base_resolution =atoi (env_base_resolution);
388+ }
379389 }
380- int base_patches = base_resolution /16 ;
390+ // preserve aspect ratio of the input image
391+ // base_patches_W = k*w, base_patches_H = k*h, base_patches_W*base_patches_H = base_resolution^2
392+ // => k = base_resolution / sqrt(w*h)
393+ if (base_patches_H == -1 )
394+ base_patches_H = (base_resolution * h *sqrt (1 .0f / (w * h))) /16 ;
395+ if (base_patches_W == -1 )
396+ base_patches_W = (base_resolution * w *sqrt (1 .0f / (w * h))) /16 ;
397+
398+ // First dim is ref image, should not need any weird rope modifications since the max pos should stay very low. 1024 is a lot
399+ std::vector<int > base_patches = {1024 , base_patches_H, base_patches_W};
381400 std::vector<std::vector<float >> ids =gen_flux_ids (h, w, patch_size, bs, context_len, ref_latents, increase_ref_index);
382401 std::vector<int > max_pos_vec = {};
383402 std::vector<float > ntk_factor_vec = {};
@@ -393,7 +412,7 @@ namespace Rope {
393412 max_pos_vec.push_back (max_pos);
394413float ntk_factor =1 .0f ;
395414if (use_ntk) {
396- float base_ntk =pow ((float )max_pos / base_patches, (float )axes_dim[i] / (axes_dim[i] -2 ));
415+ float base_ntk =pow ((float )max_pos / base_patches[i] , (float )axes_dim[i] / (axes_dim[i] -2 ));
397416 ntk_factor = use_dype ?pow (base_ntk,2 .0f * current_timestep * current_timestep) : base_ntk;
398417 ntk_factor =std::max (1 .0f , ntk_factor);
399418 }