Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Add model arcinstitute state#39480

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.

Already on GitHub?Sign in to your account

Draft
drbh wants to merge6 commits intomain
base:main
Choose a base branch
Loading
fromadd-model-arcinstitute-state
Draft

Conversation

drbh
Copy link

This PR adds the arc state model

Run embedding model via transformers

git clone https://github.com/huggingface/transformersgit checkout add-model-arcinstitute-stateuv run sanity.py

sanity.py

# /// script# requires-python = ">=3.12"# dependencies = [#     "torch",#     "transformers"# ]## [tool.uv.sources]# transformers = { path = ".", editable = true }# ///importtorchfromtransformersimportStateEmbeddingModelmodel_name="arcinstitute/SE-600M"model=StateEmbeddingModel.from_pretrained(model_name)torch.manual_seed(0)input_ids=torch.randn((1,1,5120),dtype=torch.float32)mask=torch.ones((1,1,5120),dtype=torch.bool)mask[:, :,2560:]=False# simulate half maskingprint("Input sum:\t",input_ids.sum())print("Mask sum:\t",mask.sum())outputs=model(input_ids,mask)print("Output sum:\t",outputs["gene_output"].sum())

outputs

Input sum: tensor(-38.6611)Mask sum: tensor(2560)Output sum: tensor(-19.6819, grad_fn=<SumBackward0>)

Compare to reference

git clone https://github.com/ArcInstitute/state.gitcd statecurl -OL https://huggingface.co/arcinstitute/SE-600M/resolve/main/se600m_epoch16.ckpt

next, apply this small patch so we can run the model file directly with a fixed input to compare with the impl above

file `compare.patch`
diff --git a/src/state/emb/nn/model.py b/src/state/emb/nn/model.pyindex dbbefb3..42167a1 100644--- a/src/state/emb/nn/model.py+++ b/src/state/emb/nn/model.py@@ -23,20 +23,20 @@ from torch.nn import TransformerEncoder, TransformerEncoderLayer, BCEWithLogitsL from tqdm.auto import tqdm from torch.optim.lr_scheduler import ChainedScheduler, LinearLR, CosineAnnealingLR, ReduceLROnPlateau-from ..data import create_dataloader-from ..utils import (+from state.emb.data import create_dataloader+from state.emb.utils import (     compute_gene_overlap_cross_pert,     get_embedding_cfg,     get_dataset_cfg,     compute_pearson_delta,     compute_perturbation_ranking_score, )-from ..eval.emb import cluster_embedding-from .loss import WassersteinLoss, KLDivergenceLoss, MMDLoss, TabularLoss+from state.emb.eval.emb import cluster_embedding+from loss import WassersteinLoss, KLDivergenceLoss, MMDLoss, TabularLoss-from .flash_transformer import FlashTransformerEncoderLayer-from .flash_transformer import FlashTransformerEncoder+from flash_transformer import FlashTransformerEncoderLayer+from flash_transformer import FlashTransformerEncoder   class SkipBlock(nn.Module):@@ -196,7 +196,8 @@ class StateEmbeddingModel(L.LightningModule):             self.dataset_embedder = nn.Linear(output_dim, 10)              # Assume self.cfg.model.num_datasets is set to the number of unique datasets.-            num_dataset = get_dataset_cfg(self.cfg).num_datasets+            # num_dataset = get_dataset_cfg(self.cfg).num_datasets+            num_dataset = 14420             self.dataset_encoder = nn.Sequential(                 nn.Linear(output_dim, d_model),                 nn.SiLU(),@@ -686,3 +687,18 @@ class StateEmbeddingModel(L.LightningModule):             "optimizer": optimizer,             "lr_scheduler": {"scheduler": scheduler, "monitor": "train_loss", "interval": "step", "frequency": 1},         }++if __name__ == "__main__":+    checkpoint = "/Users/drbh/Projects/state/se600m_epoch16.ckpt"+    model = StateEmbeddingModel.load_from_checkpoint(checkpoint, dropout=0.0, strict=False)++    torch.manual_seed(0)++    input_ids = torch.randn((1, 1, 5120), dtype=torch.float32)+    mask = torch.ones((1, 1, 5120), dtype=torch.bool)+    mask[:, :, 2560:] = False+    print("Input sum:\t", input_ids.sum())+    print("Mask sum:\t", mask.sum())++    output, embedding, dataset_emb = model(input_ids, mask)+    print("Output shape:\t", output.sum())

can be applied like

# save above as compare.patchgit apply compare.patch

run the model

.venv/bin/python src/state/emb/nn/model.py

output

!!! Using Flash Attention !!!Input sum: tensor(-38.6611)Mask sum: tensor(2560)Output shape: tensor(-19.6819, grad_fn=<SumBackward0>)

FL33TW00D and cyrilzakka reacted with rocket emoji
@FL33TW00D
Copy link

@HuggingFaceDocBuilderDev

The docs for this PR livehere. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@github-actionsGitHub Actions
Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: auto

@FL33TW00D
Copy link

@abhinadduri for ref

abhinadduri reacted with thumbs up emoji

Sign up for freeto join this conversation on GitHub. Already have an account?Sign in to comment
Reviewers

@cyrilzakkacyrilzakkaAwaiting requested review from cyrilzakka

@ArthurZuckerArthurZuckerAwaiting requested review from ArthurZucker

@FL33TW00DFL33TW00DAwaiting requested review from FL33TW00D

At least 1 approving review is required to merge this pull request.

Assignees
No one assigned
Labels
Projects
None yet
Milestone
No milestone
Development

Successfully merging this pull request may close these issues.

4 participants
@drbh@FL33TW00D@HuggingFaceDocBuilderDev@ArthurZucker

[8]ページ先頭

©2009-2025 Movatter.jp