diff --git a/Cargo.toml b/Cargo.toml index e52310d..233b31a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,15 +1,24 @@ [package] name = "cppshift" -version = "0.1.0" +version = "0.1.1" authors = ["Jérémy HERGAULT", "Enzo PASQUALINI"] description = "CPP parser and transpiler" repository = "https://github.com/worldline/cppshift" edition = "2024" license = "Apache-2.0" +[features] +default = ["transpiler"] +ast = [] +transpiler = ["ast", "dep:serde", "dep:syn", "dep:quote", "dep:proc-macro2"] + [dependencies] miette = { version = "7", features = ["fancy"] } thiserror = "2" +serde = { version = "1", features = ["derive"], optional = true } +syn = { version = "2", features = ["full", "extra-traits", "printing"], optional = true } +quote = { version = "1", optional = true } +proc-macro2 = { version = "1", optional = true } [dev-dependencies] tokio = { version = "1", features = ["macros"] } diff --git a/src/ast/expr.rs b/src/ast/expr.rs index ec71d47..17cf767 100644 --- a/src/ast/expr.rs +++ b/src/ast/expr.rs @@ -3,6 +3,7 @@ //! Analogous to `syn::Expr`. use crate::SourceSpan; +use crate::ast::ty::{FundamentalKind, Type}; use super::item::{Ident, Path}; use super::punct::Punctuated; @@ -16,6 +17,49 @@ pub enum LitKind { Char, } +impl LitKind { + /// Check if the fundamental kind matches the literal kind + pub fn match_fundamental(&self, kind: FundamentalKind) -> bool { + match self { + LitKind::Integer => matches!( + kind, + FundamentalKind::Short + | FundamentalKind::Int + | FundamentalKind::Long + | FundamentalKind::LongLong + | FundamentalKind::SignedChar + | FundamentalKind::UnsignedChar + | FundamentalKind::UnsignedShort + | FundamentalKind::UnsignedInt + | FundamentalKind::UnsignedLong + | FundamentalKind::UnsignedLongLong + ), + LitKind::Float => matches!( + kind, + FundamentalKind::Float | FundamentalKind::Double | FundamentalKind::LongDouble + ), + LitKind::Char => matches!( + kind, + FundamentalKind::Char + | FundamentalKind::Wchar + | FundamentalKind::Char8 + | FundamentalKind::Char16 + | FundamentalKind::Char32 + ), + _ => false, + } + } + + /// Check if the literal kind matches the type + pub fn match_type(&self, ty: &Type) -> bool { + match ty { + Type::Fundamental(fund) => self.match_fundamental(fund.kind), + Type::Qualified(qualified) => self.match_type(&qualified.ty), + _ => false, + } + } +} + /// Unary operator. #[derive(Debug, Clone, Copy, PartialEq)] pub enum UnaryOp { diff --git a/src/ast/item.rs b/src/ast/item.rs index 10ac611..5457e4f 100644 --- a/src/ast/item.rs +++ b/src/ast/item.rs @@ -3,6 +3,10 @@ //! Each variant of [`Item`] corresponds to a top-level declaration in a C++ translation unit, //! following the naming conventions of `syn::Item`. +use std::collections::LinkedList; +use std::fmt; +use std::hash::{Hash, Hasher}; + use crate::SourceSpan; use crate::lex::Token; @@ -22,6 +26,42 @@ pub struct Ident<'de> { pub span: SourceSpan<'de>, } +impl<'de> PartialEq<&str> for Ident<'de> { + fn eq(&self, sym: &&str) -> bool { + self.sym == *sym + } +} + +impl<'de> PartialEq> for &str { + fn eq(&self, ident: &Ident<'de>) -> bool { + *self == ident.sym + } +} + +impl<'de> PartialEq for Ident<'de> { + fn eq(&self, sym: &String) -> bool { + self.sym == sym.as_str() + } +} + +impl<'de> PartialEq> for String { + fn eq(&self, ident: &Ident<'de>) -> bool { + self.as_str() == ident.sym + } +} + +impl<'de> Hash for Ident<'de> { + fn hash(&self, state: &mut H) { + self.sym.hash(state); + } +} + +impl<'de> fmt::Display for Ident<'de> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.sym) + } +} + /// Visibility of a declaration. /// /// In C++, visibility applies within class/struct bodies via access specifiers. @@ -39,8 +79,11 @@ pub enum Visibility { /// A C++20 attribute `[[...]]`, analogous to `syn::Attribute`. #[derive(Debug, Clone, PartialEq)] pub struct Attribute<'de> { + /// Source location of the entire `[[...]]`. pub span: SourceSpan<'de>, + /// Attribute name/path: `nodiscard`, `gnu::unused`, etc. pub path: Path<'de>, + /// Attribute arguments as raw tokens (e.g. the `"reason"` in `[[deprecated("reason")]]`). pub args: Vec>, } @@ -49,10 +92,26 @@ pub struct Attribute<'de> { /// Analogous to `syn::Path`. #[derive(Debug, Clone, PartialEq)] pub struct Path<'de> { + /// `true` if the path starts with `::` (absolute/global scope). pub leading_colon: bool, + /// The segments of the path: `std::vector` → `[std, vector]`. pub segments: Vec>, } +impl<'de> fmt::Display for Path<'de> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "{}", + self.segments + .iter() + .map(|s| s.ident.sym) + .collect::>() + .join("::") + ) + } +} + /// A single segment of a path, analogous to `syn::PathSegment`. #[derive(Debug, Clone, Copy, PartialEq)] pub struct PathSegment<'de> { @@ -64,8 +123,11 @@ pub struct PathSegment<'de> { /// Example: `public Base`, `virtual protected Interface` #[derive(Debug, Clone, PartialEq)] pub struct BaseSpecifier<'de> { + /// Access specifier for the inheritance: `public`, `protected`, or `private`. pub access: Visibility, + /// `true` for virtual inheritance (diamond problem mitigation). pub virtual_token: bool, + /// The base class name (possibly qualified: `std::Base`). pub path: Path<'de>, } @@ -74,10 +136,17 @@ pub struct BaseSpecifier<'de> { /// Analogous to `syn::Field`. #[derive(Debug, Clone, PartialEq)] pub struct Field<'de> { + /// C++20 attributes. pub attrs: Vec>, + /// Access specifier (`public`, `private`, `protected`). pub vis: Visibility, + /// `true` if declared `static`. + pub static_token: bool, + /// The field's type. pub ty: Type<'de>, + /// Field name. `None` for anonymous fields (e.g. anonymous unions). pub ident: Option>, + /// Default member initializer (C++11): `int x = 42;`. pub default_value: Option>, } @@ -126,42 +195,99 @@ pub enum Member<'de> { } /// A function argument, analogous to `syn::FnArg`. +/// +/// Example: `[[maybe_unused]] int count = 0` #[derive(Debug, Clone, PartialEq)] pub struct FnArg<'de> { + /// C++20 attributes: `[[maybe_unused]]`, etc. pub attrs: Vec>, + /// The parameter's type. pub ty: Type<'de>, + /// Parameter name. `None` for unnamed parameters (e.g. `void foo(int)`). pub ident: Option>, + /// Default argument value: `int count = 0`. pub default_value: Option>, } /// A function signature, analogous to `syn::Signature`. +/// +/// Contains all specifiers, qualifiers, and the parameter list. +/// +/// Example: `constexpr inline virtual int compute(int x) const noexcept override` #[derive(Debug, Clone, PartialEq)] pub struct Signature<'de> { + // --- Leading specifiers --- + /// `true` if declared `constexpr`. pub constexpr_token: bool, + /// `true` if declared `consteval` (C++20, immediate function). pub consteval_token: bool, + /// `true` if declared `inline`. pub inline_token: bool, + /// `true` if declared `virtual`. pub virtual_token: bool, + /// `true` if declared `static`. pub static_token: bool, + /// `true` if declared `explicit` (for conversion operators). pub explicit_token: bool, + /// The return type (e.g. `int`, `void`, `auto`). pub return_type: Type<'de>, + /// Optional qualifying class/namespace path for out-of-line definitions. + /// For `void MyClass::myFunction()`, this is `MyClass`. + pub class_path: Option>, + /// The function name. pub ident: Ident<'de>, + /// Function parameters, comma-separated. pub inputs: Punctuated<'de, FnArg<'de>>, + /// `true` if the function accepts variadic arguments (`...`). pub variadic: bool, - // Trailing qualifiers + // --- Trailing qualifiers --- + /// `true` if the member function is `const`-qualified. pub const_token: bool, + /// `true` if declared `noexcept`. pub noexcept_token: bool, + /// `true` if declared `override` (virtual method override). pub override_token: bool, + /// `true` if declared `final` (prevents further overriding). pub final_token: bool, + /// `true` if `= 0` (pure virtual / abstract method). pub pure_virtual: bool, + /// `true` if `= default` (compiler-generated implementation). pub defaulted: bool, + /// `true` if `= delete` (explicitly deleted function). pub deleted: bool, + /// Member initializer list for out-of-line constructors: `: m_x(x), m_y(y)`. + pub member_init_list: Vec>, +} + +impl<'de> Signature<'de> { + /// Returns true if this is a class constructor. + pub fn is_class_constructor(&self) -> bool { + self.class_path + .as_ref() + .is_some_and(|cp| self.ident == cp.to_string()) + } + + /// Checks if this function can be called with no arguments + /// (either no parameters or all parameters have defaults). + pub fn has_no_required_params(&self) -> bool { + if let Some((fn_arg, _)) = self.inputs.inner.first() + && fn_arg.default_value.is_none() + { + false + } else { + !self.deleted + } + } } /// An enum variant, analogous to `syn::Variant`. #[derive(Debug, Clone, PartialEq)] pub struct Variant<'de> { + /// C++20 attributes. pub attrs: Vec>, + /// The variant name (e.g. `Red`). pub ident: Ident<'de>, + /// Explicit discriminant value: `Red = 1`. pub discriminant: Option>, } @@ -201,10 +327,12 @@ pub enum ForeignItem<'de> { /// A member initializer in a constructor initializer list. /// -/// Example: `m_x(x)`, `Base(arg)` +/// Example: `m_x(x)`, `Base(arg)`, `::std::runtime_error(msg)` #[derive(Debug, Clone, PartialEq)] pub struct MemberInit<'de> { - pub member: Ident<'de>, + /// The member or base class being initialized (may be a qualified path). + pub member: Path<'de>, + /// The initializer arguments: `m_x(x)` → args is `[x]`. pub args: Punctuated<'de, Expr<'de>>, } @@ -240,6 +368,8 @@ pub enum Item<'de> { Const(ItemConst<'de>), /// Static variable: `static int count;` Static(ItemStatic<'de>), + /// Variable declaration: `int x = 42;` + Var(ItemVar<'de>), /// Extern block: `extern "C" { ... }` ForeignMod(ItemForeignMod<'de>), /// Template declaration: `template ...` @@ -255,58 +385,100 @@ pub enum Item<'de> { } /// A function declaration or definition, analogous to `syn::ItemFn`. +/// +/// Example: `[[nodiscard]] int add(int a, int b) { return a + b; }` #[derive(Debug, Clone, PartialEq)] pub struct ItemFn<'de> { + /// C++20 attributes: `[[nodiscard]]`, `[[deprecated]]`, etc. pub attrs: Vec>, + /// Access specifier when this function is a class member. pub vis: Visibility, + /// Function signature: name, return type, parameters, and qualifiers. pub sig: Signature<'de>, + /// Function body. `None` for declarations (`;`), `Some` for definitions (`{ ... }`). pub block: Option>, } /// A struct definition, analogous to `syn::ItemStruct`. +/// +/// Example: `struct Point : public Base { int x; int y; };` #[derive(Debug, Clone, PartialEq)] pub struct ItemStruct<'de> { + /// C++20 attributes: `[[deprecated]]`, `[[nodiscard]]`, etc. pub attrs: Vec>, + /// Struct name. `None` for anonymous structs. pub ident: Option>, + /// Template parameters if this is a template specialization: `struct Foo`. pub generics: Option>, + /// Base class specifiers: `public Base, virtual Interface`. pub bases: Vec>, + /// Struct body with members, or `Unit` for forward declarations. pub fields: Fields<'de>, } /// A class definition (C++ specific). +/// +/// Identical to [`ItemStruct`] but with `private` as the default access specifier. +/// +/// Example: `class Widget : public Base { public: void draw(); };` #[derive(Debug, Clone, PartialEq)] pub struct ItemClass<'de> { + /// C++20 attributes: `[[deprecated]]`, `[[nodiscard]]`, etc. pub attrs: Vec>, + /// Class name. `None` for anonymous classes. pub ident: Option>, + /// Template parameters if this is a template specialization: `class Foo`. pub generics: Option>, + /// Base class specifiers: `public Base, virtual Interface`. pub bases: Vec>, + /// Class body with members, or `Unit` for forward declarations. pub fields: Fields<'de>, } /// An enum definition, analogous to `syn::ItemEnum`. +/// +/// Covers both unscoped (`enum Color { ... }`) and scoped (`enum class Color { ... }`) enums. +/// +/// Example: `enum class Color : uint8_t { Red, Green, Blue };` #[derive(Debug, Clone, PartialEq)] pub struct ItemEnum<'de> { + /// C++20 attributes: `[[nodiscard]]`, etc. pub attrs: Vec>, + /// Enum name. `None` for anonymous enums. pub ident: Option>, + /// `true` for `enum class` / `enum struct` (scoped enums, C++11). pub scoped: bool, + /// Explicit underlying type: `enum Color : uint8_t`. pub underlying_type: Option>, + /// Enum variants, comma-separated. pub variants: Punctuated<'de, Variant<'de>>, } /// A union definition, analogous to `syn::ItemUnion`. +/// +/// Example: `union Data { int i; float f; double d; };` #[derive(Debug, Clone, PartialEq)] pub struct ItemUnion<'de> { + /// C++20 attributes. pub attrs: Vec>, + /// Union name. `None` for anonymous unions. pub ident: Option>, + /// Union members (all share the same memory location). pub fields: FieldsNamed<'de>, } /// A namespace declaration, analogous to `syn::ItemMod`. +/// +/// Example: `inline namespace v2 { void foo(); }` #[derive(Debug, Clone, PartialEq)] pub struct ItemNamespace<'de> { + /// C++20 attributes. pub attrs: Vec>, + /// `true` for `inline namespace` (symbols visible in enclosing namespace). pub inline_token: bool, + /// Namespace name. `None` for anonymous namespaces. pub ident: Option>, + /// Items declared inside this namespace. pub content: Vec>, } @@ -332,61 +504,114 @@ pub enum ItemUse<'de> { } /// A type alias (`using X = Y`), analogous to `syn::ItemType`. +/// +/// Example: `template using Vec = std::vector;` #[derive(Debug, Clone, PartialEq)] pub struct ItemType<'de> { + /// C++20 attributes. pub attrs: Vec>, + /// The alias name. pub ident: Ident<'de>, + /// Template parameters for alias templates. pub generics: Option>, + /// The aliased type. pub ty: Type<'de>, } /// A typedef declaration (C-style type alias). +/// +/// Example: `typedef unsigned long size_t;` #[derive(Debug, Clone, PartialEq)] pub struct ItemTypedef<'de> { + /// C++20 attributes. pub attrs: Vec>, + /// The original type being aliased. pub ty: Type<'de>, + /// The new alias name. pub ident: Ident<'de>, } /// A const or constexpr variable, analogous to `syn::ItemConst`. +/// +/// Example: `constexpr int MAX_SIZE = 1024;` #[derive(Debug, Clone, PartialEq)] pub struct ItemConst<'de> { + /// C++20 attributes. pub attrs: Vec>, + /// `true` for `constexpr` (compile-time evaluated), `false` for plain `const`. pub constexpr_token: bool, + /// The constant's type. pub ty: Type<'de>, + /// The constant's name. pub ident: Ident<'de>, + /// The initializer expression. pub expr: Expr<'de>, } /// A static variable, analogous to `syn::ItemStatic`. +/// +/// Example: `static int instance_count = 0;` #[derive(Debug, Clone, PartialEq)] pub struct ItemStatic<'de> { + /// C++20 attributes. pub attrs: Vec>, + /// The variable's type. pub ty: Type<'de>, + /// The variable's name. pub ident: Ident<'de>, + /// Optional initializer. `None` for uninitialized declarations. + pub expr: Option>, +} + +/// A variable declaration (not `static`, not `const`/`constexpr`). +/// +/// Example: `int x = 42;` +#[derive(Debug, Clone, PartialEq)] +pub struct ItemVar<'de> { + /// C++20 attributes. + pub attrs: Vec>, + /// The variable's type. + pub ty: Type<'de>, + /// The variable's name. + pub ident: Ident<'de>, + /// Optional initializer. `None` for uninitialized declarations. pub expr: Option>, } /// An extern block (`extern "C" { ... }`), analogous to `syn::ItemForeignMod`. +/// +/// Example: `extern "C" { void c_function(); }` #[derive(Debug, Clone, PartialEq)] pub struct ItemForeignMod<'de> { + /// C++20 attributes. pub attrs: Vec>, + /// The ABI string: `"C"`, `"C++"`, etc. pub abi: &'de str, + /// Declarations inside the extern block. pub items: Vec>, } /// A template declaration (C++ specific). +/// +/// Example: `template class Array { ... };` #[derive(Debug, Clone, PartialEq)] pub struct ItemTemplate<'de> { + /// C++20 attributes. pub attrs: Vec>, + /// Template parameters: ``. pub params: Punctuated<'de, TemplateParam<'de>>, + /// The templated declaration (function, class, struct, etc.). pub item: Box>, } /// A static assertion (C++ specific). +/// +/// Example: `static_assert(sizeof(int) == 4, "int must be 4 bytes");` #[derive(Debug, Clone, PartialEq)] pub struct ItemStaticAssert<'de> { + /// The boolean condition to assert at compile time. pub expr: Expr<'de>, + /// Optional error message string literal (C++17 made this optional). pub message: Option>, } @@ -402,60 +627,109 @@ pub enum IncludePath<'de> { /// A `#include` preprocessor directive: `#include ` or `#include "myfile.h"`. #[derive(Debug, Clone, PartialEq)] pub struct ItemInclude<'de> { + /// Source location of the entire directive. pub span: SourceSpan<'de>, + /// The included path (system `<...>` or local `"..."`). pub path: IncludePath<'de>, } -/// A preprocessor directive. +/// A preprocessor directive (e.g. `#define`, `#ifdef`, `#pragma`). #[derive(Debug, Clone, PartialEq)] pub struct ItemMacro<'de> { + /// Source location of the entire directive. pub span: SourceSpan<'de>, + /// Raw tokens making up the directive body. pub tokens: Vec>, } /// Tokens not interpreted by the parser, analogous to `syn::Item::Verbatim`. -#[derive(Debug, Clone, PartialEq)] +/// +/// Used as a fallback when the parser encounters a construct it cannot fully parse. +#[derive(Debug, Default, Clone, PartialEq)] pub struct ItemVerbatim<'de> { - pub tokens: Vec>, + /// The raw, unparsed tokens. + pub tokens: LinkedList>, } /// A constructor declaration/definition. +/// +/// Example: `explicit Foo(int x) noexcept : m_x(x) { }` #[derive(Debug, Clone, PartialEq)] pub struct ItemConstructor<'de> { + /// C++20 attributes. pub attrs: Vec>, + /// `true` if declared `explicit` (prevents implicit conversions). pub explicit_token: bool, + /// `true` if declared `constexpr`. pub constexpr_token: bool, + /// The class name (must match the enclosing class). pub ident: Ident<'de>, + /// Constructor parameters. pub inputs: Punctuated<'de, FnArg<'de>>, + /// `true` if declared `noexcept`. pub noexcept_token: bool, + /// Member initializer list: `: m_x(x), m_y(y)`. pub member_init_list: Vec>, + /// Constructor body. `None` for declarations. pub block: Option>, + /// `true` if `= default`. pub defaulted: bool, + /// `true` if `= delete`. pub deleted: bool, } +impl<'de> ItemConstructor<'de> { + /// Checks if this constructor is a default constructor (no parameters). + pub fn is_default_constructor(&self) -> bool { + if let Some((fn_arg, _)) = self.inputs.inner.first() + && fn_arg.default_value.is_none() + { + false + } else { + !self.deleted + } + } +} + /// A destructor declaration/definition. +/// +/// Example: `virtual ~Widget() noexcept = default;` #[derive(Debug, Clone, PartialEq)] pub struct ItemDestructor<'de> { + /// C++20 attributes. pub attrs: Vec>, + /// `true` if declared `virtual` (required for polymorphic base classes). pub virtual_token: bool, + /// The class name (the `~ClassName` part, stored without the `~`). pub ident: Ident<'de>, + /// `true` if declared `noexcept`. pub noexcept_token: bool, + /// Destructor body. `None` for declarations. pub block: Option>, + /// `true` if `= default`. pub defaulted: bool, + /// `true` if `= delete`. pub deleted: bool, + /// `true` if `= 0` (pure virtual destructor). pub pure_virtual: bool, } -/// A friend declaration. +/// A friend declaration, granting another class or function access to private members. +/// +/// Example: `friend class OtherClass;` or `friend void helper(Foo&);` #[derive(Debug, Clone, PartialEq)] pub struct ItemFriend<'de> { + /// C++20 attributes. pub attrs: Vec>, + /// The befriended declaration (function or class). pub item: Box>, } -/// Template generics on a class/struct/function. +/// Template generics on a class/struct/function, analogous to `syn::Generics`. +/// +/// Example: the `` in `template class Array`. #[derive(Debug, Clone, PartialEq)] pub struct Generics<'de> { + /// The template parameters, comma-separated. pub params: Punctuated<'de, TemplateParam<'de>>, } diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 88b6bd4..9a40617 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -44,6 +44,8 @@ pub fn parse_file<'de>(content: &'de str) -> Result, AstError> { #[cfg(test)] mod tests { + use core::panic; + use crate::ast::expr::{BinaryOp, ExprBinary, ExprIdent, ExprLit, ExprPath, LitKind}; use crate::ast::stmt::{StmtCase, StmtExpr, StmtReturn, StmtSwitch}; use crate::ast::ty::{FundamentalKind, TypeArray, TypePtr}; @@ -73,6 +75,12 @@ mod tests { #define ArgText(x) \ x##TEXT + // class constructor + void MyClass::MyClass() : x(2) {} + + // class method + void MyClass::myFunction() {} + // main function int main(int argc, char* argv[]) { std::cout << "Hello, world" << std::endl; @@ -136,6 +144,54 @@ mod tests { ); } + let class_constructor = main_item_iter.next(); + if let Some(Item::Fn(ItemFn { + attrs, + vis, + sig, + block, + })) = class_constructor + { + assert!(attrs.is_empty()); + assert_eq!(Visibility::Inherited, *vis); + assert_eq!(sig.ident.sym, "MyClass"); + let class_path = sig.class_path.as_ref().expect("expected class qualifier"); + assert_eq!(class_path.to_string(), "MyClass"); + assert!(sig.has_no_required_params()); + assert!(sig.is_class_constructor()); + assert_eq!(sig.member_init_list.len(), 1); + assert_eq!(sig.member_init_list[0].member.to_string(), "x"); + assert!(block.is_some()); + } else { + panic!( + "Wrong item: expected a class function, got {:#?}", + class_constructor + ); + } + + let class_method = main_item_iter.next(); + if let Some(Item::Fn(ItemFn { + attrs, + vis, + sig, + block, + })) = class_method + { + assert!(attrs.is_empty()); + assert_eq!(Visibility::Inherited, *vis); + assert_eq!(sig.ident.sym, "myFunction"); + let class_path = sig.class_path.as_ref().expect("expected class qualifier"); + assert_eq!(class_path.to_string(), "MyClass"); + assert!(sig.has_no_required_params()); + assert!(!sig.is_class_constructor()); + assert!(block.is_some()); + } else { + panic!( + "Wrong item: expected a class function, got {:#?}", + class_method + ); + } + let main_function = main_item_iter.next(); if let Some(Item::Fn(ItemFn { attrs, @@ -148,6 +204,7 @@ mod tests { assert_eq!(Visibility::Inherited, *vis); // Check signature + assert!(sig.class_path.is_none()); assert!(!sig.constexpr_token); assert!(!sig.consteval_token); assert!(!sig.inline_token); @@ -391,4 +448,181 @@ mod tests { assert_eq!(None, main_item_iter.next()); } + + /// Test the ast parser with a simple class definition that includes a constructor, a member function, and a member variable + #[test] + fn class_header_ast() { + let class_header_src = r#" + #include + + /** + * This is a simple class definition for testing the AST parser. + * It includes a constructor, a member function, and a member variable. + */ + class MyClass: public MyMotherClass + { + MACRO_DEF(param1, 1); + typedef MyMotherClass BaseClass; + + public: + static const string STATIC_VALUE; + + private: int member_var; + public: MyClass(int x) : member_var(x) {} + public: void member_function(); + }; + "#; + + let class_header_file = parse_file(class_header_src).unwrap(); + assert!(!class_header_file.items.is_empty()); + let mut class_header_item_iter = class_header_file.items.iter(); + + let include_system_iostream = class_header_item_iter.next(); + if let Some(Item::Include(ItemInclude { span, path })) = include_system_iostream { + assert_eq!(span.src(), "#include "); + if let IncludePath::System(path_span) = path { + assert_eq!(path_span.src(), "iostream"); + } else { + panic!("Expected a system include path, got {:#?}", path); + } + } else { + panic!( + "Wrong item: expected an include directive, got {:#?}", + include_system_iostream + ); + } + + let class_header = class_header_item_iter.next(); + if let Some(Item::Class(ItemClass { + attrs, + ident, + generics, + bases, + fields: Fields::Named(fields_named), + })) = class_header + { + assert_eq!(attrs.len(), 0); + assert_eq!(Some("MyClass"), ident.as_ref().map(|id| id.sym)); + assert_eq!(&None, generics); + + if bases.len() == 1 + && let Some(base) = bases.first() + { + assert_eq!(base.access, Visibility::Public); + assert!(!base.virtual_token); + assert_eq!(base.path.to_string(), "MyMotherClass"); + } else { + panic!( + "Wrong class.bases: expected an inheritance, got {:#?}", + bases + ); + } + + let mut fields_named_iter = fields_named.members.iter(); + + // 1. MACRO_DEF(param1, 1); → Verbatim + let verbatim_item = fields_named_iter.next(); + if let Some(Member::Item(item)) = verbatim_item + && let Item::Verbatim(verbatim) = item.as_ref() + { + assert!(!verbatim.tokens.is_empty()); + } else { + panic!("Expected a macro verbatim, got {:#?}", verbatim_item); + } + + // 2. typedef MyMotherClass BaseClass; → Typedef + let typedef_item = fields_named_iter.next(); + if let Some(Member::Item(item)) = typedef_item + && let Item::Typedef(td) = item.as_ref() + { + assert_eq!(td.ident.sym, "BaseClass"); + } else { + panic!("Expected a typedef, got {:#?}", typedef_item); + } + + // 3. public: → AccessSpecifier + let access_public_static = fields_named_iter.next(); + assert_eq!( + Some(&Member::AccessSpecifier(Visibility::Public)), + access_public_static, + ); + + // 4. static const string STATIC_VALUE; → Field + let static_field = fields_named_iter.next(); + if let Some(Member::Field(field)) = static_field { + assert_eq!(Some("STATIC_VALUE"), field.ident.as_ref().map(|id| id.sym)); + assert!(field.static_token); + assert_eq!(field.default_value, None); + } else { + panic!("Expected a static field, got {:#?}", static_field); + } + + // 5. private: → AccessSpecifier + let access_private = fields_named_iter.next(); + assert_eq!( + Some(&Member::AccessSpecifier(Visibility::Private)), + access_private, + ); + + // 6. int member_var; → Field + let field_member_var = fields_named_iter.next(); + if let Some(Member::Field(field)) = field_member_var { + assert_eq!(Some("member_var"), field.ident.as_ref().map(|id| id.sym)); + assert!(!field.static_token); + assert_eq!(field.default_value, None); + } else { + panic!("Expected a field, got {:#?}", field_member_var); + } + + // 7. public: → AccessSpecifier + let access_public1 = fields_named_iter.next(); + assert_eq!( + Some(&Member::AccessSpecifier(Visibility::Public)), + access_public1, + ); + + // 8. MyClass(int x) : member_var(x) {} → Constructor + let constructor = fields_named_iter.next(); + if let Some(Member::Constructor(ctor)) = constructor { + assert_eq!(ctor.ident.sym, "MyClass"); + assert!(!ctor.explicit_token); + assert!(!ctor.constexpr_token); + assert!(!ctor.noexcept_token); + assert!(!ctor.defaulted); + assert!(!ctor.deleted); + assert_eq!(ctor.inputs.len(), 1); + assert_eq!(ctor.member_init_list.len(), 1); + assert_eq!(ctor.member_init_list[0].member.to_string(), "member_var"); + assert!(ctor.block.is_some()); + } else { + panic!("Expected a constructor, got {:#?}", constructor); + } + + // 9. public: → AccessSpecifier + let access_public2 = fields_named_iter.next(); + assert_eq!( + Some(&Member::AccessSpecifier(Visibility::Public)), + access_public2, + ); + + // 10. void member_function(); → Method + let method = fields_named_iter.next(); + if let Some(Member::Method(m)) = method { + assert_eq!(m.sig.ident.sym, "member_function"); + assert!(m.block.is_none()); + } else { + panic!("Expected a method, got {:#?}", method); + } + + // No more members + assert_eq!(None, fields_named_iter.next()); + } else { + panic!( + "Wrong item: expected a class definition, got {:#?}", + class_header + ); + } + + assert_eq!(None, class_header_item_iter.next()); + } } diff --git a/src/ast/parse.rs b/src/ast/parse.rs index 221c6b8..4712e55 100644 --- a/src/ast/parse.rs +++ b/src/ast/parse.rs @@ -2,6 +2,8 @@ //! //! Not part of the public API. Used exclusively by [`super::parse_file`]. +use std::collections::LinkedList; + use crate::SourceSpan; use crate::lex::{Lexer, Token, TokenKind}; @@ -247,7 +249,7 @@ fn parse_item<'de>(p: &mut Parser<'de>) -> Result, AstError> { // Empty declaration (stray semicolons, e.g., after namespace or class body) if p.peek_kind() == Some(TokenKind::Semicolon) { p.bump()?; - return Ok(Item::Verbatim(ItemVerbatim { tokens: Vec::new() })); + return Ok(Item::Verbatim(ItemVerbatim::default())); } // Parse leading attributes [[...]] @@ -296,7 +298,7 @@ fn parse_item<'de>(p: &mut Parser<'de>) -> Result, AstError> { } } p.expect(TokenKind::Semicolon)?; - Some(Item::Verbatim(ItemVerbatim { tokens: Vec::new() })) + Some(Item::Verbatim(ItemVerbatim::default())) } _ => None, }; @@ -335,6 +337,7 @@ fn set_item_attrs<'de>(mut item: Item<'de>, attrs: Vec>) -> Item< Item::Typedef(t) => t.attrs = attrs, Item::Const(c) => c.attrs = attrs, Item::Static(s) => s.attrs = attrs, + Item::Var(v) => v.attrs = attrs, Item::ForeignMod(f) => f.attrs = attrs, Item::Template(t) => t.attrs = attrs, Item::StaticAssert(_) | Item::Include(_) | Item::Macro(_) | Item::Verbatim(_) => {} @@ -545,8 +548,22 @@ fn parse_item_using<'de>(p: &mut Parser<'de>) -> Result, AstError> { fn parse_item_typedef<'de>(p: &mut Parser<'de>) -> Result, AstError> { p.expect(TokenKind::KeywordTypedef)?; - let ty = parse_type(p)?; + let mut ty = parse_type(p)?; let ident = parse_ident(p)?; + // Handle C-style array typedefs: `typedef char type24[3];` + while p.peek_kind() == Some(TokenKind::LeftBracket) { + p.bump()?; + let size = if p.peek_kind() != Some(TokenKind::RightBracket) { + Some(parse_expr(p)?) + } else { + None + }; + p.expect(TokenKind::RightBracket)?; + ty = Type::Array(TypeArray { + element: Box::new(ty), + size, + }); + } p.expect(TokenKind::Semicolon)?; Ok(ItemTypedef { attrs: Vec::new(), @@ -800,9 +817,7 @@ fn parse_item_foreign_mod<'de>(p: &mut Parser<'de>) -> Result(p: &mut Parser<'de>) -> Result items.push(ForeignItem::Fn(f)), Item::Static(s) => items.push(ForeignItem::Static(s)), - _ => items.push(ForeignItem::Verbatim(ItemVerbatim { tokens: Vec::new() })), + _ => items.push(ForeignItem::Verbatim(ItemVerbatim::default())), } } p.expect(TokenKind::RightBrace)?; @@ -1047,6 +1062,11 @@ fn parse_item_fn_or_var<'de>(p: &mut Parser<'de>) -> Result, AstError> // Qualified destructor: ClassName::~ClassName() if p.peek_kind() == Some(TokenKind::DoubleColon) && p.peek_nth_kind(1) == Some(TokenKind::Compl) { + let class_path = if let Type::Path(tp) = &return_type { + Some(tp.path.clone()) + } else { + None + }; p.bump()?; // :: p.bump()?; // ~ let dtor_ident = parse_ident(p)?; @@ -1099,6 +1119,7 @@ fn parse_item_fn_or_var<'de>(p: &mut Parser<'de>) -> Result, AstError> span: dtor_ident.span, kind: FundamentalKind::Void, }), + class_path, ident: dtor_ident, inputs, variadic: false, @@ -1109,6 +1130,7 @@ fn parse_item_fn_or_var<'de>(p: &mut Parser<'de>) -> Result, AstError> pure_virtual: false, defaulted: false, deleted: false, + member_init_list: Vec::new(), }, block, })); @@ -1129,11 +1151,15 @@ fn parse_item_fn_or_var<'de>(p: &mut Parser<'de>) -> Result, AstError> && tp.path.segments.len() >= 2 { let fn_name = tp.path.segments.last().unwrap().ident; - // Reconstruct the return type from the qualifier minus the last segment - // For A::B::method, the return type doesn't come from the path - // This is actually: the function has no explicit return type (constructor/destructor pattern) - // or the first segments are the qualifier. - // For now, treat the whole thing as a function with the last segment as name + // Build class_path from all segments except the last + let class_path = { + let mut segments = tp.path.segments.clone(); + segments.pop(); + Some(Path { + leading_colon: tp.path.leading_colon, + segments, + }) + }; let inputs = parse_fn_params(p)?; // Trailing qualifiers let mut const_token = false; @@ -1159,14 +1185,33 @@ fn parse_item_fn_or_var<'de>(p: &mut Parser<'de>) -> Result, AstError> _ => break, } } - // Member initializer list for constructors - if p.eat(TokenKind::Colon).is_some() { - // Skip member init list - while p.peek_kind() != Some(TokenKind::LeftBrace) - && p.peek_kind() != Some(TokenKind::Semicolon) - && !p.is_empty() - { - p.bump()?; + // Member initializer list for constructors only + let mut member_init_list = Vec::new(); + let is_constructor = class_path.as_ref().is_some_and(|cp| { + cp.segments + .last() + .is_some_and(|seg| seg.ident.sym == fn_name.sym) + }); + if is_constructor && p.eat(TokenKind::Colon).is_some() { + loop { + skip_macro_annotations(p)?; + let member = parse_path(p)?; + p.expect(TokenKind::LeftParenthese)?; + let mut args = Punctuated::new(); + while p.peek_kind() != Some(TokenKind::RightParenthese) && !p.is_empty() { + let arg = parse_expr_no_comma(p)?; + if let Some(comma) = p.eat(TokenKind::Comma) { + args.push_pair(arg, comma); + } else { + args.push_value(arg); + break; + } + } + p.expect(TokenKind::RightParenthese)?; + member_init_list.push(MemberInit { member, args }); + if p.eat(TokenKind::Comma).is_none() { + break; + } } } // Trailing return type @@ -1194,6 +1239,7 @@ fn parse_item_fn_or_var<'de>(p: &mut Parser<'de>) -> Result, AstError> span: fn_name.span, kind: FundamentalKind::Void, }), + class_path, ident: fn_name, inputs, variadic: false, @@ -1204,6 +1250,7 @@ fn parse_item_fn_or_var<'de>(p: &mut Parser<'de>) -> Result, AstError> pure_virtual: false, defaulted: false, deleted: false, + member_init_list, }, block, })); @@ -1211,55 +1258,56 @@ fn parse_item_fn_or_var<'de>(p: &mut Parser<'de>) -> Result, AstError> return Err(p.error_at_current("expected identifier or function pointer")); } // Treat as macro invocation: consume everything until matching ) and ; - let _macro_ident = match &return_type { + let macro_ident = match &return_type { Type::Path(tp) => tp.path.segments[0].ident, _ => unreachable!(), }; - p.bump()?; // ( + let mut tokens = LinkedList::from([Token::new(macro_ident.span, TokenKind::Ident)]); + tokens.push_back(p.bump()?); // ( let mut depth = 1u32; while depth > 0 && !p.is_empty() { match p.peek_kind() { Some(TokenKind::LeftParenthese) => { depth += 1; - p.bump()?; + tokens.push_back(p.bump()?); } Some(TokenKind::RightParenthese) => { depth -= 1; - p.bump()?; + tokens.push_back(p.bump()?); } _ => { - p.bump()?; + tokens.push_back(p.bump()?); } } } // Optional trailing block { ... } (e.g., MACRO_NAME(suite, name) { body }) if p.peek_kind() == Some(TokenKind::LeftBrace) { - p.bump()?; + tokens.push_back(p.bump()?); let mut brace_depth = 1u32; while brace_depth > 0 && !p.is_empty() { match p.peek_kind() { Some(TokenKind::LeftBrace) => { brace_depth += 1; - p.bump()?; + tokens.push_back(p.bump()?); } Some(TokenKind::RightBrace) => { brace_depth -= 1; - p.bump()?; + tokens.push_back(p.bump()?); } _ => { - p.bump()?; + tokens.push_back(p.bump()?); } } } - } else { - p.eat(TokenKind::Semicolon); // optional trailing semicolon + } else if let Some(semi) = p.eat(TokenKind::Semicolon) { + tokens.push_back(semi); } - return Ok(Item::Verbatim(ItemVerbatim { tokens: Vec::new() })); + return Ok(Item::Verbatim(ItemVerbatim { tokens })); } // Parse name — could be a qualified path (e.g., Foo::bar, A::B::method) // or an operator overload (operator+), or a destructor (~Foo handled elsewhere) - let ident = if p.peek_kind() == Some(TokenKind::KeywordOperator) { + let (class_path, ident) = if p.peek_kind() == Some(TokenKind::KeywordOperator) { // operator overload: use span from 'operator' keyword through operator token let op_tok = p.bump()?; // Parse the operator symbol(s) @@ -1291,10 +1339,13 @@ fn parse_item_fn_or_var<'de>(p: &mut Parser<'de>) -> Result, AstError> let start_r: core::ops::Range = op_tok.src_span().into(); let end_r: core::ops::Range = end_span.into(); let span = SourceSpan::new(p.src, start_r.start, end_r.end - start_r.start); - Ident { - sym: &p.src[start_r.start..end_r.end], - span, - } + ( + None, + Ident { + sym: &p.src[start_r.start..end_r.end], + span, + }, + ) } else { // Regular ident, possibly qualified: Foo::bar let first = parse_ident(p)?; @@ -1304,7 +1355,8 @@ fn parse_item_fn_or_var<'de>(p: &mut Parser<'de>) -> Result, AstError> Some(TokenKind::Ident | TokenKind::Compl | TokenKind::KeywordOperator) ) { - // Qualified name: consume all :: segments + // Qualified name: consume all :: segments, tracking qualifier + let mut qualifier_segments = vec![PathSegment { ident: first }]; let mut last = first; while p.eat(TokenKind::DoubleColon).is_some() { if p.peek_kind() == Some(TokenKind::Compl) { @@ -1366,13 +1418,24 @@ fn parse_item_fn_or_var<'de>(p: &mut Parser<'de>) -> Result, AstError> break; } else if p.peek_kind() == Some(TokenKind::Ident) { last = parse_ident(p)?; + qualifier_segments.push(PathSegment { ident: last }); } else { break; } } - last + // The last segment in qualifier_segments is actually the function name, remove it + qualifier_segments.pop(); + let class_path = if qualifier_segments.is_empty() { + None + } else { + Some(Path { + leading_colon: false, + segments: qualifier_segments, + }) + }; + (class_path, last) } else { - first + (None, first) } }; @@ -1442,13 +1505,33 @@ fn parse_item_fn_or_var<'de>(p: &mut Parser<'de>) -> Result, AstError> } } - // Member initializer list for constructors: : member(arg), ... - if p.eat(TokenKind::Colon).is_some() { - while p.peek_kind() != Some(TokenKind::LeftBrace) - && p.peek_kind() != Some(TokenKind::Semicolon) - && !p.is_empty() - { - p.bump()?; + // Member initializer list for constructors only: : member(arg), ... + let mut member_init_list = Vec::new(); + let is_constructor = class_path.as_ref().is_some_and(|cp| { + cp.segments + .last() + .is_some_and(|seg| seg.ident.sym == ident.sym) + }); + if is_constructor && p.eat(TokenKind::Colon).is_some() { + loop { + skip_macro_annotations(p)?; + let member = parse_path(p)?; + p.expect(TokenKind::LeftParenthese)?; + let mut args = Punctuated::new(); + while p.peek_kind() != Some(TokenKind::RightParenthese) && !p.is_empty() { + let arg = parse_expr_no_comma(p)?; + if let Some(comma) = p.eat(TokenKind::Comma) { + args.push_pair(arg, comma); + } else { + args.push_value(arg); + break; + } + } + p.expect(TokenKind::RightParenthese)?; + member_init_list.push(MemberInit { member, args }); + if p.eat(TokenKind::Comma).is_none() { + break; + } } } @@ -1460,6 +1543,7 @@ fn parse_item_fn_or_var<'de>(p: &mut Parser<'de>) -> Result, AstError> static_token, explicit_token, return_type, + class_path, ident, inputs, variadic: false, @@ -1470,6 +1554,7 @@ fn parse_item_fn_or_var<'de>(p: &mut Parser<'de>) -> Result, AstError> pure_virtual, defaulted, deleted, + member_init_list, }; // Skip macro annotations between signature and body (e.g., GTEST_LOCK_EXCLUDED_(mutex_)) @@ -1583,8 +1668,8 @@ fn parse_item_fn_or_var<'de>(p: &mut Parser<'de>) -> Result, AstError> })), })) } else { - // Regular variable — treat as static for now at file scope - Ok(Item::Static(ItemStatic { + // Regular variable declaration (not static, not const) + Ok(Item::Var(ItemVar { attrs: Vec::new(), ty: return_type, ident, @@ -1594,7 +1679,13 @@ fn parse_item_fn_or_var<'de>(p: &mut Parser<'de>) -> Result, AstError> } fn is_const_type(ty: &Type) -> bool { - matches!(ty, Type::Qualified(q) if q.cv.const_token) + match ty { + Type::Qualified(q) => q.cv.const_token, + Type::Ptr(p) => is_const_type(&p.pointee), + Type::Reference(r) => is_const_type(&r.referent), + Type::Array(a) => is_const_type(&a.element), + _ => false, + } } // --------------------------------------------------------------------------- @@ -1823,6 +1914,7 @@ fn parse_fields_named<'de>( ) -> Result, AstError> { p.expect(TokenKind::LeftBrace)?; let mut members = Vec::new(); + let mut current_vis = Visibility::Inherited; while p.peek_kind() != Some(TokenKind::RightBrace) && !p.is_empty() { // Skip attributes inside class body @@ -1935,7 +2027,8 @@ fn parse_fields_named<'de>( if p.eat(TokenKind::Colon).is_some() { loop { - let member = parse_ident(p)?; + skip_macro_annotations(p)?; + let member = parse_path(p)?; p.expect(TokenKind::LeftParenthese)?; let mut args = Punctuated::new(); while p.peek_kind() != Some(TokenKind::RightParenthese) && !p.is_empty() { @@ -1997,18 +2090,21 @@ fn parse_fields_named<'de>( Some(TokenKind::KeywordPublic) => { p.bump()?; p.expect(TokenKind::Colon)?; + current_vis = Visibility::Public; members.push(Member::AccessSpecifier(Visibility::Public)); continue; } Some(TokenKind::KeywordProtected) => { p.bump()?; p.expect(TokenKind::Colon)?; + current_vis = Visibility::Protected; members.push(Member::AccessSpecifier(Visibility::Protected)); continue; } Some(TokenKind::KeywordPrivate) => { p.bump()?; p.expect(TokenKind::Colon)?; + current_vis = Visibility::Private; members.push(Member::AccessSpecifier(Visibility::Private)); continue; } @@ -2047,7 +2143,8 @@ fn parse_fields_named<'de>( Item::Static(s) => { members.push(Member::Field(Field { attrs: Vec::new(), - vis: Visibility::Inherited, + vis: current_vis, + static_token: true, ty: s.ty, ident: Some(s.ident), default_value: s.expr, @@ -2056,12 +2153,23 @@ fn parse_fields_named<'de>( Item::Const(c) => { members.push(Member::Field(Field { attrs: Vec::new(), - vis: Visibility::Inherited, + vis: current_vis, + static_token: false, ty: c.ty, ident: Some(c.ident), default_value: Some(c.expr), })); } + Item::Var(v) => { + members.push(Member::Field(Field { + attrs: Vec::new(), + vis: current_vis, + static_token: false, + ty: v.ty, + ident: Some(v.ident), + default_value: v.expr, + })); + } other => members.push(Member::Item(Box::new(other))), } } @@ -2929,6 +3037,55 @@ fn parse_type_suffix<'de>(p: &mut Parser<'de>, mut ty: Type<'de>) -> Result(p: &mut Parser<'de>) -> Result<(), AstError> { + loop { + match p.peek_kind() { + Some(TokenKind::NumberSign) => { + parse_item_macro(p)?; + } + Some(TokenKind::Ident) => { + let src = p.peek().unwrap().src(); + let is_macro_like = src.len() > 1 + && src.contains('_') + && src + .chars() + .all(|c| c.is_ascii_uppercase() || c == '_' || c.is_ascii_digit()); + if !is_macro_like { + break; + } + p.bump()?; + if p.peek_kind() == Some(TokenKind::LeftParenthese) { + p.bump()?; + let mut depth = 1u32; + while depth > 0 && !p.is_empty() { + match p.peek_kind() { + Some(TokenKind::LeftParenthese) => { + depth += 1; + p.bump()?; + } + Some(TokenKind::RightParenthese) => { + depth -= 1; + p.bump()?; + } + _ => { + p.bump()?; + } + } + } + } + } + _ => break, + } + } + Ok(()) +} + // --------------------------------------------------------------------------- // Path: ident (:: ident)* or :: ident (:: ident)* // --------------------------------------------------------------------------- @@ -4011,6 +4168,41 @@ mod tests { } } + #[test] + fn parse_typedef_array() { + let file = parse("typedef char type24[3];"); + match &file.items[0] { + Item::Typedef(td) => { + assert_eq!(td.ident.sym, "type24"); + match &td.ty { + Type::Array(arr) => { + assert!(matches!(arr.element.as_ref(), Type::Fundamental(_))); + assert!(arr.size.is_some()); + } + other => panic!("expected Array type, got {other:?}"), + } + } + other => panic!("expected Typedef, got {other:?}"), + } + } + + #[test] + fn parse_typedef_array_2d() { + let file = parse("typedef int matrix[3][4];"); + match &file.items[0] { + Item::Typedef(td) => { + assert_eq!(td.ident.sym, "matrix"); + match &td.ty { + Type::Array(outer) => { + assert!(matches!(outer.element.as_ref(), Type::Array(_))); + } + other => panic!("expected Array type, got {other:?}"), + } + } + other => panic!("expected Typedef, got {other:?}"), + } + } + #[test] fn parse_enum_test() { let file = parse("enum Color { Red, Green, Blue };"); @@ -4550,13 +4742,12 @@ mod tests { #[test] fn parse_string_concat() { let file = parse("const char* s = \"hello\" \" \" \"world\";"); - // const char* is a pointer to const char, parsed as Static (not Const) match &file.items[0] { - Item::Static(s) => match s.expr.as_ref().unwrap() { + Item::Const(c) => match &c.expr { Expr::Lit(lit) => assert_eq!(lit.kind, LitKind::String), other => panic!("expected Lit, got {other:?}"), }, - other => panic!("expected Static, got {other:?}"), + other => panic!("expected Const, got {other:?}"), } } @@ -4648,7 +4839,7 @@ mod tests { match &f.members[0] { Member::Constructor(ctor) => { assert_eq!(ctor.member_init_list.len(), 1); - assert_eq!(ctor.member_init_list[0].member.sym, "m_x"); + assert_eq!(ctor.member_init_list[0].member.to_string(), "m_x"); } other => panic!("expected Constructor, got {other:?}"), } diff --git a/src/ast/punct.rs b/src/ast/punct.rs index 68935ea..93a9648 100644 --- a/src/ast/punct.rs +++ b/src/ast/punct.rs @@ -12,7 +12,7 @@ use crate::lex::Token; /// like function arguments: `a, b, c`. #[derive(Debug, Clone, PartialEq)] pub struct Punctuated<'de, T> { - inner: Vec<(T, Option>)>, + pub(crate) inner: Vec<(T, Option>)>, } impl<'de, T> Punctuated<'de, T> { diff --git a/src/ast/ty.rs b/src/ast/ty.rs index 12b03c2..f91d1c5 100644 --- a/src/ast/ty.rs +++ b/src/ast/ty.rs @@ -9,7 +9,7 @@ use super::item::Path; use super::punct::Punctuated; /// The kind of a fundamental (built-in) type. -#[derive(Debug, Clone, Copy, PartialEq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum FundamentalKind { Void, Bool, @@ -67,6 +67,17 @@ pub enum Type<'de> { Qualified(TypeQualified<'de>), } +impl<'de> Type<'de> { + /// Check if the type is `auto`. + pub fn is_auto(&self) -> bool { + match self { + Type::Auto(_) => true, + Type::Qualified(q) => q.ty.is_auto(), + _ => false, + } + } +} + /// A fundamental (built-in) type. #[derive(Debug, Clone, Copy, PartialEq)] pub struct TypeFundamental<'de> { diff --git a/src/ast/visit.rs b/src/ast/visit.rs index 0faa197..d7e0587 100644 --- a/src/ast/visit.rs +++ b/src/ast/visit.rs @@ -106,6 +106,12 @@ pub fn visit_item<'de, V: Visit<'de> + ?Sized>(v: &mut V, item: &Item<'de>) { v.visit_expr(expr); } } + Item::Var(i) => { + v.visit_type(&i.ty); + if let Some(expr) = &i.expr { + v.visit_expr(expr); + } + } Item::ForeignMod(i) => { for fi in &i.items { match fi { @@ -200,6 +206,9 @@ pub fn visit_member<'de, V: Visit<'de> + ?Sized>(v: &mut V, member: &Member<'de> pub fn visit_signature<'de, V: Visit<'de> + ?Sized>(v: &mut V, sig: &Signature<'de>) { v.visit_type(&sig.return_type); + if let Some(class_path) = &sig.class_path { + v.visit_path(class_path); + } v.visit_ident(&sig.ident); for arg in sig.inputs.iter() { v.visit_type(&arg.ty); @@ -207,6 +216,12 @@ pub fn visit_signature<'de, V: Visit<'de> + ?Sized>(v: &mut V, sig: &Signature<' v.visit_ident(ident); } } + for init in &sig.member_init_list { + v.visit_path(&init.member); + for expr in init.args.iter() { + v.visit_expr(expr); + } + } } pub fn visit_block<'de, V: Visit<'de> + ?Sized>(v: &mut V, block: &Block<'de>) { diff --git a/src/lib.rs b/src/lib.rs index 85b30ef..db147d5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,8 @@ +#[cfg(feature = "ast")] pub mod ast; pub mod lex; +#[cfg(feature = "transpiler")] +pub mod transpile; use std::fmt; pub use lex::Lexer; @@ -32,6 +35,11 @@ impl<'de> SourceSpan<'de> { pub fn src(&self) -> &'de str { &self.src[core::ops::Range::from(*self)] } + + /// Returns the full backing source string (not just this span's slice). + pub fn full_source(&self) -> &'de str { + self.src + } } impl<'de> fmt::Debug for SourceSpan<'de> { diff --git a/src/transpile/error.rs b/src/transpile/error.rs new file mode 100644 index 0000000..6e0690c --- /dev/null +++ b/src/transpile/error.rs @@ -0,0 +1,39 @@ +//! Transpilation error types with rich diagnostics via miette + +use miette::Diagnostic; +use thiserror::Error; + +/// Errors that can occur during C++ → Rust type mapping +#[derive(Diagnostic, Debug, Error)] +pub enum TranspileError { + /// C++ path has no registered mapping + #[error("No mapping registered for C++ path `{path}`")] + UnmappedPath { + path: String, + #[source_code] + src: String, + #[label = "no mapping for this path"] + err_span: miette::SourceSpan, + }, + /// C++ type variant cannot be mapped to Rust + #[error("{message}")] + UnsupportedType { + message: String, + #[source_code] + src: String, + #[label = "{message}"] + err_span: miette::SourceSpan, + }, + /// C++ expression cannot be transpiled to Rust + #[error("{message}")] + UnsupportedExpr { + message: String, + #[source_code] + src: String, + #[label = "{message}"] + err_span: miette::SourceSpan, + }, + /// Invalid Rust type syntax provided to the builder + #[error("Invalid Rust type syntax `{rust_type}`: {reason}")] + InvalidRustType { rust_type: String, reason: String }, +} diff --git a/src/transpile/expr.rs b/src/transpile/expr.rs new file mode 100644 index 0000000..b0cc3f5 --- /dev/null +++ b/src/transpile/expr.rs @@ -0,0 +1,90 @@ +use proc_macro2::TokenStream; +use quote::ToTokens; + +use crate::ast::expr::{Expr, ExprBool, ExprNullptr, ExprParen, ExprUnary, UnaryOp}; +use crate::transpile::{Transpile, Transpiler}; + +use super::error::TranspileError; + +/// Extract a source span from an expression (best-effort). +pub(crate) fn expr_span<'de>(expr: &Expr<'de>) -> Option> { + match expr { + Expr::Lit(l) => Some(l.span), + Expr::Ident(i) => Some(i.ident.span), + Expr::Path(p) => p.path.segments.first().map(|s| s.ident.span), + _ => None, + } +} + +/// Build a [`TranspileError::UnsupportedExpr`] from an expression. +pub(crate) fn unsupported_from_expr(message: &str, expr: &Expr<'_>) -> TranspileError { + match expr_span(expr) { + Some(span) => TranspileError::UnsupportedExpr { + message: message.to_owned(), + src: span.full_source().to_owned(), + err_span: span.into(), + }, + None => TranspileError::UnsupportedExpr { + message: message.to_owned(), + src: String::new(), + err_span: miette::SourceSpan::new(0.into(), 0), + }, + } +} + +impl<'de> Transpile for Expr<'de> { + #[allow(clippy::only_used_in_recursion)] + fn transpile( + &self, + transpiler: &Transpiler, + tokens: &mut TokenStream, + ) -> Result<(), TranspileError> { + match self { + Expr::Lit(lit) => { + let rust_expr: syn::Expr = syn::parse_str(lit.span.src()).map_err(|_| { + TranspileError::UnsupportedExpr { + message: "cannot parse literal".to_owned(), + src: lit.span.full_source().to_owned(), + err_span: lit.span.into(), + } + })?; + tokens.extend(quote::quote!(#rust_expr)); + } + Expr::Bool(ExprBool { value, .. }) => { + tokens.extend(quote::quote!(#value)); + } + Expr::Nullptr(ExprNullptr { .. }) => { + tokens.extend(quote::quote!(std::ptr::null())); + } + Expr::Ident(i) => { + i.ident.to_tokens(tokens); + } + Expr::Path(p) => { + let rust_expr = syn::Expr::try_from(p.path.clone()) + .map_err(|_| unsupported_from_expr("cannot transpile path expression", self))?; + tokens.extend(quote::quote!(#rust_expr)); + } + Expr::Unary(ExprUnary { + op: UnaryOp::Negate, + operand, + }) => { + let mut inner_tokens = TokenStream::new(); + operand.transpile(transpiler, &mut inner_tokens)?; + tokens.extend(quote::quote!(- #inner_tokens)); + } + Expr::Paren(ExprParen { expr }) => { + let mut inner_tokens = TokenStream::new(); + expr.transpile(transpiler, &mut inner_tokens)?; + tokens.extend(quote::quote!((#inner_tokens))); + } + other => { + return Err(unsupported_from_expr( + "expression cannot be transpiled to Rust", + other, + )); + } + } + + Ok(()) + } +} diff --git a/src/transpile/item.rs b/src/transpile/item.rs new file mode 100644 index 0000000..1ae8804 --- /dev/null +++ b/src/transpile/item.rs @@ -0,0 +1,258 @@ +use proc_macro2::TokenStream; +use quote::ToTokens; +use syn::parse_str; + +use crate::{ + ast::{Field, Ident, ItemEnum, Path, Visibility}, + transpile::{Transpile, TranspileError, Transpiler}, +}; + +impl<'de> From<&Ident<'de>> for syn::Ident { + fn from(ident: &Ident<'de>) -> Self { + syn::Ident::new(ident.sym, proc_macro2::Span::call_site()) + } +} + +impl<'de> From> for syn::Ident { + fn from(ident: Ident<'de>) -> Self { + Self::from(&ident) + } +} + +impl<'de> ToTokens for Ident<'de> { + fn to_tokens(&self, tokens: &mut TokenStream) { + let ident: syn::Ident = self.into(); + ident.to_tokens(tokens); + } +} + +impl From<&Visibility> for syn::Visibility { + fn from(visibility: &Visibility) -> Self { + match visibility { + Visibility::Public => syn::parse_str("pub").expect("Failed to parse pub visibility"), + Visibility::Protected => { + syn::parse_str("pub(crate)").expect("Failed to parse pub(crate) visibility") + } + Visibility::Inherited | Visibility::Private => syn::Visibility::Inherited, + } + } +} + +impl From for syn::Visibility { + fn from(visibility: Visibility) -> Self { + Self::from(&visibility) + } +} + +impl ToTokens for Visibility { + fn to_tokens(&self, tokens: &mut TokenStream) { + let visibility: syn::Visibility = self.into(); + visibility.to_tokens(tokens); + } +} + +macro_rules! impl_try_from_path { + ($($target:ty),* $(,)?) => { + $( + impl<'de> TryFrom> for $target { + type Error = syn::Error; + + fn try_from(path: Path<'de>) -> Result { + parse_str(&path.to_string()) + } + } + )* + }; +} +impl_try_from_path!(syn::Type, syn::Path, syn::Expr); + +impl<'de> Transpile for Field<'de> { + fn transpile( + &self, + transpiler: &Transpiler, + tokens: &mut TokenStream, + ) -> Result<(), TranspileError> { + if let Some(ident) = self.ident { + self.vis.to_tokens(tokens); + ident.to_tokens(tokens); + tokens.extend(quote::quote! { : }); + self.ty.transpile(transpiler, tokens)?; + } + + Ok(()) + } +} + +impl<'de> Transpile for ItemEnum<'de> { + fn transpile( + &self, + transpiler: &Transpiler, + tokens: &mut TokenStream, + ) -> Result<(), TranspileError> { + let name: syn::Ident = self + .ident + .as_ref() + .ok_or_else(|| TranspileError::UnsupportedType { + message: "anonymous enums cannot be transpiled".to_owned(), + src: String::new(), + err_span: miette::SourceSpan::new(0.into(), 0), + })? + .into(); + + // Build #[repr(...)] if an underlying type is specified + let repr_attr = match &self.underlying_type { + Some(ty) => { + let rust_ty = transpiler.ty_mapper.map_type(ty)?; + Some(quote::quote! { #[repr(#rust_ty)] }) + } + None => None, + }; + + // Build variant tokens + let mut variant_tokens = TokenStream::new(); + for variant in self.variants.iter() { + let v_name: syn::Ident = (&variant.ident).into(); + if let Some(ref disc) = variant.discriminant { + let mut expr_tokens = TokenStream::new(); + disc.transpile(transpiler, &mut expr_tokens)?; + variant_tokens.extend(quote::quote! { #v_name = #expr_tokens, }); + } else { + variant_tokens.extend(quote::quote! { #v_name, }); + } + } + + tokens.extend(quote::quote! { + #[doc = concat!(" Auto-transpiled enum for ", stringify!(#name))] + #repr_attr + pub enum #name { #variant_tokens } + }); + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use core::panic; + + use super::*; + use crate::ast::{self, parse_file}; + + // ---- ItemEnum transpilation ---- + + #[test] + fn enum_class_transpiles() -> Result<(), TranspileError> { + let transpiler = Transpiler::default(); + let src = "enum class Color { Red, Green, Blue };"; + let file = parse_file(src).unwrap(); + match &file.items[0] { + crate::ast::Item::Enum(e) => { + assert_eq!( + e.transpile_token_stream(&transpiler)?.to_string(), + "# [doc = concat ! (\" Auto-transpiled enum for \" , stringify ! (Color))] pub enum Color { Red , Green , Blue , }" + ); + } + item => panic!("expected ItemEnum, got {item:?}"), + } + + Ok(()) + } + + #[test] + fn enum_with_underlying_type_transpiles() -> Result<(), TranspileError> { + let transpiler = Transpiler::default(); + let src = "enum Color : int { Red, Green, Blue };"; + let file = parse_file(src).unwrap(); + match &file.items[0] { + crate::ast::Item::Enum(e) => { + assert_eq!( + e.transpile_token_stream(&transpiler)?.to_string(), + "# [doc = concat ! (\" Auto-transpiled enum for \" , stringify ! (Color))] # [repr (i32)] pub enum Color { Red , Green , Blue , }" + ); + } + item => panic!("expected ItemEnum, got {item:?}"), + } + + Ok(()) + } + + #[test] + fn enum_with_discriminants_transpiles() -> Result<(), TranspileError> { + let transpiler = Transpiler::default(); + let src = "enum class Color : unsigned char { A = 1, B = 2 };"; + let file = parse_file(src).unwrap(); + match &file.items[0] { + crate::ast::Item::Enum(e) => { + assert_eq!( + e.transpile_token_stream(&transpiler)?.to_string(), + "# [doc = concat ! (\" Auto-transpiled enum for \" , stringify ! (Color))] # [repr (u8)] pub enum Color { A = 1 , B = 2 , }" + ); + } + item => panic!("expected ItemEnum, got {item:?}"), + } + + Ok(()) + } + + #[test] + fn class_member_variable_transpiles() -> Result<(), TranspileError> { + let transpiler = Transpiler::default(); + let src = + "class Color { public: int R; private: unsigned long long G; protected: short B; };"; + let file = parse_file(src).unwrap(); + match &file.items[0] { + ast::Item::Class(ast::ItemClass { + fields: ast::Fields::Named(named_fields), + .. + }) => { + if let ast::Member::Field(field) = &named_fields.members[1] { + assert_eq!( + field + .transpile_token_stream(&transpiler) + .expect("Failed to transpile field[0]") + .to_string(), + "pub R : i32", + ); + } else { + panic!( + "expected field member[0], got {:?}", + &named_fields.members[0] + ); + } + + if let ast::Member::Field(field) = &named_fields.members[3] { + assert_eq!( + field + .transpile_token_stream(&transpiler) + .expect("Failed to transpile field[3]") + .to_string(), + "G : u64", + ); + } else { + panic!( + "expected field member[3], got {:?}", + &named_fields.members[3] + ); + } + + if let ast::Member::Field(field) = &named_fields.members[5] { + assert_eq!( + field + .transpile_token_stream(&transpiler) + .expect("Failed to transpile field[5]") + .to_string(), + "pub (crate) B : i16", + ); + } else { + panic!( + "expected field member[5], got {:?}", + &named_fields.members[5] + ); + } + } + item => panic!("expected ItemClass, got {item:?}"), + } + + Ok(()) + } +} diff --git a/src/transpile/mod.rs b/src/transpile/mod.rs new file mode 100644 index 0000000..9a10c7a --- /dev/null +++ b/src/transpile/mod.rs @@ -0,0 +1,59 @@ +//! Transpiler module to convert C++ ([`crate::ast`]) into Rust ([`syn`]) + +pub mod error; +pub mod expr; +pub mod item; +pub mod ty; + +use std::collections::HashSet; + +pub use error::TranspileError; +use proc_macro2::TokenStream; +use serde::Deserialize; +pub use ty::*; + +/// Transpiler struct, which is the configuration entrypoint for all transpilation operations. +#[derive(Debug, Default, Clone, Deserialize)] +pub struct Transpiler { + /// List of type we don't want to transpile + #[serde(default)] + pub skip_types: HashSet, + /// Type mapper to map C++ types to Rust types + #[serde(default)] + pub ty_mapper: TypeMapper, +} + +pub trait Transpile { + fn transpile( + &self, + transpiler: &Transpiler, + tokens: &mut TokenStream, + ) -> Result<(), TranspileError>; + + /// Convert `self` with a `Transpiler` configuration into a `TokenStream` object. + /// + /// This method is implicitly implemented using `transpile`, and acts as a + /// convenience method for consumers of the `Transpile` trait. + fn transpile_token_stream( + &self, + transpiler: &Transpiler, + ) -> Result { + let mut tokens = TokenStream::new(); + self.transpile(transpiler, &mut tokens)?; + Ok(tokens) + } + + /// Convert `self` with a `Transpiler` configuration into a `TokenStream` object. + /// + /// This method is implicitly implemented using `transpile`, and acts as a + /// convenience method for consumers of the `Transpile` trait. + fn transpile_into_token_stream( + self, + transpiler: &Transpiler, + ) -> Result + where + Self: Sized, + { + self.transpile_token_stream(transpiler) + } +} diff --git a/src/transpile/ty.rs b/src/transpile/ty.rs new file mode 100644 index 0000000..8d638a4 --- /dev/null +++ b/src/transpile/ty.rs @@ -0,0 +1,1132 @@ +use std::collections::HashMap; + +use proc_macro2::TokenStream; +use quote::ToTokens as _; +use serde::Deserialize; +use serde::de::{self, MapAccess, Visitor}; + +use crate::ast::ItemTypedef; +use crate::ast::expr::{Expr, ExprLit, LitKind}; +use crate::ast::item::{ItemConst, ItemStatic, Path}; +use crate::ast::ty::{FundamentalKind, TemplateArg, Type}; +use crate::transpile::expr::expr_span; +use crate::transpile::{Transpile, Transpiler}; + +use super::error::TranspileError; + +impl From for syn::Type { + fn from(kind: FundamentalKind) -> Self { + use FundamentalKind::*; + let s = match kind { + Void => "()", + Bool => "bool", + Char | Char8 | Char16 | Char32 | Wchar => "char", + UnsignedChar => "u8", + UnsignedShort => "u16", + UnsignedInt => "u32", + Short => "i16", + Int => "i32", + Long | LongLong => "i64", + Float => "f32", + Double | LongDouble => "f64", + SignedChar => "i8", + UnsignedLong | UnsignedLongLong => "u64", + }; + syn::parse_str(s).unwrap() + } +} + +/// Configurable mapper from C++ AST types to `syn::Type`. +/// +/// Built via [`TypeMapper::builder()`] or [`TypeMapper::new()`] (defaults only). +/// +/// ``` +/// use cppshift::transpile::TypeMapper; +/// use cppshift::ast::ty::{Type, TypeFundamental, FundamentalKind}; +/// use cppshift::SourceSpan; +/// +/// let mapper = TypeMapper::new(); +/// let src = "int"; +/// let ty = Type::Fundamental(TypeFundamental { +/// span: SourceSpan::new(src, 0, 3), +/// kind: FundamentalKind::Int, +/// }); +/// let rust_ty = mapper.map_type(&ty).expect("fundamental types always map"); +/// assert_eq!(quote::quote!(#rust_ty).to_string(), "i32"); +/// ``` +#[derive(Debug, Clone)] +pub struct TypeMapper { + paths: HashMap, +} + +/// Builder for [`TypeMapper`]. +pub struct TypeMapperBuilder { + paths: HashMap, +} + +impl TypeMapper { + /// Create a builder for configuring type mappings. + pub fn builder() -> TypeMapperBuilder { + TypeMapperBuilder { + paths: HashMap::new(), + } + } + + /// Create a mapper with default fundamental type mappings only. + pub fn new() -> Self { + Self::builder().build() + } + + /// Map a C++ AST type to a `syn::Type`. + /// + /// # Errors + /// + /// Returns [`TranspileError`] if the type cannot be mapped (e.g. unknown path, + /// `auto`, `decltype`, unsized array). + pub fn map_type(&self, ty: &Type<'_>) -> Result { + match ty { + Type::Fundamental(f) => Ok(syn::Type::from(f.kind)), + Type::Path(p) => self.resolve_path(&p.path), + Type::Ptr(p) => { + let inner = self.map_type(&p.pointee)?; + if p.cv.const_token { + Ok(syn::parse_quote!(*const #inner)) + } else { + Ok(syn::parse_quote!(*mut #inner)) + } + } + Type::Reference(r) => { + let inner = self.map_type(&r.referent)?; + if r.cv.const_token { + Ok(syn::parse_quote!(&#inner)) + } else { + Ok(syn::parse_quote!(&mut #inner)) + } + } + Type::RvalueReference(r) => self.map_type(&r.referent), + Type::Array(a) => { + let inner = self.map_type(&a.element)?; + match &a.size { + Some(Expr::Lit(lit)) if lit.kind == LitKind::Integer => { + let n: usize = + lit.span.src().parse().map_err(|_| { + unsupported_from_type("invalid array size literal", ty) + })?; + let lit_n = + syn::LitInt::new(&n.to_string(), proc_macro2::Span::call_site()); + Ok(syn::parse_quote!([#inner; #lit_n])) + } + _ => Err(unsupported_from_type("unsized or dynamic array", ty)), + } + } + Type::FnPtr(f) => { + let ret = self.map_type(&f.return_type)?; + let params: Result, _> = + f.params.iter().map(|p| self.map_type(p)).collect(); + let params = params?; + Ok(syn::parse_quote!(fn(#(#params),*) -> #ret)) + } + Type::Qualified(q) => self.map_type(&q.ty), + Type::TemplateInst(t) => { + let base_ty = self.resolve_path(&t.path)?; + let mapped_args: Result, _> = t + .args + .iter() + .map(|arg| match arg { + TemplateArg::Type(ty) => self.map_type(ty), + TemplateArg::Expr(_) => { + Err(unsupported_from_type("expression template argument", ty)) + } + }) + .collect(); + let mapped_args = mapped_args?; + + let mut result = base_ty; + if let syn::Type::Path(ref mut type_path) = result + && let Some(last_seg) = type_path.path.segments.last_mut() + { + last_seg.arguments = + syn::PathArguments::AngleBracketed(syn::AngleBracketedGenericArguments { + colon2_token: None, + lt_token: syn::token::Lt::default(), + args: mapped_args + .into_iter() + .map(syn::GenericArgument::Type) + .collect(), + gt_token: syn::token::Gt::default(), + }); + } + Ok(result) + } + Type::Auto(_) => Err(unsupported_from_type( + "auto type cannot be mapped to Rust", + ty, + )), + Type::Decltype(_) => Err(unsupported_from_type( + "decltype cannot be mapped to Rust", + ty, + )), + } + } + + fn resolve_path(&self, path: &Path<'_>) -> Result { + let key = path.to_string(); + if let Some(ty) = self.paths.get(&key) { + Ok(ty.clone()) + } else { + syn::Type::try_from(path.clone()).map_err(|_| unmapped_path_error(&key, path)) + } + } +} + +impl Default for TypeMapper { + fn default() -> Self { + Self::new() + } +} + +impl<'de> Deserialize<'de> for TypeMapper { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + struct TypeMapperVisitor; + + impl<'de> Visitor<'de> for TypeMapperVisitor { + type Value = TypeMapper; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("a map of C++ type paths to Rust type strings") + } + + fn visit_map(self, mut access: M) -> Result + where + M: MapAccess<'de>, + { + let mut builder = TypeMapper::builder(); + while let Some((cpp_path, rust_type)) = access.next_entry::()? { + builder = builder + .map_path(&cpp_path, &rust_type) + .map_err(de::Error::custom)?; + } + Ok(builder.build()) + } + } + + deserializer.deserialize_map(TypeMapperVisitor) + } +} + +impl TypeMapperBuilder { + /// Register a C++ path → Rust type mapping. + /// + /// The `rust_type` string is parsed via `syn::parse_str`. + /// + /// # Errors + /// + /// Returns [`TranspileError::InvalidRustType`] if `rust_type` is not valid Rust syntax. + pub fn map_path(mut self, cpp_path: &str, rust_type: &str) -> Result { + let ty: syn::Type = + syn::parse_str(rust_type).map_err(|e| TranspileError::InvalidRustType { + rust_type: rust_type.to_owned(), + reason: e.to_string(), + })?; + self.paths.insert(cpp_path.to_owned(), ty); + Ok(self) + } + + /// Register a C++ path → Rust type mapping with a pre-built `syn::Type`. + pub fn map_path_to_type(mut self, cpp_path: &str, ty: syn::Type) -> Self { + self.paths.insert(cpp_path.to_owned(), ty); + self + } + + /// Build the [`TypeMapper`]. + pub fn build(self) -> TypeMapper { + TypeMapper { paths: self.paths } + } +} + +/// Build an [`TranspileError::UnmappedPath`] from a path string and AST path. +fn unmapped_path_error(path_str: &str, path: &Path<'_>) -> TranspileError { + let span = path.segments.first().map(|s| s.ident.span); + match span { + Some(span) => TranspileError::UnmappedPath { + path: path_str.to_owned(), + src: span.full_source().to_owned(), + err_span: span.into(), + }, + None => TranspileError::UnmappedPath { + path: path_str.to_owned(), + src: String::new(), + err_span: miette::SourceSpan::new(0.into(), 0), + }, + } +} + +/// Build an [`TranspileError::UnsupportedType`] by extracting the best span from a [`Type`]. +fn unsupported_from_type(message: &str, ty: &Type<'_>) -> TranspileError { + match type_span(ty) { + Some(span) => TranspileError::UnsupportedType { + message: message.to_owned(), + src: span.full_source().to_owned(), + err_span: span.into(), + }, + None => TranspileError::UnsupportedType { + message: message.to_owned(), + src: String::new(), + err_span: miette::SourceSpan::new(0.into(), 0), + }, + } +} + +/// Extract the most relevant source span from a type, if available. +fn type_span<'de>(ty: &Type<'de>) -> Option> { + match ty { + Type::Fundamental(f) => Some(f.span), + Type::Path(p) => p.path.segments.first().map(|s| s.ident.span), + Type::Auto(a) => Some(a.span), + Type::Decltype(d) => expr_span(&d.expr), + Type::Ptr(p) => type_span(&p.pointee), + Type::Reference(r) => type_span(&r.referent), + Type::RvalueReference(r) => type_span(&r.referent), + Type::Array(a) => type_span(&a.element), + Type::FnPtr(f) => type_span(&f.return_type), + Type::Qualified(q) => type_span(&q.ty), + Type::TemplateInst(t) => t.path.segments.first().map(|s| s.ident.span), + } +} + +impl<'de> Transpile for Type<'de> { + fn transpile( + &self, + transpiler: &Transpiler, + tokens: &mut TokenStream, + ) -> Result<(), TranspileError> { + transpiler.ty_mapper.map_type(self)?.to_tokens(tokens); + Ok(()) + } +} + +impl<'de> Transpile for ItemTypedef<'de> { + fn transpile( + &self, + transpiler: &Transpiler, + tokens: &mut TokenStream, + ) -> Result<(), TranspileError> { + let name = self.ident; + // char array typedefs → &str + if let Type::Array(arr) = &self.ty + && is_char_element_type(&arr.element) + { + tokens.extend(quote::quote! { + #[doc = concat!(" Auto-transpiled type for ", stringify!(#name))] + pub type #name = &str; + }); + } else if let Type::Path(p) = &self.ty { + let rust_ty = transpiler.ty_mapper.resolve_path(&p.path)?; + tokens.extend(quote::quote! { + #[doc = concat!(" Auto-transpiled type for ", stringify!(#name))] + pub type #name = #rust_ty; + }); + } else { + let rust_ty = transpiler.ty_mapper.map_type(&self.ty)?; + tokens.extend(quote::quote! { + #[doc = concat!(" Auto-transpiled type for ", stringify!(#name))] + pub type #name = #rust_ty; + }); + } + + Ok(()) + } +} + +/// Returns `true` if `ty` is a byte-sized char type (char, char8_t, unsigned char, signed char), +/// stripping CV-qualifiers. +fn is_char_element_type(ty: &Type<'_>) -> bool { + use FundamentalKind::*; + match ty { + Type::Fundamental(f) => matches!(f.kind, Char | Char8 | Char16 | Char32 | Wchar), + Type::Qualified(q) => is_char_element_type(&q.ty), + _ => false, + } +} + +/// Try to transpile an unsized char array initialised with a string literal. +/// +/// C++: `const char foo[] = "ALPN";` / `static char foo[] = "ALPN";` +/// Rust: `pub static foo: &str = "ALPN";` +/// +/// `keyword` is the Rust storage keyword to emit (`const` or `static`). +/// Returns `None` if the pattern doesn't match and normal mapping should proceed. +fn try_transpile_char_array_from_str_lit<'de>( + name: crate::ast::item::Ident<'de>, + element: &Type<'de>, + expr: &Expr<'de>, + transpiler: &Transpiler, + keyword: &str, + tokens: &mut TokenStream, +) -> Option> { + if !is_char_element_type(element) { + return None; + } + let Expr::Lit(ExprLit { + kind: LitKind::String, + .. + }) = expr + else { + return None; + }; + + let mut expr_tokens = TokenStream::new(); + if let Err(e) = expr.transpile(transpiler, &mut expr_tokens) { + return Some(Err(e)); + } + + let keyword_tok: proc_macro2::TokenStream = keyword.parse().unwrap(); + tokens.extend(quote::quote! { + pub #keyword_tok #name: &str = #expr_tokens; + }); + Some(Ok(())) +} + +impl<'de> Transpile for ItemConst<'de> { + fn transpile( + &self, + transpiler: &Transpiler, + tokens: &mut TokenStream, + ) -> Result<(), TranspileError> { + let name = self.ident; + + // Special case: `const char foo[] = "ALPN";` → `pub const foo: [u8; 5] = *b"ALPN\0";` + if let Type::Array(arr) = &self.ty + && arr.size.is_none() + && let Some(result) = try_transpile_char_array_from_str_lit( + name, + &arr.element, + &self.expr, + transpiler, + "const", + tokens, + ) + { + return result; + } + + let mut expr_tokens = TokenStream::new(); + self.expr.transpile(transpiler, &mut expr_tokens)?; + + match &self.expr { + // C++ string constants with string literal init → `&str` + Expr::Lit(ExprLit { + kind: LitKind::String, + .. + }) => { + tokens.extend(quote::quote! { + #[doc = " Auto-transpiled &str const"] + pub const #name: &str = #expr_tokens; + }); + } + Expr::Lit(ExprLit { kind, .. }) => { + let rust_ty = transpiler.ty_mapper.map_type(&self.ty)?; + if kind.match_type(&self.ty) { + tokens.extend(quote::quote! { + #[doc = concat!(" Auto-transpiled const literal ", stringify!(#rust_ty))] + pub const #name: #rust_ty = #expr_tokens; + }); + } else { + tokens.extend(quote::quote! { + #[doc = concat!(" Auto-transpiled const literal ", stringify!(#rust_ty))] + pub const #name: #rust_ty = #expr_tokens as #rust_ty; + }); + } + } + _ => { + let rust_ty = transpiler.ty_mapper.map_type(&self.ty)?; + tokens.extend(quote::quote! { + #[doc = concat!(" Auto-transpiled const ", stringify!(#rust_ty))] + pub const #name: #rust_ty = #expr_tokens; + }); + } + } + + Ok(()) + } +} + +impl<'de> Transpile for ItemStatic<'de> { + fn transpile( + &self, + transpiler: &Transpiler, + tokens: &mut TokenStream, + ) -> Result<(), TranspileError> { + let name = self.ident; + let expr = self + .expr + .as_ref() + .ok_or_else(|| TranspileError::UnsupportedExpr { + message: "Rust statics require an initializer".to_owned(), + src: name.span.full_source().to_owned(), + err_span: name.span.into(), + })?; + + // Special case: `static char foo[] = "ALPN";` → `pub static foo: [u8; 5] = *b"ALPN\0";` + if let Type::Array(arr) = &self.ty + && arr.size.is_none() + && let Some(result) = try_transpile_char_array_from_str_lit( + name, + &arr.element, + expr, + transpiler, + "static", + tokens, + ) + { + return result; + } + + let mut expr_tokens = TokenStream::new(); + expr.transpile(transpiler, &mut expr_tokens)?; + + match expr { + // C++ string statics with string literal init → `&str` + Expr::Lit(ExprLit { + kind: LitKind::String, + .. + }) => { + tokens.extend(quote::quote! { + #[doc = " Auto-transpiled &str static"] + pub static #name: &str = #expr_tokens; + }); + } + Expr::Lit(ExprLit { kind, .. }) => { + let rust_ty = transpiler.ty_mapper.map_type(&self.ty)?; + if kind.match_type(&self.ty) { + tokens.extend(quote::quote! { + #[doc = concat!(" Auto-transpiled static literal ", stringify!(#rust_ty))] + pub static #name: #rust_ty = #expr_tokens; + }); + } else { + tokens.extend(quote::quote! { + #[doc = concat!(" Auto-transpiled static literal ", stringify!(#rust_ty))] + pub static #name: #rust_ty = #expr_tokens as #rust_ty; + }); + } + } + _ => { + let rust_ty = transpiler.ty_mapper.map_type(&self.ty)?; + tokens.extend(quote::quote! { + #[doc = concat!(" Auto-transpiled static ", stringify!(#rust_ty))] + pub static #name: #rust_ty = #expr_tokens; + }); + } + } + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use quote::quote; + + use crate::SourceSpan; + use crate::ast::expr::ExprLit; + use crate::ast::item::{Ident, PathSegment}; + use crate::ast::punct::Punctuated; + use crate::ast::{parse_file, ty::*}; + + fn ty_str(ty: &syn::Type) -> String { + quote!(#ty).to_string() + } + + fn make_fundamental(src: &str, kind: FundamentalKind) -> Type<'_> { + Type::Fundamental(TypeFundamental { + span: SourceSpan::new(src, 0, src.len()), + kind, + }) + } + + fn make_path<'a>(src: &'a str, segments: &[&'a str]) -> Type<'a> { + Type::Path(TypePath { + path: make_raw_path(src, segments), + }) + } + + fn make_raw_path<'a>(src: &'a str, segments: &[&'a str]) -> Path<'a> { + Path { + leading_colon: false, + segments: segments + .iter() + .map(|s| { + let offset = s.as_ptr() as usize - src.as_ptr() as usize; + PathSegment { + ident: Ident { + sym: s, + span: SourceSpan::new(src, offset, s.len()), + }, + } + }) + .collect(), + } + } + + #[test] + fn typedef_transpiles() -> Result<(), TranspileError> { + let transpiler = Transpiler { + ty_mapper: TypeMapper::builder() + .map_path("std::string", "BytesMut")? + .build(), + ..Default::default() + }; + + let typedef_header = r#" + typedef Custom::int16 MyInt16; + typedef std::string MyString; + typedef char type24[3]; + "#; + + let typedef_file = parse_file(typedef_header).unwrap(); + let mut typedef_iter = typedef_file.items.iter(); + + match typedef_iter.next() { + Some(crate::ast::Item::Typedef(t)) => { + assert_eq!( + "# [doc = concat ! (\" Auto-transpiled type for \" , stringify ! (MyInt16))] pub type MyInt16 = Custom :: int16 ;", + t.transpile_token_stream(&transpiler)?.to_string() + ); + } + t => panic!("unexpected typedef {t:?}"), + }; + + match typedef_iter.next() { + Some(crate::ast::Item::Typedef(t)) => { + assert_eq!( + "# [doc = concat ! (\" Auto-transpiled type for \" , stringify ! (MyString))] pub type MyString = BytesMut ;", + t.transpile_token_stream(&transpiler)?.to_string() + ); + } + t => panic!("unexpected typedef {t:?}"), + }; + + match typedef_iter.next() { + Some(crate::ast::Item::Typedef(t)) => { + assert_eq!( + "# [doc = concat ! (\" Auto-transpiled type for \" , stringify ! (type24))] pub type type24 = & str ;", + t.transpile_token_stream(&transpiler)?.to_string() + ); + } + t => panic!("unexpected typedef {t:?}"), + }; + + Ok(()) + } + + // ---- Fundamental types (table-driven) ---- + + #[test] + fn fundamental_defaults() -> Result<(), TranspileError> { + use FundamentalKind::*; + let mapper = TypeMapper::new(); + let cases: &[(FundamentalKind, &str, &str)] = &[ + (Void, "void", "()"), + (Bool, "bool", "bool"), + (Char, "char", "char"), + (Char8, "char8_t", "char"), + (Char16, "char16_t", "char"), + (Char32, "char32_t", "char"), + (Wchar, "wchar_t", "char"), + (Short, "short", "i16"), + (Int, "int", "i32"), + (Long, "long", "i64"), + (LongLong, "long long", "i64"), + (Float, "float", "f32"), + (Double, "double", "f64"), + (LongDouble, "long double", "f64"), + (SignedChar, "signed char", "i8"), + (UnsignedChar, "unsigned char", "u8"), + (UnsignedShort, "unsigned short", "u16"), + (UnsignedInt, "unsigned int", "u32"), + (UnsignedLong, "unsigned long", "u64"), + (UnsignedLongLong, "unsigned long long", "u64"), + ]; + for &(kind, src, expected) in cases { + let ty = make_fundamental(src, kind); + let result = mapper.map_type(&ty)?; + assert_eq!(ty_str(&result), expected, "failed for {kind:?}"); + } + Ok(()) + } + + // ---- Path mappings ---- + + #[test] + fn path_mapping_custom() -> Result<(), TranspileError> { + let mapper = TypeMapper::builder() + .map_path("Custom::int32", "i32")? + .map_path("std::string", "String")? + .build(); + + let src = "Custom::int32"; + let ty = make_path(src, &[&src[..6], &src[8..]]); + assert_eq!(ty_str(&mapper.map_type(&ty)?), "i32"); + + let src2 = "std::string"; + let ty2 = make_path(src2, &[&src2[..3], &src2[5..]]); + assert_eq!(ty_str(&mapper.map_type(&ty2)?), "String"); + Ok(()) + } + + #[test] + fn path_unknown_passes_through() -> Result<(), TranspileError> { + let mapper = TypeMapper::new(); + let src = "Unknown"; + let ty = make_path(src, &[src]); + assert_eq!(ty_str(&mapper.map_type(&ty)?), "Unknown"); + Ok(()) + } + + // ---- Composite types ---- + + #[test] + fn const_ptr() -> Result<(), TranspileError> { + let mapper = TypeMapper::new(); + let src = "int"; + let inner = make_fundamental(src, FundamentalKind::Int); + let ty = Type::Ptr(TypePtr { + cv: CvQualifiers { + const_token: true, + volatile_token: false, + }, + pointee: Box::new(inner), + }); + assert_eq!(ty_str(&mapper.map_type(&ty)?), "* const i32"); + Ok(()) + } + + #[test] + fn mut_ptr() -> Result<(), TranspileError> { + let mapper = TypeMapper::new(); + let src = "int"; + let inner = make_fundamental(src, FundamentalKind::Int); + let ty = Type::Ptr(TypePtr { + cv: CvQualifiers::default(), + pointee: Box::new(inner), + }); + assert_eq!(ty_str(&mapper.map_type(&ty)?), "* mut i32"); + Ok(()) + } + + #[test] + fn reference_mut() -> Result<(), TranspileError> { + let mapper = TypeMapper::new(); + let src = "int"; + let inner = make_fundamental(src, FundamentalKind::Int); + let ty = Type::Reference(TypeReference { + cv: CvQualifiers::default(), + referent: Box::new(inner), + }); + assert_eq!(ty_str(&mapper.map_type(&ty)?), "& mut i32"); + Ok(()) + } + + #[test] + fn reference_const() -> Result<(), TranspileError> { + let mapper = TypeMapper::new(); + let src = "int"; + let inner = make_fundamental(src, FundamentalKind::Int); + let ty = Type::Reference(TypeReference { + cv: CvQualifiers { + const_token: true, + volatile_token: false, + }, + referent: Box::new(inner), + }); + assert_eq!(ty_str(&mapper.map_type(&ty)?), "& i32"); + Ok(()) + } + + #[test] + fn rvalue_reference() -> Result<(), TranspileError> { + let mapper = TypeMapper::new(); + let src = "int"; + let inner = make_fundamental(src, FundamentalKind::Int); + let ty = Type::RvalueReference(TypeRvalueReference { + referent: Box::new(inner), + }); + assert_eq!(ty_str(&mapper.map_type(&ty)?), "i32"); + Ok(()) + } + + #[test] + fn array_with_size() -> Result<(), TranspileError> { + let mapper = TypeMapper::new(); + let src_elem = "int"; + let src_size = "10"; + let inner = make_fundamental(src_elem, FundamentalKind::Int); + let ty = Type::Array(TypeArray { + element: Box::new(inner), + size: Some(Expr::Lit(ExprLit { + span: SourceSpan::new(src_size, 0, 2), + kind: LitKind::Integer, + })), + }); + assert_eq!(ty_str(&mapper.map_type(&ty)?), "[i32 ; 10]"); + Ok(()) + } + + #[test] + fn array_without_size() { + let mapper = TypeMapper::new(); + let src = "int"; + let inner = make_fundamental(src, FundamentalKind::Int); + let ty = Type::Array(TypeArray { + element: Box::new(inner), + size: None, + }); + assert!(mapper.map_type(&ty).is_err()); + } + + // ---- CV-qualified ---- + + #[test] + fn cv_qualified_strips() -> Result<(), TranspileError> { + let mapper = TypeMapper::new(); + let src = "int"; + let inner = make_fundamental(src, FundamentalKind::Int); + let ty = Type::Qualified(TypeQualified { + cv: CvQualifiers { + const_token: true, + volatile_token: false, + }, + ty: Box::new(inner), + }); + assert_eq!(ty_str(&mapper.map_type(&ty)?), "i32"); + Ok(()) + } + + // ---- Template instantiation ---- + + #[test] + fn template_inst_with_mapped_args() -> Result<(), TranspileError> { + let mapper = TypeMapper::builder() + .map_path("std::vector", "Vec")? + .build(); + + let path_src = "std::vector"; + let inner_src = "int"; + let inner = make_fundamental(inner_src, FundamentalKind::Int); + let ty = Type::TemplateInst(TypeTemplateInst { + path: make_raw_path(path_src, &[&path_src[..3], &path_src[5..]]), + args: vec![TemplateArg::Type(inner)], + }); + assert_eq!(ty_str(&mapper.map_type(&ty)?), "Vec < i32 >"); + Ok(()) + } + + #[test] + fn template_inst_unknown_path_passes_through() -> Result<(), TranspileError> { + let mapper = TypeMapper::new(); + let path_src = "std::deque"; + let inner_src = "int"; + let inner = make_fundamental(inner_src, FundamentalKind::Int); + let ty = Type::TemplateInst(TypeTemplateInst { + path: make_raw_path(path_src, &[&path_src[..3], &path_src[5..]]), + args: vec![TemplateArg::Type(inner)], + }); + assert_eq!(ty_str(&mapper.map_type(&ty)?), "std :: deque < i32 >"); + Ok(()) + } + + #[test] + fn template_inst_unmapped_arg_passes_through() -> Result<(), TranspileError> { + let mapper = TypeMapper::builder() + .map_path("std::vector", "Vec")? + .build(); + + let path_src = "std::vector"; + let inner_src = "Unknown"; + let inner = make_path(inner_src, &[inner_src]); + let ty = Type::TemplateInst(TypeTemplateInst { + path: make_raw_path(path_src, &[&path_src[..3], &path_src[5..]]), + args: vec![TemplateArg::Type(inner)], + }); + assert_eq!(ty_str(&mapper.map_type(&ty)?), "Vec < Unknown >"); + Ok(()) + } + + // ---- auto / decltype ---- + + #[test] + fn auto_returns_err() { + let mapper = TypeMapper::new(); + let src = "auto"; + let ty = Type::Auto(TypeAuto { + span: SourceSpan::new(src, 0, 4), + }); + assert!(mapper.map_type(&ty).is_err()); + } + + #[test] + fn decltype_returns_err() { + let mapper = TypeMapper::new(); + let src = "x"; + let ty = Type::Decltype(TypeDecltype { + expr: Expr::Ident(crate::ast::expr::ExprIdent { + ident: Ident { + sym: src, + span: SourceSpan::new(src, 0, 1), + }, + }), + }); + assert!(mapper.map_type(&ty).is_err()); + } + + // ---- Inner type unknown in composite ---- + + #[test] + fn ptr_unknown_inner_passes_through() -> Result<(), TranspileError> { + let mapper = TypeMapper::new(); + let src = "Unknown"; + let inner = make_path(src, &[src]); + let ty = Type::Ptr(TypePtr { + cv: CvQualifiers::default(), + pointee: Box::new(inner), + }); + assert_eq!(ty_str(&mapper.map_type(&ty)?), "* mut Unknown"); + Ok(()) + } + + // ---- Function pointer ---- + + #[test] + fn fn_ptr() -> Result<(), TranspileError> { + let mapper = TypeMapper::new(); + let ret_src = "int"; + let p1_src = "double"; + let p2_src = "float"; + + let ret = make_fundamental(ret_src, FundamentalKind::Int); + let p1 = make_fundamental(p1_src, FundamentalKind::Double); + let p2 = make_fundamental(p2_src, FundamentalKind::Float); + + let mut params = Punctuated::new(); + params.push_value(p1); + params.push_value(p2); + + let ty = Type::FnPtr(TypeFnPtr { + return_type: Box::new(ret), + params, + }); + assert_eq!(ty_str(&mapper.map_type(&ty)?), "fn (f64 , f32) -> i32"); + Ok(()) + } + + #[test] + fn fn_ptr_unmapped_param_passes_through() -> Result<(), TranspileError> { + let mapper = TypeMapper::new(); + let ret_src = "int"; + let p_src = "Unknown"; + + let ret = make_fundamental(ret_src, FundamentalKind::Int); + let p = make_path(p_src, &[p_src]); + + let mut params = Punctuated::new(); + params.push_value(p); + + let ty = Type::FnPtr(TypeFnPtr { + return_type: Box::new(ret), + params, + }); + assert_eq!(ty_str(&mapper.map_type(&ty)?), "fn (Unknown) -> i32"); + Ok(()) + } + + // ---- Builder error ---- + + #[test] + fn builder_invalid_rust_type() { + let result = TypeMapper::builder().map_path("foo", "not a {{ valid type"); + assert!(result.is_err()); + } + + // ---- Error diagnostics ---- + + #[test] + fn error_is_diagnostic() { + let mapper = TypeMapper::new(); + let src = "auto"; + let ty = Type::Auto(TypeAuto { + span: SourceSpan::new(src, 0, 4), + }); + let err = mapper.map_type(&ty).unwrap_err(); + // TranspileError implements miette::Diagnostic + let diagnostic: &dyn miette::Diagnostic = &err; + assert!(diagnostic.source_code().is_some()); + assert!(diagnostic.labels().is_some()); + } + + // ---- ItemConst / ItemStatic transpilation ---- + + #[test] + fn const_int_transpiles() -> Result<(), TranspileError> { + let transpiler = Transpiler::default(); + let src = "const int MAX = 100;"; + let file = parse_file(src).unwrap(); + match &file.items[0] { + crate::ast::Item::Const(c) => { + assert_eq!( + c.transpile_token_stream(&transpiler)?.to_string(), + "# [doc = concat ! (\" Auto-transpiled const literal \" , stringify ! (i32))] pub const MAX : i32 = 100 ;" + ); + } + item => panic!("expected ItemConst, got {item:?}"), + } + Ok(()) + } + + #[test] + fn constexpr_transpiles() -> Result<(), TranspileError> { + let transpiler = Transpiler::default(); + let src = "constexpr double PI = 3.14;"; + let file = parse_file(src).unwrap(); + match &file.items[0] { + crate::ast::Item::Const(c) => { + assert_eq!( + c.transpile_token_stream(&transpiler)?.to_string(), + "# [doc = concat ! (\" Auto-transpiled const literal \" , stringify ! (f64))] pub const PI : f64 = 3.14 ;" + ); + } + item => panic!("expected ItemConst, got {item:?}"), + } + Ok(()) + } + + #[test] + fn const_bool_transpiles() -> Result<(), TranspileError> { + let transpiler = Transpiler::default(); + let src = "const bool FLAG = true;"; + let file = parse_file(src).unwrap(); + match &file.items[0] { + crate::ast::Item::Const(c) => { + assert_eq!( + c.transpile_token_stream(&transpiler)?.to_string(), + "# [doc = concat ! (\" Auto-transpiled const \" , stringify ! (bool))] pub const FLAG : bool = true ;" + ); + } + item => panic!("expected ItemConst, got {item:?}"), + } + Ok(()) + } + + #[test] + fn const_char_literal_casts() -> Result<(), TranspileError> { + let transpiler = Transpiler::default(); + let src = " + constexpr char CONST_CHAR_VALUE = 'W'; + constexpr CustomType CONST_CUSTOM_VALUE = 'W';"; + let file = parse_file(src).unwrap(); + match &file.items[0] { + crate::ast::Item::Const(c) => { + assert_eq!( + c.transpile_token_stream(&transpiler)?.to_string(), + "# [doc = concat ! (\" Auto-transpiled const literal \" , stringify ! (char))] pub const CONST_CHAR_VALUE : char = 'W' ;" + ); + } + item => panic!("expected ItemConst, got {item:?}"), + } + match &file.items[1] { + crate::ast::Item::Const(c) => { + assert_eq!( + c.transpile_token_stream(&transpiler)?.to_string(), + "# [doc = concat ! (\" Auto-transpiled const literal \" , stringify ! (CustomType))] pub const CONST_CUSTOM_VALUE : CustomType = 'W' as CustomType ;" + ); + } + item => panic!("expected ItemConst, got {item:?}"), + } + Ok(()) + } + + #[test] + fn const_string_uses_str_ref() -> Result<(), TranspileError> { + let transpiler = Transpiler::default(); + let src = r#"constexpr string WRONG_RETURN_CODE = "404";"#; + let file = parse_file(src).unwrap(); + match &file.items[0] { + crate::ast::Item::Const(c) => { + assert_eq!( + c.transpile_token_stream(&transpiler)?.to_string(), + r#"# [doc = " Auto-transpiled &str const"] pub const WRONG_RETURN_CODE : & str = "404" ;"# + ); + } + item => panic!("expected ItemConst, got {item:?}"), + } + Ok(()) + } + + #[test] + fn static_int_transpiles() -> Result<(), TranspileError> { + let transpiler = Transpiler::default(); + let src = "static int count = 0;"; + let file = parse_file(src).unwrap(); + match &file.items[0] { + crate::ast::Item::Static(s) => { + assert_eq!( + s.transpile_token_stream(&transpiler)?.to_string(), + "# [doc = concat ! (\" Auto-transpiled static literal \" , stringify ! (i32))] pub static count : i32 = 0 ;" + ); + } + item => panic!("expected ItemStatic, got {item:?}"), + } + Ok(()) + } + + #[test] + fn char_array_from_string_literal() -> Result<(), TranspileError> { + // `const char` is parsed as Const with a const-qualified element type. + // The transpiler should infer the size (4 chars + null terminator = 5). + let transpiler = Transpiler::default(); + let src = r#"const char listOfChars[] = "ALPN";"#; + let file = parse_file(src).unwrap(); + match &file.items[0] { + crate::ast::Item::Const(c) => { + let out = c.transpile_token_stream(&transpiler)?.to_string(); + assert_eq!(out, r#"pub const listOfChars : & str = "ALPN" ;"#); + } + item => panic!("expected ItemConst, got {item:?}"), + } + + Ok(()) + } + + #[test] + fn char_array_with_escape_sequence() -> Result<(), TranspileError> { + // `\n` counts as one character: size = 3 + 1 = 4. + let transpiler = Transpiler::default(); + let src = r#"const char nl[] = "a\nb";"#; + let file = parse_file(src).unwrap(); + match &file.items[0] { + crate::ast::Item::Const(c) => { + let out = c.transpile_token_stream(&transpiler)?.to_string(); + assert_eq!(out, r#"pub const nl : & str = "a\nb" ;"#); + } + item => panic!("expected ItemConst, got {item:?}"), + } + + Ok(()) + } + + #[test] + fn static_no_init_errors() { + let transpiler = Transpiler::default(); + let src = "static int count;"; + let file = parse_file(src).unwrap(); + match &file.items[0] { + crate::ast::Item::Static(s) => { + assert!(s.transpile_token_stream(&transpiler).is_err()); + } + item => panic!("expected ItemStatic, got {item:?}"), + } + } +}