Skip to content

Commit 499bf14

Browse files
committed
Only overwrite the locals in __build_class__ if it's a class
1 parent 8ac743d commit 499bf14

File tree

8 files changed

+54
-47
lines changed

8 files changed

+54
-47
lines changed

bytecode/src/bytecode.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ bitflags! {
5353
const HAS_DEFAULTS = 0x01;
5454
const HAS_KW_ONLY_DEFAULTS = 0x02;
5555
const HAS_ANNOTATIONS = 0x04;
56+
const NO_NEW_LOCALS = 0x08;
5657
}
5758
}
5859

compiler/src/compile.rs

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -955,12 +955,21 @@ impl<O: OutputStream> Compiler<O> {
955955

956956
self.emit(Instruction::LoadName {
957957
name: "__name__".to_string(),
958-
scope: bytecode::NameScope::Free,
958+
scope: bytecode::NameScope::Global,
959959
});
960960
self.emit(Instruction::StoreName {
961961
name: "__module__".to_string(),
962962
scope: bytecode::NameScope::Free,
963963
});
964+
self.emit(Instruction::LoadConst {
965+
value: bytecode::Constant::String {
966+
value: qualified_name.clone(),
967+
},
968+
});
969+
self.emit(Instruction::StoreName {
970+
name: "__qualname__".to_string(),
971+
scope: bytecode::NameScope::Free,
972+
});
964973
self.compile_statements(new_body)?;
965974
self.emit(Instruction::LoadConst {
966975
value: bytecode::Constant::None,
@@ -983,7 +992,7 @@ impl<O: OutputStream> Compiler<O> {
983992

984993
// Turn code object into function object:
985994
self.emit(Instruction::MakeFunction {
986-
flags: bytecode::FunctionOpArg::empty(),
995+
flags: bytecode::FunctionOpArg::NO_NEW_LOCALS,
987996
});
988997

989998
self.emit(Instruction::LoadConst {

compiler/src/symboltable.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,8 @@ impl SymbolTableBuilder {
387387
decorator_list,
388388
} => {
389389
self.enter_scope(name, SymbolTableType::Class, statement.location.row());
390+
self.register_name("__module__", SymbolUsage::Assigned)?;
391+
self.register_name("__qualname__", SymbolUsage::Assigned)?;
390392
self.scan_statements(body)?;
391393
self.leave_scope();
392394
self.scan_expressions(bases, &ExpressionContext::Load)?;

vm/src/builtins.rs

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ use crate::obj::objbool::{self, IntoPyBool};
1414
use crate::obj::objbytes::PyBytesRef;
1515
use crate::obj::objcode::PyCodeRef;
1616
use crate::obj::objdict::PyDictRef;
17+
use crate::obj::objfunction::PyFunctionRef;
1718
use crate::obj::objint::{self, PyIntRef};
1819
use crate::obj::objiter;
1920
use crate::obj::objstr::{PyString, PyStringRef};
@@ -832,6 +833,7 @@ pub fn make_module(vm: &VirtualMachine, module: PyObjectRef) {
832833
"exit" => ctx.new_rustfunc(builtin_exit),
833834
"quit" => ctx.new_rustfunc(builtin_exit),
834835
"__import__" => ctx.new_rustfunc(builtin_import),
836+
"__build_class__" => ctx.new_rustfunc(builtin_build_class_),
835837

836838
// Constants
837839
"NotImplemented" => ctx.not_implemented(),
@@ -888,7 +890,7 @@ pub fn make_module(vm: &VirtualMachine, module: PyObjectRef) {
888890
}
889891

890892
pub fn builtin_build_class_(
891-
function: PyObjectRef,
893+
function: PyFunctionRef,
892894
qualified_name: PyStringRef,
893895
bases: Args<PyClassRef>,
894896
mut kwargs: KwArgs,
@@ -925,10 +927,12 @@ pub fn builtin_build_class_(
925927

926928
let cells = vm.ctx.new_dict();
927929

928-
vm.invoke_with_locals(&function, cells.clone(), namespace.clone())?;
930+
let scope = function
931+
.scope
932+
.new_child_scope_with_locals(cells.clone())
933+
.new_child_scope_with_locals(namespace.clone());
929934

930-
namespace.set_item("__name__", name_obj.clone(), vm)?;
931-
namespace.set_item("__qualname__", qualified_name.into_object(), vm)?;
935+
vm.invoke_python_function_with_scope(&function, vec![].into(), &scope)?;
932936

933937
let class = vm.call_method(
934938
metaclass.as_object(),

vm/src/frame.rs

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
use std::cell::RefCell;
22
use std::fmt;
33

4-
use crate::builtins;
54
use crate::bytecode;
65
use crate::function::PyFuncArgs;
76
use crate::obj::objbool;
@@ -460,7 +459,7 @@ impl Frame {
460459
Ok(None)
461460
}
462461
bytecode::Instruction::LoadBuildClass => {
463-
self.push_value(vm.ctx.new_rustfunc(builtins::builtin_build_class_));
462+
self.push_value(vm.get_attribute(vm.builtins.clone(), "__build_class__")?);
464463
Ok(None)
465464
}
466465
bytecode::Instruction::UnpackSequence { size } => {
@@ -1054,9 +1053,13 @@ impl Frame {
10541053
// pop argc arguments
10551054
// argument: name, args, globals
10561055
let scope = self.scope.clone();
1057-
let func_obj = vm
1058-
.ctx
1059-
.new_function(code_obj, scope, defaults, kw_only_defaults);
1056+
let func_obj = vm.ctx.new_function(
1057+
code_obj,
1058+
scope,
1059+
defaults,
1060+
kw_only_defaults,
1061+
!flags.contains(bytecode::FunctionOpArg::NO_NEW_LOCALS),
1062+
);
10601063

10611064
let name = qualified_name.as_str().split('.').next_back().unwrap();
10621065
vm.set_attr(&func_obj, "__name__", vm.new_str(name.to_string()))?;

vm/src/obj/objfunction.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ pub struct PyFunction {
1616
pub scope: Scope,
1717
pub defaults: Option<PyTupleRef>,
1818
pub kw_only_defaults: Option<PyDictRef>,
19+
pub new_locals: bool,
1920
}
2021

2122
impl PyFunction {
@@ -24,12 +25,14 @@ impl PyFunction {
2425
scope: Scope,
2526
defaults: Option<PyTupleRef>,
2627
kw_only_defaults: Option<PyDictRef>,
28+
new_locals: bool,
2729
) -> Self {
2830
PyFunction {
2931
code,
3032
scope,
3133
defaults,
3234
kw_only_defaults,
35+
new_locals,
3336
}
3437
}
3538
}

vm/src/pyobject.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -478,9 +478,10 @@ impl PyContext {
478478
scope: Scope,
479479
defaults: Option<PyTupleRef>,
480480
kw_only_defaults: Option<PyDictRef>,
481+
new_locals: bool,
481482
) -> PyObjectRef {
482483
PyObject::new(
483-
PyFunction::new(code_obj, scope, defaults, kw_only_defaults),
484+
PyFunction::new(code_obj, scope, defaults, kw_only_defaults, new_locals),
484485
self.function_type(),
485486
Some(self.new_dict()),
486487
)

vm/src/vm.rs

Lines changed: 19 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -570,15 +570,9 @@ impl VirtualMachine {
570570
fn _invoke(&self, func_ref: &PyObjectRef, args: PyFuncArgs) -> PyResult {
571571
vm_trace!("Invoke: {:?} {:?}", func_ref, args);
572572

573-
if let Some(PyFunction {
574-
ref code,
575-
ref scope,
576-
ref defaults,
577-
ref kw_only_defaults,
578-
}) = func_ref.payload()
579-
{
573+
if let Some(py_func) = func_ref.payload() {
580574
self.trace_event(TraceEvent::Call)?;
581-
let res = self.invoke_python_function(code, scope, defaults, kw_only_defaults, args);
575+
let res = self.invoke_python_function(py_func, args);
582576
self.trace_event(TraceEvent::Return)?;
583577
res
584578
} else if let Some(PyMethod {
@@ -634,21 +628,30 @@ impl VirtualMachine {
634628
Ok(())
635629
}
636630

637-
fn invoke_python_function(
631+
pub fn invoke_python_function(&self, func: &PyFunction, func_args: PyFuncArgs) -> PyResult {
632+
self.invoke_python_function_with_scope(func, func_args, &func.scope)
633+
}
634+
635+
pub fn invoke_python_function_with_scope(
638636
&self,
639-
code: &PyCodeRef,
640-
scope: &Scope,
641-
defaults: &Option<PyTupleRef>,
642-
kw_only_defaults: &Option<PyDictRef>,
637+
func: &PyFunction,
643638
func_args: PyFuncArgs,
639+
scope: &Scope,
644640
) -> PyResult {
645-
let scope = scope.new_child_scope(&self.ctx);
641+
let code = &func.code;
642+
643+
let scope = if func.new_locals {
644+
scope.new_child_scope(&self.ctx)
645+
} else {
646+
scope.clone()
647+
};
648+
646649
self.fill_locals_from_args(
647650
&code.code,
648651
&scope.get_locals(),
649652
func_args,
650-
defaults,
651-
kw_only_defaults,
653+
&func.defaults,
654+
&func.kw_only_defaults,
652655
)?;
653656

654657
// Construct frame:
@@ -662,25 +665,6 @@ impl VirtualMachine {
662665
}
663666
}
664667

665-
pub fn invoke_with_locals(
666-
&self,
667-
function: &PyObjectRef,
668-
cells: PyDictRef,
669-
locals: PyDictRef,
670-
) -> PyResult {
671-
if let Some(PyFunction { code, scope, .. }) = &function.payload() {
672-
let scope = scope
673-
.new_child_scope_with_locals(cells)
674-
.new_child_scope_with_locals(locals);
675-
let frame = Frame::new(code.clone(), scope).into_ref(self);
676-
return self.run_frame_full(frame);
677-
}
678-
panic!(
679-
"invoke_with_locals: expected python function, got: {:?}",
680-
*function
681-
);
682-
}
683-
684668
fn fill_locals_from_args(
685669
&self,
686670
code_object: &bytecode::CodeObject,

0 commit comments

Comments
 (0)