use serde::Deserialize; use std::collections::HashMap; use std::path::Path; #[derive(Debug, Clone, Deserialize)] pub struct RbacConfig { pub roles: HashMap, pub members: Vec, } #[derive(Debug, Clone, Deserialize)] pub struct RoleDefinition { pub description: String, pub endpoints: Vec, } #[derive(Debug, Clone, Deserialize)] pub struct MemberDefinition { pub did: String, pub roles: Vec, } impl RbacConfig { /// Load RBAC configuration from a YAML file. pub fn load_from_file(path: impl AsRef) -> Result { let contents = std::fs::read_to_string(path)?; let config: RbacConfig = serde_yaml::from_str(&contents)?; config.validate()?; Ok(config) } /// Validate that all member roles reference defined roles. fn validate(&self) -> Result<(), anyhow::Error> { for member in &self.members { for role in &member.roles { if !self.roles.contains_key(role) { return Err(anyhow::anyhow!( "Member {} references undefined role: {}", member.did, role )); } } } Ok(()) } /// Check whether a DID is a configured member. pub fn is_member(&self, did: &str) -> bool { self.members.iter().any(|m| m.did == did) } /// Get the role names assigned to a DID. pub fn get_member_roles(&self, did: &str) -> Vec { self.members .iter() .find(|m| m.did == did) .map(|m| m.roles.clone()) .unwrap_or_default() } /// Get all endpoint patterns that a DID is allowed to access (aggregated from all roles). pub fn get_allowed_endpoints(&self, did: &str) -> Vec { let roles = self.get_member_roles(did); let mut endpoints = Vec::new(); for role_name in &roles { if let Some(role) = self.roles.get(role_name) { endpoints.extend(role.endpoints.clone()); } } endpoints } /// Check whether a DID can access a specific endpoint. /// Supports wildcard matching: `com.atproto.admin.*` matches `com.atproto.admin.getAccountInfo`. pub fn can_access_endpoint(&self, did: &str, endpoint: &str) -> bool { let allowed = self.get_allowed_endpoints(did); for pattern in &allowed { if matches_endpoint_pattern(pattern, endpoint) { return true; } } false } } /// Match an endpoint against a pattern. Supports trailing `*` wildcard. /// e.g. `com.atproto.admin.*` matches `com.atproto.admin.getAccountInfo` fn matches_endpoint_pattern(pattern: &str, endpoint: &str) -> bool { if pattern == endpoint { return true; } if let Some(prefix) = pattern.strip_suffix('*') { return endpoint.starts_with(prefix); } false } #[cfg(test)] mod tests { use super::*; fn test_config() -> RbacConfig { let yaml = r#" roles: pds-admin: description: "Full admin" endpoints: - "com.atproto.admin.*" - "com.atproto.server.createInviteCode" - "com.atproto.server.createAccount" moderator: description: "Content moderation" endpoints: - "com.atproto.admin.getAccountInfo" - "com.atproto.admin.getAccountInfos" - "com.atproto.admin.getSubjectStatus" - "com.atproto.admin.updateSubjectStatus" - "com.atproto.admin.getInviteCodes" invite-manager: description: "Invite management" endpoints: - "com.atproto.admin.getInviteCodes" - "com.atproto.admin.disableInviteCodes" - "com.atproto.server.createInviteCode" members: - did: "did:plc:admin123" roles: - pds-admin - did: "did:plc:mod456" roles: - moderator - did: "did:plc:both789" roles: - moderator - invite-manager "#; serde_yaml::from_str(yaml).unwrap() } #[test] fn test_is_member() { let config = test_config(); assert!(config.is_member("did:plc:admin123")); assert!(config.is_member("did:plc:mod456")); assert!(!config.is_member("did:plc:unknown")); } #[test] fn test_get_member_roles() { let config = test_config(); assert_eq!(config.get_member_roles("did:plc:admin123"), vec!["pds-admin"]); assert_eq!( config.get_member_roles("did:plc:both789"), vec!["moderator", "invite-manager"] ); assert!(config.get_member_roles("did:plc:unknown").is_empty()); } #[test] fn test_wildcard_matching() { assert!(matches_endpoint_pattern( "com.atproto.admin.*", "com.atproto.admin.getAccountInfo" )); assert!(matches_endpoint_pattern( "com.atproto.admin.*", "com.atproto.admin.deleteAccount" )); assert!(!matches_endpoint_pattern( "com.atproto.admin.*", "com.atproto.server.createAccount" )); } #[test] fn test_exact_matching() { assert!(matches_endpoint_pattern( "com.atproto.server.createInviteCode", "com.atproto.server.createInviteCode" )); assert!(!matches_endpoint_pattern( "com.atproto.server.createInviteCode", "com.atproto.server.createAccount" )); } #[test] fn test_admin_can_access_all_admin_endpoints() { let config = test_config(); assert!(config.can_access_endpoint("did:plc:admin123", "com.atproto.admin.getAccountInfo")); assert!(config.can_access_endpoint("did:plc:admin123", "com.atproto.admin.deleteAccount")); assert!(config.can_access_endpoint( "did:plc:admin123", "com.atproto.server.createInviteCode" )); assert!(config.can_access_endpoint("did:plc:admin123", "com.atproto.server.createAccount")); } #[test] fn test_moderator_limited_access() { let config = test_config(); assert!(config.can_access_endpoint("did:plc:mod456", "com.atproto.admin.getAccountInfo")); assert!(config.can_access_endpoint( "did:plc:mod456", "com.atproto.admin.updateSubjectStatus" )); assert!(!config.can_access_endpoint("did:plc:mod456", "com.atproto.admin.deleteAccount")); assert!(!config.can_access_endpoint( "did:plc:mod456", "com.atproto.server.createInviteCode" )); } #[test] fn test_combined_roles() { let config = test_config(); // Has moderator + invite-manager assert!(config.can_access_endpoint("did:plc:both789", "com.atproto.admin.getAccountInfo")); assert!(config.can_access_endpoint("did:plc:both789", "com.atproto.admin.getInviteCodes")); assert!(config.can_access_endpoint( "did:plc:both789", "com.atproto.server.createInviteCode" )); assert!(config.can_access_endpoint( "did:plc:both789", "com.atproto.admin.disableInviteCodes" )); // But not delete or create account assert!(!config.can_access_endpoint("did:plc:both789", "com.atproto.admin.deleteAccount")); assert!(!config.can_access_endpoint( "did:plc:both789", "com.atproto.server.createAccount" )); } #[test] fn test_non_member_no_access() { let config = test_config(); assert!(!config.can_access_endpoint( "did:plc:unknown", "com.atproto.admin.getAccountInfo" )); } #[test] fn test_validate_rejects_undefined_role() { let yaml = r#" roles: admin: description: "Admin" endpoints: ["com.atproto.admin.*"] members: - did: "did:plc:test" roles: - nonexistent "#; let config: RbacConfig = serde_yaml::from_str(yaml).unwrap(); assert!(config.validate().is_err()); } }