@@ -859,6 +859,21 @@ class DirectNamedTuple2(NamedTuple):
859859self .assertFalse (pytree .is_namedtuple (cls ))
860860self .assertFalse (pytree .is_namedtuple_class (cls ))
861861
862+ @parametrize (
863+ "pytree" ,
864+ [
865+ subtest (py_pytree ,name = "py" ),
866+ subtest (cxx_pytree ,name = "cxx" ),
867+ ],
868+ )
869+ def test_enum_treespec_roundtrip (self ,pytree ):
870+ data = {TestEnum .A :5 }
871+ spec = pytree .tree_structure (data )
872+
873+ serialized = pytree .treespec_dumps (spec )
874+ deserialized_spec = pytree .treespec_loads (serialized )
875+ self .assertEqual (spec ,deserialized_spec )
876+
862877
863878class TestPythonPytree (TestCase ):
864879def test_deprecated_register_pytree_node (self ):
@@ -1096,14 +1111,6 @@ def test_pytree_serialize_enum(self):
10961111serialized_spec = py_pytree .treespec_dumps (spec )
10971112self .assertIsInstance (serialized_spec ,str )
10981113
1099- def test_enum_treespec_roundtrip (self ):
1100- data = {TestEnum .A :5 }
1101- spec = py_pytree .tree_structure (data )
1102-
1103- serialized = py_pytree .treespec_dumps (spec )
1104- deserialized_spec = py_pytree .treespec_loads (serialized )
1105- self .assertEqual (spec ,deserialized_spec )
1106-
11071114def test_pytree_serialize_namedtuple (self ):
11081115Point1 = namedtuple ("Point1" , ["x" ,"y" ])
11091116py_pytree ._register_namedtuple (