@@ -5,6 +5,9 @@ use std::cell::RefCell;
55use std:: fs:: File ;
66use std:: io:: prelude:: * ;
77use std:: io:: BufReader ;
8+ use std:: io:: Cursor ;
9+ use std:: io:: SeekFrom ;
10+
811use std:: path:: PathBuf ;
912
1013use num_bigint:: ToBigInt ;
@@ -33,7 +36,7 @@ fn compute_c_flag(mode: &str) -> u16 {
3336
3437#[ derive( Debug ) ]
3538struct PyStringIO {
36- data : RefCell < String > ,
39+ data : RefCell < Cursor < Vec < u8 > > > ,
3740}
3841
3942type PyStringIORef = PyRef < PyStringIO > ;
@@ -45,19 +48,68 @@ impl PyValue for PyStringIO {
4548}
4649
4750impl PyStringIORef {
48- fn write ( self , data : objstr:: PyStringRef , _vm : & VirtualMachine ) {
49- let data = data. value . clone ( ) ;
50- self . data . borrow_mut ( ) . push_str ( & data) ;
51+ //write string to underlying vector
52+ fn write ( self , data : objstr:: PyStringRef , vm : & VirtualMachine ) -> PyResult {
53+ let bytes = & data. value . clone ( ) . into_bytes ( ) ;
54+ let length = bytes. len ( ) ;
55+
56+ let mut cursor = self . data . borrow_mut ( ) ;
57+ match cursor. write_all ( bytes) {
58+ Ok ( _) => Ok ( vm. ctx . new_int ( length) ) ,
59+ Err ( _) => Err ( vm. new_type_error ( "Error Writing String" . to_string ( ) ) ) ,
60+ }
5161 }
5262
53- fn getvalue ( self , _vm : & VirtualMachine ) -> String {
54- self . data . borrow ( ) . clone ( )
63+ //return the entire contents of the underlying
64+ fn getvalue ( self , vm : & VirtualMachine ) -> PyResult {
65+ match String :: from_utf8 ( self . data . borrow ( ) . clone ( ) . into_inner ( ) ) {
66+ Ok ( result) => Ok ( vm. ctx . new_str ( result) ) ,
67+ Err ( _) => Err ( vm. new_value_error ( "Error Retrieving Value" . to_string ( ) ) ) ,
68+ }
5569 }
5670
57- fn read ( self , _vm : & VirtualMachine ) -> String {
58- let data = self . data . borrow ( ) . clone ( ) ;
59- self . data . borrow_mut ( ) . clear ( ) ;
60- data
71+ //skip to the jth position
72+ fn seek ( self , offset : PyObjectRef , vm : & VirtualMachine ) -> PyResult {
73+ let position = objint:: get_value ( & offset) . to_u64 ( ) . unwrap ( ) ;
74+ if let Err ( _) = self
75+ . data
76+ . borrow_mut ( )
77+ . seek ( SeekFrom :: Start ( position. clone ( ) ) )
78+ {
79+ return Err ( vm. new_value_error ( "Error Retrieving Value" . to_string ( ) ) ) ;
80+ }
81+
82+ Ok ( vm. ctx . new_int ( position) )
83+ }
84+
85+ //Read k bytes from the object and return.
86+ //If k is undefined || k == -1, then we read all bytes until the end of the file.
87+ //This also increments the stream position by the value of k
88+ fn read ( self , bytes : OptionalArg < Option < PyObjectRef > > , vm : & VirtualMachine ) -> PyResult {
89+ let mut buffer = String :: new ( ) ;
90+
91+ match bytes {
92+ OptionalArg :: Present ( Some ( ref integer) ) => {
93+ let k = objint:: get_value ( integer) . to_u64 ( ) . unwrap ( ) ;
94+ let mut handle = self . data . borrow ( ) . clone ( ) . take ( k) ;
95+
96+ //read bytes into string
97+ if let Err ( _) = handle. read_to_string ( & mut buffer) {
98+ return Err ( vm. new_value_error ( "Error Retrieving Value" . to_string ( ) ) ) ;
99+ }
100+
101+ //the take above consumes the struct value
102+ //we add this back in with the takes into_inner method
103+ self . data . replace ( handle. into_inner ( ) ) ;
104+ }
105+ _ => {
106+ if let Err ( _) = self . data . borrow_mut ( ) . read_to_string ( & mut buffer) {
107+ return Err ( vm. new_value_error ( "Error Retrieving Value" . to_string ( ) ) ) ;
108+ }
109+ }
110+ } ;
111+
112+ Ok ( vm. ctx . new_str ( buffer) )
61113 }
62114}
63115
@@ -72,14 +124,14 @@ fn string_io_new(
72124 } ;
73125
74126 PyStringIO {
75- data : RefCell :: new ( raw_string) ,
127+ data : RefCell :: new ( Cursor :: new ( raw_string. into_bytes ( ) ) ) ,
76128 }
77129 . into_ref_with_type ( vm, cls)
78130}
79131
80- #[ derive( Debug , Default , Clone ) ]
132+ #[ derive( Debug ) ]
81133struct PyBytesIO {
82- data : RefCell < Vec < u8 > > ,
134+ data : RefCell < Cursor < Vec < u8 > > > ,
83135}
84136
85137type PyBytesIORef = PyRef < PyBytesIO > ;
@@ -91,19 +143,65 @@ impl PyValue for PyBytesIO {
91143}
92144
93145impl PyBytesIORef {
94- fn write ( self , data : objbytes:: PyBytesRef , _vm : & VirtualMachine ) {
95- let data = data. get_value ( ) ;
96- self . data . borrow_mut ( ) . extend ( data) ;
146+ //write string to underlying vector
147+ fn write ( self , data : objbytes:: PyBytesRef , vm : & VirtualMachine ) -> PyResult {
148+ let bytes = data. get_value ( ) ;
149+ let length = bytes. len ( ) ;
150+
151+ let mut cursor = self . data . borrow_mut ( ) ;
152+ match cursor. write_all ( bytes) {
153+ Ok ( _) => Ok ( vm. ctx . new_int ( length) ) ,
154+ Err ( _) => Err ( vm. new_type_error ( "Error Writing String" . to_string ( ) ) ) ,
155+ }
97156 }
98157
158+ //return the entire contents of the underlying
99159 fn getvalue ( self , vm : & VirtualMachine ) -> PyResult {
100- Ok ( vm. ctx . new_bytes ( self . data . borrow ( ) . clone ( ) ) )
160+ Ok ( vm. ctx . new_bytes ( self . data . borrow ( ) . clone ( ) . into_inner ( ) ) )
161+ }
162+
163+ //skip to the jth position
164+ fn seek ( self , offset : PyObjectRef , vm : & VirtualMachine ) -> PyResult {
165+ let position = objint:: get_value ( & offset) . to_u64 ( ) . unwrap ( ) ;
166+ if let Err ( _) = self
167+ . data
168+ . borrow_mut ( )
169+ . seek ( SeekFrom :: Start ( position. clone ( ) ) )
170+ {
171+ return Err ( vm. new_value_error ( "Error Retrieving Value" . to_string ( ) ) ) ;
172+ }
173+
174+ Ok ( vm. ctx . new_int ( position) )
101175 }
102176
103- fn read ( self , vm : & VirtualMachine ) -> PyResult {
104- let data = self . data . borrow ( ) . clone ( ) ;
105- self . data . borrow_mut ( ) . clear ( ) ;
106- Ok ( vm. ctx . new_bytes ( data) )
177+ //Read k bytes from the object and return.
178+ //If k is undefined || k == -1, then we read all bytes until the end of the file.
179+ //This also increments the stream position by the value of k
180+ fn read ( self , bytes : OptionalArg < Option < PyObjectRef > > , vm : & VirtualMachine ) -> PyResult {
181+ let mut buffer = Vec :: new ( ) ;
182+
183+ match bytes {
184+ OptionalArg :: Present ( Some ( ref integer) ) => {
185+ let k = objint:: get_value ( integer) . to_u64 ( ) . unwrap ( ) ;
186+ let mut handle = self . data . borrow ( ) . clone ( ) . take ( k) ;
187+
188+ //read bytes into string
189+ if let Err ( _) = handle. read_to_end ( & mut buffer) {
190+ return Err ( vm. new_value_error ( "Error Retrieving Value" . to_string ( ) ) ) ;
191+ }
192+
193+ //the take above consumes the struct value
194+ //we add this back in with the takes into_inner method
195+ self . data . replace ( handle. into_inner ( ) ) ;
196+ }
197+ _ => {
198+ if let Err ( _) = self . data . borrow_mut ( ) . read_to_end ( & mut buffer) {
199+ return Err ( vm. new_value_error ( "Error Retrieving Value" . to_string ( ) ) ) ;
200+ }
201+ }
202+ } ;
203+
204+ Ok ( vm. ctx . new_bytes ( buffer) )
107205 }
108206}
109207
@@ -118,7 +216,7 @@ fn bytes_io_new(
118216 } ;
119217
120218 PyBytesIO {
121- data : RefCell :: new ( raw_bytes) ,
219+ data : RefCell :: new ( Cursor :: new ( raw_bytes) ) ,
122220 }
123221 . into_ref_with_type ( vm, cls)
124222}
@@ -514,6 +612,7 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
514612 //StringIO: in-memory text
515613 let string_io = py_class ! ( ctx, "StringIO" , text_io_base. clone( ) , {
516614 "__new__" => ctx. new_rustfunc( string_io_new) ,
615+ "seek" => ctx. new_rustfunc( PyStringIORef :: seek) ,
517616 "read" => ctx. new_rustfunc( PyStringIORef :: read) ,
518617 "write" => ctx. new_rustfunc( PyStringIORef :: write) ,
519618 "getvalue" => ctx. new_rustfunc( PyStringIORef :: getvalue)
@@ -523,6 +622,8 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
523622 let bytes_io = py_class ! ( ctx, "BytesIO" , buffered_io_base. clone( ) , {
524623 "__new__" => ctx. new_rustfunc( bytes_io_new) ,
525624 "read" => ctx. new_rustfunc( PyBytesIORef :: read) ,
625+ "read1" => ctx. new_rustfunc( PyBytesIORef :: read) ,
626+ "seek" => ctx. new_rustfunc( PyBytesIORef :: seek) ,
526627 "write" => ctx. new_rustfunc( PyBytesIORef :: write) ,
527628 "getvalue" => ctx. new_rustfunc( PyBytesIORef :: getvalue)
528629 } ) ;
@@ -627,4 +728,5 @@ mod tests {
627728 Err ( "invalid mode: 'a++'" . to_string( ) )
628729 ) ;
629730 }
731+
630732}
0 commit comments