diff --git a/.cspell.dict/cpython.txt b/.cspell.dict/cpython.txt index 8acb1468f6..8ccd6d6b64 100644 --- a/.cspell.dict/cpython.txt +++ b/.cspell.dict/cpython.txt @@ -17,6 +17,7 @@ cmpop denom DICTFLAG dictoffset +distpoint elts excepthandler fileutils diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 8539db1a27..783e3dfa6b 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -147,6 +147,11 @@ jobs: - name: check compilation without threading run: cargo check ${{ env.CARGO_ARGS }} + - name: Test openssl build + run: + cargo build --no-default-features --features ssl-openssl + if: runner.os == 'Linux' + - name: Test example projects run: cargo run --manifest-path example_projects/barebone/Cargo.toml diff --git a/crates/stdlib/src/openssl.rs b/crates/stdlib/src/openssl.rs index b6d5f5c203..4d420e7d53 100644 --- a/crates/stdlib/src/openssl.rs +++ b/crates/stdlib/src/openssl.rs @@ -64,7 +64,7 @@ mod _ssl { }, socket::{self, PySocket}, vm::{ - AsObject, Py, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, + AsObject, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, builtins::{ PyBaseException, PyBaseExceptionRef, PyBytesRef, PyListRef, PyStrRef, PyType, PyWeak, @@ -73,8 +73,8 @@ mod _ssl { convert::ToPyException, exceptions, function::{ - ArgBytesLike, ArgCallable, ArgMemoryBuffer, ArgStrOrBytesLike, Either, FsPath, - OptionalArg, PyComparisonValue, + ArgBytesLike, ArgMemoryBuffer, ArgStrOrBytesLike, Either, FsPath, OptionalArg, + PyComparisonValue, }, types::{Comparable, Constructor, PyComparisonOp}, utils::ToCString, @@ -602,6 +602,39 @@ mod _ssl { } } + // Get or create an ex_data index for msg_callback data + fn get_msg_callback_ex_data_index() -> libc::c_int { + use std::sync::LazyLock; + static MSG_CB_EX_DATA_IDX: LazyLock = LazyLock::new(|| unsafe { + sys::SSL_get_ex_new_index( + 0, + std::ptr::null_mut(), + None, + None, + Some(msg_callback_data_free), + ) + }); + *MSG_CB_EX_DATA_IDX + } + + // Free function for msg_callback data - called by OpenSSL when SSL is freed + unsafe extern "C" fn msg_callback_data_free( + _parent: *mut libc::c_void, + ptr: *mut libc::c_void, + _ad: *mut sys::CRYPTO_EX_DATA, + _idx: libc::c_int, + _argl: libc::c_long, + _argp: *mut libc::c_void, + ) { + if !ptr.is_null() { + unsafe { + // Reconstruct PyObjectRef and drop to decrement reference count + let raw = std::ptr::NonNull::new_unchecked(ptr as *mut PyObject); + let _ = PyObjectRef::from_raw(raw); + } + } + } + // SNI callback function called by OpenSSL unsafe extern "C" fn _servername_callback( ssl_ptr: *mut sys::SSL, @@ -732,13 +765,14 @@ mod _ssl { } unsafe { - // Get SSL socket from SSL_get_ex_data (index 0) - let ssl_socket_ptr = sys::SSL_get_ex_data(ssl_ptr, 0); + // Get SSL socket from ex_data using the dedicated index + let idx = get_msg_callback_ex_data_index(); + let ssl_socket_ptr = sys::SSL_get_ex_data(ssl_ptr, idx); if ssl_socket_ptr.is_null() { return; } - // ssl_socket_ptr is a pointer to Py, set in _wrap_socket/_wrap_bio + // ssl_socket_ptr is a pointer to Box>, set in _wrap_socket/_wrap_bio let ssl_socket: &Py = &*(ssl_socket_ptr as *const Py); // Get the callback from the context @@ -849,6 +883,9 @@ mod _ssl { post_handshake_auth: PyMutex, sni_callback: PyMutex>, msg_callback: PyMutex>, + psk_client_callback: PyMutex>, + psk_server_callback: PyMutex>, + psk_identity_hint: PyMutex>, } impl fmt::Debug for PySslContext { @@ -960,6 +997,9 @@ mod _ssl { post_handshake_auth: PyMutex::new(false), sni_callback: PyMutex::new(None), msg_callback: PyMutex::new(None), + psk_client_callback: PyMutex::new(None), + psk_server_callback: PyMutex::new(None), + psk_identity_hint: PyMutex::new(None), }) } } @@ -1083,9 +1123,26 @@ mod _ssl { self.ctx.read().options().bits() as _ } #[pygetset(setter)] - fn set_options(&self, opts: libc::c_ulong) { - self.builder() - .set_options(SslOptions::from_bits_truncate(opts as _)); + fn set_options(&self, new_opts: libc::c_ulong) { + let mut ctx = self.builder(); + // Get current options + let current = ctx.options().bits() as libc::c_ulong; + + // Calculate options to clear and set + let clear = current & !new_opts; + let set = !current & new_opts; + + // Clear options first (using raw FFI since openssl crate doesn't expose clear_options) + if clear != 0 { + unsafe { + sys::SSL_CTX_clear_options(ctx.as_ptr(), clear); + } + } + + // Then set new options + if set != 0 { + ctx.set_options(SslOptions::from_bits_truncate(set as _)); + } } #[pygetset] fn protocol(&self) -> i32 { @@ -1312,7 +1369,7 @@ mod _ssl { ssl::select_next_proto(&server, client).ok_or(ssl::AlpnError::NOACK)?; let pos = memchr::memmem::find(client, proto) .expect("selected alpn proto should be present in client protos"); - Ok(&client[pos..proto.len()]) + Ok(&client[pos..pos + proto.len()]) }); Ok(()) } @@ -1324,6 +1381,78 @@ mod _ssl { } } + #[pymethod] + fn set_psk_client_callback( + &self, + callback: PyObjectRef, + vm: &VirtualMachine, + ) -> PyResult<()> { + // Cannot add PSK client callback to a server context + if self.protocol == SslVersion::TlsServer { + return Err(vm + .new_os_subtype_error( + PySSLError::class(&vm.ctx).to_owned(), + None, + "Cannot add PSK client callback to a PROTOCOL_TLS_SERVER context" + .to_owned(), + ) + .upcast()); + } + + if vm.is_none(&callback) { + *self.psk_client_callback.lock() = None; + unsafe { + sys::SSL_CTX_set_psk_client_callback(self.builder().as_ptr(), None); + } + } else { + if !callback.is_callable() { + return Err(vm.new_type_error("callback must be callable".to_owned())); + } + *self.psk_client_callback.lock() = Some(callback); + // Note: The actual callback will be invoked via SSL app_data mechanism + // when do_handshake is called + } + Ok(()) + } + + #[pymethod] + fn set_psk_server_callback( + &self, + callback: PyObjectRef, + identity_hint: OptionalArg, + vm: &VirtualMachine, + ) -> PyResult<()> { + // Cannot add PSK server callback to a client context + if self.protocol == SslVersion::TlsClient { + return Err(vm + .new_os_subtype_error( + PySSLError::class(&vm.ctx).to_owned(), + None, + "Cannot add PSK server callback to a PROTOCOL_TLS_CLIENT context" + .to_owned(), + ) + .upcast()); + } + + if vm.is_none(&callback) { + *self.psk_server_callback.lock() = None; + *self.psk_identity_hint.lock() = None; + unsafe { + sys::SSL_CTX_set_psk_server_callback(self.builder().as_ptr(), None); + } + } else { + if !callback.is_callable() { + return Err(vm.new_type_error("callback must be callable".to_owned())); + } + *self.psk_server_callback.lock() = Some(callback); + if let OptionalArg::Present(hint) = identity_hint { + *self.psk_identity_hint.lock() = Some(hint.as_str().to_owned()); + } + // Note: The actual callback will be invoked via SSL app_data mechanism + } + Ok(()) + } + #[pymethod] fn load_verify_locations( &self, @@ -1343,16 +1472,33 @@ mod _ssl { // validate cadata type and load cadata if let Some(cadata) = args.cadata { - let certs = match cadata { + let (certs, is_pem) = match cadata { Either::A(s) => { if !s.as_str().is_ascii() { return Err(invalid_cadata(vm)); } - X509::stack_from_pem(s.as_bytes()) + (X509::stack_from_pem(s.as_bytes()), true) } - Either::B(b) => b.with_ref(x509_stack_from_der), + Either::B(b) => (b.with_ref(x509_stack_from_der), false), }; let certs = certs.map_err(|e| convert_openssl_error(vm, e))?; + + // If no certificates were loaded, raise an error + if certs.is_empty() { + let msg = if is_pem { + "no start line: cadata does not contain a certificate" + } else { + "not enough data: cadata does not contain a certificate" + }; + return Err(vm + .new_os_subtype_error( + PySSLError::class(&vm.ctx).to_owned(), + None, + msg.to_owned(), + ) + .upcast()); + } + let store = ctx.cert_store_mut(); for cert in certs { store @@ -1365,27 +1511,27 @@ mod _ssl { let cafile_path = args.cafile.map(|p| p.to_path_buf(vm)).transpose()?; let capath_path = args.capath.map(|p| p.to_path_buf(vm)).transpose()?; // Check file/directory existence before calling OpenSSL to get proper errno - if let Some(ref path) = cafile_path { - if !path.exists() { - return Err(vm - .new_os_subtype_error( - vm.ctx.exceptions.file_not_found_error.to_owned(), - Some(libc::ENOENT), - format!("No such file or directory: '{}'", path.display()), - ) - .upcast()); - } + if let Some(ref path) = cafile_path + && !path.exists() + { + return Err(vm + .new_os_subtype_error( + vm.ctx.exceptions.file_not_found_error.to_owned(), + Some(libc::ENOENT), + format!("No such file or directory: '{}'", path.display()), + ) + .upcast()); } - if let Some(ref path) = capath_path { - if !path.exists() { - return Err(vm - .new_os_subtype_error( - vm.ctx.exceptions.file_not_found_error.to_owned(), - Some(libc::ENOENT), - format!("No such file or directory: '{}'", path.display()), - ) - .upcast()); - } + if let Some(ref path) = capath_path + && !path.exists() + { + return Err(vm + .new_os_subtype_error( + vm.ctx.exceptions.file_not_found_error.to_owned(), + Some(libc::ENOENT), + format!("No such file or directory: '{}'", path.display()), + ) + .upcast()); } ctx.load_verify_locations(cafile_path.as_deref(), capath_path.as_deref()) .map_err(|e| convert_openssl_error(vm, e))?; @@ -1450,7 +1596,8 @@ mod _ssl { X509_LU_X509 => { x509_count += 1; let x509_ptr = sys::X509_OBJECT_get0_X509(obj_ptr); - if !x509_ptr.is_null() && X509_check_ca(x509_ptr) == 1 { + // X509_check_ca returns non-zero for any CA type + if !x509_ptr.is_null() && X509_check_ca(x509_ptr) != 0 { ca_count += 1; } } @@ -1673,18 +1820,19 @@ mod _ssl { #[pymethod] fn load_cert_chain(&self, args: LoadCertChainArgs, vm: &VirtualMachine) -> PyResult<()> { + use openssl::pkey::PKey; + use std::cell::RefCell; + let LoadCertChainArgs { certfile, keyfile, password, } = args; - // TODO: requires passing a callback to C - if password.is_some() { - return Err(vm.new_not_implemented_error("password arg not yet supported")); - } + let mut ctx = self.builder(); let key_path = keyfile.map(|path| path.to_path_buf(vm)).transpose()?; let cert_path = certfile.to_path_buf(vm)?; + // Check file existence before calling OpenSSL to get proper errno if !cert_path.exists() { return Err(vm @@ -1695,28 +1843,139 @@ mod _ssl { ) .upcast()); } - if let Some(ref kp) = key_path { - if !kp.exists() { - return Err(vm - .new_os_subtype_error( - vm.ctx.exceptions.file_not_found_error.to_owned(), - Some(libc::ENOENT), - format!("No such file or directory: '{}'", kp.display()), - ) - .upcast()); - } + if let Some(ref kp) = key_path + && !kp.exists() + { + return Err(vm + .new_os_subtype_error( + vm.ctx.exceptions.file_not_found_error.to_owned(), + Some(libc::ENOENT), + format!("No such file or directory: '{}'", kp.display()), + ) + .upcast()); } + + // Load certificate chain ctx.set_certificate_chain_file(&cert_path) - .and_then(|()| { - ctx.set_private_key_file( - key_path.as_ref().unwrap_or(&cert_path), - ssl::SslFiletype::PEM, - ) - }) - .and_then(|()| ctx.check_private_key()) + .map_err(|e| convert_openssl_error(vm, e))?; + + // Load private key - handle password if provided + let key_file_path = key_path.as_ref().unwrap_or(&cert_path); + + // PEM_BUFSIZE = 1024 (maximum password length in OpenSSL) + const PEM_BUFSIZE: usize = 1024; + + // Read key file data + let key_data = std::fs::read(key_file_path) + .map_err(|e| crate::vm::convert::ToPyException::to_pyexception(&e, vm))?; + + let pkey = if let Some(ref pw_obj) = password { + if pw_obj.is_callable() { + // Callable password - use callback that calls Python function + // Store any Python error that occurs in the callback + let py_error: RefCell> = RefCell::new(None); + + let result = PKey::private_key_from_pem_callback(&key_data, |buf| { + // Call the Python password callback + let pw_result = pw_obj.call((), vm); + match pw_result { + Ok(result) => { + // Extract password bytes + match Self::extract_password_bytes( + &result, + "password callback must return a string", + vm, + ) { + Ok(pw) => { + // Check password length + if pw.len() > PEM_BUFSIZE { + *py_error.borrow_mut() = + Some(vm.new_value_error(format!( + "password cannot be longer than {} bytes", + PEM_BUFSIZE + ))); + return Err(openssl::error::ErrorStack::get()); + } + let len = std::cmp::min(pw.len(), buf.len()); + buf[..len].copy_from_slice(&pw[..len]); + Ok(len) + } + Err(e) => { + *py_error.borrow_mut() = Some(e); + Err(openssl::error::ErrorStack::get()) + } + } + } + Err(e) => { + *py_error.borrow_mut() = Some(e); + Err(openssl::error::ErrorStack::get()) + } + } + }); + + // Check for Python error first + if let Some(py_err) = py_error.into_inner() { + return Err(py_err); + } + + result.map_err(|e| convert_openssl_error(vm, e))? + } else { + // Direct password (string/bytes) + let pw = Self::extract_password_bytes( + pw_obj, + "password should be a string or bytes", + vm, + )?; + + // Check password length + if pw.len() > PEM_BUFSIZE { + return Err(vm.new_value_error(format!( + "password cannot be longer than {} bytes", + PEM_BUFSIZE + ))); + } + + PKey::private_key_from_pem_passphrase(&key_data, &pw) + .map_err(|e| convert_openssl_error(vm, e))? + } + } else { + // No password - use SSL_CTX_use_PrivateKey_file directly for correct error messages + ctx.set_private_key_file(key_file_path, ssl::SslFiletype::PEM) + .map_err(|e| convert_openssl_error(vm, e))?; + + // Verify key matches certificate and return early + return ctx + .check_private_key() + .map_err(|e| convert_openssl_error(vm, e)); + }; + + ctx.set_private_key(&pkey) + .map_err(|e| convert_openssl_error(vm, e))?; + + // Verify key matches certificate + ctx.check_private_key() .map_err(|e| convert_openssl_error(vm, e)) } + // Helper to extract password bytes from string/bytes/bytearray + fn extract_password_bytes( + obj: &PyObject, + bad_type_error: &str, + vm: &VirtualMachine, + ) -> PyResult> { + use crate::vm::builtins::{PyByteArray, PyBytes, PyStr}; + + if let Some(s) = obj.downcast_ref::() { + Ok(s.as_str().as_bytes().to_vec()) + } else if let Some(b) = obj.downcast_ref::() { + Ok(b.as_bytes().to_vec()) + } else if let Some(ba) = obj.downcast_ref::() { + Ok(ba.borrow_buf().to_vec()) + } else { + Err(vm.new_type_error(bad_type_error.to_owned())) + } + } + // Helper function to create SSL socket // = CPython's newPySSLSocket() fn new_py_ssl_socket( @@ -1855,11 +2114,12 @@ mod _ssl { unsafe { let ssl_ptr = py_ref.connection.read().ssl().as_ptr(); - // Store ssl_socket pointer in index 0 for msg_callback (like CPython's SSL_set_app_data) - // This is safe because ssl_socket owns the SSL object and outlives it - // We store a pointer to Py, which msg_callback can dereference - let py_ptr: *const Py = &*py_ref; - sys::SSL_set_ex_data(ssl_ptr, 0, py_ptr as *mut _); + // Clone and store via into_raw() - increments refcount and returns stable pointer + // The refcount will be decremented by msg_callback_data_free when SSL is freed + let cloned: PyObjectRef = py_ref.clone().into(); + let raw_ptr = cloned.into_raw(); + let msg_cb_idx = get_msg_callback_ex_data_index(); + sys::SSL_set_ex_data(ssl_ptr, msg_cb_idx, raw_ptr.as_ptr() as *mut _); // Set SNI callback data if needed if has_sni_callback { @@ -1869,8 +2129,8 @@ mod _ssl { ssl_context: zelf.clone(), ssl_socket_weak, }); - let idx = get_sni_ex_data_index(); - sys::SSL_set_ex_data(ssl_ptr, idx, Box::into_raw(callback_data) as *mut _); + let sni_idx = get_sni_ex_data_index(); + sys::SSL_set_ex_data(ssl_ptr, sni_idx, Box::into_raw(callback_data) as *mut _); } } @@ -1923,11 +2183,12 @@ mod _ssl { unsafe { let ssl_ptr = py_ref.connection.read().ssl().as_ptr(); - // Store ssl_socket pointer in index 0 for msg_callback (like CPython's SSL_set_app_data) - // This is safe because ssl_socket owns the SSL object and outlives it - // We store a pointer to Py, which msg_callback can dereference - let py_ptr: *const Py = &*py_ref; - sys::SSL_set_ex_data(ssl_ptr, 0, py_ptr as *mut _); + // Clone and store via into_raw() - increments refcount and returns stable pointer + // The refcount will be decremented by msg_callback_data_free when SSL is freed + let cloned: PyObjectRef = py_ref.clone().into(); + let raw_ptr = cloned.into_raw(); + let msg_cb_idx = get_msg_callback_ex_data_index(); + sys::SSL_set_ex_data(ssl_ptr, msg_cb_idx, raw_ptr.as_ptr() as *mut _); // Set SNI callback data if needed if has_sni_callback { @@ -1937,8 +2198,8 @@ mod _ssl { ssl_context: zelf.clone(), ssl_socket_weak, }); - let idx = get_sni_ex_data_index(); - sys::SSL_set_ex_data(ssl_ptr, idx, Box::into_raw(callback_data) as *mut _); + let sni_idx = get_sni_ex_data_index(); + sys::SSL_set_ex_data(ssl_ptr, sni_idx, Box::into_raw(callback_data) as *mut _); } } @@ -1995,7 +2256,7 @@ mod _ssl { #[pyarg(any, optional)] keyfile: Option, #[pyarg(any, optional)] - password: Option>, + password: Option, } // Err is true if the socket is blocking @@ -2004,7 +2265,6 @@ mod _ssl { enum SelectRet { Nonblocking, TimedOut, - IsBlocking, Closed, Ok, } @@ -2292,12 +2552,13 @@ mod _ssl { #[pymethod] fn get_unverified_chain(&self, vm: &VirtualMachine) -> PyResult> { let stream = self.connection.read(); - let Some(chain) = stream.ssl().peer_cert_chain() else { + let ssl = stream.ssl(); + let Some(chain) = ssl.peer_cert_chain() else { return Ok(None); }; // Return Certificate objects - let certs: Vec = chain + let mut certs: Vec = chain .iter() .map(|cert| unsafe { sys::X509_up_ref(cert.as_ptr()); @@ -2305,6 +2566,16 @@ mod _ssl { cert_to_certificate(vm, owned) }) .collect::>()?; + + // SSL_get_peer_cert_chain does not include peer cert for server-side sockets + // Add it manually at the beginning + if matches!(self.socket_type, SslServerOrClient::Server) + && let Some(peer_cert) = ssl.peer_certificate() + { + let peer_obj = cert_to_certificate(vm, peer_cert)?; + certs.insert(0, peer_obj); + } + Ok(Some(vm.ctx.new_list(certs))) } @@ -2652,7 +2923,7 @@ mod _ssl { return Err(socket_closed_error(vm)); } SelectRet::Nonblocking => {} - SelectRet::IsBlocking | SelectRet::Ok => { + SelectRet::Ok => { // For blocking sockets, select() has completed successfully // Continue the handshake loop (matches CPython's SOCKET_IS_BLOCKING behavior) if needs.is_some() { @@ -2719,7 +2990,7 @@ mod _ssl { } SelectRet::Closed => return Err(socket_closed_error(vm)), SelectRet::Nonblocking => {} - SelectRet::IsBlocking | SelectRet::Ok => { + SelectRet::Ok => { // For blocking sockets, select() has completed successfully // Continue the write loop (matches CPython's SOCKET_IS_BLOCKING behavior) if needs.is_some() { @@ -2896,7 +3167,7 @@ mod _ssl { } SelectRet::Closed => return Err(socket_closed_error(vm)), SelectRet::Nonblocking => {} - SelectRet::IsBlocking | SelectRet::Ok => { + SelectRet::Ok => { // For blocking sockets, select() has completed successfully // Continue the read loop (matches CPython's SOCKET_IS_BLOCKING behavior) if needs.is_some() { @@ -3576,30 +3847,33 @@ mod _ssl { let bio = bio::MemBioSlice::new(der)?; let mut certs = vec![]; + let mut was_bio_eof = false; loop { + // Check for EOF before attempting to parse (like CPython's _add_ca_certs) + // BIO_ctrl with BIO_CTRL_EOF returns 1 if EOF, 0 otherwise + if sys::BIO_ctrl(bio.as_ptr(), sys::BIO_CTRL_EOF, 0, std::ptr::null_mut()) != 0 { + was_bio_eof = true; + break; + } + let cert = sys::d2i_X509_bio(bio.as_ptr(), std::ptr::null_mut()); if cert.is_null() { + // Parse error (not just EOF) break; } certs.push(X509::from_ptr(cert)); } - if certs.is_empty() { - // No certificates loaded at all + // If we loaded some certs but didn't reach EOF, there's garbage data + // (like cacert_der + b"A") - this is an error + if !certs.is_empty() && !was_bio_eof { + // Return the error from the last failed parse attempt return Err(ErrorStack::get()); } - // Successfully loaded at least one certificate from DER data. - // Clear any trailing errors from EOF. - // CPython clears errors when: - // - DER: was_bio_eof is set (EOF reached) - // - PEM: PEM_R_NO_START_LINE error (normal EOF) - // Both cases mean successful completion with loaded certs. - eprintln!( - "[x509_stack_from_der] SUCCESS: Clearing errors and returning {} certs", - certs.len() - ); + // Clear any errors (including parse errors when no certs loaded) + // Let the caller decide how to handle empty results sys::ERR_clear_error(); Ok(certs) } diff --git a/crates/stdlib/src/openssl/cert.rs b/crates/stdlib/src/openssl/cert.rs index 1197bf4aa4..b63d824a83 100644 --- a/crates/stdlib/src/openssl/cert.rs +++ b/crates/stdlib/src/openssl/cert.rs @@ -5,16 +5,19 @@ pub(super) use ssl_cert::{PySSLCertificate, cert_to_certificate, cert_to_py, obj #[pymodule(sub)] pub(crate) mod ssl_cert { use crate::{ - common::ascii, + common::{ascii, hash::PyHash}, vm::{ - PyObjectRef, PyPayload, PyResult, VirtualMachine, + Py, PyObject, PyObjectRef, PyPayload, PyResult, VirtualMachine, + class_or_notimplemented, convert::{ToPyException, ToPyObject}, - function::{FsPath, OptionalArg}, + function::{FsPath, OptionalArg, PyComparisonValue}, + types::{Comparable, Hashable, PyComparisonOp, Representable}, }, }; use foreign_types_shared::ForeignTypeRef; use openssl::{ asn1::Asn1ObjectRef, + nid::Nid, x509::{self, X509, X509Ref}, }; use openssl_sys as sys; @@ -54,7 +57,7 @@ pub(crate) mod ssl_cert { #[pyclass(module = "ssl", name = "Certificate")] #[derive(PyPayload)] pub(crate) struct PySSLCertificate { - cert: X509, + pub(crate) cert: X509, } impl fmt::Debug for PySSLCertificate { @@ -63,7 +66,7 @@ pub(crate) mod ssl_cert { } } - #[pyclass] + #[pyclass(with(Comparable, Hashable, Representable))] impl PySSLCertificate { #[pymethod] fn public_bytes( @@ -83,12 +86,14 @@ pub(crate) mod ssl_cert { Ok(vm.ctx.new_bytes(der).into()) } ENCODING_PEM => { - // PEM encoding + // PEM encoding - returns string let pem = self .cert .to_pem() .map_err(|e| convert_openssl_error(vm, e))?; - Ok(vm.ctx.new_bytes(pem).into()) + let pem_str = String::from_utf8(pem) + .map_err(|_| vm.new_value_error("Invalid UTF-8 in PEM"))?; + Ok(vm.ctx.new_str(pem_str).into()) } _ => Err(vm.new_value_error("Unsupported format")), } @@ -100,6 +105,66 @@ pub(crate) mod ssl_cert { } } + impl Comparable for PySSLCertificate { + fn cmp( + zelf: &Py, + other: &PyObject, + op: PyComparisonOp, + vm: &VirtualMachine, + ) -> PyResult { + let other = class_or_notimplemented!(Self, other); + + // Only support equality comparison + if !matches!(op, PyComparisonOp::Eq | PyComparisonOp::Ne) { + return Ok(PyComparisonValue::NotImplemented); + } + + // Compare DER encodings + let self_der = zelf + .cert + .to_der() + .map_err(|e| convert_openssl_error(vm, e))?; + let other_der = other + .cert + .to_der() + .map_err(|e| convert_openssl_error(vm, e))?; + + let eq = self_der == other_der; + Ok(op.eval_ord(eq.cmp(&true)).into()) + } + } + + impl Hashable for PySSLCertificate { + fn hash(zelf: &Py, _vm: &VirtualMachine) -> PyResult { + // Use subject name hash as certificate hash + let hash = unsafe { sys::X509_subject_name_hash(zelf.cert.as_ptr()) }; + Ok(hash as PyHash) + } + } + + impl Representable for PySSLCertificate { + fn repr_str(zelf: &Py, _vm: &VirtualMachine) -> PyResult { + // Build subject string like "CN=localhost, O=Python" + let subject = zelf.cert.subject_name(); + let mut parts: Vec = Vec::new(); + for entry in subject.entries() { + // Use short name (SN) if available, otherwise use OID + let name = match entry.object().nid().short_name() { + Ok(sn) => sn.to_string(), + Err(_) => obj2txt(entry.object(), true).unwrap_or_default(), + }; + if let Ok(value) = entry.data().as_utf8() { + parts.push(format!("{}={}", name, value)); + } + } + if parts.is_empty() { + Ok("".to_string()) + } else { + Ok(format!("", parts.join(", "))) + } + } + } + fn name_to_py(vm: &VirtualMachine, name: &x509::X509NameRef) -> PyResult { let list = name .entries() @@ -133,11 +198,15 @@ pub(crate) mod ssl_cert { .to_bn() .and_then(|bn| bn.to_hex_str()) .map_err(|e| convert_openssl_error(vm, e))?; - dict.set_item( - "serialNumber", - vm.ctx.new_str(serial_num.to_owned()).into(), - vm, - )?; + // Serial number must have even length (each byte = 2 hex chars) + // BigNum::to_hex_str() strips leading zeros, so we need to pad + let serial_str = serial_num.to_string(); + let serial_str = if serial_str.len() % 2 == 1 { + format!("0{}", serial_str) + } else { + serial_str + }; + dict.set_item("serialNumber", vm.ctx.new_str(serial_str).into(), vm)?; dict.set_item( "notBefore", @@ -188,10 +257,23 @@ pub(crate) mod ssl_cert { return vm.new_tuple((ascii!("DirName"), py_name)).into(); } - // TODO: Handle Registered ID (GEN_RID) - // CPython implementation uses i2t_ASN1_OBJECT to convert OID - // This requires accessing GENERAL_NAME union which is complex in Rust - // For now, we return for unhandled types + // Check for Registered ID (GEN_RID) + // Access raw GENERAL_NAME to check type + let ptr = gen_name.as_ptr(); + unsafe { + if (*ptr).type_ == sys::GEN_RID { + // d is ASN1_OBJECT* for GEN_RID + let oid_ptr = (*ptr).d as *const sys::ASN1_OBJECT; + if !oid_ptr.is_null() { + let oid_ref = Asn1ObjectRef::from_ptr(oid_ptr as *mut _); + if let Some(oid_str) = obj2txt(oid_ref, true) { + return vm + .new_tuple((ascii!("Registered ID"), oid_str)) + .into(); + } + } + } + } // For othername and other unsupported types vm.new_tuple((ascii!("othername"), ascii!(""))) @@ -202,6 +284,60 @@ pub(crate) mod ssl_cert { dict.set_item("subjectAltName", vm.ctx.new_tuple(san).into(), vm)?; }; + // Authority Information Access: OCSP URIs + if let Ok(ocsp_list) = cert.ocsp_responders() + && !ocsp_list.is_empty() + { + let uris: Vec = ocsp_list + .iter() + .map(|s| vm.ctx.new_str(s.to_string()).into()) + .collect(); + dict.set_item("OCSP", vm.ctx.new_tuple(uris).into(), vm)?; + } + + // Authority Information Access: CA Issuers URIs + if let Some(aia) = cert.authority_info() { + let ca_issuers: Vec = aia + .iter() + .filter_map(|ad| { + // Check if method is CA Issuers (NID_ad_ca_issuers) + if ad.method().nid() != Nid::AD_CA_ISSUERS { + return None; + } + // Get URI from location + ad.location() + .uri() + .map(|uri| vm.ctx.new_str(uri.to_owned()).into()) + }) + .collect(); + if !ca_issuers.is_empty() { + dict.set_item("caIssuers", vm.ctx.new_tuple(ca_issuers).into(), vm)?; + } + } + + // CRL Distribution Points + if let Some(crl_dps) = cert.crl_distribution_points() { + let mut crl_uris: Vec = Vec::new(); + for dp in crl_dps.iter() { + if let Some(dp_name) = dp.distpoint() + && let Some(fullname) = dp_name.fullname() + { + for gn in fullname.iter() { + if let Some(uri) = gn.uri() { + crl_uris.push(vm.ctx.new_str(uri.to_owned()).into()); + } + } + } + } + if !crl_uris.is_empty() { + dict.set_item( + "crlDistributionPoints", + vm.ctx.new_tuple(crl_uris).into(), + vm, + )?; + } + } + Ok(dict.into()) }