@@ -5,6 +5,7 @@ use std::ops::{AddAssign, SubAssign};
55use std:: rc:: Rc ;
66
77use num_bigint:: BigInt ;
8+ use num_traits:: sign:: Signed ;
89use num_traits:: ToPrimitive ;
910
1011use crate :: function:: { Args , OptionalArg , PyFuncArgs } ;
@@ -848,6 +849,110 @@ impl PyItertoolsProduct {
848849 }
849850}
850851
852+ #[ pyclass]
853+ #[ derive( Debug ) ]
854+ struct PyItertoolsCombinations {
855+ pool : Vec < PyObjectRef > ,
856+ indices : RefCell < Vec < usize > > ,
857+ r : Cell < usize > ,
858+ exhausted : Cell < bool > ,
859+ }
860+
861+ impl PyValue for PyItertoolsCombinations {
862+ fn class ( vm : & VirtualMachine ) -> PyClassRef {
863+ vm. class ( "itertools" , "combinations" )
864+ }
865+ }
866+
867+ #[ pyimpl]
868+ impl PyItertoolsCombinations {
869+ #[ pyslot( new) ]
870+ fn tp_new (
871+ cls : PyClassRef ,
872+ iterable : PyObjectRef ,
873+ r : PyIntRef ,
874+ vm : & VirtualMachine ,
875+ ) -> PyResult < PyRef < Self > > {
876+ let iter = get_iter ( vm, & iterable) ?;
877+ let pool = get_all ( vm, & iter) ?;
878+
879+ let r = r. as_bigint ( ) ;
880+ if r. is_negative ( ) {
881+ return Err ( vm. new_value_error ( "r must be non-negative" . to_string ( ) ) ) ;
882+ }
883+ let r = r. to_usize ( ) . unwrap ( ) ;
884+
885+ let n = pool. len ( ) ;
886+
887+ PyItertoolsCombinations {
888+ pool,
889+ indices : RefCell :: new ( ( 0 ..r) . collect ( ) ) ,
890+ r : Cell :: new ( r) ,
891+ exhausted : Cell :: new ( r > n) ,
892+ }
893+ . into_ref_with_type ( vm, cls)
894+ }
895+
896+ #[ pymethod( name = "__iter__" ) ]
897+ fn iter ( zelf : PyRef < Self > , _vm : & VirtualMachine ) -> PyRef < Self > {
898+ zelf
899+ }
900+
901+ #[ pymethod( name = "__next__" ) ]
902+ fn next ( & self , vm : & VirtualMachine ) -> PyResult {
903+ // stop signal
904+ if self . exhausted . get ( ) {
905+ return Err ( new_stop_iteration ( vm) ) ;
906+ }
907+
908+ let n = self . pool . len ( ) ;
909+ let r = self . r . get ( ) ;
910+
911+ let res = PyTuple :: from (
912+ self . pool
913+ . iter ( )
914+ . enumerate ( )
915+ . filter ( |( idx, _) | self . indices . borrow ( ) . contains ( & idx) )
916+ . map ( |( _, num) | num. clone ( ) )
917+ . collect :: < Vec < PyObjectRef > > ( ) ,
918+ ) ;
919+
920+ let mut indices = self . indices . borrow_mut ( ) ;
921+ let mut sentinel = false ;
922+
923+ // Scan indices right-to-left until finding one that is not at its maximum (i + n - r).
924+ let mut idx = r - 1 ;
925+ loop {
926+ if indices[ idx] != idx + n - r {
927+ sentinel = true ;
928+ break ;
929+ }
930+
931+ if idx != 0 {
932+ idx -= 1 ;
933+ } else {
934+ break ;
935+ }
936+ }
937+ // If no suitable index is found, then the indices are all at
938+ // their maximum value and we're done.
939+ if !sentinel {
940+ self . exhausted . set ( true ) ;
941+ }
942+
943+ // Increment the current index which we know is not at its
944+ // maximum. Then move back to the right setting each index
945+ // to its lowest possible value (one higher than the index
946+ // to its left -- this maintains the sort order invariant).
947+ indices[ idx] += 1 ;
948+ for j in idx + 1 ..r {
949+ indices[ j] = indices[ j - 1 ] + 1 ;
950+ }
951+
952+ Ok ( res. into_ref ( vm) . into_object ( ) )
953+ }
954+ }
955+
851956pub fn make_module ( vm : & VirtualMachine ) -> PyObjectRef {
852957 let ctx = & vm. ctx ;
853958
@@ -858,6 +963,9 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
858963
859964 let compress = PyItertoolsCompress :: make_class ( ctx) ;
860965
966+ let combinations = ctx. new_class ( "combinations" , ctx. object ( ) ) ;
967+ PyItertoolsCombinations :: extend_class ( ctx, & combinations) ;
968+
861969 let count = ctx. new_class ( "count" , ctx. object ( ) ) ;
862970 PyItertoolsCount :: extend_class ( ctx, & count) ;
863971
@@ -887,6 +995,7 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
887995 "accumulate" => accumulate,
888996 "chain" => chain,
889997 "compress" => compress,
998+ "combinations" => combinations,
890999 "count" => count,
8911000 "dropwhile" => dropwhile,
8921001 "islice" => islice,
0 commit comments