diff --git a/sdk-libs/macros/src/light_pdas/account/seed_extraction.rs b/sdk-libs/macros/src/light_pdas/account/seed_extraction.rs index e809df6e59..0d66b06549 100644 --- a/sdk-libs/macros/src/light_pdas/account/seed_extraction.rs +++ b/sdk-libs/macros/src/light_pdas/account/seed_extraction.rs @@ -172,7 +172,7 @@ pub fn extract_from_accounts_struct( } // Check for #[light_account(token, ...)] attribute - let token_attr = extract_light_token_attr(&field.attrs, instruction_args); + let token_attr = extract_light_token_attr(&field.attrs, instruction_args)?; if has_light_account_pda { // Extract inner type from Account<'info, T> or Box> @@ -314,10 +314,11 @@ struct LightTokenAttr { /// Extract #[light_account(token, authority = [...])] attribute /// Variant name is derived from field name, not specified in attribute +/// Returns Err if the attribute exists but has malformed syntax fn extract_light_token_attr( attrs: &[syn::Attribute], instruction_args: &InstructionArgSet, -) -> Option { +) -> syn::Result> { for attr in attrs { if attr.path().is_ident("light_account") { let tokens = match &attr.meta { @@ -332,18 +333,13 @@ fn extract_light_token_attr( .any(|t| matches!(&t, proc_macro2::TokenTree::Ident(ident) if ident == "token")); if has_token { - // Parse authority = [...] if present - if let Ok(parsed) = parse_light_token_list(&tokens, instruction_args) { - return Some(parsed); - } - return Some(LightTokenAttr { - variant_name: None, - authority_seeds: None, - }); + // Parse authority = [...] - propagate errors instead of swallowing them + let parsed = parse_light_token_list(&tokens, instruction_args)?; + return Ok(Some(parsed)); } } } - None + Ok(None) } /// Parse light_account(token, authority = [...]) content diff --git a/sdk-libs/macros/src/light_pdas/accounts/builder.rs b/sdk-libs/macros/src/light_pdas/accounts/builder.rs index 9c5bf9c316..491c4bdc56 100644 --- a/sdk-libs/macros/src/light_pdas/accounts/builder.rs +++ b/sdk-libs/macros/src/light_pdas/accounts/builder.rs @@ -45,6 +45,21 @@ impl LightAccountsBuilder { }) } + /// Get the expression to access CreateAccountsProof. + /// + /// Returns either: + /// - `proof_ident` (direct) if CreateAccountsProof is passed as a direct argument + /// - `params.create_accounts_proof` (nested) if nested inside a params struct + fn get_proof_access(&self) -> Result { + if let Some(ref proof_ident) = self.parsed.direct_proof_arg { + Ok(quote! { #proof_ident }) + } else { + let first_arg = self.get_first_instruction_arg()?; + let params_ident = &first_arg.name; + Ok(quote! { #params_ident.create_accounts_proof }) + } + } + /// Validate constraints (e.g., account count < 255). pub fn validate(&self) -> Result<(), syn::Error> { let total = self.parsed.rentfree_fields.len() @@ -317,13 +332,13 @@ impl LightAccountsBuilder { let rentfree_count = self.parsed.rentfree_fields.len() as u8; let pda_count = self.parsed.rentfree_fields.len(); - let first_arg = self.get_first_instruction_arg()?; - let params_ident = &first_arg.name; + // Get proof access expression (direct arg or nested in params) + let proof_access = self.get_proof_access()?; let first_pda_output_tree = &self.parsed.rentfree_fields[0].output_tree; let mints = &self.parsed.light_mint_fields; - let mint_invocation = LightMintsBuilder::new(mints, params_ident, &self.infra) + let mint_invocation = LightMintsBuilder::new(mints, &proof_access, &self.infra) .with_pda_context(pda_count, quote! { #first_pda_output_tree }) .generate_invocation(); @@ -356,7 +371,7 @@ impl LightAccountsBuilder { use light_sdk::cpi::{InvokeLightSystemProgram, LightCpiInstruction}; light_sdk::cpi::v2::LightSystemProgramCpi::new_cpi( crate::LIGHT_CPI_SIGNER, - #params_ident.create_accounts_proof.proof.clone() + #proof_access.proof.clone() ) .with_new_addresses(&[#(#new_addr_idents),*]) .with_account_infos(&all_compressed_infos) @@ -373,8 +388,8 @@ impl LightAccountsBuilder { generate_pda_compress_blocks(&self.parsed.rentfree_fields); let rentfree_count = self.parsed.rentfree_fields.len() as u8; - let first_arg = self.get_first_instruction_arg()?; - let params_ident = &first_arg.name; + // Get proof access expression (direct arg or nested in params) + let proof_access = self.get_proof_access()?; let fee_payer = &self.infra.fee_payer; let compression_config = &self.infra.compression_config; @@ -397,7 +412,7 @@ impl LightAccountsBuilder { use light_sdk::cpi::{InvokeLightSystemProgram, LightCpiInstruction}; light_sdk::cpi::v2::LightSystemProgramCpi::new_cpi( crate::LIGHT_CPI_SIGNER, - #params_ident.create_accounts_proof.proof.clone() + #proof_access.proof.clone() ) .with_new_addresses(&[#(#new_addr_idents),*]) .with_account_infos(&all_compressed_infos) @@ -407,12 +422,12 @@ impl LightAccountsBuilder { /// Generate mints-only body WITHOUT the Ok(true) return. fn generate_pre_init_mints_only_body(&self) -> Result { - let first_arg = self.get_first_instruction_arg()?; - let params_ident = &first_arg.name; + // Get proof access expression (direct arg or nested in params) + let proof_access = self.get_proof_access()?; let mints = &self.parsed.light_mint_fields; let mint_invocation = - LightMintsBuilder::new(mints, params_ident, &self.infra).generate_invocation(); + LightMintsBuilder::new(mints, &proof_access, &self.infra).generate_invocation(); let fee_payer = &self.infra.fee_payer; diff --git a/sdk-libs/macros/src/light_pdas/accounts/light_account.rs b/sdk-libs/macros/src/light_pdas/accounts/light_account.rs index 7d68857fdf..022dcf3026 100644 --- a/sdk-libs/macros/src/light_pdas/accounts/light_account.rs +++ b/sdk-libs/macros/src/light_pdas/accounts/light_account.rs @@ -289,9 +289,16 @@ fn parse_token_ata_key_values( /// Parse #[light_account(...)] attribute from a field. /// Returns None if no light_account attribute or if it's a mark-only token/ata field. /// Returns Some(LightAccountField) for PDA, Mint, or init Token/Ata fields. +/// +/// # Arguments +/// * `field` - The field to parse +/// * `field_ident` - The field identifier +/// * `direct_proof_arg` - If `Some`, CreateAccountsProof is passed directly as an instruction arg +/// with this name, so defaults should use `.field` instead of `params.create_accounts_proof.field` pub(super) fn parse_light_account_attr( field: &Field, field_ident: &Ident, + direct_proof_arg: &Option, ) -> Result, syn::Error> { for attr in &field.attrs { if attr.path().is_ident("light_account") { @@ -316,10 +323,10 @@ pub(super) fn parse_light_account_attr( return match args.account_type { LightAccountType::Pda => Ok(Some(LightAccountField::Pda(Box::new( - build_pda_field(field, field_ident, &args.key_values)?, + build_pda_field(field, field_ident, &args.key_values, direct_proof_arg)?, )))), LightAccountType::Mint => Ok(Some(LightAccountField::Mint(Box::new( - build_mint_field(field_ident, &args.key_values, attr)?, + build_mint_field(field_ident, &args.key_values, attr, direct_proof_arg)?, )))), LightAccountType::Token => Ok(Some(LightAccountField::TokenAccount(Box::new( build_token_account_field(field_ident, &args.key_values, args.has_init, attr)?, @@ -336,10 +343,14 @@ pub(super) fn parse_light_account_attr( } /// Build a PdaField from parsed key-value pairs. +/// +/// # Arguments +/// * `direct_proof_arg` - If `Some`, use `.field` for defaults instead of `params.create_accounts_proof.field` fn build_pda_field( field: &Field, field_ident: &Ident, key_values: &[KeyValue], + direct_proof_arg: &Option, ) -> Result { let mut address_tree_info: Option = None; let mut output_tree: Option = None; @@ -359,11 +370,21 @@ fn build_pda_field( } } - // Use defaults if not specified - let address_tree_info = address_tree_info - .unwrap_or_else(|| syn::parse_quote!(params.create_accounts_proof.address_tree_info)); - let output_tree = output_tree - .unwrap_or_else(|| syn::parse_quote!(params.create_accounts_proof.output_state_tree_index)); + // Use defaults if not specified - depends on whether CreateAccountsProof is direct arg or nested + let address_tree_info = address_tree_info.unwrap_or_else(|| { + if let Some(proof_ident) = direct_proof_arg { + syn::parse_quote!(#proof_ident.address_tree_info) + } else { + syn::parse_quote!(params.create_accounts_proof.address_tree_info) + } + }); + let output_tree = output_tree.unwrap_or_else(|| { + if let Some(proof_ident) = direct_proof_arg { + syn::parse_quote!(#proof_ident.output_state_tree_index) + } else { + syn::parse_quote!(params.create_accounts_proof.output_state_tree_index) + } + }); // Validate this is an Account type (or Box) let (is_boxed, inner_type) = extract_account_inner_type(&field.ty).ok_or_else(|| { @@ -384,10 +405,14 @@ fn build_pda_field( } /// Build a LightMintField from parsed key-value pairs. +/// +/// # Arguments +/// * `direct_proof_arg` - If `Some`, use `.field` for defaults instead of `params.create_accounts_proof.field` fn build_mint_field( field_ident: &Ident, key_values: &[KeyValue], attr: &syn::Attribute, + direct_proof_arg: &Option, ) -> Result { // Required fields let mut mint_signer: Option = None; @@ -474,9 +499,14 @@ fn build_mint_field( attr, )?; - // address_tree_info defaults to params.create_accounts_proof.address_tree_info - let address_tree_info = address_tree_info - .unwrap_or_else(|| syn::parse_quote!(params.create_accounts_proof.address_tree_info)); + // address_tree_info defaults - depends on whether CreateAccountsProof is direct arg or nested + let address_tree_info = address_tree_info.unwrap_or_else(|| { + if let Some(proof_ident) = direct_proof_arg { + syn::parse_quote!(#proof_ident.address_tree_info) + } else { + syn::parse_quote!(params.create_accounts_proof.address_tree_info) + } + }); Ok(LightMintField { field_ident: field_ident.clone(), @@ -707,7 +737,7 @@ mod tests { }; let ident = field.ident.clone().unwrap(); - let result = parse_light_account_attr(&field, &ident); + let result = parse_light_account_attr(&field, &ident, &None); assert!(result.is_ok()); let result = result.unwrap(); assert!(result.is_some()); @@ -729,7 +759,7 @@ mod tests { }; let ident = field.ident.clone().unwrap(); - let result = parse_light_account_attr(&field, &ident); + let result = parse_light_account_attr(&field, &ident, &None); assert!(result.is_ok()); let result = result.unwrap(); assert!(result.is_some()); @@ -753,7 +783,7 @@ mod tests { }; let ident = field.ident.clone().unwrap(); - let result = parse_light_account_attr(&field, &ident); + let result = parse_light_account_attr(&field, &ident, &None); assert!(result.is_ok()); let result = result.unwrap(); assert!(result.is_some()); @@ -782,7 +812,7 @@ mod tests { }; let ident = field.ident.clone().unwrap(); - let result = parse_light_account_attr(&field, &ident); + let result = parse_light_account_attr(&field, &ident, &None); assert!(result.is_ok()); let result = result.unwrap(); assert!(result.is_some()); @@ -805,7 +835,7 @@ mod tests { }; let ident = field.ident.clone().unwrap(); - let result = parse_light_account_attr(&field, &ident); + let result = parse_light_account_attr(&field, &ident, &None); assert!(result.is_err()); } @@ -817,7 +847,7 @@ mod tests { }; let ident = field.ident.clone().unwrap(); - let result = parse_light_account_attr(&field, &ident); + let result = parse_light_account_attr(&field, &ident, &None); assert!(result.is_err()); } @@ -835,7 +865,7 @@ mod tests { }; let ident = field.ident.clone().unwrap(); - let result = parse_light_account_attr(&field, &ident); + let result = parse_light_account_attr(&field, &ident, &None); assert!(result.is_err()); } @@ -846,7 +876,7 @@ mod tests { }; let ident = field.ident.clone().unwrap(); - let result = parse_light_account_attr(&field, &ident); + let result = parse_light_account_attr(&field, &ident, &None); assert!(result.is_ok()); assert!(result.unwrap().is_none()); } @@ -864,7 +894,7 @@ mod tests { }; let ident = field.ident.clone().unwrap(); - let result = parse_light_account_attr(&field, &ident); + let result = parse_light_account_attr(&field, &ident, &None); assert!(result.is_ok()); assert!(result.unwrap().is_none()); } @@ -877,7 +907,7 @@ mod tests { }; let ident = field.ident.clone().unwrap(); - let result = parse_light_account_attr(&field, &ident); + let result = parse_light_account_attr(&field, &ident, &None); assert!(result.is_ok()); let result = result.unwrap(); assert!(result.is_some()); @@ -900,7 +930,7 @@ mod tests { }; let ident = field.ident.clone().unwrap(); - let result = parse_light_account_attr(&field, &ident); + let result = parse_light_account_attr(&field, &ident, &None); assert!(result.is_err()); let err = result.err().unwrap().to_string(); assert!(err.contains("authority")); @@ -919,7 +949,7 @@ mod tests { }; let ident = field.ident.clone().unwrap(); - let result = parse_light_account_attr(&field, &ident); + let result = parse_light_account_attr(&field, &ident, &None); assert!(result.is_ok()); assert!(result.unwrap().is_none()); } @@ -932,7 +962,7 @@ mod tests { }; let ident = field.ident.clone().unwrap(); - let result = parse_light_account_attr(&field, &ident); + let result = parse_light_account_attr(&field, &ident, &None); assert!(result.is_ok()); let result = result.unwrap(); assert!(result.is_some()); @@ -954,7 +984,7 @@ mod tests { }; let ident = field.ident.clone().unwrap(); - let result = parse_light_account_attr(&field, &ident); + let result = parse_light_account_attr(&field, &ident, &None); assert!(result.is_err()); let err = result.err().unwrap().to_string(); assert!(err.contains("owner")); @@ -968,7 +998,7 @@ mod tests { }; let ident = field.ident.clone().unwrap(); - let result = parse_light_account_attr(&field, &ident); + let result = parse_light_account_attr(&field, &ident, &None); assert!(result.is_err()); let err = result.err().unwrap().to_string(); assert!(err.contains("mint")); @@ -982,7 +1012,7 @@ mod tests { }; let ident = field.ident.clone().unwrap(); - let result = parse_light_account_attr(&field, &ident); + let result = parse_light_account_attr(&field, &ident, &None); assert!(result.is_err()); let err = result.err().unwrap().to_string(); assert!(err.contains("unknown")); @@ -996,7 +1026,7 @@ mod tests { }; let ident = field.ident.clone().unwrap(); - let result = parse_light_account_attr(&field, &ident); + let result = parse_light_account_attr(&field, &ident, &None); assert!(result.is_err()); let err = result.err().unwrap().to_string(); assert!(err.contains("unknown")); @@ -1011,7 +1041,7 @@ mod tests { }; let ident = field.ident.clone().unwrap(); - let result = parse_light_account_attr(&field, &ident); + let result = parse_light_account_attr(&field, &ident, &None); assert!(result.is_ok()); let result = result.unwrap(); assert!(result.is_some()); @@ -1035,7 +1065,7 @@ mod tests { }; let ident = field.ident.clone().unwrap(); - let result = parse_light_account_attr(&field, &ident); + let result = parse_light_account_attr(&field, &ident, &None); assert!(result.is_err()); let err = result.err().unwrap().to_string(); assert!( @@ -1054,7 +1084,7 @@ mod tests { }; let ident = field.ident.clone().unwrap(); - let result = parse_light_account_attr(&field, &ident); + let result = parse_light_account_attr(&field, &ident, &None); assert!(result.is_err()); let err = result.err().unwrap().to_string(); assert!( @@ -1073,7 +1103,7 @@ mod tests { }; let ident = field.ident.clone().unwrap(); - let result = parse_light_account_attr(&field, &ident); + let result = parse_light_account_attr(&field, &ident, &None); assert!(result.is_err()); let err = result.err().unwrap().to_string(); assert!( @@ -1093,8 +1123,117 @@ mod tests { let ident = field.ident.clone().unwrap(); // Mark-only mode returns Ok(None) - let result = parse_light_account_attr(&field, &ident); + let result = parse_light_account_attr(&field, &ident, &None); assert!(result.is_ok()); assert!(result.unwrap().is_none()); } + + #[test] + fn test_parse_pda_with_direct_proof_arg_uses_proof_ident_for_defaults() { + // When CreateAccountsProof is passed as a direct instruction arg (not nested in params), + // the default address_tree_info and output_tree should reference the proof arg directly. + let field: syn::Field = parse_quote! { + #[light_account(init)] + pub record: Account<'info, MyRecord> + }; + let field_ident = field.ident.clone().unwrap(); + + // Simulate passing CreateAccountsProof as direct arg named "proof" + let proof_ident: Ident = parse_quote!(proof); + let direct_proof_arg = Some(proof_ident.clone()); + + let result = parse_light_account_attr(&field, &field_ident, &direct_proof_arg); + assert!( + result.is_ok(), + "Should parse successfully with direct proof arg" + ); + let result = result.unwrap(); + assert!(result.is_some(), "Should return Some for init PDA"); + + match result.unwrap() { + LightAccountField::Pda(pda) => { + assert_eq!(pda.ident.to_string(), "record"); + + // Verify defaults use the direct proof identifier + // address_tree_info should be: proof.address_tree_info + let addr_tree_info = &pda.address_tree_info; + let addr_tree_str = quote::quote!(#addr_tree_info).to_string(); + assert!( + addr_tree_str.contains("proof"), + "address_tree_info should reference 'proof', got: {}", + addr_tree_str + ); + assert!( + addr_tree_str.contains("address_tree_info"), + "address_tree_info should access .address_tree_info field, got: {}", + addr_tree_str + ); + + // output_tree should be: proof.output_state_tree_index + let output_tree = &pda.output_tree; + let output_tree_str = quote::quote!(#output_tree).to_string(); + assert!( + output_tree_str.contains("proof"), + "output_tree should reference 'proof', got: {}", + output_tree_str + ); + assert!( + output_tree_str.contains("output_state_tree_index"), + "output_tree should access .output_state_tree_index field, got: {}", + output_tree_str + ); + } + _ => panic!("Expected PDA field"), + } + } + + #[test] + fn test_parse_mint_with_direct_proof_arg_uses_proof_ident_for_defaults() { + // When CreateAccountsProof is passed as a direct instruction arg, + // the default address_tree_info should reference the proof arg directly. + let field: syn::Field = parse_quote! { + #[light_account(init, mint, + mint_signer = mint_signer, + authority = authority, + decimals = 9, + mint_seeds = &[b"test"] + )] + pub cmint: UncheckedAccount<'info> + }; + let field_ident = field.ident.clone().unwrap(); + + // Simulate passing CreateAccountsProof as direct arg named "create_proof" + let proof_ident: Ident = parse_quote!(create_proof); + let direct_proof_arg = Some(proof_ident.clone()); + + let result = parse_light_account_attr(&field, &field_ident, &direct_proof_arg); + assert!( + result.is_ok(), + "Should parse successfully with direct proof arg" + ); + let result = result.unwrap(); + assert!(result.is_some(), "Should return Some for init mint"); + + match result.unwrap() { + LightAccountField::Mint(mint) => { + assert_eq!(mint.field_ident.to_string(), "cmint"); + + // Verify default address_tree_info uses the direct proof identifier + // Should be: create_proof.address_tree_info + let addr_tree_info = &mint.address_tree_info; + let addr_tree_str = quote::quote!(#addr_tree_info).to_string(); + assert!( + addr_tree_str.contains("create_proof"), + "address_tree_info should reference 'create_proof', got: {}", + addr_tree_str + ); + assert!( + addr_tree_str.contains("address_tree_info"), + "address_tree_info should access .address_tree_info field, got: {}", + addr_tree_str + ); + } + _ => panic!("Expected Mint field"), + } + } } diff --git a/sdk-libs/macros/src/light_pdas/accounts/mint.rs b/sdk-libs/macros/src/light_pdas/accounts/mint.rs index 61b9375cdc..f15194b2cf 100644 --- a/sdk-libs/macros/src/light_pdas/accounts/mint.rs +++ b/sdk-libs/macros/src/light_pdas/accounts/mint.rs @@ -114,13 +114,14 @@ impl InfraRefs { /// /// Usage: /// ```ignore -/// LightMintsBuilder::new(mints, params_ident, &infra) +/// LightMintsBuilder::new(mints, &proof_access, &infra) /// .with_pda_context(pda_count, quote! { #first_pda_output_tree }) /// .generate_invocation() /// ``` pub(super) struct LightMintsBuilder<'a> { mints: &'a [LightMintField], - params_ident: &'a Ident, + /// TokenStream for accessing CreateAccountsProof (e.g., `proof` or `params.create_accounts_proof`) + proof_access: &'a TokenStream, infra: &'a InfraRefs, /// PDA context: (pda_count, output_tree_expr) for batching with PDAs pda_context: Option<(usize, TokenStream)>, @@ -128,10 +129,14 @@ pub(super) struct LightMintsBuilder<'a> { impl<'a> LightMintsBuilder<'a> { /// Create builder with required fields. - pub fn new(mints: &'a [LightMintField], params_ident: &'a Ident, infra: &'a InfraRefs) -> Self { + pub fn new( + mints: &'a [LightMintField], + proof_access: &'a TokenStream, + infra: &'a InfraRefs, + ) -> Self { Self { mints, - params_ident, + proof_access, infra, pda_context: None, } @@ -161,7 +166,7 @@ impl<'a> LightMintsBuilder<'a> { /// 4. Call invoke() - seeds are extracted from SingleMintParams internally fn generate_mints_invocation(builder: &LightMintsBuilder) -> TokenStream { let mints = builder.mints; - let params_ident = builder.params_ident; + let proof_access = builder.proof_access; let infra = builder.infra; let mint_count = mints.len(); @@ -329,7 +334,7 @@ fn generate_mints_invocation(builder: &LightMintsBuilder) -> TokenStream { #output_tree_setup // Extract proof from instruction params - let __proof: light_token::CompressedProof = #params_ident.create_accounts_proof.proof.0.clone() + let __proof: light_token::CompressedProof = #proof_access.proof.0.clone() .expect("proof is required for mint creation"); // Build SingleMintParams for each mint @@ -354,9 +359,9 @@ fn generate_mints_invocation(builder: &LightMintsBuilder) -> TokenStream { // Output queue for state (compressed accounts) is at tree index 0 // State merkle tree index comes from the proof (set by pack_proof_for_mints) // Address merkle tree index comes from the proof's address_tree_info - let __tree_info = &#params_ident.create_accounts_proof.address_tree_info; + let __tree_info = &#proof_access.address_tree_info; let __output_queue_index: u8 = 0; - let __state_tree_index: u8 = #params_ident.create_accounts_proof.state_tree_index + let __state_tree_index: u8 = #proof_access.state_tree_index .ok_or(anchor_lang::prelude::ProgramError::InvalidArgument)?; let __address_tree_index: u8 = __tree_info.address_merkle_tree_pubkey_index; let __output_queue = cpi_accounts.get_tree_account_info(__output_queue_index as usize)?; diff --git a/sdk-libs/macros/src/light_pdas/accounts/parse.rs b/sdk-libs/macros/src/light_pdas/accounts/parse.rs index 93d34d7fdf..8f30e347a5 100644 --- a/sdk-libs/macros/src/light_pdas/accounts/parse.rs +++ b/sdk-libs/macros/src/light_pdas/accounts/parse.rs @@ -164,6 +164,9 @@ pub(super) struct ParsedLightAccountsStruct { pub instruction_args: Option>, /// Infrastructure fields detected by naming convention. pub infra_fields: InfraFields, + /// If CreateAccountsProof type is passed as a direct instruction arg, stores arg name. + /// Matched by TYPE, not by name - allows any argument name (e.g., `proof`, `create_proof`). + pub direct_proof_arg: Option, } /// A field marked with #[light_account(init)] @@ -193,6 +196,52 @@ impl Parse for InstructionArg { } } +/// Check if a type is `CreateAccountsProof` (match last path segment). +/// Supports both simple `CreateAccountsProof` and fully qualified paths like +/// `light_sdk::CreateAccountsProof`. +fn is_create_accounts_proof_type(ty: &Type) -> bool { + if let Type::Path(type_path) = ty { + if let Some(segment) = type_path.path.segments.last() { + return segment.ident == "CreateAccountsProof"; + } + } + false +} + +/// Find if any instruction argument has type `CreateAccountsProof`. +/// Returns the argument's name (Ident) if found. +/// +/// Returns an error if multiple `CreateAccountsProof` arguments are found, +/// as this would make proof access ambiguous. +fn find_direct_proof_arg( + instruction_args: &Option>, +) -> Result, Error> { + let Some(args) = instruction_args.as_ref() else { + return Ok(None); + }; + + let proof_args: Vec<_> = args + .iter() + .filter(|arg| is_create_accounts_proof_type(&arg.ty)) + .collect(); + + match proof_args.len() { + 0 => Ok(None), + 1 => Ok(Some(proof_args[0].name.clone())), + _ => { + let names: Vec<_> = proof_args.iter().map(|a| a.name.to_string()).collect(); + Err(Error::new_spanned( + &proof_args[1].name, + format!( + "Multiple CreateAccountsProof arguments found: [{}]. \ + Only one CreateAccountsProof argument is allowed per instruction.", + names.join(", ") + ), + )) + } + } +} + /// Parse #[instruction(...)] attribute from struct. /// /// Returns `Ok(None)` if no instruction attribute is present, @@ -220,6 +269,10 @@ pub(super) fn parse_light_accounts_struct( let instruction_args = parse_instruction_attr(&input.attrs)?; + // Check if CreateAccountsProof is passed as a direct instruction argument + // (compute this early so we can use it for field parsing defaults) + let direct_proof_arg = find_direct_proof_arg(&instruction_args)?; + let fields = match &input.data { syn::Data::Struct(data) => match &data.fields { syn::Fields::Named(fields) => &fields.named, @@ -248,7 +301,9 @@ pub(super) fn parse_light_accounts_struct( } // Check for #[light_account(...)] - the unified syntax - if let Some(light_account_field) = parse_light_account_attr(field, &field_ident)? { + if let Some(light_account_field) = + parse_light_account_attr(field, &field_ident, &direct_proof_arg)? + { match light_account_field { LightAccountField::Pda(pda) => rentfree_fields.push((*pda).into()), LightAccountField::Mint(mint) => light_mint_fields.push(*mint), @@ -281,5 +336,6 @@ pub(super) fn parse_light_accounts_struct( ata_fields, instruction_args, infra_fields, + direct_proof_arg, }) } diff --git a/sdk-libs/macros/src/light_pdas/program/instructions.rs b/sdk-libs/macros/src/light_pdas/program/instructions.rs index 339577a372..9a6ef4ca3f 100644 --- a/sdk-libs/macros/src/light_pdas/program/instructions.rs +++ b/sdk-libs/macros/src/light_pdas/program/instructions.rs @@ -627,12 +627,45 @@ pub fn light_program_impl(_args: TokenStream, mut module: ItemMod) -> Result { + if rentfree_struct_names.contains(&context_type) { + // Wrap the function with pre_init/finalize logic + *fn_item = wrap_function_with_light(fn_item, ¶ms_ident, &ctx_ident); + } + } + ExtractResult::MultipleParams { + context_type, + param_names, + } => { + // Only error if this is a rentfree struct that needs wrapping + if rentfree_struct_names.contains(&context_type) { + let fn_name = fn_item.sig.ident.to_string(); + let params_str = param_names.join(", "); + return Err(macro_error!( + fn_item, + format!( + "Function '{}' has multiple instruction arguments ({}) which is not supported by #[rentfree_program].\n\ + Please consolidate these into a single params struct.\n\ + Example: Instead of `fn {}(ctx: Context, {})`,\n\ + use: `fn {}(ctx: Context, params: MyParams)` where MyParams contains all fields.", + fn_name, + params_str, + fn_name, + params_str, + fn_name + ) + )); + } + // Non-rentfree structs with multiple params are fine - just skip wrapping + } + ExtractResult::None => { + // No context/params found, skip this function } } } diff --git a/sdk-libs/macros/src/light_pdas/program/parsing.rs b/sdk-libs/macros/src/light_pdas/program/parsing.rs index f5270eed29..d1c195442c 100644 --- a/sdk-libs/macros/src/light_pdas/program/parsing.rs +++ b/sdk-libs/macros/src/light_pdas/program/parsing.rs @@ -351,13 +351,31 @@ pub fn convert_classified_to_seed_elements_vec( // FUNCTION WRAPPING // ============================================================================= +/// Result from extracting context and params from a function signature. +pub enum ExtractResult { + /// Successfully extracted context type, params ident, and context ident + Success { + context_type: String, + params_ident: Ident, + ctx_ident: Ident, + }, + /// Multiple params arguments detected (format-2 case) - caller decides if this is an error + MultipleParams { + context_type: String, + param_names: Vec, + }, + /// No valid context/params combination found + None, +} + /// Extract the Context type name and context parameter name from a function's parameters. -/// Returns (struct_name, params_ident, ctx_ident) if found. +/// Returns ExtractResult indicating success, multiple params, or none found. /// The ctx_ident is the actual parameter name (e.g., "ctx", "context", "anchor_ctx"). -pub fn extract_context_and_params(fn_item: &ItemFn) -> Option<(String, Ident, Ident)> { +pub fn extract_context_and_params(fn_item: &ItemFn) -> ExtractResult { let mut context_type = None; - let mut params_ident = None; let mut ctx_ident = None; + // Collect ALL potential params arguments to detect multi-arg cases + let mut params_candidates: Vec = Vec::new(); for input in &fn_item.sig.inputs { if let syn::FnArg::Typed(pat_type) = input { @@ -391,18 +409,31 @@ pub fn extract_context_and_params(fn_item: &ItemFn) -> Option<(String, Ident, Id // Track potential params argument (not the context param, not signer-like names) let name = pat_ident.ident.to_string(); if !name.contains("signer") && !name.contains("bump") { - // Prefer "params" but accept others - if name == "params" || params_ident.is_none() { - params_ident = Some(pat_ident.ident.clone()); - } + params_candidates.push(pat_ident.ident.clone()); } } } } - match (context_type, params_ident, ctx_ident) { - (Some(ctx_type), Some(params), Some(ctx_name)) => Some((ctx_type, params, ctx_name)), - _ => None, + match (context_type, ctx_ident) { + (Some(ctx_type), Some(ctx_name)) => { + if params_candidates.len() > 1 { + // Multiple params detected - let caller decide if this is an error + ExtractResult::MultipleParams { + context_type: ctx_type, + param_names: params_candidates.iter().map(|id| id.to_string()).collect(), + } + } else if let Some(params) = params_candidates.into_iter().next() { + ExtractResult::Success { + context_type: ctx_type, + params_ident: params, + ctx_ident: ctx_name, + } + } else { + ExtractResult::None + } + } + _ => ExtractResult::None, } } @@ -640,12 +671,18 @@ mod tests { Ok(()) } }; - let result = extract_context_and_params(&fn_item); - assert!(result.is_some()); - let (ctx_type, params_ident, ctx_ident) = result.unwrap(); - assert_eq!(ctx_type, "MyAccounts"); - assert_eq!(params_ident.to_string(), "params"); - assert_eq!(ctx_ident.to_string(), "ctx"); + match extract_context_and_params(&fn_item) { + ExtractResult::Success { + context_type, + params_ident, + ctx_ident, + } => { + assert_eq!(context_type, "MyAccounts"); + assert_eq!(params_ident.to_string(), "params"); + assert_eq!(ctx_ident.to_string(), "ctx"); + } + _ => panic!("Expected ExtractResult::Success"), + } } #[test] @@ -655,12 +692,18 @@ mod tests { Ok(()) } }; - let result = extract_context_and_params(&fn_item); - assert!(result.is_some()); - let (ctx_type, params_ident, ctx_ident) = result.unwrap(); - assert_eq!(ctx_type, "MyAccounts"); - assert_eq!(params_ident.to_string(), "params"); - assert_eq!(ctx_ident.to_string(), "context"); + match extract_context_and_params(&fn_item) { + ExtractResult::Success { + context_type, + params_ident, + ctx_ident, + } => { + assert_eq!(context_type, "MyAccounts"); + assert_eq!(params_ident.to_string(), "params"); + assert_eq!(ctx_ident.to_string(), "context"); + } + _ => panic!("Expected ExtractResult::Success"), + } } #[test] @@ -670,12 +713,18 @@ mod tests { Ok(()) } }; - let result = extract_context_and_params(&fn_item); - assert!(result.is_some()); - let (ctx_type, params_ident, ctx_ident) = result.unwrap(); - assert_eq!(ctx_type, "MyAccounts"); - assert_eq!(params_ident.to_string(), "data"); - assert_eq!(ctx_ident.to_string(), "anchor_ctx"); + match extract_context_and_params(&fn_item) { + ExtractResult::Success { + context_type, + params_ident, + ctx_ident, + } => { + assert_eq!(context_type, "MyAccounts"); + assert_eq!(params_ident.to_string(), "data"); + assert_eq!(ctx_ident.to_string(), "anchor_ctx"); + } + _ => panic!("Expected ExtractResult::Success"), + } } #[test] @@ -685,11 +734,38 @@ mod tests { Ok(()) } }; - let result = extract_context_and_params(&fn_item); - assert!(result.is_some()); - let (ctx_type, params_ident, ctx_ident) = result.unwrap(); - assert_eq!(ctx_type, "MyAccounts"); - assert_eq!(params_ident.to_string(), "p"); - assert_eq!(ctx_ident.to_string(), "c"); + match extract_context_and_params(&fn_item) { + ExtractResult::Success { + context_type, + params_ident, + ctx_ident, + } => { + assert_eq!(context_type, "MyAccounts"); + assert_eq!(params_ident.to_string(), "p"); + assert_eq!(ctx_ident.to_string(), "c"); + } + _ => panic!("Expected ExtractResult::Success"), + } + } + + #[test] + fn test_extract_context_and_params_multiple_args_detected() { + // Format-2 case: multiple instruction arguments should be detected + let fn_item: syn::ItemFn = syn::parse_quote! { + pub fn handler(ctx: Context, amount: u64, owner: Pubkey) -> Result<()> { + Ok(()) + } + }; + match extract_context_and_params(&fn_item) { + ExtractResult::MultipleParams { + context_type, + param_names, + } => { + assert_eq!(context_type, "MyAccounts"); + assert!(param_names.contains(&"amount".to_string())); + assert!(param_names.contains(&"owner".to_string())); + } + _ => panic!("Expected ExtractResult::MultipleParams"), + } } } diff --git a/sdk-libs/macros/src/light_pdas/shared_utils.rs b/sdk-libs/macros/src/light_pdas/shared_utils.rs index 51d629e11a..a3e079b522 100644 --- a/sdk-libs/macros/src/light_pdas/shared_utils.rs +++ b/sdk-libs/macros/src/light_pdas/shared_utils.rs @@ -85,7 +85,7 @@ pub fn ident_to_type(ident: &Ident) -> Type { /// Wrapper for syn::Expr that implements darling's FromMeta trait. /// /// Enables darling to parse arbitrary expressions in attributes like -/// `#[light_account(init, mint,mint_signer = self.authority)]`. +/// `#[light_account(init, mint, mint_signer = self.authority)]`. #[derive(Clone)] pub struct MetaExpr(Expr);