diff --git a/compiler/codegen/src/ir.rs b/compiler/codegen/src/ir.rs index 7cd173c821..f08b37e888 100644 --- a/compiler/codegen/src/ir.rs +++ b/compiler/codegen/src/ir.rs @@ -4,8 +4,8 @@ use crate::{IndexMap, IndexSet, error::InternalError}; use rustpython_compiler_core::{ OneIndexed, SourceLocation, bytecode::{ - CodeFlags, CodeObject, CodeUnit, ConstantData, InstrDisplayContext, Instruction, Label, - OpArg, PyCodeLocationInfoKind, + CodeFlags, CodeObject, CodeUnit, CodeUnits, ConstantData, InstrDisplayContext, Instruction, + Label, OpArg, PyCodeLocationInfoKind, }, }; @@ -214,7 +214,7 @@ impl CodeInfo { qualname: qualname.unwrap_or(obj_name), max_stackdepth, - instructions: instructions.into_boxed_slice(), + instructions: CodeUnits::from(instructions), locations: locations.into_boxed_slice(), constants: constants.into_iter().collect(), names: name_cache.into_iter().collect(), diff --git a/compiler/core/src/bytecode.rs b/compiler/core/src/bytecode.rs index 4ab666da99..2cfb70e278 100644 --- a/compiler/core/src/bytecode.rs +++ b/compiler/core/src/bytecode.rs @@ -1,13 +1,16 @@ //! Implement python as a virtual machine with bytecode. This module //! implements bytecode structure. -use crate::{OneIndexed, SourceLocation}; +use crate::{ + marshal::MarshalError, + {OneIndexed, SourceLocation}, +}; use bitflags::bitflags; use itertools::Itertools; use malachite_bigint::BigInt; use num_complex::Complex64; use rustpython_wtf8::{Wtf8, Wtf8Buf}; -use std::{collections::BTreeSet, fmt, hash, marker::PhantomData, mem}; +use std::{collections::BTreeSet, fmt, hash, marker::PhantomData, mem, ops::Deref}; #[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)] #[repr(i8)] @@ -195,7 +198,7 @@ impl ConstantBag for BasicBag { /// a code object. Also a module has a code object. #[derive(Clone)] pub struct CodeObject { - pub instructions: Box<[CodeUnit]>, + pub instructions: CodeUnits, pub locations: Box<[SourceLocation]>, pub flags: CodeFlags, /// Number of positional-only arguments @@ -257,6 +260,12 @@ impl OpArgByte { } } +impl From for OpArgByte { + fn from(raw: u8) -> Self { + Self(raw) + } +} + impl fmt::Debug for OpArgByte { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { self.0.fmt(f) @@ -808,14 +817,14 @@ impl From for u8 { } impl TryFrom for Instruction { - type Error = crate::marshal::MarshalError; + type Error = MarshalError; #[inline] - fn try_from(value: u8) -> Result { + fn try_from(value: u8) -> Result { if value <= u8::from(LAST_INSTRUCTION) { Ok(unsafe { std::mem::transmute::(value) }) } else { - Err(crate::marshal::MarshalError::InvalidBytecode) + Err(MarshalError::InvalidBytecode) } } } @@ -835,6 +844,58 @@ impl CodeUnit { } } +impl TryFrom<&[u8]> for CodeUnit { + type Error = MarshalError; + + fn try_from(value: &[u8]) -> Result { + match value.len() { + 2 => Ok(Self::new(value[0].try_into()?, value[1].into())), + _ => Err(Self::Error::InvalidBytecode), + } + } +} + +#[derive(Clone)] +pub struct CodeUnits(Box<[CodeUnit]>); + +impl TryFrom<&[u8]> for CodeUnits { + type Error = MarshalError; + + fn try_from(value: &[u8]) -> Result { + if !value.len().is_multiple_of(2) { + return Err(Self::Error::InvalidBytecode); + } + + value.chunks_exact(2).map(CodeUnit::try_from).collect() + } +} + +impl From<[CodeUnit; N]> for CodeUnits { + fn from(value: [CodeUnit; N]) -> Self { + Self(Box::from(value)) + } +} + +impl From> for CodeUnits { + fn from(value: Vec) -> Self { + Self(value.into_boxed_slice()) + } +} + +impl FromIterator for CodeUnits { + fn from_iter>(iter: T) -> Self { + Self(iter.into_iter().collect()) + } +} + +impl Deref for CodeUnits { + type Target = [CodeUnit]; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + use self::Instruction::*; bitflags! { diff --git a/compiler/core/src/marshal.rs b/compiler/core/src/marshal.rs index f803317772..39e4807167 100644 --- a/compiler/core/src/marshal.rs +++ b/compiler/core/src/marshal.rs @@ -165,19 +165,6 @@ impl<'a> ReadBorrowed<'a> for &'a [u8] { } } -/// Parses bytecode bytes into CodeUnit instructions. -/// Each instruction is 2 bytes: opcode and argument. -pub fn parse_instructions_from_bytes(bytes: &[u8]) -> Result> { - bytes - .chunks_exact(2) - .map(|cu| { - let op = Instruction::try_from(cu[0])?; - let arg = OpArgByte(cu[1]); - Ok(CodeUnit { op, arg }) - }) - .collect() -} - pub struct Cursor { pub data: B, pub position: usize, @@ -197,8 +184,8 @@ pub fn deserialize_code( bag: Bag, ) -> Result> { let len = rdr.read_u32()?; - let instructions = rdr.read_slice(len * 2)?; - let instructions = parse_instructions_from_bytes(instructions)?; + let raw_instructions = rdr.read_slice(len * 2)?; + let instructions = CodeUnits::try_from(raw_instructions)?; let len = rdr.read_u32()?; let locations = (0..len) diff --git a/vm/src/builtins/code.rs b/vm/src/builtins/code.rs index 2a22993a9e..f9f3429e30 100644 --- a/vm/src/builtins/code.rs +++ b/vm/src/builtins/code.rs @@ -1,12 +1,10 @@ -/*! Infamous code object. The python class `code` - -*/ +//! Infamous code object. The python class `code` use super::{PyBytesRef, PyStrRef, PyTupleRef, PyType, PyTypeRef}; use crate::{ AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyResult, VirtualMachine, builtins::PyStrInterned, - bytecode::{self, AsBag, BorrowedConstant, CodeFlags, CodeUnit, Constant, ConstantBag}, + bytecode::{self, AsBag, BorrowedConstant, CodeFlags, Constant, ConstantBag}, class::{PyClassImpl, StaticType}, convert::ToPyObject, frozen, @@ -15,11 +13,7 @@ use crate::{ }; use malachite_bigint::BigInt; use num_traits::Zero; -use rustpython_compiler_core::{ - OneIndexed, - bytecode::PyCodeLocationInfoKind, - marshal::{MarshalError, parse_instructions_from_bytes}, -}; +use rustpython_compiler_core::{OneIndexed, bytecode::CodeUnits, bytecode::PyCodeLocationInfoKind}; use std::{borrow::Borrow, fmt, ops::Deref}; /// State for iterating through code address ranges @@ -457,7 +451,7 @@ impl Constructor for PyCode { // Parse and validate bytecode from bytes let bytecode_bytes = args.co_code.as_bytes(); - let instructions = parse_bytecode(bytecode_bytes) + let instructions = CodeUnits::try_from(bytecode_bytes) .map_err(|e| vm.new_value_error(format!("invalid bytecode: {}", e)))?; // Convert constants @@ -925,7 +919,7 @@ impl PyCode { let instructions = match co_code { OptionalArg::Present(code_bytes) => { // Parse and validate bytecode from bytes - parse_bytecode(code_bytes.as_bytes()) + CodeUnits::try_from(code_bytes.as_bytes()) .map_err(|e| vm.new_value_error(format!("invalid bytecode: {}", e)))? } OptionalArg::Missing => self.code.instructions.clone(), @@ -1033,19 +1027,6 @@ impl ToPyObject for bytecode::CodeObject { } } -/// Validates and parses bytecode bytes into CodeUnit instructions. -/// Returns MarshalError if bytecode is invalid (odd length or contains invalid opcodes). -/// Note: Returning MarshalError is not necessary at this point because this is not a part of marshalling API. -/// However, we (temporarily) reuse MarshalError for simplicity. -fn parse_bytecode(bytecode_bytes: &[u8]) -> Result, MarshalError> { - // Bytecode must have even length (each instruction is 2 bytes) - if !bytecode_bytes.len().is_multiple_of(2) { - return Err(MarshalError::InvalidBytecode); - } - - parse_instructions_from_bytes(bytecode_bytes) -} - // Helper struct for reading linetable struct LineTableReader<'a> { data: &'a [u8],