Skip to content

Commit 225e742

Browse files
authored
Merge pull request RustPython#3092 from youknowone/str-safe
Fix PyStr operations to be safe
2 parents 8962376 + 795738d commit 225e742

File tree

18 files changed

+250
-200
lines changed

18 files changed

+250
-200
lines changed

Cargo.lock

Lines changed: 7 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

benches/microbenchmarks.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,11 @@ fn bench_rustpy_code(group: &mut BenchmarkGroup<WallTime>, bench: &MicroBenchmar
132132
if let Some(idx) = iterations {
133133
scope
134134
.locals
135-
.set_item(vm.ctx.new_ascii_str(b"ITERATIONS"), vm.ctx.new_int(idx), vm)
135+
.set_item(
136+
vm.ctx.new_ascii_literal(crate::utils::ascii!("ITERATIONS")),
137+
vm.ctx.new_int(idx),
138+
vm,
139+
)
136140
.expect("Error adding ITERATIONS local variable");
137141
}
138142
let setup_result = vm.run_code_obj(setup_code.clone(), scope.clone());

vm/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ flame = { version = "0.2", optional = true }
113113
flamer = { version = "0.4", optional = true }
114114

115115
rustpython-common = { path = "../common" }
116+
ascii = "1.0.0"
116117

117118
[target.'cfg(unix)'.dependencies]
118119
exitcode = "1.1.2"

vm/src/builtins/dict.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,9 @@ impl PyDict {
364364
if let Some((key, value)) = self.entries.pop_back() {
365365
Ok(vm.ctx.new_tuple(vec![key, value]))
366366
} else {
367-
let err_msg = vm.ctx.new_ascii_str(b"popitem(): dictionary is empty");
367+
let err_msg = vm
368+
.ctx
369+
.new_ascii_literal(crate::utils::ascii!("popitem(): dictionary is empty"));
368370
Err(vm.new_key_error(err_msg))
369371
}
370372
}

vm/src/builtins/pystr.rs

Lines changed: 43 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -113,14 +113,6 @@ where
113113
s.as_ref().to_owned().into()
114114
}
115115
}
116-
impl<T> From<(&T, PyStrKind)> for PyStr
117-
where
118-
T: AsRef<[u8]> + ?Sized,
119-
{
120-
fn from((s, k): (&T, PyStrKind)) -> PyStr {
121-
(s.as_ref().to_owned().into_boxed_slice(), k).into()
122-
}
123-
}
124116

