11use crate :: function:: OptionalArg ;
2- use crate :: obj:: objbytes:: PyBytesRef ;
3- use crate :: pyobject:: { PyObjectRef , PyResult } ;
2+ use crate :: obj:: objbytearray:: { PyByteArray , PyByteArrayRef } ;
3+ use crate :: obj:: objbyteinner:: PyBytesLike ;
4+ use crate :: obj:: objbytes:: { PyBytes , PyBytesRef } ;
5+ use crate :: obj:: objstr:: { PyString , PyStringRef } ;
6+ use crate :: pyobject:: { PyObjectRef , PyResult , TryFromObject , TypeProtocol } ;
47use crate :: vm:: VirtualMachine ;
8+
59use crc:: { crc32, Hasher32 } ;
10+ use itertools:: Itertools ;
11+
12+ enum SerializedData {
13+ Bytes ( PyBytesRef ) ,
14+ Buffer ( PyByteArrayRef ) ,
15+ Ascii ( PyStringRef ) ,
16+ }
17+
18+ impl TryFromObject for SerializedData {
19+ fn try_from_object ( vm : & VirtualMachine , obj : PyObjectRef ) -> PyResult < Self > {
20+ match_class ! ( match obj {
21+ b @ PyBytes => Ok ( SerializedData :: Bytes ( b) ) ,
22+ b @ PyByteArray => Ok ( SerializedData :: Buffer ( b) ) ,
23+ a @ PyString => {
24+ if a. as_str( ) . is_ascii( ) {
25+ Ok ( SerializedData :: Ascii ( a) )
26+ } else {
27+ Err ( vm. new_value_error(
28+ "string argument should contain only ASCII characters" . to_string( ) ,
29+ ) )
30+ }
31+ }
32+ obj => Err ( vm. new_type_error( format!(
33+ "argument should be bytes, buffer or ASCII string, not '{}'" ,
34+ obj. class( ) . name,
35+ ) ) ) ,
36+ } )
37+ }
38+ }
39+
40+ impl SerializedData {
41+ #[ inline]
42+ pub fn with_ref < R > ( & self , f : impl FnOnce ( & [ u8 ] ) -> R ) -> R {
43+ match self {
44+ SerializedData :: Bytes ( b) => f ( b. get_value ( ) ) ,
45+ SerializedData :: Buffer ( b) => f ( & b. inner . borrow ( ) . elements ) ,
46+ SerializedData :: Ascii ( a) => f ( a. as_str ( ) . as_bytes ( ) ) ,
47+ }
48+ }
49+ }
650
751fn hex_nibble ( n : u8 ) -> u8 {
852 match n {
@@ -12,15 +56,15 @@ fn hex_nibble(n: u8) -> u8 {
1256 }
1357}
1458
15- fn binascii_hexlify ( data : PyBytesRef , vm : & VirtualMachine ) -> PyResult {
16- let bytes = data. get_value ( ) ;
17- let mut hex = Vec :: < u8 > :: with_capacity ( bytes. len ( ) * 2 ) ;
18- for b in bytes. iter ( ) {
19- hex. push ( hex_nibble ( b >> 4 ) ) ;
20- hex. push ( hex_nibble ( b & 0xf ) ) ;
21- }
22-
23- Ok ( vm . ctx . new_bytes ( hex ) )
59+ fn binascii_hexlify ( data : PyBytesLike , _vm : & VirtualMachine ) -> Vec < u8 > {
60+ data. with_ref ( |bytes| {
61+ let mut hex = Vec :: < u8 > :: with_capacity ( bytes. len ( ) * 2 ) ;
62+ for b in bytes. iter ( ) {
63+ hex. push ( hex_nibble ( b >> 4 ) ) ;
64+ hex. push ( hex_nibble ( b & 0xf ) ) ;
65+ }
66+ hex
67+ } )
2468}
2569
2670fn unhex_nibble ( c : u8 ) -> Option < u8 > {
@@ -32,37 +76,66 @@ fn unhex_nibble(c: u8) -> Option<u8> {
3276 }
3377}
3478
35- fn binascii_unhexlify ( hexstr : PyBytesRef , vm : & VirtualMachine ) -> PyResult {
36- // TODO: allow 'str' hexstrings as well
37- let hex_bytes = hexstr. get_value ( ) ;
38- if hex_bytes. len ( ) % 2 != 0 {
39- return Err ( vm. new_value_error ( "Odd-length string" . to_string ( ) ) ) ;
40- }
79+ fn binascii_unhexlify ( data : SerializedData , vm : & VirtualMachine ) -> PyResult < Vec < u8 > > {
80+ data. with_ref ( |hex_bytes| {
81+ if hex_bytes. len ( ) % 2 != 0 {
82+ return Err ( vm. new_value_error ( "Odd-length string" . to_string ( ) ) ) ;
83+ }
4184
42- let mut unhex = Vec :: < u8 > :: with_capacity ( hex_bytes. len ( ) / 2 ) ;
43- for i in ( 0 ..hex_bytes. len ( ) ) . step_by ( 2 ) {
44- let n1 = unhex_nibble ( hex_bytes[ i] ) ;
45- let n2 = unhex_nibble ( hex_bytes[ i + 1 ] ) ;
46- if let ( Some ( n1) , Some ( n2) ) = ( n1, n2) {
47- unhex. push ( n1 << 4 | n2) ;
48- } else {
49- return Err ( vm. new_value_error ( "Non-hexadecimal digit found" . to_string ( ) ) ) ;
85+ let mut unhex = Vec :: < u8 > :: with_capacity ( hex_bytes. len ( ) / 2 ) ;
86+ for ( n1, n2) in hex_bytes. iter ( ) . tuples ( ) {
87+ if let ( Some ( n1) , Some ( n2) ) = ( unhex_nibble ( * n1) , unhex_nibble ( * n2) ) {
88+ unhex. push ( n1 << 4 | n2) ;
89+ } else {
90+ return Err ( vm. new_value_error ( "Non-hexadecimal digit found" . to_string ( ) ) ) ;
91+ }
5092 }
51- }
5293
53- Ok ( vm. ctx . new_bytes ( unhex) )
94+ Ok ( unhex)
95+ } )
5496}
5597
56- fn binascii_crc32 ( data : PyBytesRef , value : OptionalArg < u32 > , vm : & VirtualMachine ) -> PyResult {
57- let bytes = data. get_value ( ) ;
58- let crc = value. unwrap_or ( 0u32 ) ;
98+ fn binascii_crc32 ( data : SerializedData , value : OptionalArg < u32 > , vm : & VirtualMachine ) -> PyResult {
99+ let crc = value. unwrap_or ( 0 ) ;
59100
60101 let mut digest = crc32:: Digest :: new_with_initial ( crc32:: IEEE , crc) ;
61- digest. write ( & bytes) ;
102+ data . with_ref ( |bytes| digest. write ( & bytes) ) ;
62103
63104 Ok ( vm. ctx . new_int ( digest. sum32 ( ) ) )
64105}
65106
107+ #[ derive( FromArgs ) ]
108+ struct NewlineArg {
109+ #[ pyarg( keyword_only, default = "true" ) ]
110+ newline : bool ,
111+ }
112+
113+ /// trim a newline from the end of the bytestring, if it exists
114+ fn trim_newline ( b : & [ u8 ] ) -> & [ u8 ] {
115+ if b. ends_with ( b"\n " ) {
116+ & b[ ..b. len ( ) - 1 ]
117+ } else {
118+ b
119+ }
120+ }
121+
122+ fn binascii_a2b_base64 ( s : SerializedData , vm : & VirtualMachine ) -> PyResult < Vec < u8 > > {
123+ s. with_ref ( |b| base64:: decode ( trim_newline ( b) ) )
124+ . map_err ( |err| vm. new_value_error ( format ! ( "error decoding base64: {}" , err) ) )
125+ }
126+
127+ fn binascii_b2a_base64 (
128+ data : PyBytesLike ,
129+ NewlineArg { newline } : NewlineArg ,
130+ _vm : & VirtualMachine ,
131+ ) -> Vec < u8 > {
132+ let mut encoded = data. with_ref ( base64:: encode) . into_bytes ( ) ;
133+ if newline {
134+ encoded. push ( b'\n' ) ;
135+ }
136+ encoded
137+ }
138+
66139pub fn make_module ( vm : & VirtualMachine ) -> PyObjectRef {
67140 let ctx = & vm. ctx ;
68141
@@ -72,5 +145,7 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
72145 "unhexlify" => ctx. new_rustfunc( binascii_unhexlify) ,
73146 "a2b_hex" => ctx. new_rustfunc( binascii_unhexlify) ,
74147 "crc32" => ctx. new_rustfunc( binascii_crc32) ,
148+ "a2b_base64" => ctx. new_rustfunc( binascii_a2b_base64) ,
149+ "b2a_base64" => ctx. new_rustfunc( binascii_b2a_base64) ,
75150 } )
76151}
0 commit comments