diff --git a/crates/stdlib/src/sqlite.rs b/crates/stdlib/src/sqlite.rs index 6851462e52..9fdb1716ac 100644 --- a/crates/stdlib/src/sqlite.rs +++ b/crates/stdlib/src/sqlite.rs @@ -2207,18 +2207,24 @@ mod _sqlite { self.inner.lock().take(); } + fn ensure_connection_open(&self, vm: &VirtualMachine) -> PyResult<()> { + if self.connection.is_closed() { + Err(new_programming_error( + vm, + "Cannot operate on a closed database".to_owned(), + )) + } else { + Ok(()) + } + } + #[pymethod] fn read( &self, length: OptionalArg, vm: &VirtualMachine, ) -> PyResult> { - if self.connection.is_closed() { - return Err(new_programming_error( - vm, - "Cannot operate on a closed database".to_owned(), - )); - } + self.ensure_connection_open(vm)?; let mut length = length.unwrap_or(-1); let mut inner = self.inner(vm)?; @@ -2245,6 +2251,7 @@ mod _sqlite { #[pymethod] fn write(&self, data: PyBuffer, vm: &VirtualMachine) -> PyResult<()> { + self.ensure_connection_open(vm)?; let mut inner = self.inner(vm)?; let blob_len = inner.blob.bytes(); let length = Self::expect_write(blob_len, data.desc.len, inner.offset, vm)?; @@ -2260,6 +2267,7 @@ mod _sqlite { #[pymethod] fn tell(&self, vm: &VirtualMachine) -> PyResult { + self.ensure_connection_open(vm)?; self.inner(vm).map(|x| x.offset) } @@ -2270,6 +2278,7 @@ mod _sqlite { origin: OptionalArg, vm: &VirtualMachine, ) -> PyResult<()> { + self.ensure_connection_open(vm)?; let origin = origin.unwrap_or(libc::SEEK_SET); let mut inner = self.inner(vm)?; let blob_len = inner.blob.bytes(); @@ -2299,12 +2308,14 @@ mod _sqlite { #[pymethod] fn __enter__(zelf: PyRef, vm: &VirtualMachine) -> PyResult> { + zelf.ensure_connection_open(vm)?; let _ = zelf.inner(vm)?; Ok(zelf) } #[pymethod] fn __exit__(&self, _args: FuncArgs, vm: &VirtualMachine) -> PyResult<()> { + self.ensure_connection_open(vm)?; let _ = self.inner(vm)?; self.close(); Ok(()) @@ -2351,6 +2362,7 @@ mod _sqlite { } fn subscript(&self, needle: &PyObject, vm: &VirtualMachine) -> PyResult { + self.ensure_connection_open(vm)?; let inner = self.inner(vm)?; if let Some(index) = needle.try_index_opt(vm) { let blob_len = inner.blob.bytes(); @@ -2396,6 +2408,7 @@ mod _sqlite { let Some(value) = value else { return Err(vm.new_type_error("Blob doesn't support slice deletion")); }; + self.ensure_connection_open(vm)?; let inner = self.inner(vm)?; if let Some(index) = needle.try_index_opt(vm) {