@@ -441,9 +441,11 @@ def __init__(self, globals):
441441self .locals = {}
442442self .overwrite_errors = {}
443443self .unconditional_adds = {}
444+ self .method_annotations = {}
444445
445446def add_fn (self ,name ,args ,body ,* ,locals = None ,return_type = MISSING ,
446- overwrite_error = False ,unconditional_add = False ,decorator = None ):
447+ overwrite_error = False ,unconditional_add = False ,decorator = None ,
448+ annotation_fields = None ):
447449if locals is not None :
448450self .locals .update (locals )
449451
@@ -464,16 +466,14 @@ def add_fn(self, name, args, body, *, locals=None, return_type=MISSING,
464466
465467self .names .append (name )
466468
467- if return_type is not MISSING :
468- self .locals [f'__dataclass_{ name } _return_type__' ]= return_type
469- return_annotation = f'->__dataclass_{ name } _return_type__'
470- else :
471- return_annotation = ''
469+ if annotation_fields is not None :
470+ self .method_annotations [name ]= (annotation_fields ,return_type )
471+
472472args = ',' .join (args )
473473body = '\n ' .join (body )
474474
475475# Compute the text of the entire function, add it to the text we're generating.
476- self .src .append (f'{ f'{ decorator } \n ' if decorator else '' } def{ name } ({ args } ){ return_annotation } :\n { body } ' )
476+ self .src .append (f'{ f'{ decorator } \n ' if decorator else '' } def{ name } ({ args } ):\n { body } ' )
477477
478478def add_fns_to_class (self ,cls ):
479479# The source to all of the functions we're generating.
@@ -509,6 +509,15 @@ def add_fns_to_class(self, cls):
509509# Now that we've generated the functions, assign them into cls.
510510for name ,fn in zip (self .names ,fns ):
511511fn .__qualname__ = f"{ cls .__qualname__ } .{ fn .__name__ } "
512+
513+ try :
514+ annotation_fields ,return_type = self .method_annotations [name ]
515+ except KeyError :
516+ pass
517+ else :
518+ annotate_fn = _make_annotate_function (cls ,name ,annotation_fields ,return_type )
519+ fn .__annotate__ = annotate_fn
520+
512521if self .unconditional_adds .get (name ,False ):
513522setattr (cls ,name ,fn )
514523else :
@@ -524,6 +533,44 @@ def add_fns_to_class(self, cls):
524533raise TypeError (error_msg )
525534
526535
536+ def _make_annotate_function (__class__ ,method_name ,annotation_fields ,return_type ):
537+ # Create an __annotate__ function for a dataclass
538+ # Try to return annotations in the same format as they would be
539+ # from a regular __init__ function
540+
541+ def __annotate__ (format ,/ ):
542+ Format = annotationlib .Format
543+ match format :
544+ case Format .VALUE | Format .FORWARDREF | Format .STRING :
545+ cls_annotations = {}
546+ for base in reversed (__class__ .__mro__ ):
547+ cls_annotations .update (
548+ annotationlib .get_annotations (base ,format = format )
549+ )
550+
551+ new_annotations = {}
552+ for k in annotation_fields :
553+ new_annotations [k ]= cls_annotations [k ]
554+
555+ if return_type is not MISSING :
556+ if format == Format .STRING :
557+ new_annotations ["return" ]= annotationlib .type_repr (return_type )
558+ else :
559+ new_annotations ["return" ]= return_type
560+
561+ return new_annotations
562+
563+ case _:
564+ raise NotImplementedError (format )
565+
566+ # This is a flag for _add_slots to know it needs to regenerate this method
567+ # In order to remove references to the original class when it is replaced
568+ __annotate__ .__generated_by_dataclasses__ = True
569+ __annotate__ .__qualname__ = f"{ __class__ .__qualname__ } .{ method_name } .__annotate__"
570+
571+ return __annotate__
572+
573+
527574def _field_assign (frozen ,name ,value ,self_name ):
528575# If we're a frozen class, then assign to our fields in __init__
529576# via object.__setattr__. Otherwise, just use a simple
@@ -612,7 +659,7 @@ def _init_param(f):
612659elif f .default_factory is not MISSING :
613660# There's a factory function. Set a marker.
614661default = '=__dataclass_HAS_DEFAULT_FACTORY__'
615- return f'{ f .name } :__dataclass_type_ { f . name } __ { default } '
662+ return f'{ f .name } { default } '
616663
617664
618665def _init_fn (fields ,std_fields ,kw_only_fields ,frozen ,has_post_init ,
@@ -635,11 +682,10 @@ def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init,
635682raise TypeError (f'non-default argument{ f .name !r} '
636683f'follows default argument{ seen_default .name !r} ' )
637684
638- locals = {** {f'__dataclass_type_{ f .name } __' :f .type for f in fields },
639- ** {'__dataclass_HAS_DEFAULT_FACTORY__' :_HAS_DEFAULT_FACTORY ,
640- '__dataclass_builtins_object__' :object ,
641- }
642- }
685+ annotation_fields = [f .name for f in fields if f .init ]
686+
687+ locals = {'__dataclass_HAS_DEFAULT_FACTORY__' :_HAS_DEFAULT_FACTORY ,
688+ '__dataclass_builtins_object__' :object }
643689
644690body_lines = []
645691for f in fields :
@@ -670,7 +716,8 @@ def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init,
670716 [self_name ]+ _init_params ,
671717body_lines ,
672718locals = locals ,
673- return_type = None )
719+ return_type = None ,
720+ annotation_fields = annotation_fields )
674721
675722
676723def _frozen_get_del_attr (cls ,fields ,func_builder ):
@@ -1336,6 +1383,25 @@ def _add_slots(cls, is_frozen, weakref_slot, defined_fields):
13361383or _update_func_cell_for__class__ (member .fdel ,cls ,newcls )):
13371384break
13381385
1386+ # Get new annotations to remove references to the original class
1387+ # in forward references
1388+ newcls_ann = annotationlib .get_annotations (
1389+ newcls ,format = annotationlib .Format .FORWARDREF )
1390+
1391+ # Fix references in dataclass Fields
1392+ for f in getattr (newcls ,_FIELDS ).values ():
1393+ try :
1394+ ann = newcls_ann [f .name ]
1395+ except KeyError :
1396+ pass
1397+ else :
1398+ f .type = ann
1399+
1400+ # Fix the class reference in the __annotate__ method
1401+ init_annotate = newcls .__init__ .__annotate__
1402+ if getattr (init_annotate ,"__generated_by_dataclasses__" ,False ):
1403+ _update_func_cell_for__class__ (init_annotate ,cls ,newcls )
1404+
13391405return newcls
13401406
13411407