diff --git a/Cargo.lock b/Cargo.lock index f96b527dfe..aef688eb21 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3278,7 +3278,6 @@ dependencies = [ "widestring", "windows", "windows-sys 0.59.0", - "winreg", ] [[package]] @@ -4674,16 +4673,6 @@ version = "0.7.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5a5364e9d77fcdeeaa6062ced926ee3381faa2ee02d3eb83a5c27a8825540829" -[[package]] -name = "winreg" -version = "0.55.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cb5a765337c50e9ec252c2069be9bf91c7df47afb103b642ba3a53bf8101be97" -dependencies = [ - "cfg-if", - "windows-sys 0.59.0", -] - [[package]] name = "winresource" version = "0.1.27" diff --git a/Lib/test/test_winreg.py b/Lib/test/test_winreg.py new file mode 100644 index 0000000000..924a962781 --- /dev/null +++ b/Lib/test/test_winreg.py @@ -0,0 +1,525 @@ +# Test the windows specific win32reg module. +# Only win32reg functions not hit here: FlushKey, LoadKey and SaveKey + +import gc +import os, sys, errno +import threading +import unittest +from platform import machine, win32_edition +from test.support import cpython_only, import_helper + +# Do this first so test will be skipped if module doesn't exist +import_helper.import_module('winreg', required_on=['win']) +# Now import everything +from winreg import * + +try: + REMOTE_NAME = sys.argv[sys.argv.index("--remote")+1] +except (IndexError, ValueError): + REMOTE_NAME = None + +# tuple of (major, minor) +WIN_VER = sys.getwindowsversion()[:2] +# Some tests should only run on 64-bit architectures where WOW64 will be. +WIN64_MACHINE = True if machine() == "AMD64" else False + +# Starting with Windows 7 and Windows Server 2008 R2, WOW64 no longer uses +# registry reflection and formerly reflected keys are shared instead. +# Windows 7 and Windows Server 2008 R2 are version 6.1. Due to this, some +# tests are only valid up until 6.1 +HAS_REFLECTION = True if WIN_VER < (6, 1) else False + +# Use a per-process key to prevent concurrent test runs (buildbot!) from +# stomping on each other. +test_key_base = "Python Test Key [%d] - Delete Me" % (os.getpid(),) +test_key_name = "SOFTWARE\\" + test_key_base +# On OS'es that support reflection we should test with a reflected key +test_reflect_key_name = "SOFTWARE\\Classes\\" + test_key_base + +test_data = [ + ("Int Value", 45, REG_DWORD), + ("Qword Value", 0x1122334455667788, REG_QWORD), + ("String Val", "A string value", REG_SZ), + ("StringExpand", "The path is %path%", REG_EXPAND_SZ), + ("Multi-string", ["Lots", "of", "string", "values"], REG_MULTI_SZ), + ("Multi-nul", ["", "", "", ""], REG_MULTI_SZ), + ("Raw Data", b"binary\x00data", REG_BINARY), + ("Big String", "x"*(2**14-1), REG_SZ), + ("Big Binary", b"x"*(2**14), REG_BINARY), + # Two and three kanjis, meaning: "Japan" and "Japanese". + ("Japanese 日本", "日本語", REG_SZ), +] + + +@cpython_only +class HeapTypeTests(unittest.TestCase): + def test_have_gc(self): + self.assertTrue(gc.is_tracked(HKEYType)) + + def test_immutable(self): + with self.assertRaisesRegex(TypeError, "immutable"): + HKEYType.foo = "bar" + + +class BaseWinregTests(unittest.TestCase): + + def setUp(self): + # Make sure that the test key is absent when the test + # starts. + self.delete_tree(HKEY_CURRENT_USER, test_key_name) + + def delete_tree(self, root, subkey): + try: + hkey = OpenKey(root, subkey, 0, KEY_ALL_ACCESS) + except OSError: + # subkey does not exist + return + while True: + try: + subsubkey = EnumKey(hkey, 0) + except OSError: + # no more subkeys + break + self.delete_tree(hkey, subsubkey) + CloseKey(hkey) + DeleteKey(root, subkey) + + def _write_test_data(self, root_key, subkeystr="sub_key", + CreateKey=CreateKey): + # Set the default value for this key. + SetValue(root_key, test_key_name, REG_SZ, "Default value") + key = CreateKey(root_key, test_key_name) + self.assertTrue(key.handle != 0) + # Create a sub-key + sub_key = CreateKey(key, subkeystr) + # Give the sub-key some named values + + for value_name, value_data, value_type in test_data: + SetValueEx(sub_key, value_name, 0, value_type, value_data) + + # Check we wrote as many items as we thought. + nkeys, nvalues, since_mod = QueryInfoKey(key) + self.assertEqual(nkeys, 1, "Not the correct number of sub keys") + self.assertEqual(nvalues, 1, "Not the correct number of values") + nkeys, nvalues, since_mod = QueryInfoKey(sub_key) + self.assertEqual(nkeys, 0, "Not the correct number of sub keys") + self.assertEqual(nvalues, len(test_data), + "Not the correct number of values") + # Close this key this way... + # (but before we do, copy the key as an integer - this allows + # us to test that the key really gets closed). + int_sub_key = int(sub_key) + CloseKey(sub_key) + try: + QueryInfoKey(int_sub_key) + self.fail("It appears the CloseKey() function does " + "not close the actual key!") + except OSError: + pass + # ... and close that key that way :-) + int_key = int(key) + key.Close() + try: + QueryInfoKey(int_key) + self.fail("It appears the key.Close() function " + "does not close the actual key!") + except OSError: + pass + def _read_test_data(self, root_key, subkeystr="sub_key", OpenKey=OpenKey): + # Check we can get default value for this key. + val = QueryValue(root_key, test_key_name) + self.assertEqual(val, "Default value", + "Registry didn't give back the correct value") + + key = OpenKey(root_key, test_key_name) + # Read the sub-keys + with OpenKey(key, subkeystr) as sub_key: + # Check I can enumerate over the values. + index = 0 + while 1: + try: + data = EnumValue(sub_key, index) + except OSError: + break + self.assertEqual(data in test_data, True, + "Didn't read back the correct test data") + index = index + 1 + self.assertEqual(index, len(test_data), + "Didn't read the correct number of items") + # Check I can directly access each item + for value_name, value_data, value_type in test_data: + read_val, read_typ = QueryValueEx(sub_key, value_name) + self.assertEqual(read_val, value_data, + "Could not directly read the value") + self.assertEqual(read_typ, value_type, + "Could not directly read the value") + sub_key.Close() + # Enumerate our main key. + read_val = EnumKey(key, 0) + self.assertEqual(read_val, subkeystr, "Read subkey value wrong") + try: + EnumKey(key, 1) + self.fail("Was able to get a second key when I only have one!") + except OSError: + pass + + key.Close() + + def _delete_test_data(self, root_key, subkeystr="sub_key"): + key = OpenKey(root_key, test_key_name, 0, KEY_ALL_ACCESS) + sub_key = OpenKey(key, subkeystr, 0, KEY_ALL_ACCESS) + # It is not necessary to delete the values before deleting + # the key (although subkeys must not exist). We delete them + # manually just to prove we can :-) + for value_name, value_data, value_type in test_data: + DeleteValue(sub_key, value_name) + + nkeys, nvalues, since_mod = QueryInfoKey(sub_key) + self.assertEqual(nkeys, 0, "subkey not empty before delete") + self.assertEqual(nvalues, 0, "subkey not empty before delete") + sub_key.Close() + DeleteKey(key, subkeystr) + + try: + # Shouldn't be able to delete it twice! + DeleteKey(key, subkeystr) + self.fail("Deleting the key twice succeeded") + except OSError: + pass + key.Close() + DeleteKey(root_key, test_key_name) + # Opening should now fail! + try: + key = OpenKey(root_key, test_key_name) + self.fail("Could open the non-existent key") + except OSError: # Use this error name this time + pass + + def _test_all(self, root_key, subkeystr="sub_key"): + self._write_test_data(root_key, subkeystr) + self._read_test_data(root_key, subkeystr) + self._delete_test_data(root_key, subkeystr) + + def _test_named_args(self, key, sub_key): + with CreateKeyEx(key=key, sub_key=sub_key, reserved=0, + access=KEY_ALL_ACCESS) as ckey: + self.assertTrue(ckey.handle != 0) + + with OpenKeyEx(key=key, sub_key=sub_key, reserved=0, + access=KEY_ALL_ACCESS) as okey: + self.assertTrue(okey.handle != 0) + + +class LocalWinregTests(BaseWinregTests): + + def test_registry_works(self): + self._test_all(HKEY_CURRENT_USER) + self._test_all(HKEY_CURRENT_USER, "日本-subkey") + + def test_registry_works_extended_functions(self): + # Substitute the regular CreateKey and OpenKey calls with their + # extended counterparts. + # Note: DeleteKeyEx is not used here because it is platform dependent + cke = lambda key, sub_key: CreateKeyEx(key, sub_key, 0, KEY_ALL_ACCESS) + self._write_test_data(HKEY_CURRENT_USER, CreateKey=cke) + + oke = lambda key, sub_key: OpenKeyEx(key, sub_key, 0, KEY_READ) + self._read_test_data(HKEY_CURRENT_USER, OpenKey=oke) + + self._delete_test_data(HKEY_CURRENT_USER) + + def test_named_arguments(self): + self._test_named_args(HKEY_CURRENT_USER, test_key_name) + # Use the regular DeleteKey to clean up + # DeleteKeyEx takes named args and is tested separately + DeleteKey(HKEY_CURRENT_USER, test_key_name) + + def test_connect_registry_to_local_machine_works(self): + # perform minimal ConnectRegistry test which just invokes it + h = ConnectRegistry(None, HKEY_LOCAL_MACHINE) + self.assertNotEqual(h.handle, 0) + h.Close() + self.assertEqual(h.handle, 0) + + def test_nonexistent_remote_registry(self): + connect = lambda: ConnectRegistry("abcdefghijkl", HKEY_CURRENT_USER) + self.assertRaises(OSError, connect) + + def testExpandEnvironmentStrings(self): + r = ExpandEnvironmentStrings("%windir%\\test") + self.assertEqual(type(r), str) + self.assertEqual(r, os.environ["windir"] + "\\test") + + def test_context_manager(self): + # ensure that the handle is closed if an exception occurs + try: + with ConnectRegistry(None, HKEY_LOCAL_MACHINE) as h: + self.assertNotEqual(h.handle, 0) + raise OSError + except OSError: + self.assertEqual(h.handle, 0) + + def test_changing_value(self): + # Issue2810: A race condition in 2.6 and 3.1 may cause + # EnumValue or QueryValue to raise "WindowsError: More data is + # available" + done = False + + class VeryActiveThread(threading.Thread): + def run(self): + with CreateKey(HKEY_CURRENT_USER, test_key_name) as key: + use_short = True + long_string = 'x'*2000 + while not done: + s = 'x' if use_short else long_string + use_short = not use_short + SetValue(key, 'changing_value', REG_SZ, s) + + thread = VeryActiveThread() + thread.start() + try: + with CreateKey(HKEY_CURRENT_USER, + test_key_name+'\\changing_value') as key: + for _ in range(1000): + num_subkeys, num_values, t = QueryInfoKey(key) + for i in range(num_values): + name = EnumValue(key, i) + QueryValue(key, name[0]) + finally: + done = True + thread.join() + DeleteKey(HKEY_CURRENT_USER, test_key_name+'\\changing_value') + DeleteKey(HKEY_CURRENT_USER, test_key_name) + + def test_long_key(self): + # Issue2810, in 2.6 and 3.1 when the key name was exactly 256 + # characters, EnumKey raised "WindowsError: More data is + # available" + name = 'x'*256 + try: + with CreateKey(HKEY_CURRENT_USER, test_key_name) as key: + SetValue(key, name, REG_SZ, 'x') + num_subkeys, num_values, t = QueryInfoKey(key) + EnumKey(key, 0) + finally: + DeleteKey(HKEY_CURRENT_USER, '\\'.join((test_key_name, name))) + DeleteKey(HKEY_CURRENT_USER, test_key_name) + + def test_dynamic_key(self): + # Issue2810, when the value is dynamically generated, these + # raise "WindowsError: More data is available" in 2.6 and 3.1 + try: + EnumValue(HKEY_PERFORMANCE_DATA, 0) + except OSError as e: + if e.errno in (errno.EPERM, errno.EACCES): + self.skipTest("access denied to registry key " + "(are you running in a non-interactive session?)") + raise + QueryValueEx(HKEY_PERFORMANCE_DATA, "") + + # Reflection requires XP x64/Vista at a minimum. XP doesn't have this stuff + # or DeleteKeyEx so make sure their use raises NotImplementedError + @unittest.skipUnless(WIN_VER < (5, 2), "Requires Windows XP") + def test_reflection_unsupported(self): + try: + with CreateKey(HKEY_CURRENT_USER, test_key_name) as ck: + self.assertNotEqual(ck.handle, 0) + + key = OpenKey(HKEY_CURRENT_USER, test_key_name) + self.assertNotEqual(key.handle, 0) + + with self.assertRaises(NotImplementedError): + DisableReflectionKey(key) + with self.assertRaises(NotImplementedError): + EnableReflectionKey(key) + with self.assertRaises(NotImplementedError): + QueryReflectionKey(key) + with self.assertRaises(NotImplementedError): + DeleteKeyEx(HKEY_CURRENT_USER, test_key_name) + finally: + DeleteKey(HKEY_CURRENT_USER, test_key_name) + + def test_setvalueex_value_range(self): + # Test for Issue #14420, accept proper ranges for SetValueEx. + # Py2Reg, which gets called by SetValueEx, was using PyLong_AsLong, + # thus raising OverflowError. The implementation now uses + # PyLong_AsUnsignedLong to match DWORD's size. + try: + with CreateKey(HKEY_CURRENT_USER, test_key_name) as ck: + self.assertNotEqual(ck.handle, 0) + SetValueEx(ck, "test_name", None, REG_DWORD, 0x80000000) + finally: + DeleteKey(HKEY_CURRENT_USER, test_key_name) + + def test_setvalueex_negative_one_check(self): + # Test for Issue #43984, check -1 was not set by SetValueEx. + # Py2Reg, which gets called by SetValueEx, wasn't checking return + # value by PyLong_AsUnsignedLong, thus setting -1 as value in the registry. + # The implementation now checks PyLong_AsUnsignedLong return value to assure + # the value set was not -1. + try: + with CreateKey(HKEY_CURRENT_USER, test_key_name) as ck: + with self.assertRaises(OverflowError): + SetValueEx(ck, "test_name_dword", None, REG_DWORD, -1) + SetValueEx(ck, "test_name_qword", None, REG_QWORD, -1) + self.assertRaises(FileNotFoundError, QueryValueEx, ck, "test_name_dword") + self.assertRaises(FileNotFoundError, QueryValueEx, ck, "test_name_qword") + + finally: + DeleteKey(HKEY_CURRENT_USER, test_key_name) + + def test_queryvalueex_return_value(self): + # Test for Issue #16759, return unsigned int from QueryValueEx. + # Reg2Py, which gets called by QueryValueEx, was returning a value + # generated by PyLong_FromLong. The implementation now uses + # PyLong_FromUnsignedLong to match DWORD's size. + try: + with CreateKey(HKEY_CURRENT_USER, test_key_name) as ck: + self.assertNotEqual(ck.handle, 0) + test_val = 0x80000000 + SetValueEx(ck, "test_name", None, REG_DWORD, test_val) + ret_val, ret_type = QueryValueEx(ck, "test_name") + self.assertEqual(ret_type, REG_DWORD) + self.assertEqual(ret_val, test_val) + finally: + DeleteKey(HKEY_CURRENT_USER, test_key_name) + + def test_setvalueex_crash_with_none_arg(self): + # Test for Issue #21151, segfault when None is passed to SetValueEx + try: + with CreateKey(HKEY_CURRENT_USER, test_key_name) as ck: + self.assertNotEqual(ck.handle, 0) + test_val = None + SetValueEx(ck, "test_name", 0, REG_BINARY, test_val) + ret_val, ret_type = QueryValueEx(ck, "test_name") + self.assertEqual(ret_type, REG_BINARY) + self.assertEqual(ret_val, test_val) + finally: + DeleteKey(HKEY_CURRENT_USER, test_key_name) + + def test_read_string_containing_null(self): + # Test for issue 25778: REG_SZ should not contain null characters + try: + with CreateKey(HKEY_CURRENT_USER, test_key_name) as ck: + self.assertNotEqual(ck.handle, 0) + test_val = "A string\x00 with a null" + SetValueEx(ck, "test_name", 0, REG_SZ, test_val) + ret_val, ret_type = QueryValueEx(ck, "test_name") + self.assertEqual(ret_type, REG_SZ) + self.assertEqual(ret_val, "A string") + finally: + DeleteKey(HKEY_CURRENT_USER, test_key_name) + + +@unittest.skipUnless(REMOTE_NAME, "Skipping remote registry tests") +class RemoteWinregTests(BaseWinregTests): + + def test_remote_registry_works(self): + remote_key = ConnectRegistry(REMOTE_NAME, HKEY_CURRENT_USER) + self._test_all(remote_key) + + +@unittest.skipUnless(WIN64_MACHINE, "x64 specific registry tests") +class Win64WinregTests(BaseWinregTests): + + def test_named_arguments(self): + self._test_named_args(HKEY_CURRENT_USER, test_key_name) + # Clean up and also exercise the named arguments + DeleteKeyEx(key=HKEY_CURRENT_USER, sub_key=test_key_name, + access=KEY_ALL_ACCESS, reserved=0) + + @unittest.skipIf(win32_edition() in ('WindowsCoreHeadless', 'IoTEdgeOS'), "APIs not available on WindowsCoreHeadless") + def test_reflection_functions(self): + # Test that we can call the query, enable, and disable functions + # on a key which isn't on the reflection list with no consequences. + with OpenKey(HKEY_LOCAL_MACHINE, "Software") as key: + # HKLM\Software is redirected but not reflected in all OSes + self.assertTrue(QueryReflectionKey(key)) + self.assertIsNone(EnableReflectionKey(key)) + self.assertIsNone(DisableReflectionKey(key)) + self.assertTrue(QueryReflectionKey(key)) + + @unittest.skipUnless(HAS_REFLECTION, "OS doesn't support reflection") + def test_reflection(self): + # Test that we can create, open, and delete keys in the 32-bit + # area. Because we are doing this in a key which gets reflected, + # test the differences of 32 and 64-bit keys before and after the + # reflection occurs (ie. when the created key is closed). + try: + with CreateKeyEx(HKEY_CURRENT_USER, test_reflect_key_name, 0, + KEY_ALL_ACCESS | KEY_WOW64_32KEY) as created_key: + self.assertNotEqual(created_key.handle, 0) + + # The key should now be available in the 32-bit area + with OpenKey(HKEY_CURRENT_USER, test_reflect_key_name, 0, + KEY_ALL_ACCESS | KEY_WOW64_32KEY) as key: + self.assertNotEqual(key.handle, 0) + + # Write a value to what currently is only in the 32-bit area + SetValueEx(created_key, "", 0, REG_SZ, "32KEY") + + # The key is not reflected until created_key is closed. + # The 64-bit version of the key should not be available yet. + open_fail = lambda: OpenKey(HKEY_CURRENT_USER, + test_reflect_key_name, 0, + KEY_READ | KEY_WOW64_64KEY) + self.assertRaises(OSError, open_fail) + + # Now explicitly open the 64-bit version of the key + with OpenKey(HKEY_CURRENT_USER, test_reflect_key_name, 0, + KEY_ALL_ACCESS | KEY_WOW64_64KEY) as key: + self.assertNotEqual(key.handle, 0) + # Make sure the original value we set is there + self.assertEqual("32KEY", QueryValue(key, "")) + # Set a new value, which will get reflected to 32-bit + SetValueEx(key, "", 0, REG_SZ, "64KEY") + + # Reflection uses a "last-writer wins policy, so the value we set + # on the 64-bit key should be the same on 32-bit + with OpenKey(HKEY_CURRENT_USER, test_reflect_key_name, 0, + KEY_READ | KEY_WOW64_32KEY) as key: + self.assertEqual("64KEY", QueryValue(key, "")) + finally: + DeleteKeyEx(HKEY_CURRENT_USER, test_reflect_key_name, + KEY_WOW64_32KEY, 0) + + @unittest.skipUnless(HAS_REFLECTION, "OS doesn't support reflection") + def test_disable_reflection(self): + # Make use of a key which gets redirected and reflected + try: + with CreateKeyEx(HKEY_CURRENT_USER, test_reflect_key_name, 0, + KEY_ALL_ACCESS | KEY_WOW64_32KEY) as created_key: + # QueryReflectionKey returns whether or not the key is disabled + disabled = QueryReflectionKey(created_key) + self.assertEqual(type(disabled), bool) + # HKCU\Software\Classes is reflected by default + self.assertFalse(disabled) + + DisableReflectionKey(created_key) + self.assertTrue(QueryReflectionKey(created_key)) + + # The key is now closed and would normally be reflected to the + # 64-bit area, but let's make sure that didn't happen. + open_fail = lambda: OpenKeyEx(HKEY_CURRENT_USER, + test_reflect_key_name, 0, + KEY_READ | KEY_WOW64_64KEY) + self.assertRaises(OSError, open_fail) + + # Make sure the 32-bit key is actually there + with OpenKeyEx(HKEY_CURRENT_USER, test_reflect_key_name, 0, + KEY_READ | KEY_WOW64_32KEY) as key: + self.assertNotEqual(key.handle, 0) + finally: + DeleteKeyEx(HKEY_CURRENT_USER, test_reflect_key_name, + KEY_WOW64_32KEY, 0) + + def test_exception_numbers(self): + with self.assertRaises(FileNotFoundError) as ctx: + QueryValue(HKEY_CLASSES_ROOT, 'some_value_that_does_not_exist') + + +if __name__ == "__main__": + if not REMOTE_NAME: + print("Remote registry calls can be tested using", + "'test_winreg.py --remote \\\\machine_name'") + unittest.main() diff --git a/crates/vm/Cargo.toml b/crates/vm/Cargo.toml index 1e12454a6b..45c32239c9 100644 --- a/crates/vm/Cargo.toml +++ b/crates/vm/Cargo.toml @@ -111,7 +111,6 @@ num_cpus = "1.17.0" [target.'cfg(windows)'.dependencies] junction = { workspace = true } -winreg = "0.55" [target.'cfg(windows)'.dependencies.windows] version = "0.52.0" diff --git a/crates/vm/src/stdlib/winreg.rs b/crates/vm/src/stdlib/winreg.rs index d64a05ea40..396f60a2bc 100644 --- a/crates/vm/src/stdlib/winreg.rs +++ b/crates/vm/src/stdlib/winreg.rs @@ -4,39 +4,73 @@ use crate::{PyRef, VirtualMachine, builtins::PyModule}; pub(crate) fn make_module(vm: &VirtualMachine) -> PyRef { - let module = winreg::make_module(vm); + winreg::make_module(vm) +} - macro_rules! add_constants { - ($($name:ident),*$(,)?) => { - extend_module!(vm, &module, { - $((stringify!($name)) => vm.new_pyobj(::winreg::enums::$name as usize)),* - }) +#[pymodule] +mod winreg { + use crate::builtins::{PyInt, PyTuple, PyTypeRef}; + use crate::common::hash::PyHash; + use crate::common::windows::ToWideString; + use crate::convert::TryFromObject; + use crate::function::FuncArgs; + use crate::object::AsObject; + use crate::protocol::PyNumberMethods; + use crate::types::{AsNumber, Hashable}; + use crate::{Py, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine}; + use crossbeam_utils::atomic::AtomicCell; + use malachite_bigint::Sign; + use num_traits::ToPrimitive; + use std::ptr; + use windows_sys::Win32::Foundation::{self, ERROR_MORE_DATA}; + use windows_sys::Win32::System::Registry; + + /// Atomic HKEY handle type for lock-free thread-safe access + type AtomicHKEY = AtomicCell; + + /// Convert byte slice to UTF-16 slice (zero-copy when aligned) + fn bytes_as_wide_slice(bytes: &[u8]) -> &[u16] { + // SAFETY: Windows Registry API returns properly aligned UTF-16 data. + // align_to handles any edge cases safely by returning empty prefix/suffix + // if alignment doesn't match. + let (prefix, u16_slice, suffix) = unsafe { bytes.align_to::() }; + debug_assert!( + prefix.is_empty() && suffix.is_empty(), + "Registry data should be u16-aligned" + ); + u16_slice + } + + // TODO: check if errno.rs can be reused here or not + fn os_error_from_windows_code( + vm: &VirtualMachine, + code: i32, + func_name: &str, + ) -> crate::PyRef { + use windows_sys::Win32::Foundation::{ERROR_ACCESS_DENIED, ERROR_FILE_NOT_FOUND}; + let msg = format!("[WinError {}] {}", code, func_name); + let exc_type = match code as u32 { + ERROR_FILE_NOT_FOUND => vm.ctx.exceptions.file_not_found_error.to_owned(), + ERROR_ACCESS_DENIED => vm.ctx.exceptions.permission_error.to_owned(), + _ => vm.ctx.exceptions.os_error.to_owned(), }; + vm.new_exception_msg(exc_type, msg) } - add_constants!( - HKEY_CLASSES_ROOT, - HKEY_CURRENT_USER, - HKEY_LOCAL_MACHINE, - HKEY_USERS, - HKEY_PERFORMANCE_DATA, - HKEY_CURRENT_CONFIG, - HKEY_DYN_DATA, - ); - module -} + /// Wrapper type for HKEY that can be created from PyHkey or int + struct HKEYArg(Registry::HKEY); -#[pymodule] -mod winreg { - use crate::common::lock::{PyRwLock, PyRwLockReadGuard, PyRwLockWriteGuard}; - use crate::{ - PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, VirtualMachine, builtins::PyStrRef, - convert::ToPyException, - }; - use ::winreg::{RegKey, RegValue, enums::RegType}; - use std::mem::ManuallyDrop; - use std::{ffi::OsStr, io}; - use windows_sys::Win32::Foundation; + impl TryFromObject for HKEYArg { + fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult { + // Try PyHkey first + if let Some(hkey_obj) = obj.downcast_ref::() { + return Ok(HKEYArg(hkey_obj.hkey.load())); + } + // Then try int + let handle = usize::try_from_object(vm, obj)?; + Ok(HKEYArg(handle as Registry::HKEY)) + } + } // access rights #[pyattr] @@ -57,288 +91,1103 @@ mod winreg { REG_RESOURCE_REQUIREMENTS_LIST, REG_SZ, REG_WHOLE_HIVE_VOLATILE, }; + // Additional constants not in windows-sys + #[pyattr] + const REG_REFRESH_HIVE: u32 = 0x00000002; + #[pyattr] + const REG_NO_LAZY_FLUSH: u32 = 0x00000004; + // REG_LEGAL_OPTION is a mask of all option flags + #[pyattr] + const REG_LEGAL_OPTION: u32 = Registry::REG_OPTION_RESERVED + | Registry::REG_OPTION_NON_VOLATILE + | Registry::REG_OPTION_VOLATILE + | Registry::REG_OPTION_CREATE_LINK + | Registry::REG_OPTION_BACKUP_RESTORE + | Registry::REG_OPTION_OPEN_LINK; + // REG_LEGAL_CHANGE_FILTER is a mask of all notify flags + #[pyattr] + const REG_LEGAL_CHANGE_FILTER: u32 = Registry::REG_NOTIFY_CHANGE_NAME + | Registry::REG_NOTIFY_CHANGE_ATTRIBUTES + | Registry::REG_NOTIFY_CHANGE_LAST_SET + | Registry::REG_NOTIFY_CHANGE_SECURITY; + + // error is an alias for OSError (for backwards compatibility) + #[pyattr] + fn error(vm: &VirtualMachine) -> PyTypeRef { + vm.ctx.exceptions.os_error.to_owned() + } + + #[pyattr(once)] + fn HKEY_CLASSES_ROOT(vm: &VirtualMachine) -> PyRef { + PyHkey::new(Registry::HKEY_CLASSES_ROOT).into_ref(&vm.ctx) + } + + #[pyattr(once)] + fn HKEY_CURRENT_USER(vm: &VirtualMachine) -> PyRef { + PyHkey::new(Registry::HKEY_CURRENT_USER).into_ref(&vm.ctx) + } + + #[pyattr(once)] + fn HKEY_LOCAL_MACHINE(vm: &VirtualMachine) -> PyRef { + PyHkey::new(Registry::HKEY_LOCAL_MACHINE).into_ref(&vm.ctx) + } + + #[pyattr(once)] + fn HKEY_USERS(vm: &VirtualMachine) -> PyRef { + PyHkey::new(Registry::HKEY_USERS).into_ref(&vm.ctx) + } + + #[pyattr(once)] + fn HKEY_PERFORMANCE_DATA(vm: &VirtualMachine) -> PyRef { + PyHkey::new(Registry::HKEY_PERFORMANCE_DATA).into_ref(&vm.ctx) + } + + #[pyattr(once)] + fn HKEY_CURRENT_CONFIG(vm: &VirtualMachine) -> PyRef { + PyHkey::new(Registry::HKEY_CURRENT_CONFIG).into_ref(&vm.ctx) + } + + #[pyattr(once)] + fn HKEY_DYN_DATA(vm: &VirtualMachine) -> PyRef { + PyHkey::new(Registry::HKEY_DYN_DATA).into_ref(&vm.ctx) + } + #[pyattr] - #[pyclass(module = "winreg", name = "HKEYType")] + #[pyclass(name = "HKEYType")] #[derive(Debug, PyPayload)] struct PyHkey { - key: PyRwLock, + hkey: AtomicHKEY, } - type PyHkeyRef = PyRef; - // TODO: fix this + unsafe impl Send for PyHkey {} unsafe impl Sync for PyHkey {} impl PyHkey { - fn new(key: RegKey) -> Self { + fn new(hkey: Registry::HKEY) -> Self { Self { - key: PyRwLock::new(key), + hkey: AtomicHKEY::new(hkey), } } - fn key(&self) -> PyRwLockReadGuard<'_, RegKey> { - self.key.read() + fn unary_fail(vm: &VirtualMachine) -> PyResult { + Err(vm.new_type_error(HKEY_ERR_MSG.to_owned())) + } + + fn binary_fail(vm: &VirtualMachine) -> PyResult { + Err(vm.new_type_error(HKEY_ERR_MSG.to_owned())) } - fn key_mut(&self) -> PyRwLockWriteGuard<'_, RegKey> { - self.key.write() + fn ternary_fail(vm: &VirtualMachine) -> PyResult { + Err(vm.new_type_error(HKEY_ERR_MSG.to_owned())) } } - #[pyclass] + #[pyclass(with(AsNumber, Hashable))] impl PyHkey { + #[pygetset] + fn handle(&self) -> usize { + self.hkey.load() as usize + } + #[pymethod] - fn Close(&self) { - let null_key = RegKey::predef(0 as ::winreg::HKEY); - let key = std::mem::replace(&mut *self.key_mut(), null_key); - drop(key); + fn Close(&self, vm: &VirtualMachine) -> PyResult<()> { + // Atomically swap the handle with null and get the old value + let old_hkey = self.hkey.swap(std::ptr::null_mut()); + // Already closed - silently succeed + if old_hkey.is_null() { + return Ok(()); + } + let res = unsafe { Registry::RegCloseKey(old_hkey) }; + if res == 0 { + Ok(()) + } else { + Err(vm.new_os_error(format!("RegCloseKey failed with error code: {res}"))) + } } + #[pymethod] - fn Detach(&self) -> usize { - let null_key = RegKey::predef(0 as ::winreg::HKEY); - let key = std::mem::replace(&mut *self.key_mut(), null_key); - let handle = key.raw_handle(); - std::mem::forget(key); - handle as usize + fn Detach(&self) -> PyResult { + // Atomically swap the handle with null and return the old value + let old_hkey = self.hkey.swap(std::ptr::null_mut()); + Ok(old_hkey as usize) } #[pymethod] fn __bool__(&self) -> bool { - !self.key().raw_handle().is_null() + !self.hkey.load().is_null() + } + + #[pymethod] + fn __enter__(zelf: PyRef, _vm: &VirtualMachine) -> PyResult> { + Ok(zelf) + } + + #[pymethod] + fn __exit__(zelf: PyRef, _args: FuncArgs, vm: &VirtualMachine) -> PyResult<()> { + zelf.Close(vm) } + #[pymethod] - fn __enter__(zelf: PyRef) -> PyRef { - zelf + fn __int__(&self) -> usize { + self.hkey.load() as usize } + #[pymethod] - fn __exit__(&self, _cls: PyObjectRef, _exc: PyObjectRef, _tb: PyObjectRef) { - self.Close(); + fn __str__(&self) -> String { + format!("", self.hkey.load()) } } - enum Hkey { - PyHkey(PyHkeyRef), - Constant(::winreg::HKEY), + impl Drop for PyHkey { + fn drop(&mut self) { + let hkey = self.hkey.swap(std::ptr::null_mut()); + if !hkey.is_null() { + unsafe { Registry::RegCloseKey(hkey) }; + } + } } - impl TryFromObject for Hkey { - fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult { - obj.downcast().map(Self::PyHkey).or_else(|o| { - usize::try_from_object(vm, o).map(|i| Self::Constant(i as ::winreg::HKEY)) - }) + + impl Hashable for PyHkey { + // CPython uses PyObject_GenericHash which hashes the object's address + fn hash(zelf: &Py, _vm: &VirtualMachine) -> PyResult { + Ok(zelf.get_id() as PyHash) } } - impl Hkey { - fn with_key(&self, f: impl FnOnce(&RegKey) -> R) -> R { - match self { - Self::PyHkey(py) => f(&py.key()), - Self::Constant(hkey) => { - let k = ManuallyDrop::new(RegKey::predef(*hkey)); - f(&k) - } - } + + pub const HKEY_ERR_MSG: &str = "bad operand type"; + + impl AsNumber for PyHkey { + fn as_number() -> &'static PyNumberMethods { + static AS_NUMBER: PyNumberMethods = PyNumberMethods { + add: Some(|_a, _b, vm| PyHkey::binary_fail(vm)), + subtract: Some(|_a, _b, vm| PyHkey::binary_fail(vm)), + multiply: Some(|_a, _b, vm| PyHkey::binary_fail(vm)), + remainder: Some(|_a, _b, vm| PyHkey::binary_fail(vm)), + divmod: Some(|_a, _b, vm| PyHkey::binary_fail(vm)), + power: Some(|_a, _b, _c, vm| PyHkey::ternary_fail(vm)), + negative: Some(|_a, vm| PyHkey::unary_fail(vm)), + positive: Some(|_a, vm| PyHkey::unary_fail(vm)), + absolute: Some(|_a, vm| PyHkey::unary_fail(vm)), + boolean: Some(|a, vm| { + if let Some(a) = a.downcast_ref::() { + Ok(a.__bool__()) + } else { + PyHkey::unary_fail(vm)?; + unreachable!() + } + }), + invert: Some(|_a, vm| PyHkey::unary_fail(vm)), + lshift: Some(|_a, _b, vm| PyHkey::binary_fail(vm)), + rshift: Some(|_a, _b, vm| PyHkey::binary_fail(vm)), + and: Some(|_a, _b, vm| PyHkey::binary_fail(vm)), + xor: Some(|_a, _b, vm| PyHkey::binary_fail(vm)), + or: Some(|_a, _b, vm| PyHkey::binary_fail(vm)), + int: Some(|a, vm| { + if let Some(a) = a.downcast_ref::() { + Ok(vm.new_pyobj(a.__int__())) + } else { + PyHkey::unary_fail(vm)?; + unreachable!() + } + }), + float: Some(|_a, vm| PyHkey::unary_fail(vm)), + ..PyNumberMethods::NOT_IMPLEMENTED + }; + &AS_NUMBER } - fn into_key(self) -> RegKey { - let k = match self { - Self::PyHkey(py) => py.key().raw_handle(), - Self::Constant(k) => k, + } + + #[pyfunction] + fn ConnectRegistry( + computer_name: Option, + key: PyRef, + vm: &VirtualMachine, + ) -> PyResult { + if let Some(computer_name) = computer_name { + let mut ret_key = std::ptr::null_mut(); + let wide_computer_name = computer_name.to_wide_with_nul(); + let res = unsafe { + Registry::RegConnectRegistryW( + wide_computer_name.as_ptr(), + key.hkey.load(), + &mut ret_key, + ) + }; + if res == 0 { + Ok(PyHkey::new(ret_key)) + } else { + Err(vm.new_os_error(format!("error code: {res}"))) + } + } else { + let mut ret_key = std::ptr::null_mut(); + let res = unsafe { + Registry::RegConnectRegistryW(std::ptr::null_mut(), key.hkey.load(), &mut ret_key) }; - RegKey::predef(k) + if res == 0 { + Ok(PyHkey::new(ret_key)) + } else { + Err(vm.new_os_error(format!("error code: {res}"))) + } } } - #[derive(FromArgs)] - struct OpenKeyArgs { - key: Hkey, - sub_key: Option, + #[pyfunction] + fn CreateKey(key: PyRef, sub_key: String, vm: &VirtualMachine) -> PyResult { + let wide_sub_key = sub_key.to_wide_with_nul(); + let mut out_key = std::ptr::null_mut(); + let res = unsafe { + Registry::RegCreateKeyW(key.hkey.load(), wide_sub_key.as_ptr(), &mut out_key) + }; + if res == 0 { + Ok(PyHkey::new(out_key)) + } else { + Err(vm.new_os_error(format!("error code: {res}"))) + } + } + + #[derive(FromArgs, Debug)] + struct CreateKeyExArgs { + #[pyarg(any)] + key: PyRef, + #[pyarg(any)] + sub_key: String, #[pyarg(any, default = 0)] - reserved: i32, - #[pyarg(any, default = ::winreg::enums::KEY_READ)] + reserved: u32, + #[pyarg(any, default = windows_sys::Win32::System::Registry::KEY_WRITE)] access: u32, } - #[pyfunction(name = "OpenKeyEx")] #[pyfunction] - fn OpenKey(args: OpenKeyArgs, vm: &VirtualMachine) -> PyResult { - let OpenKeyArgs { - key, - sub_key, - reserved, - access, - } = args; + fn CreateKeyEx(args: CreateKeyExArgs, vm: &VirtualMachine) -> PyResult { + let wide_sub_key = args.sub_key.to_wide_with_nul(); + let mut res: Registry::HKEY = core::ptr::null_mut(); + let err = unsafe { + let key = args.key.hkey.load(); + Registry::RegCreateKeyExW( + key, + wide_sub_key.as_ptr(), + args.reserved, + std::ptr::null(), + Registry::REG_OPTION_NON_VOLATILE, + args.access, + std::ptr::null(), + &mut res, + std::ptr::null_mut(), + ) + }; + if err == 0 { + Ok(PyHkey { + #[allow(clippy::arc_with_non_send_sync)] + hkey: AtomicHKEY::new(res), + }) + } else { + Err(vm.new_os_error(format!("error code: {err}"))) + } + } + + #[pyfunction] + fn CloseKey(hkey: PyRef, vm: &VirtualMachine) -> PyResult<()> { + hkey.Close(vm) + } - if reserved != 0 { - // RegKey::open_subkey* doesn't have a reserved param, so this'll do - return Err(vm.new_value_error("reserved param must be 0")); + #[pyfunction] + fn DeleteKey(key: PyRef, sub_key: String, vm: &VirtualMachine) -> PyResult<()> { + let wide_sub_key = sub_key.to_wide_with_nul(); + let res = unsafe { Registry::RegDeleteKeyW(key.hkey.load(), wide_sub_key.as_ptr()) }; + if res == 0 { + Ok(()) + } else { + Err(vm.new_os_error(format!("error code: {res}"))) } + } - let sub_key = sub_key.as_ref().map_or("", |s| s.as_str()); - let key = key - .with_key(|k| k.open_subkey_with_flags(sub_key, access)) - .map_err(|e| e.to_pyexception(vm))?; + #[pyfunction] + fn DeleteValue(key: PyRef, value: Option, vm: &VirtualMachine) -> PyResult<()> { + let wide_value = value.map(|v| v.to_wide_with_nul()); + let value_ptr = wide_value.as_ref().map_or(std::ptr::null(), |v| v.as_ptr()); + let res = unsafe { Registry::RegDeleteValueW(key.hkey.load(), value_ptr) }; + if res == 0 { + Ok(()) + } else { + Err(vm.new_os_error(format!("error code: {res}"))) + } + } - Ok(PyHkey::new(key)) + #[derive(FromArgs, Debug)] + struct DeleteKeyExArgs { + #[pyarg(any)] + key: PyRef, + #[pyarg(any)] + sub_key: String, + #[pyarg(any, default = windows_sys::Win32::System::Registry::KEY_WOW64_64KEY)] + access: u32, + #[pyarg(any, default = 0)] + reserved: u32, } #[pyfunction] - fn QueryValue(key: Hkey, subkey: Option, vm: &VirtualMachine) -> PyResult { - let subkey = subkey.as_ref().map_or("", |s| s.as_str()); - key.with_key(|k| k.get_value(subkey)) - .map_err(|e| e.to_pyexception(vm)) + fn DeleteKeyEx(args: DeleteKeyExArgs, vm: &VirtualMachine) -> PyResult<()> { + let wide_sub_key = args.sub_key.to_wide_with_nul(); + let res = unsafe { + Registry::RegDeleteKeyExW( + args.key.hkey.load(), + wide_sub_key.as_ptr(), + args.access, + args.reserved, + ) + }; + if res == 0 { + Ok(()) + } else { + Err(vm.new_os_error(format!("error code: {res}"))) + } } #[pyfunction] - fn QueryValueEx( - key: Hkey, - subkey: Option, - vm: &VirtualMachine, - ) -> PyResult<(PyObjectRef, usize)> { - let subkey = subkey.as_ref().map_or("", |s| s.as_str()); - let regval = key - .with_key(|k| k.get_raw_value(subkey)) - .map_err(|e| e.to_pyexception(vm))?; - #[allow(clippy::redundant_clone)] - let ty = regval.vtype.clone() as usize; - Ok((reg_to_py(regval, vm)?, ty)) + fn EnumKey(key: PyRef, index: i32, vm: &VirtualMachine) -> PyResult { + // The Windows docs claim that the max key name length is 255 + // characters, plus a terminating nul character. However, + // empirical testing demonstrates that it is possible to + // create a 256 character key that is missing the terminating + // nul. RegEnumKeyEx requires a 257 character buffer to + // retrieve such a key name. + let mut tmpbuf = [0u16; 257]; + let mut len = tmpbuf.len() as u32; + let res = unsafe { + Registry::RegEnumKeyExW( + key.hkey.load(), + index as u32, + tmpbuf.as_mut_ptr(), + &mut len, + std::ptr::null_mut(), + std::ptr::null_mut(), + std::ptr::null_mut(), + std::ptr::null_mut(), + ) + }; + if res != 0 { + return Err(vm.new_os_error(format!("error code: {res}"))); + } + String::from_utf16(&tmpbuf[..len as usize]) + .map_err(|e| vm.new_value_error(format!("UTF16 error: {e}"))) } #[pyfunction] - fn EnumKey(key: Hkey, index: u32, vm: &VirtualMachine) -> PyResult { - key.with_key(|k| k.enum_keys().nth(index as usize)) - .unwrap_or_else(|| { - Err(io::Error::from_raw_os_error( - Foundation::ERROR_NO_MORE_ITEMS as i32, - )) - }) - .map_err(|e| e.to_pyexception(vm)) + fn EnumValue(hkey: PyRef, index: u32, vm: &VirtualMachine) -> PyResult { + // Query registry for the required buffer sizes. + let mut ret_value_size: u32 = 0; + let mut ret_data_size: u32 = 0; + let hkey: Registry::HKEY = hkey.hkey.load(); + let rc = unsafe { + Registry::RegQueryInfoKeyW( + hkey, + ptr::null_mut(), + ptr::null_mut(), + ptr::null_mut(), + ptr::null_mut(), + ptr::null_mut(), + ptr::null_mut(), + ptr::null_mut(), + &mut ret_value_size as *mut u32, + &mut ret_data_size as *mut u32, + ptr::null_mut(), + ptr::null_mut(), + ) + }; + if rc != 0 { + return Err(vm.new_os_error(format!("RegQueryInfoKeyW failed with error code {rc}"))); + } + + // Include room for null terminators. + ret_value_size += 1; + ret_data_size += 1; + let mut buf_value_size = ret_value_size; + let mut buf_data_size = ret_data_size; + + // Allocate buffers. + let mut ret_value_buf: Vec = vec![0; ret_value_size as usize]; + let mut ret_data_buf: Vec = vec![0; ret_data_size as usize]; + + // Loop to enumerate the registry value. + loop { + let mut current_value_size = ret_value_size; + let mut current_data_size = ret_data_size; + let mut reg_type: u32 = 0; + let rc = unsafe { + Registry::RegEnumValueW( + hkey, + index, + ret_value_buf.as_mut_ptr(), + &mut current_value_size as *mut u32, + ptr::null_mut(), + &mut reg_type as *mut u32, + ret_data_buf.as_mut_ptr(), + &mut current_data_size as *mut u32, + ) + }; + if rc == ERROR_MORE_DATA { + // Double the buffer sizes. + buf_data_size *= 2; + buf_value_size *= 2; + ret_data_buf.resize(buf_data_size as usize, 0); + ret_value_buf.resize(buf_value_size as usize, 0); + // Reset sizes for next iteration. + ret_value_size = buf_value_size; + ret_data_size = buf_data_size; + continue; + } + if rc != 0 { + return Err(vm.new_os_error(format!("RegEnumValueW failed with error code {rc}"))); + } + + // Convert the registry value name from UTF‑16. + let name_len = ret_value_buf + .iter() + .position(|&c| c == 0) + .unwrap_or(ret_value_buf.len()); + let name = String::from_utf16(&ret_value_buf[..name_len]) + .map_err(|e| vm.new_value_error(format!("UTF16 conversion error: {e}")))?; + + // Slice the data buffer to the actual size returned. + let data_slice = &ret_data_buf[..current_data_size as usize]; + let py_data = reg_to_py(vm, data_slice, reg_type)?; + + // Return tuple (value_name, data, type) + return Ok(vm + .ctx + .new_tuple(vec![ + vm.ctx.new_str(name).into(), + py_data, + vm.ctx.new_int(reg_type).into(), + ]) + .into()); + } + } + + #[pyfunction] + fn FlushKey(key: PyRef, vm: &VirtualMachine) -> PyResult<()> { + let res = unsafe { Registry::RegFlushKey(key.hkey.load()) }; + if res == 0 { + Ok(()) + } else { + Err(vm.new_os_error(format!("error code: {res}"))) + } } #[pyfunction] - fn EnumValue( - key: Hkey, - index: u32, + fn LoadKey( + key: PyRef, + sub_key: String, + file_name: String, vm: &VirtualMachine, - ) -> PyResult<(String, PyObjectRef, usize)> { - let (name, value) = key - .with_key(|k| k.enum_values().nth(index as usize)) - .unwrap_or_else(|| { - Err(io::Error::from_raw_os_error( - Foundation::ERROR_NO_MORE_ITEMS as i32, - )) + ) -> PyResult<()> { + let sub_key = sub_key.to_wide_with_nul(); + let file_name = file_name.to_wide_with_nul(); + let res = + unsafe { Registry::RegLoadKeyW(key.hkey.load(), sub_key.as_ptr(), file_name.as_ptr()) }; + if res == 0 { + Ok(()) + } else { + Err(vm.new_os_error(format!("error code: {res}"))) + } + } + + #[derive(Debug, FromArgs)] + struct OpenKeyArgs { + #[pyarg(any)] + key: PyRef, + #[pyarg(any)] + sub_key: String, + #[pyarg(any, default = 0)] + reserved: u32, + #[pyarg(any, default = windows_sys::Win32::System::Registry::KEY_READ)] + access: u32, + } + + #[pyfunction] + #[pyfunction(name = "OpenKeyEx")] + fn OpenKey(args: OpenKeyArgs, vm: &VirtualMachine) -> PyResult { + let wide_sub_key = args.sub_key.to_wide_with_nul(); + let mut res: Registry::HKEY = std::ptr::null_mut(); + let err = unsafe { + let key = args.key.hkey.load(); + Registry::RegOpenKeyExW( + key, + wide_sub_key.as_ptr(), + args.reserved, + args.access, + &mut res, + ) + }; + if err == 0 { + Ok(PyHkey { + #[allow(clippy::arc_with_non_send_sync)] + hkey: AtomicHKEY::new(res), }) - .map_err(|e| e.to_pyexception(vm))?; - #[allow(clippy::redundant_clone)] - let ty = value.vtype.clone() as usize; - Ok((name, reg_to_py(value, vm)?, ty)) + } else { + Err(os_error_from_windows_code(vm, err as i32, "RegOpenKeyEx")) + } + } + + #[pyfunction] + fn QueryInfoKey(key: HKEYArg, vm: &VirtualMachine) -> PyResult> { + let key = key.0; + let mut lpcsubkeys: u32 = 0; + let mut lpcvalues: u32 = 0; + let mut lpftlastwritetime: Foundation::FILETIME = unsafe { std::mem::zeroed() }; + let err = unsafe { + Registry::RegQueryInfoKeyW( + key, + std::ptr::null_mut(), + std::ptr::null_mut(), + 0 as _, + &mut lpcsubkeys, + std::ptr::null_mut(), + std::ptr::null_mut(), + &mut lpcvalues, + std::ptr::null_mut(), + std::ptr::null_mut(), + std::ptr::null_mut(), + &mut lpftlastwritetime, + ) + }; + + if err != 0 { + return Err(vm.new_os_error(format!("error code: {err}"))); + } + let l: u64 = (lpftlastwritetime.dwHighDateTime as u64) << 32 + | lpftlastwritetime.dwLowDateTime as u64; + let tup: Vec = vec![ + vm.ctx.new_int(lpcsubkeys).into(), + vm.ctx.new_int(lpcvalues).into(), + vm.ctx.new_int(l).into(), + ]; + Ok(vm.ctx.new_tuple(tup)) } #[pyfunction] - fn CloseKey(key: Hkey) { - match key { - Hkey::PyHkey(py) => py.Close(), - Hkey::Constant(hkey) => drop(RegKey::predef(hkey)), + fn QueryValue(key: HKEYArg, sub_key: Option, vm: &VirtualMachine) -> PyResult { + let hkey = key.0; + + if hkey == Registry::HKEY_PERFORMANCE_DATA { + return Err(os_error_from_windows_code( + vm, + Foundation::ERROR_INVALID_HANDLE as i32, + "RegQueryValue", + )); + } + + // Open subkey if provided and non-empty + let child_key = if let Some(ref sk) = sub_key { + if !sk.is_empty() { + let wide_sub_key = sk.to_wide_with_nul(); + let mut out_key = std::ptr::null_mut(); + let res = unsafe { + Registry::RegOpenKeyExW( + hkey, + wide_sub_key.as_ptr(), + 0, + Registry::KEY_QUERY_VALUE, + &mut out_key, + ) + }; + if res != 0 { + return Err(os_error_from_windows_code(vm, res as i32, "RegOpenKeyEx")); + } + Some(out_key) + } else { + None + } + } else { + None + }; + + let target_key = child_key.unwrap_or(hkey); + let mut buf_size: u32 = 256; + let mut buffer: Vec = vec![0; buf_size as usize]; + let mut reg_type: u32 = 0; + + // Loop to handle ERROR_MORE_DATA + let result = loop { + let mut size = buf_size; + let res = unsafe { + Registry::RegQueryValueExW( + target_key, + std::ptr::null(), // NULL value name for default value + std::ptr::null_mut(), + &mut reg_type, + buffer.as_mut_ptr(), + &mut size, + ) + }; + if res == ERROR_MORE_DATA { + buf_size *= 2; + buffer.resize(buf_size as usize, 0); + continue; + } + if res == Foundation::ERROR_FILE_NOT_FOUND { + // Return empty string if there's no default value + break Ok(String::new()); + } + if res != 0 { + break Err(os_error_from_windows_code( + vm, + res as i32, + "RegQueryValueEx", + )); + } + if reg_type != Registry::REG_SZ { + break Err(os_error_from_windows_code( + vm, + Foundation::ERROR_INVALID_DATA as i32, + "RegQueryValue", + )); + } + + // Convert UTF-16 to String + let u16_slice = bytes_as_wide_slice(&buffer[..size as usize]); + let len = u16_slice + .iter() + .position(|&c| c == 0) + .unwrap_or(u16_slice.len()); + break String::from_utf16(&u16_slice[..len]) + .map_err(|e| vm.new_value_error(format!("UTF16 error: {e}"))); + }; + + // Close child key if we opened one + if let Some(ck) = child_key { + unsafe { Registry::RegCloseKey(ck) }; } + + result } #[pyfunction] - fn CreateKey(key: Hkey, subkey: Option, vm: &VirtualMachine) -> PyResult { - let k = match subkey { - Some(subkey) => { - let (k, _disp) = key - .with_key(|k| k.create_subkey(subkey.as_str())) - .map_err(|e| e.to_pyexception(vm))?; - k + fn QueryValueEx(key: HKEYArg, name: String, vm: &VirtualMachine) -> PyResult> { + let hkey = key.0; + let wide_name = name.to_wide_with_nul(); + let mut buf_size: u32 = 0; + let res = unsafe { + Registry::RegQueryValueExW( + hkey, + wide_name.as_ptr(), + std::ptr::null_mut(), + std::ptr::null_mut(), + std::ptr::null_mut(), + &mut buf_size, + ) + }; + // Handle ERROR_MORE_DATA by using a default buffer size + if res == ERROR_MORE_DATA || buf_size == 0 { + buf_size = 256; + } else if res != 0 { + return Err(os_error_from_windows_code( + vm, + res as i32, + "RegQueryValueEx", + )); + } + + let mut ret_buf = vec![0u8; buf_size as usize]; + let mut typ = 0; + let mut ret_size: u32; + + // Loop to handle ERROR_MORE_DATA + loop { + ret_size = buf_size; + let res = unsafe { + Registry::RegQueryValueExW( + hkey, + wide_name.as_ptr(), + std::ptr::null_mut(), + &mut typ, + ret_buf.as_mut_ptr(), + &mut ret_size, + ) + }; + + if res != ERROR_MORE_DATA { + if res != 0 { + return Err(os_error_from_windows_code( + vm, + res as i32, + "RegQueryValueEx", + )); + } + break; } - None => key.into_key(), + + // Double buffer size and retry + buf_size *= 2; + ret_buf.resize(buf_size as usize, 0); + } + + // Only pass the bytes actually returned by the API + let obj = reg_to_py(vm, &ret_buf[..ret_size as usize], typ)?; + // Return tuple (value, type) + Ok(vm.ctx.new_tuple(vec![obj, vm.ctx.new_int(typ).into()])) + } + + #[pyfunction] + fn SaveKey(key: PyRef, file_name: String, vm: &VirtualMachine) -> PyResult<()> { + let file_name = file_name.to_wide_with_nul(); + let res = unsafe { + Registry::RegSaveKeyW(key.hkey.load(), file_name.as_ptr(), std::ptr::null_mut()) }; - Ok(PyHkey::new(k)) + if res == 0 { + Ok(()) + } else { + Err(vm.new_os_error(format!("error code: {res}"))) + } } #[pyfunction] fn SetValue( - key: Hkey, - subkey: Option, + key: PyRef, + sub_key: String, typ: u32, - value: PyStrRef, + value: String, vm: &VirtualMachine, ) -> PyResult<()> { - if typ != REG_SZ { + if typ != Registry::REG_SZ { return Err(vm.new_type_error("type must be winreg.REG_SZ")); } - let subkey = subkey.as_ref().map_or("", |s| s.as_str()); - key.with_key(|k| k.set_value(subkey, &OsStr::new(value.as_str()))) - .map_err(|e| e.to_pyexception(vm)) - } - #[pyfunction] - fn DeleteKey(key: Hkey, subkey: PyStrRef, vm: &VirtualMachine) -> PyResult<()> { - key.with_key(|k| k.delete_subkey(subkey.as_str())) - .map_err(|e| e.to_pyexception(vm)) + let hkey = key.hkey.load(); + if hkey == Registry::HKEY_PERFORMANCE_DATA { + return Err(os_error_from_windows_code( + vm, + Foundation::ERROR_INVALID_HANDLE as i32, + "RegSetValue", + )); + } + + // Create subkey if sub_key is non-empty + let child_key = if !sub_key.is_empty() { + let wide_sub_key = sub_key.to_wide_with_nul(); + let mut out_key = std::ptr::null_mut(); + let res = unsafe { + Registry::RegCreateKeyExW( + hkey, + wide_sub_key.as_ptr(), + 0, + std::ptr::null(), + 0, + Registry::KEY_SET_VALUE, + std::ptr::null(), + &mut out_key, + std::ptr::null_mut(), + ) + }; + if res != 0 { + return Err(os_error_from_windows_code(vm, res as i32, "RegCreateKeyEx")); + } + Some(out_key) + } else { + None + }; + + let target_key = child_key.unwrap_or(hkey); + // Convert value to UTF-16 for Wide API + let wide_value = value.to_wide_with_nul(); + let res = unsafe { + Registry::RegSetValueExW( + target_key, + std::ptr::null(), // value name is NULL + 0, + typ, + wide_value.as_ptr() as *const u8, + (wide_value.len() * 2) as u32, // byte count + ) + }; + + // Close child key if we created one + if let Some(ck) = child_key { + unsafe { Registry::RegCloseKey(ck) }; + } + + if res == 0 { + Ok(()) + } else { + Err(os_error_from_windows_code(vm, res as i32, "RegSetValueEx")) + } } - fn reg_to_py(value: RegValue, vm: &VirtualMachine) -> PyResult { - macro_rules! bytes_to_int { - ($int:ident, $f:ident, $name:ident) => {{ - let i = if value.bytes.is_empty() { - Ok(0 as $int) + fn reg_to_py(vm: &VirtualMachine, ret_data: &[u8], typ: u32) -> PyResult { + match typ { + REG_DWORD => { + // If there isn’t enough data, return 0. + if ret_data.len() < std::mem::size_of::() { + Ok(vm.ctx.new_int(0).into()) } else { - (&*value.bytes).try_into().map($int::$f).map_err(|_| { - vm.new_value_error(format!("{} value is wrong length", stringify!(name))) - }) - }; - i.map(|i| vm.ctx.new_int(i).into()) - }}; - } - let bytes_to_wide = |b| { - if <[u8]>::len(b) % 2 == 0 { - let (pref, wide, suf) = unsafe { <[u8]>::align_to::(b) }; - assert!(pref.is_empty() && suf.is_empty(), "wide slice is unaligned"); - Some(wide) - } else { - None + let val = u32::from_ne_bytes(ret_data[..4].try_into().unwrap()); + Ok(vm.ctx.new_int(val).into()) + } } - }; - match value.vtype { - RegType::REG_DWORD => bytes_to_int!(u32, from_ne_bytes, REG_DWORD), - RegType::REG_DWORD_BIG_ENDIAN => { - bytes_to_int!(u32, from_be_bytes, REG_DWORD_BIG_ENDIAN) + REG_QWORD => { + if ret_data.len() < std::mem::size_of::() { + Ok(vm.ctx.new_int(0).into()) + } else { + let val = u64::from_ne_bytes(ret_data[..8].try_into().unwrap()); + Ok(vm.ctx.new_int(val).into()) + } } - RegType::REG_QWORD => bytes_to_int!(u64, from_ne_bytes, REG_DWORD), - // RegType::REG_QWORD_BIG_ENDIAN => bytes_to_int!(u64, from_be_bytes, REG_DWORD_BIG_ENDIAN), - RegType::REG_SZ | RegType::REG_EXPAND_SZ => { - let wide_slice = bytes_to_wide(&value.bytes).ok_or_else(|| { - vm.new_value_error("REG_SZ string doesn't have an even byte length") - })?; - let nul_pos = wide_slice + REG_SZ | REG_EXPAND_SZ => { + let u16_slice = bytes_as_wide_slice(ret_data); + // Only use characters up to the first NUL. + let len = u16_slice .iter() - .position(|w| *w == 0) - .unwrap_or(wide_slice.len()); - let s = String::from_utf16_lossy(&wide_slice[..nul_pos]); + .position(|&c| c == 0) + .unwrap_or(u16_slice.len()); + let s = String::from_utf16(&u16_slice[..len]) + .map_err(|e| vm.new_value_error(format!("UTF16 error: {e}")))?; Ok(vm.ctx.new_str(s).into()) } - RegType::REG_MULTI_SZ => { - if value.bytes.is_empty() { - return Ok(vm.ctx.new_list(vec![]).into()); - } - let wide_slice = bytes_to_wide(&value.bytes).ok_or_else(|| { - vm.new_value_error("REG_MULTI_SZ string doesn't have an even byte length") - })?; - let wide_slice = if let Some((0, rest)) = wide_slice.split_last() { - rest + REG_MULTI_SZ => { + if ret_data.is_empty() { + Ok(vm.ctx.new_list(vec![]).into()) } else { - wide_slice - }; - let strings = wide_slice - .split(|c| *c == 0) - .map(|s| vm.new_pyobj(String::from_utf16_lossy(s))) - .collect(); - Ok(vm.ctx.new_list(strings).into()) + let u16_slice = bytes_as_wide_slice(ret_data); + let u16_count = u16_slice.len(); + + // Remove trailing null if present (like countStrings) + let len = if u16_count > 0 && u16_slice[u16_count - 1] == 0 { + u16_count - 1 + } else { + u16_count + }; + + let mut strings: Vec = Vec::new(); + let mut start = 0; + for i in 0..len { + if u16_slice[i] == 0 { + let s = String::from_utf16(&u16_slice[start..i]) + .map_err(|e| vm.new_value_error(format!("UTF16 error: {e}")))?; + strings.push(vm.ctx.new_str(s).into()); + start = i + 1; + } + } + // Handle last string if not null-terminated + if start < len { + let s = String::from_utf16(&u16_slice[start..len]) + .map_err(|e| vm.new_value_error(format!("UTF16 error: {e}")))?; + strings.push(vm.ctx.new_str(s).into()); + } + Ok(vm.ctx.new_list(strings).into()) + } } + // For REG_BINARY and any other unknown types, return a bytes object if data exists. _ => { - if value.bytes.is_empty() { + if ret_data.is_empty() { Ok(vm.ctx.none()) } else { - Ok(vm.ctx.new_bytes(value.bytes).into()) + Ok(vm.ctx.new_bytes(ret_data.to_vec()).into()) + } + } + } + } + + fn py2reg(value: PyObjectRef, typ: u32, vm: &VirtualMachine) -> PyResult>> { + match typ { + REG_DWORD => { + if vm.is_none(&value) { + return Ok(Some(0u32.to_le_bytes().to_vec())); + } + let val = value + .downcast_ref::() + .ok_or_else(|| vm.new_type_error("value must be an integer".to_string()))?; + let bigint = val.as_bigint(); + // Check for negative value - raise OverflowError + if bigint.sign() == Sign::Minus { + return Err(vm.new_overflow_error("int too big to convert".to_string())); + } + let val = bigint + .to_u32() + .ok_or_else(|| vm.new_overflow_error("int too big to convert".to_string()))?; + Ok(Some(val.to_le_bytes().to_vec())) + } + REG_QWORD => { + if vm.is_none(&value) { + return Ok(Some(0u64.to_le_bytes().to_vec())); } + let val = value + .downcast_ref::() + .ok_or_else(|| vm.new_type_error("value must be an integer".to_string()))?; + let bigint = val.as_bigint(); + // Check for negative value - raise OverflowError + if bigint.sign() == Sign::Minus { + return Err(vm.new_overflow_error("int too big to convert".to_string())); + } + let val = bigint + .to_u64() + .ok_or_else(|| vm.new_overflow_error("int too big to convert".to_string()))?; + Ok(Some(val.to_le_bytes().to_vec())) + } + REG_SZ | REG_EXPAND_SZ => { + if vm.is_none(&value) { + // Return empty string as UTF-16 null terminator + return Ok(Some(vec![0u8, 0u8])); + } + let s = value + .downcast::() + .map_err(|_| vm.new_type_error("value must be a string".to_string()))?; + let wide = s.as_str().to_wide_with_nul(); + // Convert Vec to Vec + let bytes: Vec = wide.iter().flat_map(|&c| c.to_le_bytes()).collect(); + Ok(Some(bytes)) + } + REG_MULTI_SZ => { + if vm.is_none(&value) { + // Empty list = double null terminator + return Ok(Some(vec![0u8, 0u8, 0u8, 0u8])); + } + let list = value.downcast::().map_err(|_| { + vm.new_type_error("value must be a list of strings".to_string()) + })?; + + let mut bytes: Vec = Vec::new(); + for item in list.borrow_vec().iter() { + let s = item + .downcast_ref::() + .ok_or_else(|| { + vm.new_type_error("list items must be strings".to_string()) + })?; + let wide = s.as_str().to_wide_with_nul(); + bytes.extend(wide.iter().flat_map(|&c| c.to_le_bytes())); + } + // Add final null terminator (double null at end) + bytes.extend([0u8, 0u8]); + Ok(Some(bytes)) + } + // REG_BINARY and other types + _ => { + if vm.is_none(&value) { + return Ok(None); + } + // Try to get bytes + if let Some(bytes) = value.downcast_ref::() { + return Ok(Some(bytes.as_bytes().to_vec())); + } + Err(vm.new_type_error(format!( + "Objects of type '{}' can not be used as binary registry values", + value.class().name() + ))) } } } + + #[pyfunction] + fn SetValueEx( + key: PyRef, + value_name: String, + _reserved: PyObjectRef, + typ: u32, + value: PyObjectRef, + vm: &VirtualMachine, + ) -> PyResult<()> { + match py2reg(value, typ, vm) { + Ok(Some(v)) => { + let len = v.len() as u32; + let ptr = v.as_ptr(); + let wide_value_name = value_name.to_wide_with_nul(); + let res = unsafe { + Registry::RegSetValueExW( + key.hkey.load(), + wide_value_name.as_ptr(), + 0, + typ, + ptr, + len, + ) + }; + if res != 0 { + return Err(vm.new_os_error(format!("error code: {res}"))); + } + } + Ok(None) => { + let len = 0; + let ptr = std::ptr::null(); + let wide_value_name = value_name.to_wide_with_nul(); + let res = unsafe { + Registry::RegSetValueExW( + key.hkey.load(), + wide_value_name.as_ptr(), + 0, + typ, + ptr, + len, + ) + }; + if res != 0 { + return Err(vm.new_os_error(format!("error code: {res}"))); + } + } + Err(e) => return Err(e), + } + Ok(()) + } + + #[pyfunction] + fn DisableReflectionKey(key: PyRef, vm: &VirtualMachine) -> PyResult<()> { + let res = unsafe { Registry::RegDisableReflectionKey(key.hkey.load()) }; + if res == 0 { + Ok(()) + } else { + Err(vm.new_os_error(format!("error code: {res}"))) + } + } + + #[pyfunction] + fn EnableReflectionKey(key: PyRef, vm: &VirtualMachine) -> PyResult<()> { + let res = unsafe { Registry::RegEnableReflectionKey(key.hkey.load()) }; + if res == 0 { + Ok(()) + } else { + Err(vm.new_os_error(format!("error code: {res}"))) + } + } + + #[pyfunction] + fn QueryReflectionKey(key: PyRef, vm: &VirtualMachine) -> PyResult { + let mut result: windows_sys::Win32::Foundation::BOOL = 0; + let res = unsafe { Registry::RegQueryReflectionKey(key.hkey.load(), &mut result) }; + if res == 0 { + Ok(result != 0) + } else { + Err(vm.new_os_error(format!("error code: {res}"))) + } + } + + #[pyfunction] + fn ExpandEnvironmentStrings(i: String, vm: &VirtualMachine) -> PyResult { + let wide_input = i.to_wide_with_nul(); + + // First call with size=0 to get required buffer size + let required_size = unsafe { + windows_sys::Win32::System::Environment::ExpandEnvironmentStringsW( + wide_input.as_ptr(), + std::ptr::null_mut(), + 0, + ) + }; + if required_size == 0 { + return Err(vm.new_os_error("ExpandEnvironmentStringsW failed".to_string())); + } + + // Allocate buffer with exact size and expand + let mut out = vec![0u16; required_size as usize]; + let r = unsafe { + windows_sys::Win32::System::Environment::ExpandEnvironmentStringsW( + wide_input.as_ptr(), + out.as_mut_ptr(), + required_size, + ) + }; + if r == 0 { + return Err(vm.new_os_error("ExpandEnvironmentStringsW failed".to_string())); + } + + let len = out.iter().position(|&c| c == 0).unwrap_or(out.len()); + String::from_utf16(&out[..len]).map_err(|e| vm.new_value_error(format!("UTF16 error: {e}"))) + } }