Skip to content

Commit 1322c8d

Browse files
committed
Change socket parameters to match CPython
1 parent c84d351 commit 1322c8d

File tree

1 file changed

+23
-7
lines changed

1 file changed

+23
-7
lines changed

vm/src/stdlib/socket.rs

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use crate::obj::objbytes;
22
use crate::obj::objint;
3+
use crate::obj::objsequence::get_elements;
34
use crate::obj::objstr;
4-
55
use crate::pyobject::{
66
PyContext, PyFuncArgs, PyObject, PyObjectPayload, PyObjectRef, PyResult, TypeProtocol,
77
};
@@ -125,14 +125,20 @@ fn socket_connect(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
125125
arg_check!(
126126
vm,
127127
args,
128-
required = [(zelf, None), (address, Some(vm.ctx.str_type()))]
128+
required = [(zelf, None), (address, Some(vm.ctx.tuple_type()))]
129129
);
130130

131+
let elements = get_elements(address);
132+
let host = objstr::get_value(&elements[0]);
133+
let port = objint::get_value(&elements[1]);
134+
135+
let address_string = format!("{}:{}", host, port.to_string());
136+
131137
let mut mut_obj = zelf.borrow_mut();
132138

133139
match mut_obj.payload {
134140
PyObjectPayload::Socket { ref mut socket } => {
135-
if let Ok(stream) = TcpStream::connect(objstr::get_value(&address)) {
141+
if let Ok(stream) = TcpStream::connect(address_string) {
136142
socket.con = Some(Connection::TcpStream(stream));
137143
Ok(vm.get_none())
138144
} else {
@@ -148,14 +154,20 @@ fn socket_bind(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
148154
arg_check!(
149155
vm,
150156
args,
151-
required = [(zelf, None), (address, Some(vm.ctx.str_type()))]
157+
required = [(zelf, None), (address, Some(vm.ctx.tuple_type()))]
152158
);
153159

160+
let elements = get_elements(address);
161+
let host = objstr::get_value(&elements[0]);
162+
let port = objint::get_value(&elements[1]);
163+
164+
let address_string = format!("{}:{}", host, port.to_string());
165+
154166
let mut mut_obj = zelf.borrow_mut();
155167

156168
match mut_obj.payload {
157169
PyObjectPayload::Socket { ref mut socket } => {
158-
if let Ok(stream) = TcpListener::bind(objstr::get_value(&address)) {
170+
if let Ok(stream) = TcpListener::bind(address_string) {
159171
socket.con = Some(Connection::TcpListener(stream));
160172
Ok(vm.get_none())
161173
} else {
@@ -194,9 +206,13 @@ fn socket_accept(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
194206
con: Some(Connection::TcpStream(tcp_stream)),
195207
};
196208

209+
let sock_obj = PyObject::new(PyObjectPayload::Socket { socket }, mut_obj.typ());
210+
211+
let elements = vec![sock_obj, vm.get_none()];
212+
197213
Ok(PyObject::new(
198-
PyObjectPayload::Socket { socket },
199-
mut_obj.typ(),
214+
PyObjectPayload::Sequence { elements },
215+
vm.ctx.tuple_type(),
200216
))
201217
}
202218
_ => Err(vm.new_type_error("".to_string())),

0 commit comments

Comments
 (0)