@@ -780,9 +780,9 @@ def _load_ply_raw(f, path_manager: PathManager) -> Tuple[_PlyHeader, dict]:
780780
781781def _get_verts_column_indices (
782782vertex_head :_PlyElementType ,
783- )-> Tuple [List [int ],Optional [List [int ]],float ]:
783+ )-> Tuple [List [int ],Optional [List [int ]],float , Optional [ List [ int ]] ]:
784784"""
785- Get the columns of verts andverts_colors in the vertex
785+ Get the columns of verts, verts_colors, andverts_normals in the vertex
786786 element of a parsed ply file, together with a color scale factor.
787787 When the colors are in byte format, they are scaled from 0..255 to [0,1].
788788 Otherwise they are not scaled.
@@ -793,11 +793,14 @@ def _get_verts_column_indices(
793793 property double x
794794 property double y
795795 property double z
796+ property double nx
797+ property double ny
798+ property double nz
796799 property uchar red
797800 property uchar green
798801 property uchar blue
799802
800- then the return value will be ([0,1,2], [6,7,8], 1.0/255)
803+ then the return value will be ([0,1,2], [6,7,8], 1.0/255, [3,4,5] )
801804
802805 Args:
803806 vertex_head: as returned from load_ply_raw.
@@ -807,9 +810,12 @@ def _get_verts_column_indices(
807810 color_idxs: List[int] of 3 color columns if they are present,
808811 otherwise None.
809812 color_scale: value to scale colors by.
813+ normal_idxs: List[int] of 3 normals columns if they are present,
814+ otherwise None.
810815 """
811816point_idxs :List [Optional [int ]]= [None ,None ,None ]
812817color_idxs :List [Optional [int ]]= [None ,None ,None ]
818+ normal_idxs :List [Optional [int ]]= [None ,None ,None ]
813819for i ,prop in enumerate (vertex_head .properties ):
814820if prop .list_size_type is not None :
815821raise ValueError ("Invalid vertices in file: did not expect list." )
@@ -819,6 +825,9 @@ def _get_verts_column_indices(
819825for j ,name in enumerate (["red" ,"green" ,"blue" ]):
820826if prop .name == name :
821827color_idxs [j ]= i
828+ for j ,name in enumerate (["nx" ,"ny" ,"nz" ]):
829+ if prop .name == name :
830+ normal_idxs [j ]= i
822831if None in point_idxs :
823832raise ValueError ("Invalid vertices in file." )
824833color_scale = 1.0
@@ -831,21 +840,23 @@ def _get_verts_column_indices(
831840point_idxs ,
832841None if None in color_idxs else cast (List [int ],color_idxs ),
833842color_scale ,
843+ None if None in normal_idxs else cast (List [int ],normal_idxs ),
834844 )
835845
836846
837847def _get_verts (
838848header :_PlyHeader ,elements :dict
839- )-> Tuple [torch .Tensor ,Optional [torch .Tensor ]]:
849+ )-> Tuple [torch .Tensor ,Optional [torch .Tensor ], Optional [ torch . Tensor ] ]:
840850"""
841- Get the vertex locations andcolors from a parsed ply file.
851+ Get the vertex locations, colors andnormals from a parsed ply file.
842852
843853 Args:
844854 header, elements: as returned from load_ply_raw.
845855
846856 Returns:
847857 verts: FloatTensor of shape (V, 3).
848858 vertex_colors: None or FloatTensor of shape (V, 3).
859+ vertex_normals: None or FloatTensor of shape (V, 3).
849860 """
850861
851862vertex = elements .get ("vertex" ,None )
@@ -854,14 +865,16 @@ def _get_verts(
854865if not isinstance (vertex ,list ):
855866raise ValueError ("Invalid vertices in file." )
856867vertex_head = next (head for head in header .elements if head .name == "vertex" )
857- point_idxs ,color_idxs ,color_scale = _get_verts_column_indices (vertex_head )
868+ point_idxs ,color_idxs ,color_scale ,normal_idxs = _get_verts_column_indices (
869+ vertex_head
870+ )
858871
859872# Case of no vertices
860873if vertex_head .count == 0 :
861874verts = torch .zeros ((0 ,3 ),dtype = torch .float32 )
862875if color_idxs is None :
863- return verts ,None
864- return verts ,torch .zeros ((0 ,3 ),dtype = torch .float32 )
876+ return verts ,None , None
877+ return verts ,torch .zeros ((0 ,3 ),dtype = torch .float32 ), None
865878
866879# Simple case where the only data is the vertices themselves
867880if (
@@ -870,9 +883,10 @@ def _get_verts(
870883and vertex [0 ].ndim == 2
871884and vertex [0 ].shape [1 ]== 3
872885 ):
873- return _make_tensor (vertex [0 ],cols = 3 ,dtype = torch .float32 ),None
886+ return _make_tensor (vertex [0 ],cols = 3 ,dtype = torch .float32 ),None , None
874887
875888vertex_colors = None
889+ vertex_normals = None
876890
877891if len (vertex )== 1 :
878892# This is the case where the whole vertex element has one type,
@@ -882,6 +896,10 @@ def _get_verts(
882896vertex_colors = color_scale * torch .tensor (
883897vertex [0 ][:,color_idxs ],dtype = torch .float32
884898 )
899+ if normal_idxs is not None :
900+ vertex_normals = torch .tensor (
901+ vertex [0 ][:,normal_idxs ],dtype = torch .float32
902+ )
885903else :
886904# The vertex element is heterogeneous. It was read as several arrays,
887905# part by part, where a part is a set of properties with the same type.
@@ -913,13 +931,22 @@ def _get_verts(
913931partnum ,col = prop_to_partnum_col [color_idxs [color ]]
914932vertex_colors .numpy ()[:,color ]= vertex [partnum ][:,col ]
915933vertex_colors *= color_scale
934+ if normal_idxs is not None :
935+ vertex_normals = torch .empty (
936+ size = (vertex_head .count ,3 ),dtype = torch .float32
937+ )
938+ for axis in range (3 ):
939+ partnum ,col = prop_to_partnum_col [normal_idxs [axis ]]
940+ vertex_normals .numpy ()[:,axis ]= vertex [partnum ][:,col ]
916941
917- return verts ,vertex_colors
942+ return verts ,vertex_colors , vertex_normals
918943
919944
920945def _load_ply (
921946f ,* ,path_manager :PathManager
922- )-> Tuple [torch .Tensor ,Optional [torch .Tensor ],Optional [torch .Tensor ]]:
947+ )-> Tuple [
948+ torch .Tensor ,Optional [torch .Tensor ],Optional [torch .Tensor ],Optional [torch .Tensor ]
949+ ]:
923950"""
924951 Load the data from a .ply file.
925952
@@ -935,10 +962,11 @@ def _load_ply(
935962 verts: FloatTensor of shape (V, 3).
936963 faces: None or LongTensor of vertex indices, shape (F, 3).
937964 vertex_colors: None or FloatTensor of shape (V, 3).
965+ vertex_normals: None or FloatTensor of shape (V, 3).
938966 """
939967header ,elements = _load_ply_raw (f ,path_manager = path_manager )
940968
941- verts ,vertex_colors = _get_verts (header ,elements )
969+ verts ,vertex_colors , vertex_normals = _get_verts (header ,elements )
942970
943971face = elements .get ("face" ,None )
944972if face is not None :
@@ -976,7 +1004,7 @@ def _load_ply(
9761004if faces is not None :
9771005_check_faces_indices (faces ,max_index = verts .shape [0 ])
9781006
979- return verts ,faces ,vertex_colors
1007+ return verts ,faces ,vertex_colors , vertex_normals
9801008
9811009
9821010def load_ply (
@@ -1031,7 +1059,7 @@ def load_ply(
10311059
10321060if path_manager is None :
10331061path_manager = PathManager ()
1034- verts ,faces ,_ = _load_ply (f ,path_manager = path_manager )
1062+ verts ,faces ,_ , _ = _load_ply (f ,path_manager = path_manager )
10351063if faces is None :
10361064faces = torch .zeros (0 ,3 ,dtype = torch .int64 )
10371065
@@ -1211,18 +1239,23 @@ def read(
12111239if not endswith (path ,self .known_suffixes ):
12121240return None
12131241
1214- verts ,faces ,verts_colors = _load_ply (f = path ,path_manager = path_manager )
1242+ verts ,faces ,verts_colors ,verts_normals = _load_ply (
1243+ f = path ,path_manager = path_manager
1244+ )
12151245if faces is None :
12161246faces = torch .zeros (0 ,3 ,dtype = torch .int64 )
12171247
1218- textures = None
1248+ texture = None
12191249if include_textures and verts_colors is not None :
1220- textures = TexturesVertex ([verts_colors .to (device )])
1250+ texture = TexturesVertex ([verts_colors .to (device )])
12211251
1252+ if verts_normals is not None :
1253+ verts_normals = [verts_normals ]
12221254mesh = Meshes (
12231255verts = [verts .to (device )],
12241256faces = [faces .to (device )],
1225- textures = textures ,
1257+ textures = texture ,
1258+ verts_normals = verts_normals ,
12261259 )
12271260return mesh
12281261
@@ -1286,12 +1319,14 @@ def read(
12861319if not endswith (path ,self .known_suffixes ):
12871320return None
12881321
1289- verts ,faces ,features = _load_ply (f = path ,path_manager = path_manager )
1322+ verts ,faces ,features , normals = _load_ply (f = path ,path_manager = path_manager )
12901323verts = verts .to (device )
12911324if features is not None :
12921325features = [features .to (device )]
1326+ if normals is not None :
1327+ normals = [normals .to (device )]
12931328
1294- pointcloud = Pointclouds (points = [verts ],features = features )
1329+ pointcloud = Pointclouds (points = [verts ],features = features , normals = normals )
12951330return pointcloud
12961331
12971332def save (