Files
confetti/src/main.rs
T
uelen 2b86e4443f Setup protocol correctly
For some reason it is supposed to default to something and doesn't. No
clue why it doesn't work but we can manually set the provider instead
2026-06-01 15:53:25 -07:00

308 lines
9.2 KiB
Rust

// suppress warnings from unimplemented paths
// remove once actually done
#![allow(dead_code)]
mod client;
mod config;
mod connection;
mod db;
mod state;
use diesel_async::pooled_connection::AsyncDieselConnectionManager;
use diesel_async::pooled_connection::deadpool::Pool;
use quinn::rustls::pki_types::{PrivateKeyDer,CertificateDer,pem::PemObject};
use quinn::Endpoint;
use quinn::crypto::rustls::QuicServerConfig;
use std::io;
use std::fs;
use std::net::{IpAddr,SocketAddr};
use std::process::ExitCode;
use std::str::FromStr;
use std::sync::Arc;
use serde::{Deserialize,Serialize};
use tracing::{error,instrument,warn,debug,info,Level};
use tokio::sync::{RwLock,broadcast,mpsc};
use crate::config::Config;
use crate::state::{State,StateError};
#[derive(Hash,Eq,PartialEq,Clone,Serialize,Deserialize,Debug)]
pub struct Coordinate(Vec<i64>);
impl Coordinate {
pub fn to_roomid(self) -> fedichat::RoomId {
fedichat::RoomId{coordinates: self.0}
}
}
impl From<fedichat::RoomId> for Coordinate {
fn from(other: fedichat::RoomId) -> Coordinate {
Coordinate(other.coordinates)
}
}
#[tokio::main]
#[instrument]
async fn main() -> ExitCode {
// NOTE: This doesn't work as you can only initialize the global logger once
// Initial logger so we have something during config
//tracing::subscriber::set_global_default(
// tracing_subscriber::fmt().with_max_level(Level::WARN).finish()
//).expect("Failed to setup logger");
// Read in config
let config = match Config::load() {
Ok(c) => c,
Err(e) => {
eprintln!("Problem while reading config file");
eprintln!("{:?}",e);
return ExitCode::FAILURE;
}
};
let level = match config.loglevel {
Some(ref s) => {
let loglevel = s.to_lowercase();
match loglevel.as_str() {
"trace" => Level::TRACE,
"debug" => Level::DEBUG,
"info" => Level::INFO,
"warn" => Level::WARN,
"error" => Level::ERROR,
_ => {
eprintln!("Invalid loglevel in config: {}",&loglevel);
Level::INFO
},
}
},
// Default to info level
None => Level::INFO
};
tracing::subscriber::set_global_default(
tracing_subscriber::fmt().with_max_level(level).finish()
).expect("Failed to setup logger");
// Check to make sure media directory exists
match std::fs::exists(&config.media_directory) {
Ok(true) => {},
// NOTE: maybe shouldnt shadow this error
_ => {
error!("Media directory {} does not exist. Check to make sure it is a directory and is writable.",config.media_directory);
return ExitCode::FAILURE;
}
}
// Set up database connection
let db_string = format!("postgres://{}:{}@{}/{}",config.database.user,config.database.password,config.database.url,config.database.db_name);
let db_config = AsyncDieselConnectionManager::<diesel_async::AsyncPgConnection>::new(db_string);
let db_pool = match Pool::builder(db_config)
.max_size(config.database.num_connections)
.build()
{
Ok(val) => val,
Err(e) => {
error!("Error while creating database connection pool");
error!("{:?}",e);
return ExitCode::FAILURE;
}
};
// Read certificate file
debug!("Reading certificate file");
let certs = match CertificateDer::pem_file_iter(&config.certfile) {
Ok(certs) => match certs.collect::<Result<Vec<_>,_>>() {
Ok(k) => k,
Err(e) => {
error!("Could not read certificates {}",&config.certfile);
error!("{}",e);
return ExitCode::FAILURE
}
},
Err(e) => {
error!("Could not read certificates.");
error!("{}",e);
return ExitCode::FAILURE
}
};
let key = match PrivateKeyDer::from_pem_file(&config.keyfile){
Ok(val) => val,
Err(e) => {
error!("Could not read key file {}",&config.keyfile);
error!("{}",e);
return ExitCode::FAILURE
}
};
let address = match IpAddr::from_str(&config.listen_address) {
Ok(val) => val,
Err(e) => {
error!("Could not parse IP address: {:?}",e);
return ExitCode::FAILURE;
}
};
let provider = rustls::crypto::aws_lc_rs::default_provider();
let protocol = match rustls::ServerConfig::builder_with_provider(Arc::new(provider))
.with_protocol_versions(&[&rustls::version::TLS13]) {
Ok(prot) => prot,
Err(e) => {
error!("Unable to intialize TLS protocol configuration: {}",e);
return ExitCode::FAILURE;
}
};
let server_crypto = match protocol
.with_no_client_auth()
.with_single_cert(certs, key)
{
Ok(mut val) => {
val.alpn_protocols = vec![b"fedichatv0".to_vec()];
match QuicServerConfig::try_from(val) {
Ok(conf) => conf,
Err(e) => {
error!("Unable to intialize TLS server configuration: {}",e);
return ExitCode::FAILURE;
}
}
},
Err(e) => {
error!("Unable to intialize TLS server configuration: {}",e);
return ExitCode::FAILURE;
}
};
let server_config =
quinn::ServerConfig::with_crypto(Arc::new(server_crypto));
//let quinn_config = match quinn::ServerConfig::with_single_cert(certs, key){
// Ok(val) => val,
// Err(e) => {
// error!("Unable to intialize quinn server configuration: {:?}",e);
// return ExitCode::FAILURE;
// }
//};
// Bind this endpoint to a UDP socket on the given server address.
let endpoint = match Endpoint::server(
server_config,
SocketAddr::new(address,config.port)
) {
Ok(val) => val,
Err(e) => {
error!("Could not create incoming socket");
error!("{:?}",e);
return ExitCode::FAILURE;
}
};
debug!("Bound to {address} on port {}",config.port);
// Load or create new state
let state = match State::load_from_file(&config.statefile) {
Ok(state) => state,
// Create file if it does not exist
Err(StateError::IOError(e)) if e.kind() == io::ErrorKind::NotFound => {
match fs::File::create(&config.statefile) {
// If the statefile is writable then create an empty state
// and use that
Ok(_) => {
debug!("Creating fresh state");
State::new()
},
Err(e) => {
error!("Could not open or create statefile. Check your config.");
error!("{:?}",e);
return ExitCode::FAILURE;
}
}
},
Err(e) => {
error!("Could not open or create statefile. Check your config.");
error!("{:?}",e);
return ExitCode::FAILURE;
}
};
let statehandle = Arc::new(RwLock::new(state));
// Global connections
let (close_send,close_recv) = broadcast::channel(1);
let (message_send,message_recv) = broadcast::channel(128);
let (message_ack_send,_message_ack_recv) = mpsc::channel(128);
debug!("Setting ctrl-c handler");
match ctrlc_async::set_async_handler(
async move { match close_send.send(()) {
Ok(_val) => debug!("Propogating ctrl-c"),
Err(e) => {
error!("Shutdown handler is broken. Cannot gracefully exit.");
error!("{:?}",e);
}
}
}
) {
Ok(()) => (),
Err(e) => {
error!("Could not set up signal handler");
error!("{:?}",e);
return ExitCode::FAILURE;
}
};
let mut join_handles = Vec::new();
// Create client listener
debug!("Setting up client handler");
let statehandle_cloned = statehandle.clone();
let config_cloned = config.clone();
join_handles.push(tokio::spawn(async move {
connection::client_handler(
statehandle_cloned,
(message_send.clone(),message_recv.resubscribe()),
message_ack_send,
db_pool,
endpoint,
close_recv.resubscribe(),
config_cloned
).await;
} ));
info!("Successfully started confetti");
// Wait for child threads to exit
for handle in join_handles {
match handle.await {
Ok(()) => (),
Err(e) => {
warn!("Problem while cleaning up threads");
warn!("{:?}",e);
}
}
}
info!("Shutting down");
//Save state
match statehandle.write().await.write_to_file(&config.statefile).await {
Ok(()) => debug!("Successfully wrote state to {:?}",config.statefile),
Err(e) => {
error!("Problem while writing to statefile");
error!("{:?}",e);
}
}
ExitCode::SUCCESS
}