Skip to content

Commit 5d82e2c

Browse files
committed
Fix equality check in list_count, list_index and list_contains
1 parent bf19d65 commit 5d82e2c

File tree

3 files changed

+34
-10
lines changed

3 files changed

+34
-10
lines changed

tests/snippets/list.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,21 @@
9292
x = [1]
9393
x.append(x)
9494
assert x in x
95+
assert x.index(x) == 1
96+
assert x.count(x) == 1
97+
98+
class Foo(object):
99+
def __eq__(self, x):
100+
return False
101+
102+
foo = Foo()
103+
foo1 = Foo()
104+
x = [1, foo, 2, foo, []]
105+
assert foo in x
106+
assert 2 in x
107+
assert x.index(foo) == 1
108+
assert x.count(foo) == 2
109+
assert x.index(2) == 2
110+
assert [] in x
111+
assert x.index([]) == 4
112+
assert foo1 not in x

vm/src/obj/objlist.rs

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use super::super::pyobject::{
2-
PyContext, PyFuncArgs, PyObject, PyObjectPayload, PyObjectRef, PyResult, TypeProtocol,
2+
IdProtocol, PyContext, PyFuncArgs, PyObject, PyObjectPayload, PyObjectRef, PyResult, TypeProtocol,
33
};
44
use super::super::vm::{ReprGuard, VirtualMachine};
55
use super::objbool;
@@ -234,9 +234,13 @@ fn list_count(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
234234
let elements = get_elements(zelf);
235235
let mut count: usize = 0;
236236
for element in elements.iter() {
237-
let is_eq = vm._eq(element.clone(), value.clone())?;
238-
if objbool::boolval(vm, is_eq)? {
237+
if value.is(&element) {
239238
count += 1;
239+
} else {
240+
let is_eq = vm._eq(element.clone(), value.clone())?;
241+
if objbool::boolval(vm, is_eq)? {
242+
count += 1;
243+
}
240244
}
241245
}
242246
Ok(vm.context().new_int(count))
@@ -262,6 +266,9 @@ fn list_index(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
262266
required = [(list, Some(vm.ctx.list_type())), (needle, None)]
263267
);
264268
for (index, element) in get_elements(list).iter().enumerate() {
269+
if needle.is(&element) {
270+
return Ok(vm.context().new_int(index));
271+
}
265272
let py_equal = vm._eq(needle.clone(), element.clone())?;
266273
if objbool::get_value(&py_equal) {
267274
return Ok(vm.context().new_int(index));
@@ -335,6 +342,9 @@ fn list_contains(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
335342
required = [(list, Some(vm.ctx.list_type())), (needle, None)]
336343
);
337344
for element in get_elements(list).iter() {
345+
if needle.is(&element) {
346+
return Ok(vm.new_bool(true));
347+
}
338348
match vm._eq(needle.clone(), element.clone()) {
339349
Ok(value) => {
340350
if objbool::get_value(&value) {

vm/src/vm.rs

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -608,13 +608,9 @@ impl VirtualMachine {
608608
}
609609

610610
pub fn _eq(&mut self, a: PyObjectRef, b: PyObjectRef) -> PyResult {
611-
if a.is(&b) {
612-
Ok(self.new_bool(true))
613-
} else {
614-
self.call_or_unsupported(a, b, "__eq__", "__eq__", |vm, _, _| {
615-
Ok(vm.new_bool(false))
616-
})
617-
}
611+
self.call_or_unsupported(a, b, "__eq__", "__eq__", |vm, a, b| {
612+
Ok(vm.new_bool(a.is(&b)))
613+
})
618614
}
619615

620616
pub fn _ne(&mut self, a: PyObjectRef, b: PyObjectRef) -> PyResult {

0 commit comments

Comments
 (0)