Skip to content

Commit b86e803

Browse files
author
Oscar Shrimpton
committed
Implement .indices(len) of slice (Fixes RustPython#1431)
range.__getitem now also uses slice.indices() internally. CPython: https://github.com/python/cpython/blob/master/Objects/sliceobject.c#L373
1 parent 40bbb6b commit b86e803

File tree

2 files changed

+113
-72
lines changed

2 files changed

+113
-72
lines changed

vm/src/obj/objrange.rs

Lines changed: 11 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -339,77 +339,19 @@ impl PyRange {
339339
fn getitem(&self, subscript: RangeIndex, vm: &VirtualMachine) -> PyResult {
340340
match subscript {
341341
RangeIndex::Slice(slice) => {
342-
let range_start = self.start.as_bigint();
343-
let range_step = self.step.as_bigint();
344-
let range_length = &self.length();
345-
346-
let substep = if let Some(slice_step) = slice.step_index(vm)? {
347-
if slice_step.is_zero() {
348-
return Err(vm.new_value_error("slice step cannot be zero".to_string()));
349-
}
350-
slice_step
351-
} else {
352-
BigInt::one()
353-
};
354-
355-
let negative_step = substep.is_negative();
356-
let lower_bound = if negative_step {
357-
-BigInt::one()
358-
} else {
359-
BigInt::zero()
360-
};
361-
let upper_bound = if negative_step {
362-
&lower_bound + range_length
363-
} else {
364-
range_length.clone()
365-
};
366-
367-
let substart = if let Some(slice_start) = slice.start_index(vm)? {
368-
if slice_start.is_negative() {
369-
let tmp = slice_start + range_length;
370-
if tmp < lower_bound {
371-
lower_bound.clone()
372-
} else {
373-
tmp.clone()
374-
}
375-
} else if slice_start > upper_bound {
376-
upper_bound.clone()
377-
} else {
378-
slice_start.clone()
379-
}
380-
} else if negative_step {
381-
upper_bound.clone()
382-
} else {
383-
lower_bound.clone()
384-
};
385-
386-
let substop = if let Some(slice_stop) = slice.stop_index(vm)? {
387-
if slice_stop.is_negative() {
388-
let tmp = slice_stop + range_length;
389-
if tmp < lower_bound {
390-
lower_bound.clone()
391-
} else {
392-
tmp.clone()
393-
}
394-
} else if slice_stop > upper_bound {
395-
upper_bound.clone()
396-
} else {
397-
slice_stop.clone()
398-
}
399-
} else if negative_step {
400-
lower_bound.clone()
401-
} else {
402-
upper_bound.clone()
403-
};
404-
405-
let step = range_step * &substep;
406-
let start = range_start + (&substart * range_step);
407-
let stop = range_start + (&substop * range_step);
342+
let (mut substart, mut substop, mut substep) =
343+
slice.inner_indices(&self.length(), vm)?;
344+
let range_step = self.step(vm);
345+
let range_start = self.start(vm);
346+
347+
substep *= range_step.as_bigint();
348+
substart = (substart * range_step.as_bigint()) + range_start.as_bigint();
349+
substop = (substop * range_step.as_bigint()) + range_start.as_bigint();
408350

409351
Ok(PyRange {
410-
start: PyInt::new(start).into_ref(vm),
411-
stop: PyInt::new(stop).into_ref(vm),
412-
step: PyInt::new(step).into_ref(vm),
352+
start: PyInt::new(substart).into_ref(vm),
353+
stop: PyInt::new(substop).into_ref(vm),
354+
step: PyInt::new(substep).into_ref(vm),
413355
}
414356
.into_ref(vm)
415357
.into_object())

vm/src/obj/objslice.rs

Lines changed: 102 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
1-
use num_bigint::BigInt;
2-
31
use super::objint::PyInt;
42
use super::objtype::{class_has_attr, PyClassRef};
53
use crate::function::{OptionalArg, PyFuncArgs};
64
use crate::pyobject::{
7-
IdProtocol, PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, TypeProtocol,
5+
IdProtocol, PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, TryIntoRef,
6+
TypeProtocol,
87
};
98
use crate::vm::VirtualMachine;
9+
use num_bigint::{BigInt, ToBigInt};
10+
use num_traits::{One, Signed, Zero};
1011

1112
#[pyclass]
1213
#[derive(Debug)]
@@ -157,6 +158,92 @@ impl PySlice {
157158
Ok(eq)
158159
}
159160

161+
pub(crate) fn inner_indices(
162+
&self,
163+
length: &BigInt,
164+
vm: &VirtualMachine,
165+
) -> PyResult<(BigInt, BigInt, BigInt)> {
166+
// Calculate step
167+
let step: BigInt;
168+
if vm.is_none(&self.step(vm)) {
169+
step = One::one();
170+
} else {
171+
// Clone the value, not the reference.
172+
let this_step: PyRef<PyInt> = self.step(vm).try_into_ref(vm)?;
173+
step = this_step.as_bigint().clone();
174+
175+
if step.is_zero() {
176+
return Err(vm.new_value_error("slice step cannot be zero.".to_owned()));
177+
}
178+
}
179+
180+
// For convenience
181+
let backwards = step.is_negative();
182+
183+
// Each end of the array
184+
let lower = if backwards {
185+
-1_i8.to_bigint().unwrap()
186+
} else {
187+
Zero::zero()
188+
};
189+
190+
let upper = if backwards {
191+
lower.clone() + length
192+
} else {
193+
length.clone()
194+
};
195+
196+
// Calculate start
197+
let mut start: BigInt;
198+
if vm.is_none(&self.start(vm)) {
199+
// Default
200+
start = if backwards {
201+
upper.clone()
202+
} else {
203+
lower.clone()
204+
};
205+
} else {
206+
let this_start: PyRef<PyInt> = self.start(vm).try_into_ref(vm)?;
207+
start = this_start.as_bigint().clone();
208+
209+
if start < Zero::zero() {
210+
// From end of array
211+
start += length;
212+
213+
if start < lower {
214+
start = lower.clone();
215+
}
216+
} else if start > upper {
217+
start = upper.clone();
218+
}
219+
}
220+
221+
// Calculate Stop
222+
let mut stop: BigInt;
223+
if vm.is_none(&self.stop(vm)) {
224+
stop = if backwards {
225+
lower.clone()
226+
} else {
227+
upper.clone()
228+
};
229+
} else {
230+
let this_stop: PyRef<PyInt> = self.stop(vm).try_into_ref(vm)?;
231+
stop = this_stop.as_bigint().clone();
232+
233+
if stop < Zero::zero() {
234+
// From end of array
235+
stop += length;
236+
if stop < lower {
237+
stop = lower.clone();
238+
}
239+
} else if stop > upper {
240+
stop = upper.clone();
241+
}
242+
}
243+
244+
Ok((start, stop, step))
245+
}
246+
160247
#[pymethod(name = "__eq__")]
161248
fn eq(&self, rhs: PyObjectRef, vm: &VirtualMachine) -> PyResult {
162249
if let Some(rhs) = rhs.payload::<PySlice>() {
@@ -221,6 +308,18 @@ impl PySlice {
221308
fn hash(&self, vm: &VirtualMachine) -> PyResult<()> {
222309
Err(vm.new_type_error("unhashable type".to_string()))
223310
}
311+
312+
#[pymethod(name = "indices")]
313+
fn indices(&self, length: PyObjectRef, vm: &VirtualMachine) -> PyResult {
314+
if let Some(length) = length.payload::<PyInt>() {
315+
let (start, stop, step) = self.inner_indices(length.as_bigint(), vm)?;
316+
Ok(vm
317+
.ctx
318+
.new_tuple(vec![vm.new_int(start), vm.new_int(stop), vm.new_int(step)]))
319+
} else {
320+
Ok(vm.ctx.not_implemented())
321+
}
322+
}
224323
}
225324

226325
fn to_index_value(vm: &VirtualMachine, obj: &PyObjectRef) -> PyResult<Option<BigInt>> {

0 commit comments

Comments
 (0)