@@ -87,6 +87,14 @@ impl CompileContext {
8787 }
8888}
8989
90+ #[ derive( Debug , Clone , Copy , PartialEq ) ]
91+ enum ComprehensionType {
92+ Generator ,
93+ List ,
94+ Set ,
95+ Dict ,
96+ }
97+
9098/// Compile an located_ast::Mod produced from rustpython_parser::parse()
9199pub fn compile_top (
92100 ast : & located_ast:: Mod ,
@@ -2431,6 +2439,8 @@ impl Compiler {
24312439 ) ;
24322440 Ok ( ( ) )
24332441 } ,
2442+ ComprehensionType :: List ,
2443+ Self :: contains_await ( elt) ,
24342444 ) ?;
24352445 }
24362446 Expr :: SetComp ( located_ast:: ExprSetComp {
@@ -2452,6 +2462,8 @@ impl Compiler {
24522462 ) ;
24532463 Ok ( ( ) )
24542464 } ,
2465+ ComprehensionType :: Set ,
2466+ Self :: contains_await ( elt) ,
24552467 ) ?;
24562468 }
24572469 Expr :: DictComp ( located_ast:: ExprDictComp {
@@ -2480,19 +2492,28 @@ impl Compiler {
24802492
24812493 Ok ( ( ) )
24822494 } ,
2495+ ComprehensionType :: Dict ,
2496+ Self :: contains_await ( key) || Self :: contains_await ( value) ,
24832497 ) ?;
24842498 }
24852499 Expr :: GeneratorExp ( located_ast:: ExprGeneratorExp {
24862500 elt, generators, ..
24872501 } ) => {
2488- self . compile_comprehension ( "<genexpr>" , None , generators, & |compiler| {
2489- compiler. compile_comprehension_element ( elt) ?;
2490- compiler. mark_generator ( ) ;
2491- emit ! ( compiler, Instruction :: YieldValue ) ;
2492- emit ! ( compiler, Instruction :: Pop ) ;
2502+ self . compile_comprehension (
2503+ "<genexpr>" ,
2504+ None ,
2505+ generators,
2506+ & |compiler| {
2507+ compiler. compile_comprehension_element ( elt) ?;
2508+ compiler. mark_generator ( ) ;
2509+ emit ! ( compiler, Instruction :: YieldValue ) ;
2510+ emit ! ( compiler, Instruction :: Pop ) ;
24932511
2494- Ok ( ( ) )
2495- } ) ?;
2512+ Ok ( ( ) )
2513+ } ,
2514+ ComprehensionType :: Generator ,
2515+ Self :: contains_await ( elt) ,
2516+ ) ?;
24962517 }
24972518 Expr :: Starred ( _) => {
24982519 return Err ( self . error ( CodegenErrorType :: InvalidStarExpr ) ) ;
@@ -2744,9 +2765,35 @@ impl Compiler {
27442765 init_collection : Option < Instruction > ,
27452766 generators : & [ located_ast:: Comprehension ] ,
27462767 compile_element : & dyn Fn ( & mut Self ) -> CompileResult < ( ) > ,
2768+ comprehension_type : ComprehensionType ,
2769+ element_contains_await : bool ,
27472770 ) -> CompileResult < ( ) > {
27482771 let prev_ctx = self . ctx ;
2749- let is_async = generators. iter ( ) . any ( |g| g. is_async ) ;
2772+ let has_an_async_gen = generators. iter ( ) . any ( |g| g. is_async ) ;
2773+
2774+ // async comprehensions are allowed in various contexts:
2775+ // - list/set/dict comprehensions in async functions
2776+ // - always for generator expressions
2777+ // Note: generators have to be treated specially since their async version is a fundamentally
2778+ // different type (aiter vs iter) instead of just an awaitable.
2779+
2780+ // for if it actually is async, we check if any generator is async or if the element contains await
2781+
2782+ // if the element expression contains await, but the context doesn't allow for async,
2783+ // then we continue on here with is_async=false and will produce a syntax once the await is hit
2784+
2785+ let is_async_list_set_dict_comprehension = comprehension_type
2786+ != ComprehensionType :: Generator
2787+ && ( has_an_async_gen || element_contains_await) // does it have to be async? (uses await or async for)
2788+ && prev_ctx. func == FunctionContext :: AsyncFunction ; // is it allowed to be async? (in an async function)
2789+
2790+ let is_async_generator_comprehension = comprehension_type == ComprehensionType :: Generator
2791+ && ( has_an_async_gen || element_contains_await) ;
2792+
2793+ // since one is for generators, and one for not generators, they should never both be true
2794+ debug_assert ! ( !( is_async_list_set_dict_comprehension && is_async_generator_comprehension) ) ;
2795+
2796+ let is_async = is_async_list_set_dict_comprehension || is_async_generator_comprehension;
27502797
27512798 self . ctx = CompileContext {
27522799 loop_data : None ,
@@ -2838,7 +2885,7 @@ impl Compiler {
28382885
28392886 // End of for loop:
28402887 self . switch_to_block ( after_block) ;
2841- if is_async {
2888+ if has_an_async_gen {
28422889 emit ! ( self , Instruction :: EndAsyncFor ) ;
28432890 }
28442891 }
@@ -2877,19 +2924,23 @@ impl Compiler {
28772924 self . compile_expression ( & generators[ 0 ] . iter ) ?;
28782925
28792926 // Get iterator / turn item into an iterator
2880- if is_async {
2927+ if has_an_async_gen {
28812928 emit ! ( self , Instruction :: GetAIter ) ;
28822929 } else {
28832930 emit ! ( self , Instruction :: GetIter ) ;
28842931 } ;
28852932
28862933 // Call just created <listcomp> function:
28872934 emit ! ( self , Instruction :: CallFunctionPositional { nargs: 1 } ) ;
2888- if is_async {
2935+ if is_async_list_set_dict_comprehension {
2936+ // async, but not a generator and not an async for
2937+ // in this case, we end up with an awaitable
2938+ // that evaluates to the list/set/dict, so here we add an await
28892939 emit ! ( self , Instruction :: GetAwaitable ) ;
28902940 self . emit_load_const ( ConstantData :: None ) ;
28912941 emit ! ( self , Instruction :: YieldFrom ) ;
28922942 }
2943+
28932944 Ok ( ( ) )
28942945 }
28952946
@@ -3016,6 +3067,117 @@ impl Compiler {
30163067 fn mark_generator ( & mut self ) {
30173068 self . current_code_info ( ) . flags |= bytecode:: CodeFlags :: IS_GENERATOR
30183069 }
3070+
3071+ /// Whether the expression contains an await expression and
3072+ /// thus requires the function to be async.
3073+ /// Async with and async for are statements, so I won't check for them here
3074+ fn contains_await ( expression : & located_ast:: Expr ) -> bool {
3075+ use located_ast:: * ;
3076+
3077+ match & expression {
3078+ Expr :: Call ( ExprCall {
3079+ func,
3080+ args,
3081+ keywords,
3082+ ..
3083+ } ) => {
3084+ Self :: contains_await ( func)
3085+ || args. iter ( ) . any ( Self :: contains_await)
3086+ || keywords. iter ( ) . any ( |kw| Self :: contains_await ( & kw. value ) )
3087+ }
3088+ Expr :: BoolOp ( ExprBoolOp { values, .. } ) => values. iter ( ) . any ( Self :: contains_await) ,
3089+ Expr :: BinOp ( ExprBinOp { left, right, .. } ) => {
3090+ Self :: contains_await ( left) || Self :: contains_await ( right)
3091+ }
3092+ Expr :: Subscript ( ExprSubscript { value, slice, .. } ) => {
3093+ Self :: contains_await ( value) || Self :: contains_await ( slice)
3094+ }
3095+ Expr :: UnaryOp ( ExprUnaryOp { operand, .. } ) => Self :: contains_await ( operand) ,
3096+ Expr :: Attribute ( ExprAttribute { value, .. } ) => Self :: contains_await ( value) ,
3097+ Expr :: Compare ( ExprCompare {
3098+ left, comparators, ..
3099+ } ) => Self :: contains_await ( left) || comparators. iter ( ) . any ( Self :: contains_await) ,
3100+ Expr :: Constant ( ExprConstant { .. } ) => false ,
3101+ Expr :: List ( ExprList { elts, .. } ) => elts. iter ( ) . any ( Self :: contains_await) ,
3102+ Expr :: Tuple ( ExprTuple { elts, .. } ) => elts. iter ( ) . any ( Self :: contains_await) ,
3103+ Expr :: Set ( ExprSet { elts, .. } ) => elts. iter ( ) . any ( Self :: contains_await) ,
3104+ Expr :: Dict ( ExprDict { keys, values, .. } ) => {
3105+ keys. iter ( )
3106+ . any ( |key| key. as_ref ( ) . map_or ( false , Self :: contains_await) )
3107+ || values. iter ( ) . any ( Self :: contains_await)
3108+ }
3109+ Expr :: Slice ( ExprSlice {
3110+ lower, upper, step, ..
3111+ } ) => {
3112+ lower. as_ref ( ) . map_or ( false , |l| Self :: contains_await ( l) )
3113+ || upper. as_ref ( ) . map_or ( false , |u| Self :: contains_await ( u) )
3114+ || step. as_ref ( ) . map_or ( false , |s| Self :: contains_await ( s) )
3115+ }
3116+ Expr :: Yield ( ExprYield { value, .. } ) => {
3117+ value. as_ref ( ) . map_or ( false , |v| Self :: contains_await ( v) )
3118+ }
3119+ Expr :: Await ( ExprAwait { .. } ) => true ,
3120+ Expr :: YieldFrom ( ExprYieldFrom { value, .. } ) => Self :: contains_await ( value) ,
3121+ Expr :: JoinedStr ( ExprJoinedStr { values, .. } ) => {
3122+ values. iter ( ) . any ( Self :: contains_await)
3123+ }
3124+ Expr :: FormattedValue ( ExprFormattedValue {
3125+ value,
3126+ conversion : _,
3127+ format_spec,
3128+ ..
3129+ } ) => {
3130+ Self :: contains_await ( value)
3131+ || format_spec
3132+ . as_ref ( )
3133+ . map_or ( false , |fs| Self :: contains_await ( fs) )
3134+ }
3135+ Expr :: Name ( located_ast:: ExprName { .. } ) => false ,
3136+ Expr :: Lambda ( located_ast:: ExprLambda { body, .. } ) => Self :: contains_await ( body) ,
3137+ Expr :: ListComp ( located_ast:: ExprListComp {
3138+ elt, generators, ..
3139+ } ) => {
3140+ Self :: contains_await ( elt)
3141+ || generators. iter ( ) . any ( |gen| Self :: contains_await ( & gen. iter ) )
3142+ }
3143+ Expr :: SetComp ( located_ast:: ExprSetComp {
3144+ elt, generators, ..
3145+ } ) => {
3146+ Self :: contains_await ( elt)
3147+ || generators. iter ( ) . any ( |gen| Self :: contains_await ( & gen. iter ) )
3148+ }
3149+ Expr :: DictComp ( located_ast:: ExprDictComp {
3150+ key,
3151+ value,
3152+ generators,
3153+ ..
3154+ } ) => {
3155+ Self :: contains_await ( key)
3156+ || Self :: contains_await ( value)
3157+ || generators. iter ( ) . any ( |gen| Self :: contains_await ( & gen. iter ) )
3158+ }
3159+ Expr :: GeneratorExp ( located_ast:: ExprGeneratorExp {
3160+ elt, generators, ..
3161+ } ) => {
3162+ Self :: contains_await ( elt)
3163+ || generators. iter ( ) . any ( |gen| Self :: contains_await ( & gen. iter ) )
3164+ }
3165+ Expr :: Starred ( expr) => Self :: contains_await ( & expr. value ) ,
3166+ Expr :: IfExp ( located_ast:: ExprIfExp {
3167+ test, body, orelse, ..
3168+ } ) => {
3169+ Self :: contains_await ( test)
3170+ || Self :: contains_await ( body)
3171+ || Self :: contains_await ( orelse)
3172+ }
3173+
3174+ Expr :: NamedExpr ( located_ast:: ExprNamedExpr {
3175+ target,
3176+ value,
3177+ range : _,
3178+ } ) => Self :: contains_await ( target) || Self :: contains_await ( value) ,
3179+ }
3180+ }
30193181}
30203182
30213183trait EmitArg < Arg : OpArgType > {
0 commit comments