@@ -65,8 +65,9 @@ def asdl_of(name, obj):
6565class EmitVisitor (asdl .VisitorBase ):
6666 """Visit that emits lines"""
6767
68- def __init__ (self , file ):
68+ def __init__ (self , file , typeinfo ):
6969 self .file = file
70+ self .typeinfo = typeinfo
7071 self .identifiers = set ()
7172 super (EmitVisitor , self ).__init__ ()
7273
@@ -163,10 +164,13 @@ def rust_field(field_name):
163164 return field_name
164165
165166
167+ def product_has_expr (product ):
168+ return any (f .type != "identifier" for f in product .fields )
169+
170+
166171class TypeInfoEmitVisitor (EmitVisitor ):
167172 def __init__ (self , file , typeinfo ):
168- self .typeinfo = typeinfo
169- super ().__init__ (file )
173+ super ().__init__ (file , typeinfo )
170174
171175 def has_userdata (self , typ ):
172176 return self .typeinfo [typ ].has_userdata
@@ -300,7 +304,7 @@ def visitProduct(self, product, name, depth):
300304 if product .attributes :
301305 dataname = rustname + "Data"
302306 self .emit_attrs (depth )
303- has_expr = any ( f . type != "identifier" for f in product . fields )
307+ has_expr = product_has_expr ( product )
304308 if has_expr :
305309 datadef = f"{ dataname } { generics } "
306310 else :
@@ -327,7 +331,11 @@ def visitModule(self, mod, depth):
327331 self .emit ("type Error;" , depth + 1 )
328332 self .emit (
329333 "fn map_user(&mut self, user: U) -> Result<Self::TargetU, Self::Error>;" ,
330- depth + 2 ,
334+ depth + 1 ,
335+ )
336+ self .emit (
337+ "fn map_located<T>(&mut self, located: Located<T, U>) -> Result<Located<T, Self::TargetU>, Self::Error> { let custom = self.map_user(located.custom)?; Ok(Located { range: located.range, custom, node: located.node }) }" ,
338+ depth + 1 ,
331339 )
332340 for dfn in mod .dfns :
333341 self .visit (dfn , depth + 2 )
@@ -352,7 +360,7 @@ def visitModule(self, mod, depth):
352360 depth ,
353361 )
354362 self .emit (
355- "Ok(Located { custom: folder.map_user( node.custom)? , range: node.range, node: f(folder, node.node)? })" ,
363+ "let node = folder.map_located(node)?; Ok(Located { custom: node.custom, range: node.range, node: f(folder, node.node)? })" ,
356364 depth + 1 ,
357365 )
358366 self .emit ("}" , depth )
@@ -575,11 +583,15 @@ def visitSum(self, sum, name, depth):
575583 rustname = enumname = get_rust_type (name )
576584 if sum .attributes :
577585 rustname = enumname + "Kind"
586+ if is_simple (sum ) or not self .typeinfo [name ].has_userdata :
587+ custom = ""
588+ else :
589+ custom = "<LocationRange>"
578590
579- self .emit (f"impl NamedNode for ast::{ rustname } {{" , depth )
591+ self .emit (f"impl NamedNode for ast::{ rustname } { custom } {{" , depth )
580592 self .emit (f"const NAME: &'static str = { json .dumps (name )} ;" , depth + 1 )
581593 self .emit ("}" , depth )
582- self .emit (f"impl Node for ast::{ rustname } {{" , depth )
594+ self .emit (f"impl Node for ast::{ rustname } { custom } {{" , depth )
583595 self .emit (
584596 "fn ast_to_object(self, _vm: &VirtualMachine) -> PyObjectRef {" , depth + 1
585597 )
@@ -609,11 +621,16 @@ def visitProduct(self, product, name, depth):
609621 structname = get_rust_type (name )
610622 if product .attributes :
611623 structname += "Data"
624+ custom = ""
625+ if product_has_expr (product ):
626+ custom = "<LocationRange>"
627+ else :
628+ custom = ""
612629
613- self .emit (f"impl NamedNode for ast::{ structname } {{" , depth )
630+ self .emit (f"impl NamedNode for ast::{ structname } { custom } {{" , depth )
614631 self .emit (f"const NAME: &'static str = { json .dumps (name )} ;" , depth + 1 )
615632 self .emit ("}" , depth )
616- self .emit (f"impl Node for ast::{ structname } {{" , depth )
633+ self .emit (f"impl Node for ast::{ structname } { custom } {{" , depth )
617634 self .emit (
618635 "fn ast_to_object(self, _vm: &VirtualMachine) -> PyObjectRef {" , depth + 1
619636 )
@@ -656,7 +673,14 @@ def gen_sum_fromobj(self, sum, sumname, enumname, rustname, depth):
656673 for cons in sum .types :
657674 self .emit (f"if _cls.is(Node{ cons .name } ::static_type()) {{" , depth )
658675 if cons .fields :
659- self .emit (f"ast::{ rustname } ::{ cons .name } (ast::{ enumname } { cons .name } {{" , depth + 1 )
676+ if self .typeinfo [cons .name ].has_userdata :
677+ generics = "::<LocationRange>"
678+ else :
679+ generics = ""
680+ self .emit (
681+ f"ast::{ rustname } ::{ cons .name } (ast::{ enumname } { cons .name } { generics } {{" ,
682+ depth + 1 ,
683+ )
660684 self .gen_construction_fields (cons , sumname , depth + 1 )
661685 self .emit ("})" , depth + 1 )
662686 else :
@@ -691,7 +715,7 @@ def gen_construction(self, cons_path, cons, name, depth):
691715 def extract_location (self , typename , depth ):
692716 row = self .decode_field (asdl .Field ("int" , "lineno" ), typename )
693717 column = self .decode_field (asdl .Field ("int" , "col_offset" ), typename )
694- self .emit (f"let _location = ast:: Location::new({ row } , { column } );" , depth )
718+ self .emit (f"let _location = Location::new({ row } , { column } );" , depth )
695719
696720 def decode_field (self , field , typename ):
697721 name = json .dumps (field .name )
@@ -717,86 +741,20 @@ def write_ast_def(mod, typeinfo, f):
717741 """
718742 #![allow(clippy::derive_partial_eq_without_eq)]
719743
720- pub use crate::constant::*;
721- pub use rustpython_compiler_core::text_size::{TextSize, TextRange};
744+ pub use crate::{Located, constant::*} ;
745+ pub use rustpython_compiler_core::{ text_size::{TextSize, TextRange} };
722746
723747 type Ident = String;
724748 \n
725749 """
726750 )
727751 )
728- StructVisitor (f , typeinfo ).emit_attrs (0 )
729- f .write (
730- textwrap .dedent (
731- """
732- pub struct Located<T, U = ()> {
733- pub range: TextRange,
734- pub custom: U,
735- pub node: T,
736- }
737-
738- impl<T> Located<T> {
739- pub fn new(start: TextSize, end: TextSize, node: T) -> Self {
740- Self { range: TextRange::new(start, end), custom: (), node }
741- }
742-
743- /// Creates a new node that spans the position specified by `range`.
744- pub fn with_range(node: T, range: TextRange) -> Self {
745- Self {
746- range,
747- custom: (),
748- node,
749- }
750- }
751-
752- /// Returns the absolute start position of the node from the beginning of the document.
753- #[inline]
754- pub const fn start(&self) -> TextSize {
755- self.range.start()
756- }
757-
758- /// Returns the node
759- #[inline]
760- pub fn node(&self) -> &T {
761- &self.node
762- }
763-
764- /// Consumes self and returns the node.
765- #[inline]
766- pub fn into_node(self) -> T {
767- self.node
768- }
769-
770- /// Returns the `range` of the node. The range offsets are absolute to the start of the document.
771- #[inline]
772- pub const fn range(&self) -> TextRange {
773- self.range
774- }
775-
776- /// Returns the absolute position at which the node ends in the source document.
777- #[inline]
778- pub const fn end(&self) -> TextSize {
779- self.range.end()
780- }
781- }
782-
783- impl<T, U> std::ops::Deref for Located<T, U> {
784- type Target = T;
785-
786- fn deref(&self) -> &Self::Target {
787- &self.node
788- }
789- }
790- \n
791- """ .lstrip ()
792- )
793- )
794752
795753 c = ChainOfVisitors (StructVisitor (f , typeinfo ), FoldModuleVisitor (f , typeinfo ))
796754 c .visit (mod )
797755
798756
799- def write_ast_mod (mod , f ):
757+ def write_ast_mod (mod , typeinfo , f ):
800758 f .write (
801759 textwrap .dedent (
802760 """
@@ -809,7 +767,11 @@ def write_ast_mod(mod, f):
809767 )
810768 )
811769
812- c = ChainOfVisitors (ClassDefVisitor (f ), TraitImplVisitor (f ), ExtendModuleVisitor (f ))
770+ c = ChainOfVisitors (
771+ ClassDefVisitor (f , typeinfo ),
772+ TraitImplVisitor (f , typeinfo ),
773+ ExtendModuleVisitor (f , typeinfo ),
774+ )
813775 c .visit (mod )
814776
815777
@@ -830,7 +792,7 @@ def main(input_filename, ast_mod_filename, ast_def_filename, dump_module=False):
830792 write_ast_def (mod , typeinfo , def_file )
831793
832794 mod_file .write (auto_gen_msg )
833- write_ast_mod (mod , mod_file )
795+ write_ast_mod (mod , typeinfo , mod_file )
834796
835797 print (f"{ ast_def_filename } , { ast_mod_filename } regenerated." )
836798
0 commit comments