Skip to content

Commit 2f29798

Browse files
committed
Fixed __contains__ comparison on lists, tuples and iter objects
1 parent a450c0e commit 2f29798

File tree

3 files changed

+31
-15
lines changed

3 files changed

+31
-15
lines changed

vm/src/obj/objiter.rs

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use super::super::pyobject::{
77
TypeProtocol,
88
};
99
use super::super::vm::VirtualMachine;
10+
use super::objbool;
1011
use super::objstr;
1112
use super::objtype; // Required for arg_check! to use isinstance
1213

@@ -61,13 +62,16 @@ fn iter_contains(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
6162
);
6263
loop {
6364
match vm.call_method(&iter, "__next__", vec![]) {
64-
Ok(element) => {
65-
if &element == needle {
66-
return Ok(vm.new_bool(true));
67-
} else {
68-
continue;
65+
Ok(element) => match vm.call_method(needle, "__eq__", vec![element.clone()]) {
66+
Ok(value) => {
67+
if objbool::get_value(&value) {
68+
return Ok(vm.new_bool(true));
69+
} else {
70+
continue;
71+
}
6972
}
70-
}
73+
Err(_) => return Err(vm.new_type_error("".to_string())),
74+
},
7175
Err(_) => return Ok(vm.new_bool(false)),
7276
}
7377
}

vm/src/obj/objlist.rs

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use super::super::pyobject::{
22
AttributeProtocol, PyContext, PyFuncArgs, PyObjectKind, PyObjectRef, PyResult, TypeProtocol,
33
};
44
use super::super::vm::VirtualMachine;
5+
use super::objbool;
56
use super::objiter;
67
use super::objsequence::{seq_equal, PySliceableSequence};
78
use super::objstr;
@@ -162,16 +163,21 @@ fn reverse(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
162163
}
163164
}
164165

165-
pub fn contains(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
166-
trace!("list.len called with: {:?}", args);
166+
fn list_contains(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
167+
trace!("list.contains called with: {:?}", args);
167168
arg_check!(
168169
vm,
169170
args,
170-
required = [(list, Some(vm.ctx.list_type())), (x, None)]
171+
required = [(list, Some(vm.ctx.list_type())), (needle, None)]
171172
);
172173
for element in get_elements(list).iter() {
173-
if x == element {
174-
return Ok(vm.new_bool(true));
174+
match vm.call_method(needle, "__eq__", vec![element.clone()]) {
175+
Ok(value) => {
176+
if objbool::get_value(&value) {
177+
return Ok(vm.new_bool(true));
178+
}
179+
}
180+
Err(_) => return Err(vm.new_type_error("".to_string())),
175181
}
176182
}
177183

@@ -181,7 +187,7 @@ pub fn contains(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
181187
pub fn init(context: &PyContext) {
182188
let ref list_type = context.list_type;
183189
list_type.set_attr("__add__", context.new_rustfunc(list_add));
184-
list_type.set_attr("__contains__", context.new_rustfunc(contains));
190+
list_type.set_attr("__contains__", context.new_rustfunc(list_contains));
185191
list_type.set_attr("__eq__", context.new_rustfunc(list_eq));
186192
list_type.set_attr("__len__", context.new_rustfunc(list_len));
187193
list_type.set_attr("__new__", context.new_rustfunc(list_new));

vm/src/obj/objtuple.rs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use super::super::pyobject::{
22
AttributeProtocol, PyContext, PyFuncArgs, PyObjectKind, PyObjectRef, PyResult, TypeProtocol,
33
};
44
use super::super::vm::VirtualMachine;
5+
use super::objbool;
56
use super::objsequence::seq_equal;
67
use super::objstr;
78
use super::objtype;
@@ -54,11 +55,16 @@ pub fn tuple_contains(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
5455
arg_check!(
5556
vm,
5657
args,
57-
required = [(tuple, Some(vm.ctx.tuple_type())), (x, None)]
58+
required = [(tuple, Some(vm.ctx.tuple_type())), (needle, None)]
5859
);
5960
for element in get_elements(tuple).iter() {
60-
if x == element {
61-
return Ok(vm.new_bool(true));
61+
match vm.call_method(needle, "__eq__", vec![element.clone()]) {
62+
Ok(value) => {
63+
if objbool::get_value(&value) {
64+
return Ok(vm.new_bool(true));
65+
}
66+
}
67+
Err(_) => return Err(vm.new_type_error("".to_string())),
6268
}
6369
}
6470

0 commit comments

Comments
 (0)