Skip to content

Commit 4313abc

Browse files
committed
Fix type.__new__ for enum
1 parent dd6ca78 commit 4313abc

File tree

2 files changed

+64
-22
lines changed

2 files changed

+64
-22
lines changed

vm/src/obj/objmappingproxy.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use super::objstr::PyStringRef;
22
use super::objtype::{self, PyClassRef};
3-
use crate::pyobject::{PyClassImpl, PyContext, PyRef, PyResult, PyValue};
3+
use crate::function::OptionalArg;
4+
use crate::pyobject::{PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue};
45
use crate::vm::VirtualMachine;
56

67
#[pyclass]
@@ -23,6 +24,14 @@ impl PyMappingProxy {
2324
PyMappingProxy { class }
2425
}
2526

27+
#[pymethod]
28+
fn get(&self, key: PyStringRef, default: OptionalArg, vm: &VirtualMachine) -> PyObjectRef {
29+
let default = default.into_option();
30+
objtype::class_get_attr(&self.class, key.as_str())
31+
.or(default)
32+
.unwrap_or_else(|| vm.get_none())
33+
}
34+
2635
#[pymethod(name = "__getitem__")]
2736
pub fn getitem(&self, key: PyStringRef, vm: &VirtualMachine) -> PyResult {
2837
if let Some(value) = objtype::class_get_attr(&self.class, key.as_str()) {

vm/src/obj/objtype.rs

Lines changed: 54 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use std::cell::RefCell;
22
use std::collections::HashMap;
33
use std::fmt;
44

5-
use crate::function::{Args, KwArgs, PyFuncArgs};
5+
use crate::function::PyFuncArgs;
66
use crate::pyobject::{
77
IdProtocol, PyAttributes, PyContext, PyIterable, PyObject, PyObjectRef, PyRef, PyResult,
88
PyValue, TypeProtocol,
@@ -203,6 +203,11 @@ impl PyClassRef {
203203
}
204204
}
205205

206+
fn type_mro(cls: PyClassRef, vm: &VirtualMachine) -> PyObjectRef {
207+
vm.ctx
208+
.new_list(cls.mro.iter().map(|x| x.clone().into_object()).collect())
209+
}
210+
206211
/*
207212
* The magical type type
208213
*/
@@ -213,13 +218,14 @@ pub fn init(ctx: &PyContext) {
213218
type(name, bases, dict) -> a new type";
214219

215220
extend_class!(&ctx, &ctx.types.type_type, {
221+
"mro" => ctx.new_rustfunc(type_mro),
216222
"__call__" => ctx.new_rustfunc(type_call),
217223
"__dict__" =>
218224
PropertyBuilder::new(ctx)
219225
.add_getter(type_dict)
220226
.add_setter(type_dict_setter)
221227
.create(),
222-
"__new__" => ctx.new_rustfunc(type_new),
228+
"__new__" => ctx.new_classmethod(type_new),
223229
"__mro__" =>
224230
PropertyBuilder::new(ctx)
225231
.add_getter(PyClassRef::mro)
@@ -260,13 +266,30 @@ pub fn issubclass(subclass: &PyClassRef, cls: &PyClassRef) -> bool {
260266
subclass.is(cls) || mro.iter().any(|c| c.is(cls.as_object()))
261267
}
262268

263-
pub fn type_new(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult {
269+
pub fn type_new(
270+
zelf: PyClassRef,
271+
cls: PyClassRef,
272+
args: PyFuncArgs,
273+
vm: &VirtualMachine,
274+
) -> PyResult {
264275
vm_trace!("type.__new__ {:?}", args);
265-
if args.args.len() == 2 {
266-
Ok(args.args[1].class().into_object())
267-
} else if args.args.len() == 4 {
268-
let (typ, name, bases, dict) = args.bind(vm)?;
269-
type_new_class(vm, typ, name, bases, dict).map(PyRef::into_object)
276+
277+
if !issubclass(&cls, &zelf) {
278+
return Err(vm.new_type_error(format!(
279+
"{zelf}.__new__({cls}): {cls} is not a subtype of {zelf}",
280+
zelf = zelf.name,
281+
cls = cls.name,
282+
)));
283+
}
284+
285+
// let new = class_get_super_attr(&zelf, "__new__").expect("Couldn't find __new__");
286+
287+
// vm.invoke(&new, args.insert(cls.into_object()));
288+
if args.args.len() == 1 {
289+
Ok(args.args[0].class().into_object())
290+
} else if args.args.len() == 3 {
291+
let (name, bases, dict) = args.bind(vm)?;
292+
type_new_class(vm, cls, name, bases, dict).map(PyRef::into_object)
270293
} else {
271294
Err(vm.new_type_error("type() takes 1 or 3 arguments".to_string()))
272295
}
@@ -284,15 +307,21 @@ pub fn type_new_class(
284307
new(typ.clone(), name.as_str(), bases, dict.to_attributes())
285308
}
286309

287-
pub fn type_call(class: PyClassRef, args: Args, kwargs: KwArgs, vm: &VirtualMachine) -> PyResult {
310+
pub fn type_call(class: PyClassRef, args: PyFuncArgs, vm: &VirtualMachine) -> PyResult {
288311
vm_trace!("type_call: {:?}", class);
289312
let new = class_get_attr(&class, "__new__").expect("All types should have a __new__.");
290-
let new_wrapped = vm.call_get_descriptor(new, class.into_object())?;
291-
let obj = vm.invoke(&new_wrapped, (&args, &kwargs))?;
313+
let new_wrapped = vm.call_get_descriptor(new, class.clone().into_object())?;
314+
// TODO: don't do this, init __new__ based on tp_new
315+
let new_args = if class.is(&vm.ctx.types.type_type) {
316+
args.insert(class.into_object())
317+
} else {
318+
args.clone()
319+
};
320+
let obj = vm.invoke(&new_wrapped, new_args)?;
292321

293322
if let Some(init_method_or_err) = vm.get_method(obj.clone(), "__init__") {
294323
let init_method = init_method_or_err?;
295-
let res = vm.invoke(&init_method, (&args, &kwargs))?;
324+
let res = vm.invoke(&init_method, args)?;
296325
if !res.is(&vm.get_none()) {
297326
return Err(vm.new_type_error("__init__ must return None".to_string()));
298327
}
@@ -314,15 +343,19 @@ fn type_dict_setter(_instance: PyClassRef, _value: PyObjectRef, vm: &VirtualMach
314343
pub fn class_get_attr(class: &PyClassRef, attr_name: &str) -> Option<PyObjectRef> {
315344
flame_guard!(format!("class_get_attr({:?})", attr_name));
316345

317-
if let Some(item) = class.attributes.borrow().get(attr_name).cloned() {
318-
return Some(item);
319-
}
320-
for class in &class.mro {
321-
if let Some(item) = class.attributes.borrow().get(attr_name).cloned() {
322-
return Some(item);
323-
}
324-
}
325-
None
346+
class
347+
.attributes
348+
.borrow()
349+
.get(attr_name)
350+
.cloned()
351+
.or_else(|| class_get_super_attr(class, attr_name))
352+
}
353+
354+
pub fn class_get_super_attr(class: &PyClassRef, attr_name: &str) -> Option<PyObjectRef> {
355+
class
356+
.mro
357+
.iter()
358+
.find_map(|class| class.attributes.borrow().get(attr_name).cloned())
326359
}
327360

328361
// This is the internal has_attr implementation for fast lookup on a class.

0 commit comments

Comments
 (0)