this repo has no description
1#![cfg(feature = "loopback")]
2
3use crate::{
4 atproto::AtprotoClientMetadata,
5 authstore::ClientAuthStore,
6 client::OAuthClient,
7 dpop::DpopExt,
8 error::{CallbackError, OAuthError},
9 resolver::OAuthResolver,
10 types::{AuthorizeOptions, CallbackParams},
11};
12use jacquard_common::{IntoStatic, cowstr::ToCowStr};
13use rouille::Server;
14use std::net::SocketAddr;
15use tokio::sync::mpsc;
16use url::Url;
17
18#[derive(Clone, Debug)]
19pub enum LoopbackPort {
20 Fixed(u16),
21 Ephemeral,
22}
23
24#[derive(Clone, Debug)]
25pub struct LoopbackConfig {
26 pub host: String,
27 pub port: LoopbackPort,
28 pub open_browser: bool,
29 pub timeout_ms: u64,
30}
31
32impl Default for LoopbackConfig {
33 fn default() -> Self {
34 Self {
35 host: "127.0.0.1".into(),
36 port: LoopbackPort::Fixed(4000),
37 open_browser: true,
38 timeout_ms: 5 * 60 * 1000,
39 }
40 }
41}
42
43#[cfg(feature = "browser-open")]
44fn try_open_in_browser(url: &str) -> bool {
45 webbrowser::open(url).is_ok()
46}
47#[cfg(not(feature = "browser-open"))]
48fn try_open_in_browser(_url: &str) -> bool {
49 false
50}
51
52pub fn create_callback_router(
53 request: &rouille::Request,
54 tx: mpsc::Sender<CallbackParams>,
55) -> rouille::Response {
56 rouille::router!(request,
57 (GET) (/oauth/callback) => {
58 let state = request.get_param("state").unwrap();
59 let code = request.get_param("code").unwrap();
60 let iss = request.get_param("iss").unwrap();
61 let callback_params = CallbackParams {
62 state: Some(state.to_cowstr().into_static()),
63 code: code.to_cowstr().into_static(),
64 iss: Some(iss.to_cowstr().into_static()),
65 };
66 tx.try_send(callback_params).unwrap();
67 rouille::Response::text("Logged in!")
68 },
69 _ => rouille::Response::empty_404()
70 )
71}
72
73struct CallbackHandle {
74 #[allow(dead_code)]
75 server_handle: std::thread::JoinHandle<()>,
76 server_stop: std::sync::mpsc::Sender<()>,
77 callback_rx: mpsc::Receiver<CallbackParams<'static>>,
78}
79
80fn one_shot_server(addr: SocketAddr) -> (SocketAddr, CallbackHandle) {
81 let (tx, callback_rx) = mpsc::channel(5);
82 let server = Server::new(addr, move |request| {
83 create_callback_router(request, tx.clone())
84 })
85 .expect("Could not start server");
86 let (server_handle, server_stop) = server.stoppable();
87 let handle = CallbackHandle {
88 server_handle,
89 server_stop,
90 callback_rx,
91 };
92 (addr, handle)
93}
94
95impl<T, S> OAuthClient<T, S>
96where
97 T: OAuthResolver + DpopExt + Send + Sync + 'static,
98 S: ClientAuthStore + Send + Sync + 'static,
99{
100 /// Drive the full OAuth flow using a local loopback server.
101 pub async fn login_with_local_server(
102 &self,
103 input: impl AsRef<str>,
104 opts: AuthorizeOptions<'_>,
105 cfg: LoopbackConfig,
106 ) -> crate::error::Result<super::client::OAuthSession<T, S>> {
107 let port = match cfg.port {
108 LoopbackPort::Fixed(p) => p,
109 LoopbackPort::Ephemeral => 0,
110 };
111 // TODO: fix this to it also accepts ipv6 and properly finds a free port
112 let bind_addr: SocketAddr = format!("0.0.0.0:{}", port)
113 .parse()
114 .expect("invalid loopback host/port");
115 let (local_addr, handle) = one_shot_server(bind_addr);
116 println!("Listening on {}", local_addr);
117 // build redirect uri
118 let redirect = Url::parse(&format!(
119 "http://{}:{}/oauth/callback",
120 cfg.host,
121 local_addr.port(),
122 ))
123 .unwrap();
124
125 let scopes = if opts.scopes.is_empty() {
126 Some(self.registry.client_data.config.scopes.clone())
127 } else {
128 Some(opts.scopes.clone().into_static())
129 };
130
131 let client_data = crate::session::ClientData {
132 keyset: self.registry.client_data.keyset.clone(),
133 config: AtprotoClientMetadata::new_localhost(Some(vec![redirect.clone()]), scopes),
134 };
135 // Build client using store and resolver
136 let flow_client = OAuthClient::new_with_shared(
137 self.registry.store.clone(),
138 self.client.clone(),
139 client_data,
140 );
141
142 // Start auth and get authorization URL
143 let auth_url = flow_client.start_auth(input.as_ref(), opts).await?;
144 // Print URL for copy/paste
145 println!("To authenticate with your PDS, visit:\n{}\n", auth_url);
146 // Optionally open browser
147 if cfg.open_browser {
148 let _ = try_open_in_browser(&auth_url);
149 }
150
151 // Await callback or timeout
152 let mut callback_rx = handle.callback_rx;
153 let cb = tokio::time::timeout(
154 std::time::Duration::from_millis(cfg.timeout_ms),
155 callback_rx.recv(),
156 )
157 .await;
158 // trigger shutdown
159 let _ = handle.server_stop.send(());
160 if let Ok(Some(cb)) = cb {
161 // Handle callback and create a session
162 Ok(flow_client.callback(cb).await?)
163 } else {
164 Err(OAuthError::Callback(CallbackError::Timeout))
165 }
166 }
167}