Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit6fa66f5

Browse files
bottlerfacebook-github-bot
authored andcommitted
PLY load normals
Summary: Add ability to load normals when they are present in a PLY file.Reviewed By: nikhilaraviDifferential Revision: D26458971fbshipit-source-id: 658270b611f7624eab4f5f62ff438038e1d25723
1 parentb314bee commit6fa66f5

File tree

2 files changed

+105
-21
lines changed

2 files changed

+105
-21
lines changed

‎pytorch3d/io/ply_io.py‎

Lines changed: 55 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -780,9 +780,9 @@ def _load_ply_raw(f, path_manager: PathManager) -> Tuple[_PlyHeader, dict]:
780780

781781
def_get_verts_column_indices(
782782
vertex_head:_PlyElementType,
783-
)->Tuple[List[int],Optional[List[int]],float]:
783+
)->Tuple[List[int],Optional[List[int]],float,Optional[List[int]]]:
784784
"""
785-
Get the columns of vertsandverts_colors in the vertex
785+
Get the columns of verts, verts_colors,andverts_normals in the vertex
786786
element of a parsed ply file, together with a color scale factor.
787787
When the colors are in byte format, they are scaled from 0..255 to [0,1].
788788
Otherwise they are not scaled.
@@ -793,11 +793,14 @@ def _get_verts_column_indices(
793793
property double x
794794
property double y
795795
property double z
796+
property double nx
797+
property double ny
798+
property double nz
796799
property uchar red
797800
property uchar green
798801
property uchar blue
799802
800-
then the return value will be ([0,1,2], [6,7,8], 1.0/255)
803+
then the return value will be ([0,1,2], [6,7,8], 1.0/255, [3,4,5])
801804
802805
Args:
803806
vertex_head: as returned from load_ply_raw.
@@ -807,9 +810,12 @@ def _get_verts_column_indices(
807810
color_idxs: List[int] of 3 color columns if they are present,
808811
otherwise None.
809812
color_scale: value to scale colors by.
813+
normal_idxs: List[int] of 3 normals columns if they are present,
814+
otherwise None.
810815
"""
811816
point_idxs:List[Optional[int]]= [None,None,None]
812817
color_idxs:List[Optional[int]]= [None,None,None]
818+
normal_idxs:List[Optional[int]]= [None,None,None]
813819
fori,propinenumerate(vertex_head.properties):
814820
ifprop.list_size_typeisnotNone:
815821
raiseValueError("Invalid vertices in file: did not expect list.")
@@ -819,6 +825,9 @@ def _get_verts_column_indices(
819825
forj,nameinenumerate(["red","green","blue"]):
820826
ifprop.name==name:
821827
color_idxs[j]=i
828+
forj,nameinenumerate(["nx","ny","nz"]):
829+
ifprop.name==name:
830+
normal_idxs[j]=i
822831
ifNoneinpoint_idxs:
823832
raiseValueError("Invalid vertices in file.")
824833
color_scale=1.0
@@ -831,21 +840,23 @@ def _get_verts_column_indices(
831840
point_idxs,
832841
NoneifNoneincolor_idxselsecast(List[int],color_idxs),
833842
color_scale,
843+
NoneifNoneinnormal_idxselsecast(List[int],normal_idxs),
834844
)
835845

836846

837847
def_get_verts(
838848
header:_PlyHeader,elements:dict
839-
)->Tuple[torch.Tensor,Optional[torch.Tensor]]:
849+
)->Tuple[torch.Tensor,Optional[torch.Tensor],Optional[torch.Tensor]]:
840850
"""
841-
Get the vertex locationsandcolors from a parsed ply file.
851+
Get the vertex locations, colorsandnormals from a parsed ply file.
842852
843853
Args:
844854
header, elements: as returned from load_ply_raw.
845855
846856
Returns:
847857
verts: FloatTensor of shape (V, 3).
848858
vertex_colors: None or FloatTensor of shape (V, 3).
859+
vertex_normals: None or FloatTensor of shape (V, 3).
849860
"""
850861

