diff --git a/Lib/gzip.py b/Lib/gzip.py index 5b20e5ba69..1a3c82ce7e 100644 --- a/Lib/gzip.py +++ b/Lib/gzip.py @@ -15,12 +15,16 @@ FTEXT, FHCRC, FEXTRA, FNAME, FCOMMENT = 1, 2, 4, 8, 16 -READ, WRITE = 1, 2 +READ = 'rb' +WRITE = 'wb' _COMPRESS_LEVEL_FAST = 1 _COMPRESS_LEVEL_TRADEOFF = 6 _COMPRESS_LEVEL_BEST = 9 +READ_BUFFER_SIZE = 128 * 1024 +_WRITE_BUFFER_SIZE = 4 * io.DEFAULT_BUFFER_SIZE + def open(filename, mode="rb", compresslevel=_COMPRESS_LEVEL_BEST, encoding=None, errors=None, newline=None): @@ -118,6 +122,21 @@ class BadGzipFile(OSError): """Exception raised in some cases for invalid gzip files.""" +class _WriteBufferStream(io.RawIOBase): + """Minimal object to pass WriteBuffer flushes into GzipFile""" + def __init__(self, gzip_file): + self.gzip_file = gzip_file + + def write(self, data): + return self.gzip_file._write_raw(data) + + def seekable(self): + return False + + def writable(self): + return True + + class GzipFile(_compression.BaseStream): """The GzipFile class simulates most of the methods of a file object with the exception of the truncate() method. @@ -160,9 +179,10 @@ def __init__(self, filename=None, mode=None, and 9 is slowest and produces the most compression. 0 is no compression at all. The default is 9. - The mtime argument is an optional numeric timestamp to be written - to the last modification time field in the stream when compressing. - If omitted or None, the current time is used. + The optional mtime argument is the timestamp requested by gzip. The time + is in Unix format, i.e., seconds since 00:00:00 UTC, January 1, 1970. + If mtime is omitted or None, the current time is used. Use mtime = 0 + to generate a compressed stream that does not depend on creation time. """ @@ -182,6 +202,7 @@ def __init__(self, filename=None, mode=None, if mode is None: mode = getattr(fileobj, 'mode', 'rb') + if mode.startswith('r'): self.mode = READ raw = _GzipReader(fileobj) @@ -204,6 +225,9 @@ def __init__(self, filename=None, mode=None, zlib.DEF_MEM_LEVEL, 0) self._write_mtime = mtime + self._buffer_size = _WRITE_BUFFER_SIZE + self._buffer = io.BufferedWriter(_WriteBufferStream(self), + buffer_size=self._buffer_size) else: raise ValueError("Invalid mode: {!r}".format(mode)) @@ -212,14 +236,6 @@ def __init__(self, filename=None, mode=None, if self.mode == WRITE: self._write_gzip_header(compresslevel) - @property - def filename(self): - import warnings - warnings.warn("use the name attribute", DeprecationWarning, 2) - if self.mode == WRITE and self.name[-3:] != ".gz": - return self.name + ".gz" - return self.name - @property def mtime(self): """Last modification time read from stream, or None""" @@ -237,6 +253,11 @@ def _init_write(self, filename): self.bufsize = 0 self.offset = 0 # Current file offset for seek(), tell(), etc + def tell(self): + self._check_not_closed() + self._buffer.flush() + return super().tell() + def _write_gzip_header(self, compresslevel): self.fileobj.write(b'\037\213') # magic header self.fileobj.write(b'\010') # compression method @@ -278,6 +299,10 @@ def write(self,data): if self.fileobj is None: raise ValueError("write() on closed GzipFile object") + return self._buffer.write(data) + + def _write_raw(self, data): + # Called by our self._buffer underlying WriteBufferStream. if isinstance(data, (bytes, bytearray)): length = len(data) else: @@ -326,11 +351,11 @@ def closed(self): def close(self): fileobj = self.fileobj - if fileobj is None: + if fileobj is None or self._buffer.closed: return - self.fileobj = None try: if self.mode == WRITE: + self._buffer.flush() fileobj.write(self.compress.flush()) write32u(fileobj, self.crc) # self.size may exceed 2 GiB, or even 4 GiB @@ -338,6 +363,7 @@ def close(self): elif self.mode == READ: self._buffer.close() finally: + self.fileobj = None myfileobj = self.myfileobj if myfileobj: self.myfileobj = None @@ -346,6 +372,7 @@ def close(self): def flush(self,zlib_mode=zlib.Z_SYNC_FLUSH): self._check_not_closed() if self.mode == WRITE: + self._buffer.flush() # Ensure the compressor's buffer is flushed self.fileobj.write(self.compress.flush(zlib_mode)) self.fileobj.flush() @@ -376,6 +403,9 @@ def seekable(self): def seek(self, offset, whence=io.SEEK_SET): if self.mode == WRITE: + self._check_not_closed() + # Flush buffer to ensure validity of self.offset + self._buffer.flush() if whence != io.SEEK_SET: if whence == io.SEEK_CUR: offset = self.offset + offset @@ -384,10 +414,10 @@ def seek(self, offset, whence=io.SEEK_SET): if offset < self.offset: raise OSError('Negative seek in write mode') count = offset - self.offset - chunk = b'\0' * 1024 - for i in range(count // 1024): + chunk = b'\0' * self._buffer_size + for i in range(count // self._buffer_size): self.write(chunk) - self.write(b'\0' * (count % 1024)) + self.write(b'\0' * (count % self._buffer_size)) elif self.mode == READ: self._check_not_closed() return self._buffer.seek(offset, whence) @@ -454,7 +484,7 @@ def _read_gzip_header(fp): class _GzipReader(_compression.DecompressReader): def __init__(self, fp): - super().__init__(_PaddedFile(fp), zlib.decompressobj, + super().__init__(_PaddedFile(fp), zlib._ZlibDecompressor, wbits=-zlib.MAX_WBITS) # Set flag indicating start of a new member self._new_member = True @@ -502,12 +532,13 @@ def read(self, size=-1): self._new_member = False # Read a chunk of data from the file - buf = self._fp.read(io.DEFAULT_BUFFER_SIZE) + if self._decompressor.needs_input: + buf = self._fp.read(READ_BUFFER_SIZE) + uncompress = self._decompressor.decompress(buf, size) + else: + uncompress = self._decompressor.decompress(b"", size) - uncompress = self._decompressor.decompress(buf, size) - if self._decompressor.unconsumed_tail != b"": - self._fp.prepend(self._decompressor.unconsumed_tail) - elif self._decompressor.unused_data != b"": + if self._decompressor.unused_data != b"": # Prepend the already read bytes to the fileobj so they can # be seen by _read_eof() and _read_gzip_header() self._fp.prepend(self._decompressor.unused_data) @@ -518,14 +549,11 @@ def read(self, size=-1): raise EOFError("Compressed file ended before the " "end-of-stream marker was reached") - self._add_read_data( uncompress ) + self._crc = zlib.crc32(uncompress, self._crc) + self._stream_size += len(uncompress) self._pos += len(uncompress) return uncompress - def _add_read_data(self, data): - self._crc = zlib.crc32(data, self._crc) - self._stream_size = self._stream_size + len(data) - def _read_eof(self): # We've read to the end of the file # We check that the computed CRC and size of the @@ -552,43 +580,21 @@ def _rewind(self): self._new_member = True -def _create_simple_gzip_header(compresslevel: int, - mtime = None) -> bytes: - """ - Write a simple gzip header with no extra fields. - :param compresslevel: Compresslevel used to determine the xfl bytes. - :param mtime: The mtime (must support conversion to a 32-bit integer). - :return: A bytes object representing the gzip header. - """ - if mtime is None: - mtime = time.time() - if compresslevel == _COMPRESS_LEVEL_BEST: - xfl = 2 - elif compresslevel == _COMPRESS_LEVEL_FAST: - xfl = 4 - else: - xfl = 0 - # Pack ID1 and ID2 magic bytes, method (8=deflate), header flags (no extra - # fields added to header), mtime, xfl and os (255 for unknown OS). - return struct.pack(" { #let_obj - for name in [(#(#names,)*)] { + for name in [#(#names),*] { vm.__module_set_attr(module, vm.ctx.intern_str(name), obj.clone()).unwrap(); } }}, diff --git a/stdlib/src/zlib.rs b/stdlib/src/zlib.rs index a4828552a7..bcfe93aff5 100644 --- a/stdlib/src/zlib.rs +++ b/stdlib/src/zlib.rs @@ -5,7 +5,7 @@ pub(crate) use zlib::make_module; #[pymodule] mod zlib { use crate::vm::{ - builtins::{PyBaseExceptionRef, PyBytes, PyBytesRef, PyIntRef, PyTypeRef}, + builtins::{PyBaseExceptionRef, PyBytesRef, PyIntRef, PyTypeRef}, common::lock::PyMutex, convert::TryFromBorrowedObject, function::{ArgBytesLike, ArgPrimitiveIndex, ArgSize, OptionalArg}, @@ -13,7 +13,6 @@ mod zlib { PyObject, PyPayload, PyResult, VirtualMachine, }; use adler32::RollingAdler32 as Adler32; - use crossbeam_utils::atomic::AtomicCell; use flate2::{ write::ZlibEncoder, Compress, Compression, Decompress, FlushCompress, FlushDecompress, Status, @@ -27,6 +26,17 @@ mod zlib { Z_NO_COMPRESSION, Z_NO_FLUSH, Z_PARTIAL_FLUSH, Z_RLE, Z_SYNC_FLUSH, Z_TREES, }; + // we're statically linking libz-rs, so the compile-time and runtime + // versions will always be the same + #[pyattr(name = "ZLIB_RUNTIME_VERSION")] + #[pyattr] + const ZLIB_VERSION: &str = unsafe { + match std::ffi::CStr::from_ptr(libz_sys::zlibVersion()).to_str() { + Ok(s) => s, + Err(_) => unreachable!(), + } + }; + // copied from zlibmodule.c (commit 530f506ac91338) #[pyattr] const MAX_WBITS: i8 = 15; @@ -80,15 +90,10 @@ mod zlib { } = args; let level = level.ok_or_else(|| new_zlib_error("Bad compression level", vm))?; - let encoded_bytes = if args.wbits.value == MAX_WBITS { - let mut encoder = ZlibEncoder::new(Vec::new(), level); - data.with_ref(|input_bytes| encoder.write_all(input_bytes).unwrap()); - encoder.finish().unwrap() - } else { - let mut inner = CompressInner::new(InitOptions::new(wbits.value, vm)?.compress(level)); - data.with_ref(|input_bytes| inner.compress(input_bytes, vm))?; - inner.flush(vm)? - }; + let compress = InitOptions::new(wbits.value, vm)?.compress(level); + let mut encoder = ZlibEncoder::new_with_compress(Vec::new(), compress); + data.with_ref(|input_bytes| encoder.write_all(input_bytes).unwrap()); + let encoded_bytes = encoder.finish().unwrap(); Ok(vm.ctx.new_bytes(encoded_bytes)) } @@ -109,6 +114,11 @@ mod zlib { let header = wbits > 0; let wbits = wbits.unsigned_abs(); match wbits { + // TODO: wbits = 0 should be a valid option: + // > windowBits can also be zero to request that inflate use the window size in + // > the zlib header of the compressed stream. + // but flate2 doesn't expose it + // 0 => ... 9..=15 => Ok(InitOptions::Standard { header, wbits }), 25..=31 => Ok(InitOptions::Gzip { wbits: wbits - 16 }), _ => Err(vm.new_value_error("Invalid initialization option".to_owned())), @@ -131,29 +141,81 @@ mod zlib { } } + #[derive(Clone)] + struct Chunker<'a> { + data1: &'a [u8], + data2: &'a [u8], + } + impl<'a> Chunker<'a> { + fn new(data: &'a [u8]) -> Self { + Self { + data1: data, + data2: &[], + } + } + fn chain(data1: &'a [u8], data2: &'a [u8]) -> Self { + if data1.is_empty() { + Self { + data1: data2, + data2: &[], + } + } else { + Self { data1, data2 } + } + } + fn len(&self) -> usize { + self.data1.len() + self.data2.len() + } + fn is_empty(&self) -> bool { + self.data1.is_empty() + } + fn to_vec(&self) -> Vec { + [self.data1, self.data2].concat() + } + fn chunk(&self) -> &'a [u8] { + self.data1.get(..CHUNKSIZE).unwrap_or(self.data1) + } + fn advance(&mut self, consumed: usize) { + self.data1 = &self.data1[consumed..]; + if self.data1.is_empty() { + self.data1 = std::mem::take(&mut self.data2); + } + } + } + fn _decompress( - mut data: &[u8], + data: &[u8], d: &mut Decompress, bufsize: usize, max_length: Option, is_flush: bool, + zdict: Option<&ArgBytesLike>, + vm: &VirtualMachine, + ) -> PyResult<(Vec, bool)> { + let mut data = Chunker::new(data); + _decompress_chunks(&mut data, d, bufsize, max_length, is_flush, zdict, vm) + } + + fn _decompress_chunks( + data: &mut Chunker<'_>, + d: &mut Decompress, + bufsize: usize, + max_length: Option, + is_flush: bool, + zdict: Option<&ArgBytesLike>, vm: &VirtualMachine, ) -> PyResult<(Vec, bool)> { if data.is_empty() { return Ok((Vec::new(), true)); } + let max_length = max_length.unwrap_or(usize::MAX); let mut buf = Vec::new(); - loop { - let final_chunk = data.len() <= CHUNKSIZE; - let chunk = if final_chunk { - data - } else { - &data[..CHUNKSIZE] - }; - // if this is the final chunk, finish it + 'outer: loop { + let chunk = data.chunk(); let flush = if is_flush { - if final_chunk { + // if this is the final chunk, finish it + if chunk.len() == data.len() { FlushDecompress::Finish } else { FlushDecompress::None @@ -162,34 +224,43 @@ mod zlib { FlushDecompress::Sync }; loop { - let additional = if let Some(max_length) = max_length { - std::cmp::min(bufsize, max_length - buf.capacity()) - } else { - bufsize - }; + let additional = std::cmp::min(bufsize, max_length - buf.capacity()); if additional == 0 { return Ok((buf, false)); } - buf.reserve_exact(additional); + let prev_in = d.total_in(); - let status = d - .decompress_vec(chunk, &mut buf, flush) - .map_err(|_| new_zlib_error("invalid input data", vm))?; + let res = d.decompress_vec(chunk, &mut buf, flush); let consumed = d.total_in() - prev_in; - data = &data[consumed as usize..]; - let stream_end = status == Status::StreamEnd; - if stream_end || data.is_empty() { - // we've reached the end of the stream, we're done - buf.shrink_to_fit(); - return Ok((buf, stream_end)); - } else if !chunk.is_empty() && consumed == 0 { - // we're gonna need a bigger buffer - continue; - } else { - // next chunk - break; - } + + data.advance(consumed as usize); + + match res { + Ok(status) => { + let stream_end = status == Status::StreamEnd; + if stream_end || data.is_empty() { + // we've reached the end of the stream, we're done + buf.shrink_to_fit(); + return Ok((buf, stream_end)); + } else if !chunk.is_empty() && consumed == 0 { + // we're gonna need a bigger buffer + continue; + } else { + // next chunk + continue 'outer; + } + } + Err(e) => { + let Some(zdict) = e.needs_dictionary().and(zdict) else { + return Err(new_zlib_error(&e.to_string(), vm)); + }; + d.set_dictionary(&zdict.borrow_buf()) + .map_err(|_| new_zlib_error("failed to set dictionary", vm))?; + // now try the next chunk + continue 'outer; + } + }; } } } @@ -214,7 +285,8 @@ mod zlib { } = args; data.with_ref(|data| { let mut d = InitOptions::new(wbits.value, vm)?.decompress(); - let (buf, stream_end) = _decompress(data, &mut d, bufsize.value, None, false, vm)?; + let (buf, stream_end) = + _decompress(data, &mut d, bufsize.value, None, false, None, vm)?; if !stream_end { return Err(new_zlib_error( "Error -5 while decompressing data: incomplete or truncated stream", @@ -230,101 +302,116 @@ mod zlib { #[pyarg(any, default = "ArgPrimitiveIndex { value: MAX_WBITS }")] wbits: ArgPrimitiveIndex, #[pyarg(any, optional)] - _zdict: OptionalArg, + zdict: OptionalArg, } #[pyfunction] fn decompressobj(args: DecompressobjArgs, vm: &VirtualMachine) -> PyResult { - #[allow(unused_mut)] let mut decompress = InitOptions::new(args.wbits.value, vm)?.decompress(); - if let OptionalArg::Present(_dict) = args._zdict { - // FIXME: always fails - // dict.with_ref(|d| decompress.set_dictionary(d)); + let zdict = args.zdict.into_option(); + if let Some(dict) = &zdict { + if args.wbits.value < 0 { + dict.with_ref(|d| decompress.set_dictionary(d)) + .map_err(|_| new_zlib_error("failed to set dictionary", vm))?; + } } + let inner = PyDecompressInner { + decompress: Some(decompress), + eof: false, + zdict, + unused_data: vm.ctx.empty_bytes.clone(), + unconsumed_tail: vm.ctx.empty_bytes.clone(), + }; Ok(PyDecompress { - decompress: PyMutex::new(decompress), - eof: AtomicCell::new(false), - unused_data: PyMutex::new(PyBytes::from(vec![]).into_ref(&vm.ctx)), - unconsumed_tail: PyMutex::new(PyBytes::from(vec![]).into_ref(&vm.ctx)), + inner: PyMutex::new(inner), }) } + + #[derive(Debug)] + struct PyDecompressInner { + decompress: Option, + zdict: Option, + eof: bool, + unused_data: PyBytesRef, + unconsumed_tail: PyBytesRef, + } + #[pyattr] #[pyclass(name = "Decompress")] #[derive(Debug, PyPayload)] struct PyDecompress { - decompress: PyMutex, - eof: AtomicCell, - unused_data: PyMutex, - unconsumed_tail: PyMutex, + inner: PyMutex, } + #[pyclass] impl PyDecompress { #[pygetset] fn eof(&self) -> bool { - self.eof.load() + self.inner.lock().eof } #[pygetset] fn unused_data(&self) -> PyBytesRef { - self.unused_data.lock().clone() + self.inner.lock().unused_data.clone() } #[pygetset] fn unconsumed_tail(&self) -> PyBytesRef { - self.unconsumed_tail.lock().clone() + self.inner.lock().unconsumed_tail.clone() } - fn save_unused_input( - &self, - d: &Decompress, + fn decompress_inner( + inner: &mut PyDecompressInner, data: &[u8], - stream_end: bool, - orig_in: u64, + bufsize: usize, + max_length: Option, + is_flush: bool, vm: &VirtualMachine, - ) { - let leftover = &data[(d.total_in() - orig_in) as usize..]; - - if stream_end && !leftover.is_empty() { - let mut unused_data = self.unused_data.lock(); - let unused: Vec<_> = unused_data - .as_bytes() - .iter() - .chain(leftover) - .copied() - .collect(); - *unused_data = vm.ctx.new_pyref(unused); + ) -> PyResult<(PyResult>, bool)> { + let Some(d) = &mut inner.decompress else { + return Err(new_zlib_error(USE_AFTER_FINISH_ERR, vm)); + }; + + let zdict = if is_flush { None } else { inner.zdict.as_ref() }; + + let prev_in = d.total_in(); + let (ret, stream_end) = + match _decompress(data, d, bufsize, max_length, is_flush, zdict, vm) { + Ok((buf, stream_end)) => (Ok(buf), stream_end), + Err(err) => (Err(err), false), + }; + let consumed = (d.total_in() - prev_in) as usize; + + // save unused input + let unconsumed = &data[consumed..]; + if !unconsumed.is_empty() { + if stream_end { + let unused = [inner.unused_data.as_bytes(), unconsumed].concat(); + inner.unused_data = vm.ctx.new_pyref(unused); + } else { + inner.unconsumed_tail = vm.ctx.new_bytes(unconsumed.to_vec()); + } + } else if !inner.unconsumed_tail.is_empty() { + inner.unconsumed_tail = vm.ctx.empty_bytes.clone(); } + + Ok((ret, stream_end)) } #[pymethod] fn decompress(&self, args: DecompressArgs, vm: &VirtualMachine) -> PyResult> { - let max_length = args.max_length.value; + let max_length: usize = args + .max_length + .map_or(0, |x| x.value) + .try_into() + .map_err(|_| vm.new_value_error("must be non-negative".to_owned()))?; let max_length = (max_length != 0).then_some(max_length); - let data = args.data.borrow_buf(); - let data = &*data; + let data = &*args.data.borrow_buf(); - let mut d = self.decompress.lock(); - let orig_in = d.total_in(); + let inner = &mut *self.inner.lock(); let (ret, stream_end) = - match _decompress(data, &mut d, DEF_BUF_SIZE, max_length, false, vm) { - Ok((buf, true)) => { - self.eof.store(true); - (Ok(buf), true) - } - Ok((buf, false)) => (Ok(buf), false), - Err(err) => (Err(err), false), - }; - self.save_unused_input(&d, data, stream_end, orig_in, vm); - - let leftover = if stream_end { - b"" - } else { - &data[(d.total_in() - orig_in) as usize..] - }; + Self::decompress_inner(inner, data, DEF_BUF_SIZE, max_length, false, vm)?; - let mut unconsumed_tail = self.unconsumed_tail.lock(); - if !leftover.is_empty() || !unconsumed_tail.is_empty() { - *unconsumed_tail = PyBytes::from(leftover.to_owned()).into_ref(&vm.ctx); - } + inner.eof |= stream_end; ret } @@ -332,36 +419,22 @@ mod zlib { #[pymethod] fn flush(&self, length: OptionalArg, vm: &VirtualMachine) -> PyResult> { let length = match length { - OptionalArg::Present(l) => { - let l: isize = l.into(); - if l <= 0 { - return Err( - vm.new_value_error("length must be greater than zero".to_owned()) - ); - } else { - l as usize - } + OptionalArg::Present(ArgSize { value }) if value <= 0 => { + return Err(vm.new_value_error("length must be greater than zero".to_owned())) } + OptionalArg::Present(ArgSize { value }) => value as usize, OptionalArg::Missing => DEF_BUF_SIZE, }; - let mut data = self.unconsumed_tail.lock(); - let mut d = self.decompress.lock(); + let inner = &mut *self.inner.lock(); + let data = std::mem::replace(&mut inner.unconsumed_tail, vm.ctx.empty_bytes.clone()); - let orig_in = d.total_in(); - - let (ret, stream_end) = match _decompress(&data, &mut d, length, None, true, vm) { - Ok((buf, stream_end)) => (Ok(buf), stream_end), - Err(err) => (Err(err), false), - }; - self.save_unused_input(&d, &data, stream_end, orig_in, vm); + let (ret, _) = Self::decompress_inner(inner, &data, length, None, true, vm)?; - *data = PyBytes::from(Vec::new()).into_ref(&vm.ctx); + if inner.eof { + inner.decompress = None; + } - // TODO: drop the inner decompressor, somehow - // if stream_end { - // - // } ret } } @@ -370,11 +443,8 @@ mod zlib { struct DecompressArgs { #[pyarg(positional)] data: ArgBytesLike, - #[pyarg( - any, - default = "rustpython_vm::function::ArgPrimitiveIndex { value: 0 }" - )] - max_length: ArgPrimitiveIndex, + #[pyarg(any, optional)] + max_length: OptionalArg, } #[derive(FromArgs)] @@ -384,13 +454,13 @@ mod zlib { level: Level, // only DEFLATED is valid right now, it's w/e #[pyarg(any, default = "DEFLATED")] - _method: i32, + method: i32, #[pyarg(any, default = "ArgPrimitiveIndex { value: MAX_WBITS }")] wbits: ArgPrimitiveIndex, - #[pyarg(any, name = "_memLevel", default = "DEF_MEM_LEVEL")] - _mem_level: u8, + #[pyarg(any, name = "memLevel", default = "DEF_MEM_LEVEL")] + mem_level: u8, #[pyarg(any, default = "Z_DEFAULT_STRATEGY")] - _strategy: i32, + strategy: i32, #[pyarg(any, optional)] zdict: Option, } @@ -417,8 +487,7 @@ mod zlib { #[derive(Debug)] struct CompressInner { - compress: Compress, - unconsumed: Vec, + compress: Option, } #[pyattr] @@ -436,10 +505,17 @@ mod zlib { data.with_ref(|b| inner.compress(b, vm)) } - // TODO: mode argument isn't used #[pymethod] - fn flush(&self, _mode: OptionalArg, vm: &VirtualMachine) -> PyResult> { - self.inner.lock().flush(vm) + fn flush(&self, mode: OptionalArg, vm: &VirtualMachine) -> PyResult> { + let mode = match mode.unwrap_or(Z_FINISH) { + Z_NO_FLUSH => return Ok(vec![]), + Z_PARTIAL_FLUSH => FlushCompress::Partial, + Z_SYNC_FLUSH => FlushCompress::Sync, + Z_FULL_FLUSH => FlushCompress::Full, + Z_FINISH => FlushCompress::Finish, + _ => return Err(new_zlib_error("invalid mode", vm)), + }; + self.inner.lock().flush(mode, vm) } // TODO: This is an optional feature of Compress @@ -456,59 +532,55 @@ mod zlib { impl CompressInner { fn new(compress: Compress) -> Self { Self { - compress, - unconsumed: Vec::new(), + compress: Some(compress), } } + + fn get_compress(&mut self, vm: &VirtualMachine) -> PyResult<&mut Compress> { + self.compress + .as_mut() + .ok_or_else(|| new_zlib_error(USE_AFTER_FINISH_ERR, vm)) + } + fn compress(&mut self, data: &[u8], vm: &VirtualMachine) -> PyResult> { - let orig_in = self.compress.total_in() as usize; - let mut cur_in = 0; - let unconsumed = std::mem::take(&mut self.unconsumed); + let c = self.get_compress(vm)?; let mut buf = Vec::new(); - 'outer: for chunk in unconsumed.chunks(CHUNKSIZE).chain(data.chunks(CHUNKSIZE)) { - while cur_in < chunk.len() { + for mut chunk in data.chunks(CHUNKSIZE) { + while !chunk.is_empty() { buf.reserve(DEF_BUF_SIZE); - let status = self - .compress - .compress_vec(&chunk[cur_in..], &mut buf, FlushCompress::None) - .map_err(|_| { - self.unconsumed.extend_from_slice(&data[cur_in..]); - new_zlib_error("error while compressing", vm) - })?; - cur_in = (self.compress.total_in() as usize) - orig_in; - match status { - Status::Ok => continue, - Status::StreamEnd => break 'outer, - _ => break, - } + let prev_in = c.total_in(); + c.compress_vec(chunk, &mut buf, FlushCompress::None) + .map_err(|_| new_zlib_error("error while compressing", vm))?; + let consumed = c.total_in() - prev_in; + chunk = &chunk[consumed as usize..]; } } - self.unconsumed.extend_from_slice(&data[cur_in..]); buf.shrink_to_fit(); Ok(buf) } - // TODO: flush mode (FlushDecompress) parameter - fn flush(&mut self, vm: &VirtualMachine) -> PyResult> { - let data = std::mem::take(&mut self.unconsumed); - let mut data_it = data.chunks(CHUNKSIZE); + fn flush(&mut self, mode: FlushCompress, vm: &VirtualMachine) -> PyResult> { + let c = self.get_compress(vm)?; let mut buf = Vec::new(); - loop { - let chunk = data_it.next().unwrap_or(&[]); + let status = loop { if buf.len() == buf.capacity() { buf.reserve(DEF_BUF_SIZE); } - let status = self - .compress - .compress_vec(chunk, &mut buf, FlushCompress::Finish) + let status = c + .compress_vec(&[], &mut buf, mode) .map_err(|_| new_zlib_error("error while compressing", vm))?; - match status { - Status::StreamEnd => break, - _ => continue, + if buf.len() != buf.capacity() { + break status; } + }; + + match status { + Status::Ok | Status::BufError => {} + Status::StreamEnd if mode == FlushCompress::Finish => self.compress = None, + Status::StreamEnd => return Err(new_zlib_error("unexpected eof", vm)), } buf.shrink_to_fit(); @@ -520,6 +592,8 @@ mod zlib { vm.new_exception_msg(vm.class("zlib", "error"), message.to_owned()) } + const USE_AFTER_FINISH_ERR: &str = "Error -2: inconsistent stream state"; + struct Level(Option); impl Level { @@ -551,130 +625,116 @@ mod zlib { #[pyattr] #[pyclass(name = "_ZlibDecompressor")] #[derive(Debug, PyPayload)] - pub struct ZlibDecompressor { - decompress: PyMutex, - unused_data: PyMutex, - unconsumed_tail: PyMutex, + struct ZlibDecompressor { + inner: PyMutex, + } + + #[derive(Debug)] + struct ZlibDecompressorInner { + decompress: Decompress, + unused_data: PyBytesRef, + input_buffer: Vec, + zdict: Option, + eof: bool, + needs_input: bool, } impl Constructor for ZlibDecompressor { - type Args = (); - - fn py_new(cls: PyTypeRef, _args: Self::Args, vm: &VirtualMachine) -> PyResult { - let decompress = Decompress::new(true); - let zlib_decompressor = Self { - decompress: PyMutex::new(decompress), - unused_data: PyMutex::new(PyBytes::from(vec![]).into_ref(&vm.ctx)), - unconsumed_tail: PyMutex::new(PyBytes::from(vec![]).into_ref(&vm.ctx)), + type Args = DecompressobjArgs; + + fn py_new(cls: PyTypeRef, args: Self::Args, vm: &VirtualMachine) -> PyResult { + let mut decompress = InitOptions::new(args.wbits.value, vm)?.decompress(); + let zdict = args.zdict.into_option(); + if let Some(dict) = &zdict { + if args.wbits.value < 0 { + dict.with_ref(|d| decompress.set_dictionary(d)) + .map_err(|_| new_zlib_error("failed to set dictionary", vm))?; + } + } + let inner = ZlibDecompressorInner { + decompress, + unused_data: vm.ctx.empty_bytes.clone(), + input_buffer: Vec::new(), + zdict, + eof: false, + needs_input: true, }; - zlib_decompressor - .into_ref_with_type(vm, cls) - .map(Into::into) + Self { + inner: PyMutex::new(inner), + } + .into_ref_with_type(vm, cls) + .map(Into::into) } } #[pyclass(with(Constructor))] impl ZlibDecompressor { #[pygetset] - fn unused_data(&self) -> PyBytesRef { - self.unused_data.lock().clone() + fn eof(&self) -> bool { + self.inner.lock().eof } #[pygetset] - fn unconsumed_tail(&self) -> PyBytesRef { - self.unconsumed_tail.lock().clone() + fn unused_data(&self) -> PyBytesRef { + self.inner.lock().unused_data.clone() } - fn save_unused_input( - &self, - d: &Decompress, - data: &[u8], - stream_end: bool, - orig_in: u64, - vm: &VirtualMachine, - ) { - let leftover = &data[(d.total_in() - orig_in) as usize..]; - - if stream_end && !leftover.is_empty() { - let mut unused_data = self.unused_data.lock(); - let unused: Vec<_> = unused_data - .as_bytes() - .iter() - .chain(leftover) - .copied() - .collect(); - *unused_data = vm.ctx.new_pyref(unused); - } + #[pygetset] + fn needs_input(&self) -> bool { + self.inner.lock().needs_input } #[pymethod] - fn decompress(&self, args: PyBytesRef, vm: &VirtualMachine) -> PyResult> { - // let max_length = args.max_length.value; - // let max_length = (max_length != 0).then_some(max_length); - let max_length = None; - let data = args.as_bytes(); - - let mut d = self.decompress.lock(); - let orig_in = d.total_in(); - - let (ret, stream_end) = - match _decompress(data, &mut d, DEF_BUF_SIZE, max_length, false, vm) { - Ok((buf, true)) => { - // Eof is true - (Ok(buf), true) - } - Ok((buf, false)) => (Ok(buf), false), - Err(err) => (Err(err), false), - }; - self.save_unused_input(&d, data, stream_end, orig_in, vm); + fn decompress(&self, args: DecompressArgs, vm: &VirtualMachine) -> PyResult> { + let max_length = args + .max_length + .into_option() + .and_then(|ArgSize { value }| usize::try_from(value).ok()); + let data = &*args.data.borrow_buf(); - let leftover = if stream_end { - b"" - } else { - &data[(d.total_in() - orig_in) as usize..] - }; + let inner = &mut *self.inner.lock(); - let mut unconsumed_tail = self.unconsumed_tail.lock(); - if !leftover.is_empty() || !unconsumed_tail.is_empty() { - *unconsumed_tail = PyBytes::from(leftover.to_owned()).into_ref(&vm.ctx); + if inner.eof { + return Err(vm.new_eof_error("End of stream already reached".to_owned())); } - ret - } + let input_buffer = &mut inner.input_buffer; + let d = &mut inner.decompress; - #[pymethod] - fn flush(&self, length: OptionalArg, vm: &VirtualMachine) -> PyResult> { - let length = match length { - OptionalArg::Present(l) => { - let l: isize = l.into(); - if l <= 0 { - return Err( - vm.new_value_error("length must be greater than zero".to_owned()) - ); - } else { - l as usize - } - } - OptionalArg::Missing => DEF_BUF_SIZE, - }; + let mut chunks = Chunker::chain(input_buffer, data); - let mut data = self.unconsumed_tail.lock(); - let mut d = self.decompress.lock(); + let zdict = inner.zdict.as_ref(); + let bufsize = DEF_BUF_SIZE; - let orig_in = d.total_in(); + let prev_len = chunks.len(); + let (ret, stream_end) = + match _decompress_chunks(&mut chunks, d, bufsize, max_length, false, zdict, vm) { + Ok((buf, stream_end)) => (Ok(buf), stream_end), + Err(err) => (Err(err), false), + }; + let consumed = prev_len - chunks.len(); - let (ret, stream_end) = match _decompress(&data, &mut d, length, None, true, vm) { - Ok((buf, stream_end)) => (Ok(buf), stream_end), - Err(err) => (Err(err), false), - }; - self.save_unused_input(&d, &data, stream_end, orig_in, vm); + inner.eof |= stream_end; - *data = PyBytes::from(Vec::new()).into_ref(&vm.ctx); + if inner.eof { + inner.needs_input = false; + if !chunks.is_empty() { + inner.unused_data = vm.ctx.new_bytes(chunks.to_vec()); + } + } else if chunks.is_empty() { + input_buffer.clear(); + inner.needs_input = true; + } else { + inner.needs_input = false; + if let Some(n_consumed_from_data) = consumed.checked_sub(input_buffer.len()) { + input_buffer.clear(); + input_buffer.extend_from_slice(&data[n_consumed_from_data..]); + } else { + input_buffer.drain(..consumed); + input_buffer.extend_from_slice(data); + } + } - // TODO: drop the inner decompressor, somehow - // if stream_end { - // - // } ret }