Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit262c1bf

Browse files
nikhilaravifacebook-github-bot
authored andcommitted
Join points as batch
Summary: Function to join a list of pointclouds as a batch similar to the corresponding function for Meshes.Reviewed By: bottlerDifferential Revision: D33145906fbshipit-source-id: 160639ebb5065e4fae1a1aa43117172719f3871b
1 parenteb2bbf8 commit262c1bf

File tree

2 files changed

+102
-1
lines changed

2 files changed

+102
-1
lines changed

‎pytorch3d/structures/pointclouds.py‎

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1178,3 +1178,40 @@ def inside_box(self, box):
11781178

11791179
coord_inside= (points_packed>=box[:,0])* (points_packed<=box[:,1])
11801180
returncoord_inside.all(dim=-1)
1181+
1182+
1183+
defjoin_pointclouds_as_batch(pointclouds:Sequence[Pointclouds]):
1184+
"""
1185+
Merge a list of Pointclouds objects into a single batched Pointclouds
1186+
object. All pointclouds must be on the same device.
1187+
1188+
Args:
1189+
batch: List of Pointclouds objects each with batch dim [b1, b2, ..., bN]
1190+
Returns:
1191+
pointcloud: Poinclouds object with all input pointclouds collated into
1192+
a single object with batch dim = sum(b1, b2, ..., bN)
1193+
"""
1194+
ifisinstance(pointclouds,Pointclouds)ornotisinstance(pointclouds,Sequence):
1195+
raiseValueError("Wrong first argument to join_points_as_batch.")
1196+
1197+
device=pointclouds[0].device
1198+
ifnotall(p.device==deviceforpinpointclouds):
1199+
raiseValueError("Pointclouds must all be on the same device")
1200+
1201+
kwargs= {}
1202+
forfieldin ("points","normals","features"):
1203+
field_list= [getattr(p,field+"_list")()forpinpointclouds]
1204+
ifNoneinfield_list:
1205+
iffield=="points":
1206+
raiseValueError("Pointclouds cannot have their points set to None!")
1207+
ifnotall(fisNoneforfinfield_list):
1208+
raiseValueError(
1209+
f"Pointclouds in the batch have some fields '{field}'"
1210+
+" defined and some set to None."
1211+
)
1212+
field_list=None
1213+
else:
1214+
field_list= [pforpointsinfield_listforpinpoints]
1215+
kwargs[field]=field_list
1216+
1217+
returnPointclouds(**kwargs)

‎tests/test_pointclouds.py‎

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
importtorch
1313
fromcommon_testingimportTestCaseMixin
1414
frompytorch3d.structuresimportutilsasstruct_utils
15-
frompytorch3d.structures.pointcloudsimportPointclouds
15+
frompytorch3d.structures.pointcloudsimportPointclouds,join_pointclouds_as_batch
1616

1717

1818
classTestPointclouds(TestCaseMixin,unittest.TestCase):
@@ -1098,6 +1098,70 @@ def test_subsample(self):
10981098
forlength,points_inzip(lengths_max_4,pcl_copy2.points_list()):
10991099
self.assertEqual(points_.shape, (length,3))
11001100

1101+
deftest_join_pointclouds_as_batch(self):
1102+
"""
1103+
Test join_pointclouds_as_batch
1104+
"""
1105+
1106+
defcheck_item(x,y):
1107+
self.assertEqual(xisNone,yisNone)
1108+
ifxisnotNone:
1109+
self.assertClose(torch.cat([x,x,x]),y)
1110+
1111+
defcheck_triple(points,points3):
1112+
"""
1113+
Verify that points3 is three copies of points.
1114+
"""
1115+
check_item(points.points_padded(),points3.points_padded())
1116+
check_item(points.normals_padded(),points3.normals_padded())
1117+
check_item(points.features_padded(),points3.features_padded())
1118+
1119+
lengths= [4,5,13,3]
1120+
points= [torch.rand(length,3)forlengthinlengths]
1121+
features= [torch.rand(length,5)forlengthinlengths]
1122+
normals= [torch.rand(length,3)forlengthinlengths]
1123+
1124+
# Test with normals and features present
1125+
pcl=Pointclouds(points=points,features=features,normals=normals)
1126+
pcl3=join_pointclouds_as_batch([pcl]*3)
1127+
check_triple(pcl,pcl3)
1128+
1129+
# Test with normals and features present for tensor backed pointclouds
1130+
N,P,D=5,30,4
1131+
pcl=Pointclouds(
1132+
points=torch.rand(N,P,3),
1133+
features=torch.rand(N,P,D),
1134+
normals=torch.rand(N,P,3),
1135+
)
1136+
pcl3=join_pointclouds_as_batch([pcl]*3)
1137+
check_triple(pcl,pcl3)
1138+
1139+
# Test without normals
1140+
pcl_nonormals=Pointclouds(points=points,features=features)
1141+
pcl3=join_pointclouds_as_batch([pcl_nonormals]*3)
1142+
check_triple(pcl_nonormals,pcl3)
1143+
1144+
# Test without features
1145+
pcl_nofeats=Pointclouds(points=points,normals=normals)
1146+
pcl3=join_pointclouds_as_batch([pcl_nofeats]*3)
1147+
check_triple(pcl_nofeats,pcl3)
1148+
1149+
# Check error raised if all pointclouds in the batch
1150+
# are not consistent in including normals/features
1151+
withself.assertRaisesRegex(ValueError,"some set to None"):
1152+
join_pointclouds_as_batch([pcl,pcl_nonormals,pcl_nonormals])
1153+
withself.assertRaisesRegex(ValueError,"some set to None"):
1154+
join_pointclouds_as_batch([pcl,pcl_nofeats,pcl_nofeats])
1155+
1156+
# Check error if first input is a single pointclouds object
1157+
# instead of a list
1158+
withself.assertRaisesRegex(ValueError,"Wrong first argument"):
1159+
join_pointclouds_as_batch(pcl)
1160+
1161+
# Check error if all pointclouds are not on the same device
1162+
withself.assertRaisesRegex(ValueError,"same device"):
1163+
join_pointclouds_as_batch([pcl,pcl.to("cuda:0")])
1164+
11011165
@staticmethod
11021166
defcompute_packed_with_init(
11031167
num_clouds:int=10,max_p:int=100,features:int=300

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp