Skip to content

Commit 08614bf

Browse files
authored
Merge pull request RustPython#2830 from DimitrisJim/bytearray_init_alloc
Add __alloc__ to bytearray and fix its __init__ issues.
2 parents a4fd014 + 04c9ae0 commit 08614bf

File tree

3 files changed

+39
-27
lines changed

3 files changed

+39
-27
lines changed

Lib/test/test_bytes.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1449,8 +1449,11 @@ def test_irepeat_1char(self):
14491449
self.assertEqual(b, b1)
14501450
self.assertIs(b, b1)
14511451

1452-
# TODO: RUSTPYTHON
1453-
@unittest.expectedFailure
1452+
# NOTE: RUSTPYTHON:
1453+
#
1454+
# The second instance of self.assertGreater was replaced with
1455+
# self.assertGreaterEqual since, in RustPython, the underlying storage
1456+
# is a Vec which doesn't require trailing null byte.
14541457
def test_alloc(self):
14551458
b = bytearray()
14561459
alloc = b.__alloc__()
@@ -1459,12 +1462,15 @@ def test_alloc(self):
14591462
for i in range(100):
14601463
b += b"x"
14611464
alloc = b.__alloc__()
1462-
self.assertGreater(alloc, len(b)) # including trailing null byte
1465+
self.assertGreaterEqual(alloc, len(b)) # NOTE: RUSTPYTHON patched
14631466
if alloc not in seq:
14641467
seq.append(alloc)
14651468

1466-
# TODO: RUSTPYTHON
1467-
@unittest.expectedFailure
1469+
# NOTE: RUSTPYTHON:
1470+
#
1471+
# The usages of self.assertGreater were replaced with
1472+
# self.assertGreaterEqual since, in RustPython, the underlying storage
1473+
# is a Vec which doesn't require trailing null byte.
14681474
def test_init_alloc(self):
14691475
b = bytearray()
14701476
def g():
@@ -1475,12 +1481,12 @@ def g():
14751481
self.assertEqual(len(b), len(a))
14761482
self.assertLessEqual(len(b), i)
14771483
alloc = b.__alloc__()
1478-
self.assertGreater(alloc, len(b)) # including trailing null byte
1484+
self.assertGreaterEqual(alloc, len(b)) # NOTE: RUSTPYTHON patched
14791485
b.__init__(g())
14801486
self.assertEqual(list(b), list(range(1, 100)))
14811487
self.assertEqual(len(b), 99)
14821488
alloc = b.__alloc__()
1483-
self.assertGreater(alloc, len(b))
1489+
self.assertGreaterEqual(alloc, len(b)) # NOTE: RUSTPYTHON patched
14841490

14851491
def test_extend(self):
14861492
orig = b'hello'
@@ -1999,8 +2005,6 @@ class ByteArraySubclassTest(SubclassTest, unittest.TestCase):
19992005
basetype = bytearray
20002006
type2test = ByteArraySubclass
20012007

2002-
# TODO: RUSTPYTHON
2003-
@unittest.expectedFailure
20042008
def test_init_override(self):
20052009
class subclass(bytearray):
20062010
def __init__(me, newarg=1, *args, **kwargs):