851862
vertex=elements.get("vertex",None)
@@ -854,14 +865,16 @@ def _get_verts(
854865
ifnotisinstance(vertex,list):
855866
raiseValueError("Invalid vertices in file.")
856867
vertex_head=next(headforheadinheader.elementsifhead.name=="vertex")
857-
point_idxs,color_idxs,color_scale=_get_verts_column_indices(vertex_head)
868+
point_idxs,color_idxs,color_scale,normal_idxs=_get_verts_column_indices(
869+
vertex_head
870+
)
858871

859872
# Case of no vertices
860873
ifvertex_head.count==0:
861874
verts=torch.zeros((0,3),dtype=torch.float32)
862875
ifcolor_idxsisNone:
863-
returnverts,None
864-
returnverts,torch.zeros((0,3),dtype=torch.float32)
876+
returnverts,None,None
877+
returnverts,torch.zeros((0,3),dtype=torch.float32),None
865878

866879
# Simple case where the only data is the vertices themselves
867880
if (
@@ -870,9 +883,10 @@ def _get_verts(
870883
andvertex[0].ndim==2
871884
andvertex[0].shape[1]==3
872885
):
873-
return_make_tensor(vertex[0],cols=3,dtype=torch.float32),None
886+
return_make_tensor(vertex[0],cols=3,dtype=torch.float32),None,None
874887

875888
vertex_colors=None
889+
vertex_normals=None
876890

877891
iflen(vertex)==1:
878892
# This is the case where the whole vertex element has one type,
@@ -882,6 +896,10 @@ def _get_verts(
882896
vertex_colors=color_scale*torch.tensor(
883897
vertex[0][:,color_idxs],dtype=torch.float32
884898
)
899+
ifnormal_idxsisnotNone:
900+
vertex_normals=torch.tensor(
901+
vertex[0][:,normal_idxs],dtype=torch.float32
902+
)
885903
else:
886904
# The vertex element is heterogeneous. It was read as several arrays,
887905
# part by part, where a part is a set of properties with the same type.
@@ -913,13 +931,22 @@ def _get_verts(
913931
partnum,col=prop_to_partnum_col[color_idxs[color]]
914932
vertex_colors.numpy()[:,color]=vertex[partnum][:,col]
915933
vertex_colors*=color_scale
934+
ifnormal_idxsisnotNone:
935+
vertex_normals=torch.empty(
936+
size=(vertex_head.count,3),dtype=torch.float32
937+
)
938+
foraxisinrange(3):
939+
partnum,col=prop_to_partnum_col[normal_idxs[axis]]
940+
vertex_normals.numpy()[:,axis]=vertex[partnum][:,col]
916941

917-
returnverts,vertex_colors
942+
returnverts,vertex_colors,vertex_normals
918943

919944

920945
def_load_ply(
921946
f,*,path_manager:PathManager
922-
)->Tuple[torch.Tensor,Optional[torch.Tensor],Optional[torch.Tensor]]:
947+
)->Tuple[
948+
torch.Tensor,Optional[torch.Tensor],Optional[torch.Tensor],Optional[torch.Tensor]
949+
]:
923950
"""
924951
Load the data from a .ply file.
925952
@@ -935,10 +962,11 @@ def _load_ply(
935962
verts: FloatTensor of shape (V, 3).
936963
faces: None or LongTensor of vertex indices, shape (F, 3).
937964
vertex_colors: None or FloatTensor of shape (V, 3).
965+
vertex_normals: None or FloatTensor of shape (V, 3).
938966
"""
939967
header,elements=_load_ply_raw(f,path_manager=path_manager)
940968

941-
verts,vertex_colors=_get_verts(header,elements)
969+
verts,vertex_colors,vertex_normals=_get_verts(header,elements)
942970

943971
face=elements.get("face",None)
944972
iffaceisnotNone:
@@ -976,7 +1004,7 @@ def _load_ply(
9761004
iffacesisnotNone:
9771005
_check_faces_indices(faces,max_index=verts.shape[0])
9781006

979-
returnverts,faces,vertex_colors
1007+
returnverts,faces,vertex_colors,vertex_normals
9801008

9811009

9821010
defload_ply(
@@ -1031,7 +1059,7 @@ def load_ply(
10311059

10321060
ifpath_managerisNone:
10331061
path_manager=PathManager()
1034-
verts,faces,_=_load_ply(f,path_manager=path_manager)
1062+
verts,faces,_,_=_load_ply(f,path_manager=path_manager)
10351063
iffacesisNone:
10361064
faces=torch.zeros(0,3,dtype=torch.int64)
10371065

@@ -1211,18 +1239,23 @@ def read(
12111239
ifnotendswith(path,self.known_suffixes):
12121240
returnNone
12131241

1214-
verts,faces,verts_colors=_load_ply(f=path,path_manager=path_manager)
1242+
verts,faces,verts_colors,verts_normals=_load_ply(
1243+
f=path,path_manager=path_manager
1244+
)
12151245
iffacesisNone:
12161246
faces=torch.zeros(0,3,dtype=torch.int64)
12171247

1218-
textures=None
1248+
texture=None
12191249
ifinclude_texturesandverts_colorsisnotNone:
1220-
textures=TexturesVertex([verts_colors.to(device)])
1250+
texture=TexturesVertex([verts_colors.to(device)])
12211251

1252+
ifverts_normalsisnotNone:
1253+
verts_normals= [verts_normals]
12221254
mesh=Meshes(
12231255
verts=[verts.to(device)],
12241256
faces=[faces.to(device)],
1225-
textures=textures,
1257+
textures=texture,
1258+
verts_normals=verts_normals,
12261259
)
12271260
returnmesh
12281261

@@ -1286,12 +1319,14 @@ def read(
12861319
ifnotendswith(path,self.known_suffixes):
12871320
returnNone
12881321

1289-
verts,faces,features=_load_ply(f=path,path_manager=path_manager)
1322+
verts,faces,features,normals=_load_ply(f=path,path_manager=path_manager)
12901323
verts=verts.to(device)
12911324
iffeaturesisnotNone:
12921325
features= [features.to(device)]
1326+
ifnormalsisnotNone:
1327+
normals= [normals.to(device)]
12931328

1294-
pointcloud=Pointclouds(points=[verts],features=features)
1329+
pointcloud=Pointclouds(points=[verts],features=features,normals=normals)
12951330
returnpointcloud
12961331

12971332
defsave(

‎tests/test_io_ply.py‎

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,14 +216,18 @@ def test_save_load_meshes(self):
216216
[[0,0,0], [0,0,1], [0,1,0], [1,0,0]],dtype=torch.float32
217217
)
218218
faces=torch.tensor([[0,1,2], [0,2,3]])
219+
normals=torch.tensor(
220+
[[0,1,0], [1,0,0], [1,4,1], [1,0,0]],dtype=torch.float32
221+
)
219222
vert_colors=torch.rand_like(verts)
220223
texture=TexturesVertex(verts_features=[vert_colors])
221224

222-
fordo_texturesinitertools.product([True,False]):
225+
fordo_textures,do_normalsinitertools.product([True,False],[True,False]):
223226
mesh=Meshes(
224227
verts=[verts],
225228
faces=[faces],
226229
textures=textureifdo_textureselseNone,
230+
verts_normals=[normals]ifdo_normalselseNone,
227231
)
228232
device=torch.device("cuda:0")
229233

@@ -236,12 +240,57 @@ def test_save_load_meshes(self):
236240
mesh2=mesh2.cpu()
237241
self.assertClose(mesh2.verts_padded(),mesh.verts_padded())
238242
self.assertClose(mesh2.faces_padded(),mesh.faces_padded())
243+
ifdo_normals:
244+
self.assertTrue(mesh.has_verts_normals())
245+
self.assertTrue(mesh2.has_verts_normals())
246+
self.assertClose(
247+
mesh2.verts_normals_padded(),mesh.verts_normals_padded()
248+
)
249+
else:
250+
self.assertFalse(mesh.has_verts_normals())
251+
self.assertFalse(mesh2.has_verts_normals())
252+
self.assertFalse(torch.allclose(mesh2.verts_normals_padded(),normals))
239253
ifdo_textures:
240254
self.assertIsInstance(mesh2.textures,TexturesVertex)
241255
self.assertClose(mesh2.textures.verts_features_list()[0],vert_colors)
242256
else:
243257
self.assertIsNone(mesh2.textures)
244258

259+
deftest_save_load_with_normals(self):
260+
points=torch.tensor(
261+
[[0,0,0], [0,0,1], [0,1,0], [1,0,0]],dtype=torch.float32
262+
)
263+
normals=torch.tensor(
264+
[[0,1,0], [1,0,0], [1,4,1], [1,0,0]],dtype=torch.float32
265+
)
266+
features=torch.rand_like(points)
267+
268+
fordo_features,do_normalsinitertools.product([True,False], [True,False]):
269+
cloud=Pointclouds(
270+
points=[points],
271+
features=[features]ifdo_featureselseNone,
272+
normals=[normals]ifdo_normalselseNone,
273+
)
274+
device=torch.device("cuda:0")
275+
276+
io=IO()
277+
withNamedTemporaryFile(mode="w",suffix=".ply")asf:
278+
io.save_pointcloud(cloud.cuda(),f.name)
279+
f.flush()
280+
cloud2=io.load_pointcloud(f.name,device=device)
281+
self.assertEqual(cloud2.device,device)
282+
cloud2=cloud2.cpu()
283+
self.assertClose(cloud2.points_padded(),cloud.points_padded())
284+
ifdo_normals:
285+
self.assertClose(cloud2.normals_padded(),cloud.normals_padded())
286+
else:
287+
self.assertIsNone(cloud.normals_padded())
288+
self.assertIsNone(cloud2.normals_padded())
289+
ifdo_features:
290+
self.assertClose(cloud2.features_packed(),features)
291+
else:
292+
self.assertIsNone(cloud2.features_packed())
293+
245294
deftest_save_ply_invalid_shapes(self):
246295
# Invalid vertices shape
247296
withself.assertRaises(ValueError)aserror:

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp