Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

[tests] unbloattests/lora/utils.py#11845

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.

Already on GitHub?Sign in to your account

Draft
sayakpaul wants to merge24 commits intomain
base:main
Choose a base branch
Loading
fromtests/unbloat-lora-utils

Conversation

sayakpaul
Copy link
Member

@sayakpaulsayakpaul commentedJul 1, 2025
edited
Loading

What does this PR do?

We take the following approach:

  • Useparameterized to combine similar flavored tests.
  • Modularize repeated blocks of code into functions and use them as much as possible.
  • Remove redundant tests.
  • We makepeft>=0.15.0 a mandate. So, I removed@require_peft_version_greater decorator.

In a follow-up PR, I will attempt to improve tests from the LoRA test suite that take the most amount of time.

@@ -103,34 +103,6 @@ def get_dummy_inputs(self, with_generator=True):

return noise, input_ids, pipeline_inputs

@unittest.skip("Not supported in AuraFlow.")
Copy link
MemberAuthor

@sayakpaulsayakpaulJul 1, 2025
edited
Loading

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

These are skipped appropriately from the parent method. I think it's okay in this case, because it eases things a bit.

@sayakpaulsayakpaul changed the title[wip][tests] unbloattests/lora/utils.py[tests] unbloattests/lora/utils.pyJul 3, 2025
@sayakpaulsayakpaul marked this pull request as ready for reviewJuly 3, 2025 09:32
@BenjaminBossan
Copy link
Member

I haven't checked the PR yet, but I was wondering: When I do bigger refactors of tests, I always check the line coverage before and after the refactor and ensure that they're the same (in PEFT we usepytest-cov). IMO, having some kind of validation that the refactor maintains the same coverage (or increases it) is quite important, as it's hard to notice when suddenly, some lines of code are not tested anymore.

sayakpaul reacted with thumbs up emoji

@sayakpaul
Copy link
MemberAuthor

Indeed, it's important. Do you have any more guidelines for me to do that?

@BenjaminBossan
Copy link
Member

Do you have any more guidelines for me to do that?

So you can installpytest-cov if not already done and then add--cov=src/diffusers --cov-report=term-missing to your pytest invocation. Do this once for main and once for this branch. You get a report with something like this:

Name                                                  Stmts   Miss  Cover   Missing-----------------------------------------------------------------------------------src/peft/__init__.py                                     10      0   100%src/peft/auto.py                                         71     34    52%   61, 82-142src/peft/config.py                                      132     29    78%   39-43, 67, 89, 144, 148-172, 199-204, 225-228, 240, 242, 261-268, 344src/peft/helpers.py                                      72     72     0%   15-251src/peft/import_utils.py                                 94     56    40%   25, 30-35, 40-46, 54-73, 81, 87-98, 108, 128-147, 156-167, 172src/peft/mapping.py                                      21     10    52%   25-26, 44, 65-78src/peft/mapping_func.py                                 36     11    69%   73, 79, 85-89, 96, 103, 110, 123-125[...]src/peft/utils/peft_types.py                             58      6    90%   137, 140, 143, 156, 159, 163-----------------------------------------------------------------------------------TOTAL                                                 16679  12939    22%

