Skip to content

Commit 16b2b42

Browse files
committed
Add itertools.combinations()
re: RustPython#1361
1 parent 53b3911 commit 16b2b42

File tree

2 files changed

+132
-0
lines changed

2 files changed

+132
-0
lines changed

tests/snippets/stdlib_itertools.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,7 @@ def assert_matches_seq(it, seq):
301301
assert list(t[0]) == []
302302

303303
# itertools.product
304+
304305
it = itertools.product([1, 2], [3, 4])
305306
assert (1, 3) == next(it)
306307
assert (1, 4) == next(it)
@@ -321,3 +322,25 @@ def assert_matches_seq(it, seq):
321322
itertools.product(None)
322323
with assert_raises(TypeError):
323324
itertools.product([1, 2], repeat=None)
325+
326+
# itertools.combinations
327+
328+
it = itertools.combinations([1, 2, 3, 4], 2)
329+
assert list(it) == [(1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)]
330+
331+
it = itertools.combinations([1, 2, 3], 1)
332+
assert list(it) == [(1,), (2,), (3,)]
333+
334+
it = itertools.combinations([1, 2, 3], 2)
335+
assert next(it) == (1, 2)
336+
assert next(it) == (1, 3)
337+
assert next(it) == (2, 3)
338+
with assert_raises(StopIteration):
339+
next(it)
340+
341+
it = itertools.combinations([1, 2, 3], 4)
342+
with assert_raises(StopIteration):
343+
next(it)
344+
345+
with assert_raises(ValueError):
346+
itertools.combinations([1, 2, 3, 4], -2)

vm/src/stdlib/itertools.rs

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use std::ops::{AddAssign, SubAssign};
55
use std::rc::Rc;
66

77
use num_bigint::BigInt;
8+
use num_traits::sign::Signed;
89
use num_traits::ToPrimitive;
910

1011
use crate::function::{Args, OptionalArg, PyFuncArgs};
@@ -848,6 +849,110 @@ impl PyItertoolsProduct {
848849
}
849850
}
850851

852+
#[pyclass]
853+
#[derive(Debug)]
854+
struct PyItertoolsCombinations {
855+
pool: Vec<PyObjectRef>,
856+
indices: RefCell<Vec<usize>>,
857+
r: Cell<usize>,
858+
exhausted: Cell<bool>,
859+
}
860+
861+
impl PyValue for PyItertoolsCombinations {
862+
fn class(vm: &VirtualMachine) -> PyClassRef {
863+
vm.class("itertools", "combinations")
864+
}
865+
}
866+
867+
#[pyimpl]
868+
impl PyItertoolsCombinations {
869+
#[pyslot(new)]
870+
fn tp_new(
871+
cls: PyClassRef,
872+
iterable: PyObjectRef,
873+
r: PyIntRef,
874+
vm: &VirtualMachine,
875+
) -> PyResult<PyRef<Self>> {
876+
let iter = get_iter(vm, &iterable)?;
877+
let pool = get_all(vm, &iter)?;
878+
879+
let r = r.as_bigint();
880+
if r.is_negative() {
881+
return Err(vm.new_value_error("r must be non-negative".to_string()));
882+
}
883+
let r = r.to_usize().unwrap();
884+
885+
let n = pool.len();
886+
887+
PyItertoolsCombinations {
888+
pool,
889+
indices: RefCell::new((0..r).collect()),
890+
r: Cell::new(r),
891+
exhausted: Cell::new(r > n),
892+
}
893+
.into_ref_with_type(vm, cls)
894+
}
895+
896+
#[pymethod(name = "__iter__")]
897+
fn iter(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyRef<Self> {
898+
zelf
899+
}
900+
901+
#[pymethod(name = "__next__")]
902+
fn next(&self, vm: &VirtualMachine) -> PyResult {
903+
// stop signal
904+
if self.exhausted.get() {
905+
return Err(new_stop_iteration(vm));
906+
}
907+
908+
let n = self.pool.len();
909+
let r = self.r.get();
910+
911+
let res = PyTuple::from(
912+
self.pool
913+
.iter()
914+
.enumerate()
915+
.filter(|(idx, _)| self.indices.borrow().contains(&idx))
916+
.map(|(_, num)| num.clone())
917+
.collect::<Vec<PyObjectRef>>(),
918+
);
919+
920+
let mut indices = self.indices.borrow_mut();
921+
let mut sentinel = false;
922+
923+
// Scan indices right-to-left until finding one that is not at its maximum (i + n - r).
924+
let mut idx = r - 1;
925+
loop {
926+
if indices[idx] != idx + n - r {
927+
sentinel = true;
928+
break;
929+
}
930+
931+
if idx != 0 {
932+
idx -= 1;
933+
} else {
934+
break;
935+
}
936+
}
937+
// If no suitable index is found, then the indices are all at
938+
// their maximum value and we're done.
939+
if !sentinel {
940+
self.exhausted.set(true);
941+
}
942+
943+
// Increment the current index which we know is not at its
944+
// maximum. Then move back to the right setting each index
945+
// to its lowest possible value (one higher than the index
946+
// to its left -- this maintains the sort order invariant).
947+
indices[idx] += 1;
948+
for j in idx + 1..r {
949+
indices[j] = indices[j - 1] + 1;
950+
}
951+
952+
Ok(res.into_ref(vm).into_object())
953+
}
954+
}
955+
851956
pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
852957
let ctx = &vm.ctx;
853958

@@ -858,6 +963,9 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
858963

859964
let compress = PyItertoolsCompress::make_class(ctx);
860965

966+
let combinations = ctx.new_class("combinations", ctx.object());
967+
PyItertoolsCombinations::extend_class(ctx, &combinations);
968+
861969
let count = ctx.new_class("count", ctx.object());
862970
PyItertoolsCount::extend_class(ctx, &count);
863971

@@ -887,6 +995,7 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
887995
"accumulate" => accumulate,
888996
"chain" => chain,
889997
"compress" => compress,
998+
"combinations" => combinations,
890999
"count" => count,
8911000
"dropwhile" => dropwhile,
8921001
"islice" => islice,

0 commit comments

Comments
 (0)