Microservice to bring 2FA to self hosted PDSes
1use serde::Deserialize;
2use std::collections::HashMap;
3use std::path::Path;
4
5#[derive(Debug, Clone, Deserialize)]
6pub struct RbacConfig {
7 pub roles: HashMap<String, RoleDefinition>,
8 pub members: Vec<MemberDefinition>,
9}
10
11#[derive(Debug, Clone, Deserialize)]
12pub struct RoleDefinition {
13 pub description: String,
14 pub endpoints: Vec<String>,
15}
16
17#[derive(Debug, Clone, Deserialize)]
18pub struct MemberDefinition {
19 pub did: String,
20 pub roles: Vec<String>,
21}
22
23impl RbacConfig {
24 /// Load RBAC configuration from a YAML file.
25 pub fn load_from_file(path: impl AsRef<Path>) -> Result<Self, anyhow::Error> {
26 let contents = std::fs::read_to_string(path)?;
27 let config: RbacConfig = serde_yaml::from_str(&contents)?;
28 config.validate()?;
29 Ok(config)
30 }
31
32 /// Validate that all member roles reference defined roles.
33 fn validate(&self) -> Result<(), anyhow::Error> {
34 for member in &self.members {
35 for role in &member.roles {
36 if !self.roles.contains_key(role) {
37 return Err(anyhow::anyhow!(
38 "Member {} references undefined role: {}",
39 member.did,
40 role
41 ));
42 }
43 }
44 }
45 Ok(())
46 }
47
48 /// Check whether a DID is a configured member.
49 pub fn is_member(&self, did: &str) -> bool {
50 self.members.iter().any(|m| m.did == did)
51 }
52
53 /// Get the role names assigned to a DID.
54 pub fn get_member_roles(&self, did: &str) -> Vec<String> {
55 self.members
56 .iter()
57 .find(|m| m.did == did)
58 .map(|m| m.roles.clone())
59 .unwrap_or_default()
60 }
61
62 /// Get all endpoint patterns that a DID is allowed to access (aggregated from all roles).
63 pub fn get_allowed_endpoints(&self, did: &str) -> Vec<String> {
64 let roles = self.get_member_roles(did);
65 let mut endpoints = Vec::new();
66 for role_name in &roles {
67 if let Some(role) = self.roles.get(role_name) {
68 endpoints.extend(role.endpoints.clone());
69 }
70 }
71 endpoints
72 }
73
74 /// Check whether a DID can access a specific endpoint.
75 /// Supports wildcard matching: `com.atproto.admin.*` matches `com.atproto.admin.getAccountInfo`.
76 pub fn can_access_endpoint(&self, did: &str, endpoint: &str) -> bool {
77 let allowed = self.get_allowed_endpoints(did);
78 for pattern in &allowed {
79 if matches_endpoint_pattern(pattern, endpoint) {
80 return true;
81 }
82 }
83 false
84 }
85}
86
87/// Match an endpoint against a pattern. Supports trailing `*` wildcard.
88/// e.g. `com.atproto.admin.*` matches `com.atproto.admin.getAccountInfo`
89fn matches_endpoint_pattern(pattern: &str, endpoint: &str) -> bool {
90 if pattern == endpoint {
91 return true;
92 }
93 if let Some(prefix) = pattern.strip_suffix('*') {
94 return endpoint.starts_with(prefix);
95 }
96 false
97}
98
99#[cfg(test)]
100mod tests {
101 use super::*;
102
103 fn test_config() -> RbacConfig {
104 let yaml = r#"
105roles:
106 pds-admin:
107 description: "Full admin"
108 endpoints:
109 - "com.atproto.admin.*"
110 - "com.atproto.server.createInviteCode"
111 - "com.atproto.server.createAccount"
112 moderator:
113 description: "Content moderation"
114 endpoints:
115 - "com.atproto.admin.getAccountInfo"
116 - "com.atproto.admin.getAccountInfos"
117 - "com.atproto.admin.getSubjectStatus"
118 - "com.atproto.admin.updateSubjectStatus"
119 - "com.atproto.admin.getInviteCodes"
120 invite-manager:
121 description: "Invite management"
122 endpoints:
123 - "com.atproto.admin.getInviteCodes"
124 - "com.atproto.admin.disableInviteCodes"
125 - "com.atproto.server.createInviteCode"
126
127members:
128 - did: "did:plc:admin123"
129 roles:
130 - pds-admin
131 - did: "did:plc:mod456"
132 roles:
133 - moderator
134 - did: "did:plc:both789"
135 roles:
136 - moderator
137 - invite-manager
138"#;
139 serde_yaml::from_str(yaml).unwrap()
140 }
141
142 #[test]
143 fn test_is_member() {
144 let config = test_config();
145 assert!(config.is_member("did:plc:admin123"));
146 assert!(config.is_member("did:plc:mod456"));
147 assert!(!config.is_member("did:plc:unknown"));
148 }
149
150 #[test]
151 fn test_get_member_roles() {
152 let config = test_config();
153 assert_eq!(config.get_member_roles("did:plc:admin123"), vec!["pds-admin"]);
154 assert_eq!(
155 config.get_member_roles("did:plc:both789"),
156 vec!["moderator", "invite-manager"]
157 );
158 assert!(config.get_member_roles("did:plc:unknown").is_empty());
159 }
160
161 #[test]
162 fn test_wildcard_matching() {
163 assert!(matches_endpoint_pattern(
164 "com.atproto.admin.*",
165 "com.atproto.admin.getAccountInfo"
166 ));
167 assert!(matches_endpoint_pattern(
168 "com.atproto.admin.*",
169 "com.atproto.admin.deleteAccount"
170 ));
171 assert!(!matches_endpoint_pattern(
172 "com.atproto.admin.*",
173 "com.atproto.server.createAccount"
174 ));
175 }
176
177 #[test]
178 fn test_exact_matching() {
179 assert!(matches_endpoint_pattern(
180 "com.atproto.server.createInviteCode",
181 "com.atproto.server.createInviteCode"
182 ));
183 assert!(!matches_endpoint_pattern(
184 "com.atproto.server.createInviteCode",
185 "com.atproto.server.createAccount"
186 ));
187 }
188
189 #[test]
190 fn test_admin_can_access_all_admin_endpoints() {
191 let config = test_config();
192 assert!(config.can_access_endpoint("did:plc:admin123", "com.atproto.admin.getAccountInfo"));
193 assert!(config.can_access_endpoint("did:plc:admin123", "com.atproto.admin.deleteAccount"));
194 assert!(config.can_access_endpoint(
195 "did:plc:admin123",
196 "com.atproto.server.createInviteCode"
197 ));
198 assert!(config.can_access_endpoint("did:plc:admin123", "com.atproto.server.createAccount"));
199 }
200
201 #[test]
202 fn test_moderator_limited_access() {
203 let config = test_config();
204 assert!(config.can_access_endpoint("did:plc:mod456", "com.atproto.admin.getAccountInfo"));
205 assert!(config.can_access_endpoint(
206 "did:plc:mod456",
207 "com.atproto.admin.updateSubjectStatus"
208 ));
209 assert!(!config.can_access_endpoint("did:plc:mod456", "com.atproto.admin.deleteAccount"));
210 assert!(!config.can_access_endpoint(
211 "did:plc:mod456",
212 "com.atproto.server.createInviteCode"
213 ));
214 }
215
216 #[test]
217 fn test_combined_roles() {
218 let config = test_config();
219 // Has moderator + invite-manager
220 assert!(config.can_access_endpoint("did:plc:both789", "com.atproto.admin.getAccountInfo"));
221 assert!(config.can_access_endpoint("did:plc:both789", "com.atproto.admin.getInviteCodes"));
222 assert!(config.can_access_endpoint(
223 "did:plc:both789",
224 "com.atproto.server.createInviteCode"
225 ));
226 assert!(config.can_access_endpoint(
227 "did:plc:both789",
228 "com.atproto.admin.disableInviteCodes"
229 ));
230 // But not delete or create account
231 assert!(!config.can_access_endpoint("did:plc:both789", "com.atproto.admin.deleteAccount"));
232 assert!(!config.can_access_endpoint(
233 "did:plc:both789",
234 "com.atproto.server.createAccount"
235 ));
236 }
237
238 #[test]
239 fn test_non_member_no_access() {
240 let config = test_config();
241 assert!(!config.can_access_endpoint(
242 "did:plc:unknown",
243 "com.atproto.admin.getAccountInfo"
244 ));
245 }
246
247 #[test]
248 fn test_validate_rejects_undefined_role() {
249 let yaml = r#"
250roles:
251 admin:
252 description: "Admin"
253 endpoints: ["com.atproto.admin.*"]
254members:
255 - did: "did:plc:test"
256 roles:
257 - nonexistent
258"#;
259 let config: RbacConfig = serde_yaml::from_str(yaml).unwrap();
260 assert!(config.validate().is_err());
261 }
262}