125117
impl From<String> for PyStr {
126118
fn from(s: String) -> PyStr {
@@ -134,10 +126,7 @@ impl From<Box<str>> for PyStr {
134126
// doing the check is ~10x faster for ascii, and is actually only 2% slower worst case for
135127
// non-ascii; see https://github.com/RustPython/RustPython/pull/2586#issuecomment-844611532
136128
let is_ascii = value.is_ascii();
137-
let bytes = unsafe {
138-
// SAFETY: Box<str> and Box<[u8]> have same layout
139-
Box::from_raw(Box::into_raw(value) as _)
140-
};
129+
let bytes = value.into_boxed_bytes();
141130
let kind = if is_ascii {
142131
PyStrKind::Ascii
143132
} else {
@@ -152,24 +141,6 @@ impl From<Box<str>> for PyStr {
152141
}
153142
}
154143

155-
impl From<(Vec<u8>, PyStrKind)> for PyStr {
156-
fn from((s, kind): (Vec<u8>, PyStrKind)) -> PyStr {
157-
(s.into_boxed_slice(), kind).into()
158-
}
159-
}
160-
161-
impl From<(Box<[u8]>, PyStrKind)> for PyStr {
162-
fn from((bytes, kind): (Box<[u8]>, PyStrKind)) -> PyStr {
163-
let s = Self {
164-
bytes,
165-
kind: kind.new_data(),
166-
hash: Radium::new(hash::SENTINEL),
167-
};
168-
debug_assert!(matches!(s.kind, PyStrKindData::Ascii) || !s.as_str().is_ascii());
169-
s
170-
}
171-
}
172-
173144
pub type PyStrRef = PyRef<PyStr>;
174145

175146
impl fmt::Display for PyStr {
@@ -241,7 +212,11 @@ impl PyStrIterator {
241212
fn reduce(&self, vm: &VirtualMachine) -> PyResult {
242213
let iter = vm.get_attribute(vm.builtins.clone(), "iter")?;
243214
Ok(vm.ctx.new_tuple(match self.status.load() {
244-
Exhausted => vec![iter, vm.ctx.new_tuple(vec![vm.ctx.new_ascii_str(b"")])],
215+
Exhausted => vec![
216+
iter,
217+
vm.ctx
218+
.new_tuple(vec![vm.ctx.new_ascii_literal(crate::utils::ascii!(""))]),
219+
],
245220
Active => vec![
246221
iter,
247222
vm.ctx.new_tuple(vec![self.string.clone().into_object()]),
@@ -322,13 +297,20 @@ impl SlotConstructor for PyStr {
322297
}
323298

324299
impl PyStr {
325-
/// SAFETY: Given 's' must be valid data for given 'kind'
326-
unsafe fn new_str_unchecked(s: String, kind: PyStrKind) -> Self {
327-
Self {
328-
bytes: Box::from_raw(Box::into_raw(s.into_boxed_str()) as _),
300+
/// SAFETY: Given 'bytes' must be valid data for given 'kind'
301+
pub(crate) unsafe fn new_str_unchecked(bytes: Vec<u8>, kind: PyStrKind) -> Self {
302+
let s = Self {
303+
bytes: bytes.into_boxed_slice(),
329304
kind: kind.new_data(),
330305
hash: Radium::new(hash::SENTINEL),
331-
}
306+
};
307+
debug_assert!(matches!(s.kind, PyStrKindData::Ascii) || !s.as_str().is_ascii());
308+
s
309+
}
310+
311+
/// SAFETY: Given 'bytes' must be ascii
312+
unsafe fn new_ascii_unchecked(bytes: Vec<u8>) -> Self {
313+
Self::new_str_unchecked(bytes, PyStrKind::Ascii)
332314
}
333315

334316
fn new_substr(&self, s: String) -> Self {
@@ -338,8 +320,8 @@ impl PyStr {
338320
PyStrKind::Utf8
339321
};
340322
unsafe {
341-
// SAFETY: kind is safely calculated for substring
342-
Self::new_str_unchecked(s, kind)
323+
// SAFETY: kind is properly decided for substring
324+
Self::new_str_unchecked(s.into_bytes(), kind)
343325
}
344326
}
345327

@@ -367,11 +349,11 @@ impl PyStr {
367349
#[pymethod(magic)]
368350
fn add(zelf: PyRef<Self>, other: PyObjectRef, vm: &VirtualMachine) -> PyResult {
369351
if let Some(other) = other.payload::<PyStr>() {
370-
let kind = zelf.kind.kind() | other.kind.kind();
371352
let bytes = zelf.as_str().py_add(other.as_ref());
372353
Ok(unsafe {
373-
// SAFETY: safe kind operation
374-
Self::new_str_unchecked(bytes, kind)
354+
// SAFETY: `kind` is safely decided
355+
let kind = zelf.kind.kind() | other.kind.kind();
356+
Self::new_str_unchecked(bytes.into_bytes(), kind)
375357
}
376358
.into_pyobject(vm))
377359
} else if let Some(radd) = vm.get_method(other.clone(), "__radd__") {
@@ -625,18 +607,23 @@ impl PyStr {
625607
|v, s, vm| {
626608
v.as_bytes()
627609
.split_str(s)
628-
.map(|s| vm.ctx.new_ascii_str(s))
610+
.map(|s| {
611+
unsafe { PyStr::new_ascii_unchecked(s.to_owned()) }.into_pyobject(vm)
612+
})
629613
.collect()
630614
},
631615
|v, s, n, vm| {
632616
v.as_bytes()
633617
.splitn_str(n, s)
634-
.map(|s| vm.ctx.new_ascii_str(s))
618+
.map(|s| {
619+
unsafe { PyStr::new_ascii_unchecked(s.to_owned()) }.into_pyobject(vm)
620+
})
635621
.collect()
636622
},
637623
|v, n, vm| {
638-
v.as_bytes()
639-
.py_split_whitespace(n, |s| vm.ctx.new_ascii_str(s))
624+
v.as_bytes().py_split_whitespace(n, |s| {
625+
unsafe { PyStr::new_ascii_unchecked(s.to_owned()) }.into_pyobject(vm)
626+
})
640627
},
641628
),
642629
PyStrKind::Utf8 => self.as_str().py_split(
@@ -993,7 +980,7 @@ impl PyStr {
993980
if has_mid {
994981
sep.into_object()
995982
} else {
996-
vm.ctx.new_ascii_str(b"")
983+
vm.ctx.new_ascii_literal(crate::utils::ascii!(""))
997984
},
998985
self.new_substr(back),
999986
)
@@ -1012,7 +999,7 @@ impl PyStr {
1012999
if has_mid {
10131000
sep.into_object()
10141001
} else {
1015-
vm.ctx.new_ascii_str(b"")
1002+
vm.ctx.new_ascii_literal(crate::utils::ascii!(""))
10161003
},
10171004
self.new_substr(back),
10181005
)
@@ -1410,8 +1397,10 @@ impl PySliceableSequence for PyStr {
14101397
// this is an ascii string
14111398
let mut v = self.bytes[range].to_vec();
14121399
v.reverse();
1413-
// TODO: from_utf8_unchecked?
1414-
String::from_utf8(v).unwrap()
1400+
unsafe {
1401+
// SAFETY: an ascii string is always utf8
1402+
String::from_utf8_unchecked(v)
1403+
}
14151404
} else {
14161405
let mut s = String::with_capacity(self.bytes.len());
14171406
s.extend(
@@ -1556,7 +1545,11 @@ mod tests {
15561545
table.set_item("a", vm.ctx.new_utf8_str("🎅"), &vm).unwrap();
15571546
table.set_item("b", vm.ctx.none(), &vm).unwrap();
15581547
table
1559-
.set_item("c", vm.ctx.new_ascii_str(b"xda"), &vm)
1548+
.set_item(
1549+
"c",
1550+
vm.ctx.new_ascii_literal(crate::utils::ascii!("xda")),
1551+
&vm,
1552+
)
15601553
.unwrap();
15611554
let translated = PyStr::maketrans(
15621555
table.into_object(),

vm/src/builtins/pytype.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ impl PyType {
344344
Some(found)
345345
}
346346
})
347-
.unwrap_or_else(|| vm.ctx.new_ascii_str(b"builtins"))
347+
.unwrap_or_else(|| vm.ctx.new_ascii_literal(crate::utils::ascii!("builtins")))
348348
}
349349

350350
#[pyproperty(magic, setter)]

vm/src/builtins/set.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,9 @@ impl PySetInner {
236236
if let Some((key, _)) = self.content.pop_back() {
237237
Ok(key)
238238
} else {
239-
let err_msg = vm.ctx.new_ascii_str(b"pop from an empty set");
239+
let err_msg = vm
240+
.ctx
241+
.new_ascii_literal(crate::utils::ascii!("pop from an empty set"));
240242
Err(vm.new_key_error(err_msg))
241243
}
242244
}

vm/src/codecs.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,10 @@ fn strict_errors(err: PyObjectRef, vm: &VirtualMachine) -> PyResult {
367367
fn ignore_errors(err: PyObjectRef, vm: &VirtualMachine) -> PyResult<(PyObjectRef, usize)> {
368368
if is_encode_ish_err(&err, vm) || is_decode_err(&err, vm) {
369369
let range = extract_unicode_error_range(&err, vm)?;
370-
Ok((vm.ctx.new_ascii_str(b""), range.end))
370+
Ok((
371+
vm.ctx.new_ascii_literal(crate::utils::ascii!("")),
372+
range.end,
373+
))
371374
} else {
372375
Err(bad_err_type(err, vm))
373376
}

vm/src/dictdatatype.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -775,12 +775,12 @@ mod tests {
775775
assert_eq!(0, dict.len());
776776

777777
let key1 = vm.ctx.new_bool(true);
778-
let value1 = vm.ctx.new_ascii_str(b"abc");
778+
let value1 = vm.ctx.new_ascii_literal(crate::utils::ascii!("abc"));
779779
dict.insert(&vm, key1.clone(), value1.clone()).unwrap();
780780
assert_eq!(1, dict.len());
781781

782-
let key2 = vm.ctx.new_ascii_str(b"x");
783-
let value2 = vm.ctx.new_ascii_str(b"def");
782+
let key2 = vm.ctx.new_ascii_literal(crate::utils::ascii!("x"));
783+
let value2 = vm.ctx.new_ascii_literal(crate::utils::ascii!("def"));
784784
dict.insert(&vm, key2.clone(), value2.clone()).unwrap();
785785
assert_eq!(2, dict.len());
786786

vm/src/py_io.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,9 @@ pub fn file_readline(obj: &PyObjectRef, size: Option<usize>, vm: &VirtualMachine
6565
let eof_err = || {
6666
vm.new_exception(
6767
vm.ctx.exceptions.eof_error.clone(),
68-
vec![vm.ctx.new_ascii_str(b"EOF when reading a line")],
68+
vec![vm
69+
.ctx
70+
.new_ascii_literal(crate::utils::ascii!("EOF when reading a line"))],
6971
)
7072
};
7173
let ret = match_class!(match ret {

0 commit comments

Comments
 (0)