this repo has no description
at main 167 lines 5.2 kB view raw
1#![cfg(feature = "loopback")] 2 3use crate::{ 4 atproto::AtprotoClientMetadata, 5 authstore::ClientAuthStore, 6 client::OAuthClient, 7 dpop::DpopExt, 8 error::{CallbackError, OAuthError}, 9 resolver::OAuthResolver, 10 types::{AuthorizeOptions, CallbackParams}, 11}; 12use jacquard_common::{IntoStatic, cowstr::ToCowStr}; 13use rouille::Server; 14use std::net::SocketAddr; 15use tokio::sync::mpsc; 16use url::Url; 17 18#[derive(Clone, Debug)] 19pub enum LoopbackPort { 20 Fixed(u16), 21 Ephemeral, 22} 23 24#[derive(Clone, Debug)] 25pub struct LoopbackConfig { 26 pub host: String, 27 pub port: LoopbackPort, 28 pub open_browser: bool, 29 pub timeout_ms: u64, 30} 31 32impl Default for LoopbackConfig { 33 fn default() -> Self { 34 Self { 35 host: "127.0.0.1".into(), 36 port: LoopbackPort::Fixed(4000), 37 open_browser: true, 38 timeout_ms: 5 * 60 * 1000, 39 } 40 } 41} 42 43#[cfg(feature = "browser-open")] 44fn try_open_in_browser(url: &str) -> bool { 45 webbrowser::open(url).is_ok() 46} 47#[cfg(not(feature = "browser-open"))] 48fn try_open_in_browser(_url: &str) -> bool { 49 false 50} 51 52pub fn create_callback_router( 53 request: &rouille::Request, 54 tx: mpsc::Sender<CallbackParams>, 55) -> rouille::Response { 56 rouille::router!(request, 57 (GET) (/oauth/callback) => { 58 let state = request.get_param("state").unwrap(); 59 let code = request.get_param("code").unwrap(); 60 let iss = request.get_param("iss").unwrap(); 61 let callback_params = CallbackParams { 62 state: Some(state.to_cowstr().into_static()), 63 code: code.to_cowstr().into_static(), 64 iss: Some(iss.to_cowstr().into_static()), 65 }; 66 tx.try_send(callback_params).unwrap(); 67 rouille::Response::text("Logged in!") 68 }, 69 _ => rouille::Response::empty_404() 70 ) 71} 72 73struct CallbackHandle { 74 #[allow(dead_code)] 75 server_handle: std::thread::JoinHandle<()>, 76 server_stop: std::sync::mpsc::Sender<()>, 77 callback_rx: mpsc::Receiver<CallbackParams<'static>>, 78} 79 80fn one_shot_server(addr: SocketAddr) -> (SocketAddr, CallbackHandle) { 81 let (tx, callback_rx) = mpsc::channel(5); 82 let server = Server::new(addr, move |request| { 83 create_callback_router(request, tx.clone()) 84 }) 85 .expect("Could not start server"); 86 let (server_handle, server_stop) = server.stoppable(); 87 let handle = CallbackHandle { 88 server_handle, 89 server_stop, 90 callback_rx, 91 }; 92 (addr, handle) 93} 94 95impl<T, S> OAuthClient<T, S> 96where 97 T: OAuthResolver + DpopExt + Send + Sync + 'static, 98 S: ClientAuthStore + Send + Sync + 'static, 99{ 100 /// Drive the full OAuth flow using a local loopback server. 101 pub async fn login_with_local_server( 102 &self, 103 input: impl AsRef<str>, 104 opts: AuthorizeOptions<'_>, 105 cfg: LoopbackConfig, 106 ) -> crate::error::Result<super::client::OAuthSession<T, S>> { 107 let port = match cfg.port { 108 LoopbackPort::Fixed(p) => p, 109 LoopbackPort::Ephemeral => 0, 110 }; 111 // TODO: fix this to it also accepts ipv6 and properly finds a free port 112 let bind_addr: SocketAddr = format!("0.0.0.0:{}", port) 113 .parse() 114 .expect("invalid loopback host/port"); 115 let (local_addr, handle) = one_shot_server(bind_addr); 116 println!("Listening on {}", local_addr); 117 // build redirect uri 118 let redirect = Url::parse(&format!( 119 "http://{}:{}/oauth/callback", 120 cfg.host, 121 local_addr.port(), 122 )) 123 .unwrap(); 124 125 let scopes = if opts.scopes.is_empty() { 126 Some(self.registry.client_data.config.scopes.clone()) 127 } else { 128 Some(opts.scopes.clone().into_static()) 129 }; 130 131 let client_data = crate::session::ClientData { 132 keyset: self.registry.client_data.keyset.clone(), 133 config: AtprotoClientMetadata::new_localhost(Some(vec![redirect.clone()]), scopes), 134 }; 135 // Build client using store and resolver 136 let flow_client = OAuthClient::new_with_shared( 137 self.registry.store.clone(), 138 self.client.clone(), 139 client_data, 140 ); 141 142 // Start auth and get authorization URL 143 let auth_url = flow_client.start_auth(input.as_ref(), opts).await?; 144 // Print URL for copy/paste 145 println!("To authenticate with your PDS, visit:\n{}\n", auth_url); 146 // Optionally open browser 147 if cfg.open_browser { 148 let _ = try_open_in_browser(&auth_url); 149 } 150 151 // Await callback or timeout 152 let mut callback_rx = handle.callback_rx; 153 let cb = tokio::time::timeout( 154 std::time::Duration::from_millis(cfg.timeout_ms), 155 callback_rx.recv(), 156 ) 157 .await; 158 // trigger shutdown 159 let _ = handle.server_stop.send(()); 160 if let Ok(Some(cb)) = cb { 161 // Handle callback and create a session 162 Ok(flow_client.callback(cb).await?) 163 } else { 164 Err(OAuthError::Callback(CallbackError::Timeout)) 165 } 166 } 167}