It can be a bit difficult to parse what has changed, but basically, you want themissing column after the refactor to be 1) the same as before or 2) a subset (i.e. strictly better coverage). You could write a small tool to compare before and after (I know I did it once but can't find the snippet anymore) or just ask an LLM, they are quite good at finding the differences.

sayakpaul reacted with eyes emoji

@sayakpaul
Copy link
MemberAuthor

sayakpaul commentedJul 4, 2025
edited
Loading

@BenjaminBossan I reran with:

CUDA_VISIBLE_DEVICES="" pytest \  -n 24 --max-worker-restart=0 --dist=loadfile \  --cov=src/diffusers/ \  --cov-report=term-missing \  --cov-report=json:unbloat.json \  tests/lora/

on this branch andmain.

I then used Gemini to report the file locations where this branch has reduced coverage:

Coverage Comparison Summary==============================📊 Coverage Changes:  - utils/testing_utils.py: 33.53% -> 33.38% (-0.15%)  - schedulers/scheduling_k_dpm_2_ancestral_discrete.py: 16.40% -> 0.00% (-16.40%)  - utils/typing_utils.py: 64.86% -> 8.11% (-56.75%)  - models/unets/unet_2d_condition.py: 53.81% -> 53.36% (-0.45%)  - configuration_utils.py: 66.24% -> 46.50% (-19.74%)  - schedulers/scheduling_dpmsolver_singlestep.py: 8.67% -> 0.00% (-8.67%)  - models/model_loading_utils.py: 33.18% -> 17.73% (-15.45%)  - pipelines/pipeline_loading_utils.py: 28.78% -> 16.77% (-12.01%)  - utils/hub_utils.py: 30.26% -> 22.56% (-7.70%)  - schedulers/scheduling_euler_discrete.py: 22.22% -> 14.14% (-8.08%)  - utils/logging.py: 52.27% -> 50.76% (-1.51%)  - utils/peft_utils.py: 82.74% -> 82.23% (-0.51%)  - schedulers/scheduling_deis_multistep.py: 10.63% -> 0.00% (-10.63%)  - schedulers/scheduling_unipc_multistep.py: 7.60% -> 0.00% (-7.60%)  - schedulers/scheduling_utils.py: 98.00% -> 84.00% (-14.00%)  - pipelines/pipeline_utils.py: 52.76% -> 33.55% (-19.21%)  - schedulers/scheduling_edm_euler.py: 22.22% -> 0.00% (-22.22%)  - models/modeling_utils.py: 51.14% -> 23.00% (-28.14%)  - schedulers/scheduling_k_dpm_2_discrete.py: 17.17% -> 0.00% (-17.17%)==============================
Code
importjsonfromdecimalimportDecimal,ROUND_HALF_UPdefget_coverage_data(report_path):"""Loads a JSON coverage report and extracts file coverage data."""withopen(report_path,'r')asf:data=json.load(f)file_coverage= {}forfilename,statsindata['files'].items():# Clean up the filename for better readabilityclean_filename=filename.replace('src/diffusers/','')# Calculate coverage percentage with two decimal placescovered_lines=stats['summary']['covered_lines']num_statements=stats['summary']['num_statements']ifnum_statements>0:coverage_percent= (Decimal(covered_lines)/Decimal(num_statements))*100file_coverage[clean_filename]=coverage_percent.quantize(Decimal('0.01'),rounding=ROUND_HALF_UP)else:file_coverage[clean_filename]=Decimal('0.0')returnfile_coveragedefcompare_coverage(main_report,feature_report):"""Compares two coverage reports and prints a summary of the differences."""main_coverage=get_coverage_data(main_report)feature_coverage=get_coverage_data(feature_report)main_files=set(main_coverage.keys())feature_files=set(feature_coverage.keys())# --- Report Summary ---print("Coverage Comparison Summary\n"+"="*30)# Files with changed coveragecommon_files=main_files.intersection(feature_files)changed_coverage_files= {file: (main_coverage[file],feature_coverage[file])forfileincommon_filesifmain_coverage[file]!=feature_coverage[file]    }ifchanged_coverage_files:print("\n📊 Coverage Changes:")forfile, (main_cov,feature_cov)inchanged_coverage_files.items():change=feature_cov-main_covprint(f"  -{file}:{main_cov}% ->{feature_cov}% ({'+'ifchange>0else''}{change.quantize(Decimal('0.01'))}%)")else:print("\nNo change in coverage for existing files.")# New files in the feature branchnew_files=feature_files-main_filesifnew_files:print("\n✨ New Files in Feature Branch:")forfileinnew_files:print(f"  -{file} (Coverage:{feature_coverage[file]}%)")# Removed files from the feature branchremoved_files=main_files-feature_filesifremoved_files:print("\n🗑️ Removed Files from Feature Branch:")forfileinremoved_files:print(f"  -{file}")print("\n"+"="*30)if__name__=="__main__":compare_coverage('coverage_main.json','coverage_feature.json')

Will try to improve it / see what is going on. I think coverage reductions in files likehub_utils, `peft_utils might be okay to ignore. But will see.

@BenjaminBossan
Copy link
Member

Nice, there seem to be some big drops in a couple files, definitely worth investigating.

I skimmed the script and I think it's not quite correct. If, say, before, foo.py was covered line 0-10 of 20 total lines, and after, lines 10-20 are covered, the difference would be reported as 0. But in reality, 10 lines are being missed. So the more accurate way would be to check theMissing column line intervals for overlap. You could also try just asking directly: "Given these 2 coverage reports, please list every line in every file that is no longer covered".

sayakpaul reacted with eyes emoji

@sayakpaul
Copy link
MemberAuthor

sayakpaul commentedJul 4, 2025
edited
Loading

@BenjaminBossan here are my findings.

First, here's the updated comparison script:

Code
importjsonfromdecimalimportDecimal,ROUND_HALF_UPdefparse_coverage_report(report_path:str)->dict:"""    Loads a JSON coverage report and extracts detailed data for each file,    including missing lines and coverage percentage.    """try:withopen(report_path,'r')asf:data=json.load(f)except (FileNotFoundError,json.JSONDecodeError)ase:print(f"Error loading{report_path}:{e}")return {}coverage_data= {}forfilename,statsindata.get('files', {}).items():summary=stats.get('summary', {})covered=summary.get('covered_lines',0)total=summary.get('num_statements',0)# Calculate coverage percentageiftotal>0:percentage= (Decimal(covered)/Decimal(total))*100else:percentage=Decimal('100.0')# No statements means 100% coveredcoverage_data[filename]= {'missing_lines':set(stats.get('missing_lines', [])),'coverage_pct':percentage.quantize(Decimal('0.01'),rounding=ROUND_HALF_UP)        }returncoverage_datadefformat_lines_as_ranges(lines:list[int])->str:"""Converts a list of line numbers into a compact string of ranges."""ifnotlines:return""ranges= []start=end=lines[0]foriinrange(1,len(lines)):iflines[i]==end+1:end=lines[i]else:ranges.append(f"{start}-{end}"ifstart!=endelsef"{start}")start=end=lines[i]ranges.append(f"{start}-{end}"ifstart!=endelsef"{start}")return", ".join(ranges)deffind_and_report_coverage_changes(main_report_path:str,feature_report_path:str):"""    Compares two coverage reports and prints a detailed report on any    lost coverage, including percentages and specific line numbers.    """main_data=parse_coverage_report(main_report_path)feature_data=parse_coverage_report(feature_report_path)lost_coverage_report= {}# Find files with lost line coverageforfilename,main_file_datainmain_data.items():iffilenameinfeature_data:feature_file_data=feature_data[filename]# Find lines that are missing now but were NOT missing beforenewly_missed_lines=sorted(list(feature_file_data['missing_lines']-main_file_data['missing_lines']            ))# Record if there are newly missed lines OR if the percentage has dropped# (e.g., due to new uncovered lines being added)ifnewly_missed_linesorfeature_file_data['coverage_pct']<main_file_data['coverage_pct']:lost_coverage_report[filename]= {'lines':newly_missed_lines,'main_pct':main_file_data['coverage_pct'],'feature_pct':feature_file_data['coverage_pct']                }# --- Print the Final Report ---print("📊❌ Coverage Change Report")print("="*30)ifnotlost_coverage_report:print("\n✅ No coverage degradation detected. Great job!")returnprint("\nThe following files have reduced coverage:\n")forfilename,changesinlost_coverage_report.items():clean_filename=filename.replace('src/diffusers/','')main_pct=changes['main_pct']feature_pct=changes['feature_pct']diff= (feature_pct-main_pct).quantize(Decimal('0.01'))print(f"📄 File:{clean_filename}")print(f"   Percentage:{main_pct}% →{feature_pct}% ({diff}%)")ifchanges['lines']:print(f"   Newly Missed Lines:{format_lines_as_ranges(changes['lines'])}")print("-"*25)if__name__=="__main__":find_and_report_coverage_changes('coverage_main.json','unbloat.json')

The JSON files were obtained by running the following command once onmain and once on this PR branch:

CUDA_VISIBLE_DEVICES="" pytest \  -n 24 --max-worker-restart=0 --dist=loadfile \  --cov=src/diffusers/ \  --cov-report=json:<CHANGE_ME>.json \  tests/lora/

Here is first report before fixes:

Unroll
📊❌ Coverage Change Report==============================The following files have reduced coverage:📄 File: configuration_utils.py   Percentage: 66.24% → 46.50% (-19.74%)   Newly Missed Lines: 161, 164, 167, 169-170, 172, 268, 342-353, 355-356, 358, 360, 366, 373, 375-376, 380, 382, 440-441, 443, 447-448, 450, 452-453, 455-456, 458, 493, 499-500, 567, 570-572, 595-597, 599-600, 602, 604, 606, 613, 615-616, 618, 620, 630-631-------------------------📄 File: loaders/lora_base.py   Percentage: 78.16% → 78.16% (0.00%)   Newly Missed Lines: 732, 760-------------------------📄 File: models/model_loading_utils.py   Percentage: 33.18% → 17.73% (-15.45%)   Newly Missed Lines: 67, 114, 141-142, 165, 169, 173, 176, 231-232, 234-235, 238, 242, 258-261, 263, 266-268, 270-271, 273, 277, 291, 293, 295, 302, 304, 350-351, 381-------------------------📄 File: models/modeling_utils.py   Percentage: 51.14% → 23.00% (-28.14%)   Newly Missed Lines: 82-83, 86-87, 90, 238-239, 241-244, 247-248, 649, 653-654, 666-668, 672, 674, 683, 687-688, 691, 694, 699-701, 703-704, 706-708, 710, 716-719, 722, 726, 743-744, 746, 907-927, 929, 935-938, 940, 949, 956, 962, 968, 975, 977, 985, 993, 999-1000, 1004, 1009, 1012, 1015, 1031, 1035-1036, 1047, 1049, 1065, 1069-1071, 1074, 1077, 1080, 1082-1083, 1086-1089, 1105, 1108, 1110, 1113, 1117, 1139, 1152-1154, 1177, 1193-1194, 1199-1200, 1207, 1209-1210, 1212-1213, 1215, 1218-1219, 1221, 1223, 1225, 1228, 1230, 1236, 1239, 1242, 1265, 1273, 1281, 1285, 1293, 1298, 1301, 1303, 1306, 1460-1463, 1465, 1468-1470, 1472, 1474-1475, 1478, 1490-1491, 1495-1496, 1498, 1501, 1503, 1506-1507, 1509, 1515-1516, 1531, 1533, 1540-1541, 1560, 1568, 1576, 1581, 1583, 1589-1590, 1596, 1610, 1800, 1802-1804, 1806-1808, 1810, 1816, 1820, 1822, 1826, 1828, 1832, 1834, 1838, 1840, 1842-------------------------📄 File: models/unets/unet_2d_condition.py   Percentage: 53.81% → 53.36% (-0.45%)   Newly Missed Lines: 536-537-------------------------📄 File: pipelines/pipeline_loading_utils.py   Percentage: 28.78% → 16.77% (-12.01%)   Newly Missed Lines: 378, 380, 385, 393, 395-396, 398, 445, 455-456, 721, 725, 735, 737-739, 742, 756, 759-763, 768, 770-771, 775, 784-790, 792, 797, 805-806, 809-810, 814, 826, 829-830, 835, 851, 862, 867, 901-902, 909, 913-914, 916, 922, 926, 1137-1138-------------------------📄 File: pipelines/pipeline_utils.py   Percentage: 52.76% → 33.55% (-19.21%)   Newly Missed Lines: 272-276, 278, 286, 288-293, 295-298, 302, 306, 308-310, 316-318, 320-323, 325, 333, 336-339, 341-346, 350, 353, 355, 739, 741-764, 766, 772, 781, 784, 790, 796, 801, 804, 809, 813, 819, 827, 852, 856, 869-870, 876, 878, 882, 888-889, 895, 898, 908, 914, 926-930, 933, 938, 941-944, 946, 948, 951, 959, 965, 968-969, 987-989, 991, 999, 1002-1004, 1007, 1017, 1022, 1047, 1051, 1054, 1064-1070, 1077, 1079-1080, 1083-1085, 1090, 1093, 1096-1097, 1099, 1699-1700-------------------------📄 File: schedulers/scheduling_deis_multistep.py   Percentage: 10.63% → 0.00% (-10.63%)   Newly Missed Lines: 18-19, 21-22, 24-26, 29-30, 34, 78, 130-131, 133-134, 210-211, 217-218, 225, 235, 314, 348, 372, 383, 409, 431, 462, 522, 580, 649, 738, 758, 770, 835, 851, 885-------------------------📄 File: schedulers/scheduling_dpmsolver_singlestep.py   Percentage: 8.67% → 0.00% (-8.67%)   Newly Missed Lines: 17-18, 20-21, 23-26, 29-30, 32, 36, 80, 145-146, 148-149, 235, 275-276, 282-283, 290, 300, 405, 439, 463, 474, 500, 522, 553, 653, 717, 828, 950, 1014, 1034, 1046, 1117, 1133, 1167-------------------------📄 File: schedulers/scheduling_edm_euler.py   Percentage: 22.22% → 0.00% (-22.22%)   Newly Missed Lines: 15-17, 19, 21-24, 27, 30, 32, 45-46, 49, 85-86, 88-89, 133-134, 138-139, 145-146, 153, 163, 168, 176, 191, 215, 265, 276, 287, 302, 310, 410, 443, 447-------------------------📄 File: schedulers/scheduling_euler_discrete.py   Percentage: 22.22% → 14.14% (-8.08%)   Newly Missed Lines: 207, 209, 213, 215, 217, 219, 226, 229-230, 232, 237-239, 242, 245, 248, 250, 252-255, 257-259-------------------------📄 File: schedulers/scheduling_k_dpm_2_ancestral_discrete.py   Percentage: 16.40% → 0.00% (-16.40%)   Newly Missed Lines: 15-17, 19-20, 22-25, 28-29, 32, 34, 47-48, 52, 96, 135-136, 138-139, 181-182, 189-190, 196-197, 204, 214, 244, 344, 368, 394, 416, 447-448, 452, 467, 475, 583, 616-------------------------📄 File: schedulers/scheduling_k_dpm_2_discrete.py   Percentage: 17.17% → 0.00% (-17.17%)   Newly Missed Lines: 15-17, 19-20, 22-24, 27-28, 31, 33, 46-47, 51, 95, 134-135, 137-138, 181-182, 189-190, 196-197, 204, 214, 244, 328-329, 333, 348, 357, 381, 407, 429, 460, 555, 588-------------------------📄 File: schedulers/scheduling_unipc_multistep.py   Percentage: 7.60% → 0.00% (-7.60%)   Newly Missed Lines: 18-19, 21-22, 24-26, 29-30, 34, 79, 115, 185-186, 188-189, 276-277, 283-284, 291, 301, 424, 458, 482, 493, 519, 541, 572, 645, 774, 912, 932, 944, 1025, 1041, 1075-------------------------📄 File: schedulers/scheduling_utils.py   Percentage: 98.00% → 84.00% (-14.00%)   Newly Missed Lines: 151, 158, 175, 189-191, 194-------------------------📄 File: utils/hub_utils.py   Percentage: 30.26% → 22.56% (-7.70%)   Newly Missed Lines: 79-80, 82-84, 87, 90, 92-93, 96, 189, 191-194, 200, 205-------------------------📄 File: utils/logging.py   Percentage: 52.27% → 50.76% (-1.51%)   Newly Missed Lines: 307-308-------------------------📄 File: utils/peft_utils.py   Percentage: 82.74% → 82.23% (-0.51%)   Newly Missed Lines: 222-------------------------📄 File: utils/testing_utils.py   Percentage: 33.53% → 33.38% (-0.15%)   Newly Missed Lines: 543-544, 547, 551-------------------------📄 File: utils/typing_utils.py   Percentage: 64.86% → 8.11% (-56.75%)   Newly Missed Lines: 26, 30-32, 35-36, 38, 41, 43, 47-49, 51, 54, 64, 71, 78, 80, 84, 86, 91-------------------------

Then I added back in this test:
test_simple_inference_save_pretrained_with_text_lora() and the coverage improved immediately and rightfully so (because of the call tosave_pretrained()here). This way, through the LoRA tests, we are touching the code in the sections where coverage was lagging behind.

After that, when I added

deftest_simple_inference(self):

the coverage was no longer lagging behind, except forutils/typing_utils.py (24.32%).

Here are my two cents:

  • test_simple_inference() isn't really a meaningful test. So, even if we reduce the coverage with a bit, it's alright, IMO.
  • And then we have only type-checking related things inutils/typing_utils.py, which again, are likely not that important to check here, IMO. I have also investigated the affected lines reported after the comparison. They didn't seem particularly important to me, either.

LMK if these make sense or if anything is unclear.

@BenjaminBossan
Copy link
Member

Nice, so the coverage is basically back to what it was previously. I'm not sure if

test_simple_inference() isn't really a meaningful test.

as it seems to hit lines that would otherwise remain untested. If the same coverage can be achieved with a better test, then that should be added, otherwise I don't really see the harm in keeping this simple test.

@sayakpaul
Copy link
MemberAuthor

I will add it back in but I think the current state of PR should not now be a blocker for reviews.

BenjaminBossan reacted with thumbs up emoji

Copy link
Member

@BenjaminBossanBenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

First of all, thanks a lot for taking this big task to refactor and simplify those tests. Seeing net 1000 lines removed is always great.

As for reviewing the change, I have to admit that it's quite hard. The overall amount of changes is quite large (also, some changes appear to be unrelated, like renaming variables). Although I suspect most new code is just old code moved to a different location (with some small changes), the diff view does not reveal that. Therefore, I haven't done a line per line review at this point.

Generally, refactoring tests is a delicate task. It is easy to accidentally coverfewer scenarios, since the tests will pass just fine. That is why I really wanted to see the change in test coverage. Of course, this is not a magic bullet to ensure that everything is still tested that was tested before, but I think it's the best boost in confidence we can get.

As for the gist of the refactor, I think it is a nice improvement, as witnessed by the lower line count combined with keeping the test coverage. That said, if I could dream of a "perfect" design, it would look more like this to me (usingtest_lora_set_adapters_scenarios as an example):

# specific model class@parametrize("scheduler_cls,<scheduler-classes>)deftest_lora_set_adapters_simple(self,scheduler_cls):super()._test_lora_set_adapters_simple(scheduler_cls)@parametrize("scheduler_cls,<scheduler-classes>)deftest_lora_set_adapters_weighted(self,scheduler_cls):super()._test_lora_set_adapters_weighted(scheduler_cls)...# base classdef_test_lora_set_adapters_simple(self,scheduler_cls):# maybe even consider parametrized pytest fixturespipe,inputs,output_no_lora,_=self._setup_multi_adapter_pipeline(scheduler_cls)# test for simple scenariodef_test_lora_set_adapters_weighted(self,scheduler_cls):# maybe even consider parametrized pytest fixturespipe,inputs,output_no_lora,_=self._setup_multi_adapter_pipeline(scheduler_cls)# test for weighted scenario

That way, each test would be testing for "one thing" instead of multiple things, which is preferable most of the time. Thechecks that precede the scenario-specific logic could be moved out into a separate test, that way we avoid duplicated checks.

I'm not asking to do another refactor according to what I described. As I wrote, I think this is already an improvement and whether my suggestion is really better can be debated, I just wanted present my opinion on it.


def test_simple_inference_with_text_denoiser_lora_unfused(self):
super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3)
@parameterized.expand([("simple",), ("weighted",), ("block_lora",), ("delete_adapter",)])

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Why do the scenarios have to be single element tuples? And is"simple" a valid scenario here?

@unittest.skip("Not supported in CogVideoX.")
def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
pass
# TODO: skip them properly

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

I think this comment, on its own, is not very helpful in explain what it is that needs to be done here.

@DN6
Copy link
Collaborator

DN6 commentedJul 7, 2025
edited
Loading

A few things to consider when refactoring this Mixin. I think we should try to address speed, composability, control flow and readability.

Regarding Speed;

We run a very large combination of tests here. And base output is computed for each of them. We can speed up test time by ~2X, by caching the base output and reusing across test cases since this doesn't change

Another thing to help with speed is to only test with a single scheduler. The default is to use two schedulers and manually override in the inheriting class. I don't think the additional scheduler test is giving us much signal in terms of LoRA functionality.

Regarding Control Flow / Readability

When we refactoredLoRABaseMixin, we did so in a way that made it agnostic to the underlying Pipeline (it didn't matter how many text encoders were LoRA compatible, the Mixin was able to account for any number of them). In the current testing Mixin we have lots of manual checks for text_encoder, text_encoder_2, etc. We can apply a similar approach to make the testing Mixin agnostic to this.

I thinkget_dummy_components() is doing quite a bit here. A Mixin class shouldn't be expected to account for every case in the inheriting class. It would be better to move this method out of the Mixin class and into the inheriting classes like we do for other test cases. You have to define the class and repo id in the inheriting class anyway, so there isn't much benefit trying to account for it in the Mixin and you can then remove a bunch of the conditional checks in the tests as well. Similarlyget_dummy_inputs should also be moved into the individual test instances. It makes reading the tests easier.

There are utility functions that are probably better off being broken up into individual functions. e.g.self._setup_pipeline_and_get_base_output is used to initialise a pipeline, compute a baseline output, and setup text and denoiser lora configs. This is quite a lot to do, and as a result different combination of return values are omitted in numerous places

pipe,inputs,_,text_lora_config,denoiser_lora_config=self._setup_pipeline_and_get_base_output(

pipe,inputs,_,text_lora_config,_=self._setup_pipeline_and_get_base_output(scheduler_cls)

pipe,inputs,output_no_lora,_,_=self._setup_pipeline_and_get_base_output(scheduler_cls)

pipe,_,_,text_lora_config,denoiser_lora_config=self._setup_pipeline_and_get_base_output(

This is a sign that the function should be broken up into smaller pieces and called individually.

Regarding Composability

It's good to use parameterized expand if code can be reused across different combinations of LoRA actions. But I don't think having a single function handle all cases will scale well as we add additional testing conditions. There is a risk to get a very large function with lots of conditional paths.

It's better to make a look up table for components that need to be tested (text encoder, denoiser, text + denoiser) and actions to tested (fuse, unfuse, load, unload) and then compose those together to run the test.

Proposed new Mixin and a some pseudo examples.

classPeftLoraLoaderMixinTests:pipeline_class=Nonescheduler_class=Nonescheduler_kwargs=Nonelora_supported_text_encoders= []denoiser_name=""text_encoder_target_modules= ["q_proj","k_proj","v_proj","out_proj"]denoiser_target_modules= ["to_q","to_k","to_v","to_out.0"]COMPONENT_SETUP_MAP= {"text_encoder_only": ["_setup_text_encoder", ["text_lora_config"]],"denoiser_only": ["_setup_denoiser", ["denoiser_lora_config"]],"text_and_denoiser": ["_setup_text_and_denoiser", ["text_lora_config","denoiser_lora_config"]],    }ACTION_MAP= {"fuse":"_action_fuse","unfuse":"_action_unfuse","save_load":"_action_save_load","unload":"_action_unload","scale":"_action_scale",    }_base_output=Nonerank=4defget_lora_config(self,rank,target_modules,lora_alpha=None,use_dora=False):returnLoraConfig(r=rank,target_modules=target_modules,lora_alpha=lora_alpha,init_lora_weights=False,use_dora=use_dora        )defget_dummy_components(self):raiseNotImplementedErrordefget_dummy_inputs(self,with_generator=True):raiseNotImplementedError@propertydefoutput_shape(self):raiseNotImplementedErrordefsetup_pipeline(self):components=self.get_dummy_components()pipe=self.pipeline_class(**components)pipe=pipe.to(torch_device)pipe.set_progress_bar_config(disable=None)returnpipedefget_base_output(self,pipe,with_generator=True):ifself._base_outputisNone:inputs=self.get_dummy_inputs(with_generator)self._base_output=pipe(**inputs,generator=torch.manual_seed(0))[0]returnself._base_outputdef_setup_lora_text_encoders(self,pipe,text_lora_config):forcomponent_nameinself.lora_supported_text_encoders:component=getattr(pipe,component_name)component.add_adapter(text_lora_config)self.assertTrue(check_if_lora_correctly_set(component),f"Lora not correctly set in{component_name}")returnself.lora_supported_text_encodersdef_setup_lora_denoiser(self,pipe,denoiser_lora_config):component=getattr(pipe,self.denoiser_name)component.add_adapter(denoiser_lora_config)self.assertTrue(check_if_lora_correctly_set(component),f"Lora not correctly set in{component}.")returnself.denoiser_namedef_setup_text_and_denoiser(self,pipe,text_lora_config,denoiser_lora_config):text_encoders=self._setup_lora_text_encoders(pipe,text_lora_config)denoiser=self._setup_lora_denoiser(pipe,denoiser_lora_config)returntext_encoders.append(denoiser)def_action_fuse(self,pipe,base_output,lora_output,lora_components,expected_atol=1e-3):pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules)inputs=self.get_dummy_inputs(with_generator=False)output=pipe(**inputs,generator=torch.manual_seed(0))[0]self.assertFalse(np.allclose(base_output,output,atol=expected_atol),f"Output after fuse should differ from base output",        )def_action_save_load(self,pipe,base_output,lora_output,lora_components,expected_atol=1e-3):withtempfile.TemporaryDirectory()astmpdir:modules_to_save= {}forcomponent_nameinlora_components:ifnothasattr(pipe,component_name):continuemodules_to_save[component_name]=getattr(pipe,component_name)# Savestate_dicts= {}metadatas= {}formodule_name,moduleinmodules_to_save.items():ifmoduleisnotNoneandgetattr(module,"peft_config",None)isnotNone:state_dicts[f"{module_name}_lora_layers"]=get_peft_model_state_dict(module)metadatas[f"{module_name}_lora_adapter_metadata"]=module.peft_config["default"].to_dict()pipe.save_lora_weights(tmpdir,weight_name="lora.safetensors",**state_dicts,**metadatas)pipe.unload_lora_weights()pipe.load_lora_weights(tmpdir,weight_name="lora.safetensors")# remaining assertionsdef_action_unload(self,pipe,base_output,lora_output,lora_components,expected_atol=1e-3):pipe.unload_lora_weights()forcomponent_nameinlora_components:self.assertFalse(check_if_lora_correctly_set(getattr(pipe,component_name)),f"Lora layers should not be present in{component_name} after unloading",            )inputs=self.get_dummy_inputs(with_generator=False)outputs=pipe(**inputs,generator=torch.manual_seed(0))[0]# remaining assertionsdef_should_skip_test(self,components):ifcomponentsin ["text_encoder_only","text_and_denoiser"]:return"text_encoder"notinself.pipeline_class._lora_loadable_modulesreturnFalsedef_setup_lora_components(self,pipe,components,text_lora_config,denoiser_lora_config):method_name,config_names=self.COMPONENT_SETUP_MAP[components]setup_method=getattr(self,method_name)kwargs= {"pipe":pipe}config_map= {"text_lora_config":text_lora_config,"denoiser_lora_config":denoiser_lora_config}forconfig_nameinconfig_names:kwargs.update({config_name:config_map[config_name]})components=setup_method(**kwargs)returncomponentsdef_execute_lora_action(self,action,pipe,base_output,lora_output,lora_components,expected_atol):"""Execute a specific LoRA action and return the output"""action_method=getattr(self,self.ACTION_MAP[action])returnaction_method(pipe,base_output,lora_output,lora_components,expected_atol)def_test_lora_action(self,action,components,expected_atol=1e-4):# Skip if not supportedifself._should_skip_test(components):self.skipTest(f"{components} LoRA is not supported")pipe=self.setup_pipeline()base_output=self.get_base_output(pipe)lora_components=self._setup_lora_components(pipe,components)lora_output=pipe(**get_inputs())self._execute_lora_action(action,pipe,base_output,lora_output,lora_components,expected_atol)@parameterized.expand(        [# Test actions on text_encoder LoRA only            ("fused","text_encoder_only"),            ("unloaded","text_encoder_only"),            ("save_load","text_encoder_only"),# Test actions on both text_encoder and denoiser LoRA            ("fused","text_and_denoiser"),            ("unloaded","text_and_denoiser"),            ("unfused","text_and_denoiser"),            ("save_load","text_and_denoiser"),            ("disable","text_and_denoiser"),        ],name_func=lambdafunc,num,p:f"{func.__name__}_{p[0]}_{p[1]}",# so that test logs give us a nice test name and not an index    )deftest_lora_actions(self,action,components):"""Test various LoRA actions with different component combinations"""self._test_lora_action(action,components)deftest_low_cpu_mem_usage_with_injection(self):pipe=self.setup_pipeline()text_lora_config=self.get_config(...)denoiser_lora_config=self.get_lora_config(...)forcomponent_nameinself.lora_supported_text_encoders:ifcomponent_namenotinself.pipeline_class._lora_loadable_modules:continueinject_adapter_in_model(text_lora_config,getattr(pipe,component_name),low_cpu_mem_usage=True)selfself.assertTrue(check_if_lora_correctly_set(getattr(pipe,component_name)),f"Lora not correctly set in{component_name}."            )# remaining assertions# denoiser tests

@sayakpaul
Copy link
MemberAuthor

Thanks both for being candid about the feedback. I will try to address as much as possible.

@sayakpaulsayakpaul marked this pull request as draftJuly 7, 2025 15:26
Sign up for freeto join this conversation on GitHub. Already have an account?Sign in to comment
Reviewers

@BenjaminBossanBenjaminBossanBenjaminBossan left review comments

At least 1 approving review is required to merge this pull request.

Assignees
No one assigned
Projects
None yet
Milestone
No milestone
Development

Successfully merging this pull request may close these issues.

3 participants
@sayakpaul@BenjaminBossan@DN6

[8]ページ先頭

©2009-2025 Movatter.jp