Skip to content

Commit bd5772d

Browse files
committed
Implement dict.__eq__
1 parent 939f109 commit bd5772d

2 files changed

Lines changed: 45 additions & 12 deletions

File tree

tests/snippets/dict.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,23 @@
11
from testutils import assertRaises
22

3-
def dict_eq(d1, d2):
4-
return (all(k in d2 and d1[k] == d2[k] for k in d1)
5-
and all(k in d1 and d1[k] == d2[k] for k in d2))
3+
assert dict(a=2, b=3) == {'a': 2, 'b': 3}
4+
assert dict({'a': 2, 'b': 3}, b=4) == {'a': 2, 'b': 4}
5+
assert dict([('a', 2), ('b', 3)]) == {'a': 2, 'b': 3}
66

7-
8-
assert dict_eq(dict(a=2, b=3), {'a': 2, 'b': 3})
9-
assert dict_eq(dict({'a': 2, 'b': 3}, b=4), {'a': 2, 'b': 4})
10-
assert dict_eq(dict([('a', 2), ('b', 3)]), {'a': 2, 'b': 3})
7+
assert {} == {}
8+
assert not {'a': 2} == {}
9+
assert not {} == {'a': 2}
10+
assert not {'b': 2} == {'a': 2}
11+
assert not {'a': 4} == {'a': 2}
12+
assert {'a': 2} == {'a': 2}
1113

1214
a = {'g': 5}
1315
b = {'a': a, 'd': 9}
1416
c = dict(b)
1517
c['d'] = 3
1618
c['a']['g'] = 2
17-
assert dict_eq(a, {'g': 2})
18-
assert dict_eq(b, {'a': a, 'd': 9})
19+
assert a == {'g': 2}
20+
assert b == {'a': a, 'd': 9}
1921

2022
a.clear()
2123
assert len(a) == 0
@@ -142,10 +144,10 @@ def __eq__(self, other):
142144

143145
y = x.copy()
144146
x['c'] = 12
145-
assert dict_eq(y, {'a': 2, 'b': 10})
147+
assert y == {'a': 2, 'b': 10}
146148

147149
y.update({'c': 19, "d": -1, 'b': 12})
148-
assert dict_eq(y, {'a': 2, 'b': 12, 'c': 19, 'd': -1})
150+
assert y == {'a': 2, 'b': 12, 'c': 19, 'd': -1}
149151

150152
y.update(y)
151-
assert dict_eq(y, {'a': 2, 'b': 12, 'c': 19, 'd': -1}) # hasn't changed
153+
assert y == {'a': 2, 'b': 12, 'c': 19, 'd': -1} # hasn't changed

vm/src/obj/objdict.rs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use crate::pyobject::{
77
};
88
use crate::vm::{ReprGuard, VirtualMachine};
99

10+
use super::objbool;
1011
use super::objiter;
1112
use super::objstr;
1213
use crate::dictdatatype;
@@ -96,6 +97,35 @@ impl PyDictRef {
9697
!self.entries.borrow().is_empty()
9798
}
9899

100+
fn inner_eq(self, other: &PyDict, vm: &VirtualMachine) -> PyResult<bool> {
101+
if other.entries.borrow().len() != self.entries.borrow().len() {
102+
return Ok(false);
103+
}
104+
for (k, v1) in self {
105+
match other.entries.borrow().get(vm, &k)? {
106+
Some(v2) => {
107+
let value = objbool::boolval(vm, vm._eq(v1, v2)?)?;
108+
if !value {
109+
return Ok(false);
110+
}
111+
}
112+
None => {
113+
return Ok(false);
114+
}
115+
}
116+
}
117+
return Ok(true);
118+
}
119+
120+
fn eq(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult {
121+
if let Some(other) = other.payload::<PyDict>() {
122+
let eq = self.inner_eq(other, vm)?;
123+
Ok(vm.ctx.new_bool(eq))
124+
} else {
125+
Ok(vm.ctx.not_implemented())
126+
}
127+
}
128+
99129
fn len(self, _vm: &VirtualMachine) -> usize {
100130
self.entries.borrow().len()
101131
}
@@ -387,6 +417,7 @@ pub fn init(context: &PyContext) {
387417
"__len__" => context.new_rustfunc(PyDictRef::len),
388418
"__contains__" => context.new_rustfunc(PyDictRef::contains),
389419
"__delitem__" => context.new_rustfunc(PyDictRef::inner_delitem),
420+
"__eq__" => context.new_rustfunc(PyDictRef::eq),
390421
"__getitem__" => context.new_rustfunc(PyDictRef::inner_getitem),
391422
"__iter__" => context.new_rustfunc(PyDictRef::iter),
392423
"__new__" => context.new_rustfunc(PyDictRef::new),

0 commit comments

Comments
 (0)