vm/src/builtins/bytearray.rs

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ use crate::common::lock::{
1818
PyMappedRwLockReadGuard, PyMappedRwLockWriteGuard, PyRwLock, PyRwLockReadGuard,
1919
PyRwLockWriteGuard,
2020
};
21-
use crate::function::{OptionalArg, OptionalOption};
21+
use crate::function::{FuncArgs, OptionalArg, OptionalOption};
2222
use crate::sliceable::{PySliceableSequence, PySliceableSequenceMut, SequenceIndex};
2323
use crate::slots::{
2424
BufferProtocol, Comparable, Hashable, Iterable, PyComparisonOp, PyIter, Unhashable,
@@ -45,7 +45,7 @@ use std::mem::size_of;
4545
/// - any object implementing the buffer API.\n \
4646
/// - an integer";
4747
#[pyclass(module = false, name = "bytearray")]
48-
#[derive(Debug)]
48+
#[derive(Debug, Default)]
4949
pub struct PyByteArray {
5050
inner: PyRwLock<PyBytesInner>,
5151
exports: AtomicCell<usize>,
@@ -102,12 +102,16 @@ pub(crate) fn init(context: &PyContext) {
102102
#[pyimpl(flags(BASETYPE), with(Hashable, Comparable, BufferProtocol, Iterable))]
103103
impl PyByteArray {
104104
#[pyslot]
105-
fn tp_new(
106-
cls: PyTypeRef,
107-
options: ByteInnerNewOptions,
108-
vm: &VirtualMachine,
109-
) -> PyResult<PyRef<Self>> {
110-
options.get_bytearray(cls, vm)
105+
fn tp_new(cls: PyTypeRef, _args: FuncArgs, vm: &VirtualMachine) -> PyResult<PyRef<Self>> {
106+
PyByteArray::default().into_ref_with_type(vm, cls)
107+
}
108+
109+
#[pymethod(magic)]
110+
fn init(&self, options: ByteInnerNewOptions, vm: &VirtualMachine) -> PyResult<()> {
111+
// First unpack bytearray and *then* get a lock to set it.
112+
let mut inner = options.get_bytearray_inner(vm)?;
113+
std::mem::swap(&mut *self.inner_mut(), &mut inner);
114+
Ok(())
111115
}
112116

113117
#[inline]
@@ -124,6 +128,11 @@ impl PyByteArray {
124128
self.inner().repr("bytearray(", ")")
125129
}
126130

131+
#[pymethod(magic)]
132+
fn alloc(&self) -> usize {
133+
self.inner().capacity()
134+
}
135+
127136
#[pymethod(name = "__len__")]
128137
fn len(&self) -> usize {
129138
self.borrow_buf().len()

vm/src/bytesinner.rs

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use num_bigint::BigInt;
44
use num_traits::ToPrimitive;
55

66
use crate::anystr::{self, AnyStr, AnyStrContainer, AnyStrWrapper};
7-
use crate::builtins::bytearray::{PyByteArray, PyByteArrayRef};
7+
use crate::builtins::bytearray::PyByteArray;
88
use crate::builtins::bytes::{PyBytes, PyBytesRef};
99
use crate::builtins::int::{PyInt, PyIntRef};
1010
use crate::builtins::pystr::{self, PyStr, PyStrRef};
@@ -120,12 +120,8 @@ impl ByteInnerNewOptions {
120120
PyBytes::from(inner).into_ref_with_type(vm, cls)
121121
}
122122

123-
pub fn get_bytearray(
124-
mut self,
125-
cls: PyTypeRef,
126-
vm: &VirtualMachine,
127-
) -> PyResult<PyByteArrayRef> {
128-
let inner = if let OptionalArg::Present(source) = self.source.take() {
123+
pub fn get_bytearray_inner(mut self, vm: &VirtualMachine) -> PyResult<PyBytesInner> {
124+
if let OptionalArg::Present(source) = self.source.take() {
129125
match_class!(match source {
130126
s @ PyStr => Self::get_value_from_string(s, self.encoding, self.errors, vm),
131127
i @ PyInt => {
@@ -139,9 +135,7 @@ impl ByteInnerNewOptions {
139135
})
140136
} else {
141137
self.check_args(vm).map(|_| vec![].into())
142-
}?;
143-
144-
PyByteArray::from(inner).into_ref_with_type(vm, cls)
138+
}
145139
}
146140
}
147141

@@ -293,6 +287,11 @@ impl PyBytesInner {
293287
self.elements.len()
294288
}
295289

290+
#[inline]
291+
pub fn capacity(&self) -> usize {
292+
self.elements.capacity()
293+
}
294+
296295
#[inline]
297296
pub fn is_empty(&self) -> bool {
298297
self.elements.is_empty()

0 commit comments

Comments
 (0)