Lots more plumbing

Focusing mostly on getting main in shape and figuring out how to make
the connections flow. The main is pretty much there, going to be a bit
more when I take a shot at federation.
This commit is contained in:
2026-05-16 13:31:23 -07:00
parent 1bd9555287
commit 42bcebb50c
8 changed files with 920 additions and 52 deletions
+91
View File
@@ -0,0 +1,91 @@
use fedichat::RoomId;
use fedichat::client::TaggedClientMessage;
use fedichat::state::StatePath;
use diesel_async::AsyncPgConnection;
use diesel_async::pooled_connection::deadpool::Pool;
use tokio::sync::{broadcast,mpsc,RwLock};
use tokio::select;
use std::collections::HashSet;
use std::sync::{Arc};
use quinn::{SendStream,RecvStream};
use crate::state::State;
pub enum MessageId {
State(RoomId,StatePath),
Messages(RoomId)
}
pub struct Client {
statehandle: Arc<RwLock<State>>,
// Channel for local messages, all are broadcast to all
// Every client gets to filter based on name
// Same with state changes
// Potentially could be turned into a more efficient localised filter maybe
// Remote messages can come through the same channel maybe?
message_send: broadcast::Sender<TaggedClientMessage>,
// This probably could be a more specific type
// But we should be filtering and sending back only stuff that matters
message_recv: broadcast::Receiver<TaggedClientMessage>,
// Sends back match if anything filtered matched the client filters
// Connections are closed after a period of inactivity so this should be fine,
// if client is connected acks are sent, if inactive connection is closed
message_ack: mpsc::Sender<MessageId>,
db_handle: Pool<AsyncPgConnection>,
// Filled once user is authed
username: Option<String>,
// how do I keep track of activity??? maybe there's a hashmap that gets updated
subscriptions: HashSet<MessageId>,
quic_connection: quinn::Connection,
close_handle: broadcast::Receiver<()>
}
impl Client {
pub fn new(
statehandle: Arc<RwLock<State>>,
(message_send,message_recv): (
broadcast::Sender<TaggedClientMessage>,
broadcast::Receiver<TaggedClientMessage>),
message_ack: mpsc::Sender<MessageId>,
db_handle: Pool<AsyncPgConnection>,
quic_connection: quinn::Connection,
close_handle: broadcast::Receiver<()>
) -> Self {
Client {
statehandle,
message_send,
message_recv,
message_ack,
db_handle,
username: None,
subscriptions: HashSet::new(),
quic_connection,
close_handle
}
}
pub async fn run(mut self) {
let mut chunk_arr = [0u8; 1024];
let mut serde_buf: Vec<u8> = Vec::new();
loop {
select!{
//result = self.quic_recv.read(&mut chunk_arr) => {
// unimplemented!()
//}
result = self.message_recv.recv() => {
unimplemented!()
}
_result = self.close_handle.recv() => {
// Maybe TODO do I need to check the result?
break;
}
}
}
// Do any cleanup that needs done
}
}
+8 -3
View File
@@ -3,21 +3,26 @@ use std::io::{Read,self};
use serde::{Serialize,Deserialize};
use thiserror::Error;
#[derive(Default,Clone,Serialize,Deserialize)]
// TODO: config file defaults instead of having to specify everything
#[derive(Clone,Serialize,Deserialize)]
pub struct Config {
pub hostname: String,
pub port: u16,
pub federation_port: u16,
pub listen_address: String,
pub database: DBConfig,
pub certfile: String,
pub keyfile: String,
pub statefile: String,
pub loglevel: String,
pub media_directory: String
}
#[derive(Default,Clone,Serialize,Deserialize)]
#[derive(Clone,Serialize,Deserialize)]
pub struct DBConfig {
pub url: String,
pub user: String,
pub password: String,
pub num_connections: usize
}
const LOCATIONS: [&'static str; 3] = [
+64
View File
@@ -0,0 +1,64 @@
use crate::client::{Client,MessageId};
use crate::state::State;
use tokio::sync::{mpsc,RwLock,broadcast};
use std::sync::Arc;
use fedichat::client::TaggedClientMessage;
use diesel_async::pooled_connection::deadpool::Pool;
use diesel_async::{AsyncPgConnection};
use tracing::{instrument,warn};
#[instrument(skip_all)]
pub async fn client_handler(
statehandle: Arc<RwLock<State>>,
(message_send,message_recv): (
broadcast::Sender<TaggedClientMessage>,
broadcast::Receiver<TaggedClientMessage>),
message_ack_send: mpsc::Sender<MessageId>,
db_handle: Pool<AsyncPgConnection>,
endpoint: quinn::Endpoint,
close_handle: broadcast::Receiver<()>
) {
let mut thread_handles = Vec::new();
// Start iterating over incoming connections.
while let Some(conn) = endpoint.accept().await {
let connection = match conn.await {
Ok(conn) => conn,
Err(e) => {
warn!("Error while processing client connection");
warn!("{:?}",e);
continue
}
};
let statehandle = statehandle.clone();
let message_send = message_send.clone();
let message_recv = message_recv.resubscribe();
let close_handle = close_handle.resubscribe();
let db_handle = db_handle.clone();
let message_ack_send = message_ack_send.clone();
thread_handles.push(tokio::spawn(async move {
// listen for the first bidirectional stream
// create client
let client = Client::new(
statehandle,
(message_send,message_recv),
message_ack_send,
db_handle,
connection,
close_handle
);
// run client
client.run()
}));
}
}
pub fn federation_handler() {
}
+118 -13
View File
@@ -1,16 +1,29 @@
mod client;
mod config;
mod connection;
mod state;
use diesel::{Connection,PgConnection};
use diesel_async::pooled_connection::AsyncDieselConnectionManager;
use diesel_async::pooled_connection::deadpool::Pool;
use diesel_async::{AsyncConnection,AsyncPgConnection};
use quinn::rustls::pki_types::{PrivateKeyDer,CertificateDer,pem::PemObject};
use std::net::{IpAddr,SocketAddr};
use std::str::FromStr;
use quinn::Endpoint;
use tracing::{error,instrument};
use std::io;
use std::fs;
use std::net::{IpAddr,SocketAddr};
use std::process::ExitCode;
use std::str::FromStr;
use std::sync::Arc;
use serde::{Deserialize,Serialize};
use tracing::{error,instrument,warn,debug};
use tokio::sync::{RwLock,broadcast,mpsc};
use crate::config::Config;
use crate::state::State;
#[derive(Hash,Eq,PartialEq,Clone,Serialize,Deserialize)]
pub struct Coordinate(Vec<i64>);
#[tokio::main]
#[instrument]
@@ -22,12 +35,19 @@ async fn main() -> ExitCode {
error!("Problem while reading config file");
error!("{:?}",e);
return ExitCode::FAILURE;
}
};
// Set up database connection
let db_connection = PgConnection::establish(&config.database.url);
let db_config = AsyncDieselConnectionManager::<diesel_async::AsyncPgConnection>::new(config.database.url.clone());
let db_pool = match Pool::builder(db_config).build() {
Ok(val) => val,
Err(e) => {
error!("Error while creating database connection pool");
error!("{:?}",e);
return ExitCode::FAILURE;
}
};
// Read certificate file
@@ -66,14 +86,99 @@ async fn main() -> ExitCode {
}
};
// Load or create new state
let state = match State::load_from_file(&config.statefile) {
Ok(state) => state,
// Create file if it does not exist
Err(e) => {
match e.kind() {
io::ErrorKind::NotFound => {
match fs::File::create(&config.statefile) {
// If the statefile is writable then create an empty state
// and use that
Ok(_) => State::new(),
Err(e) => {
error!("Could not open or create statefile. Check your config.");
error!("{:?}",e);
return ExitCode::FAILURE;
// Start iterating over incoming connections.
while let Some(conn) = endpoint.accept().await {
let connection = conn.await;
}
}
// Save connection somewhere, start transferring, receiving data, see DataTransfer tutorial.
},
_ => {
error!("Could not open or create statefile. Check your config.");
error!("{:?}",e);
return ExitCode::FAILURE;
}
}
},
};
let statehandle = Arc::new(RwLock::new(state));
// Global connections
let (close_send,close_recv) = broadcast::channel(1);
let (message_send,message_recv) = broadcast::channel(128);
let (message_ack_send,message_ack_recv) = mpsc::channel(128);
match ctrlc_async::set_async_handler(
async move { match close_send.send(()) {
Ok(_val) => (),
Err(e) => {
error!("Shutdown handler is broken. Cannot gracefully exit.");
error!("{:?}",e);
}
}
}
) {
Ok(()) => (),
Err(e) => {
error!("Could not set up signal handler");
error!("{:?}",e);
return ExitCode::FAILURE;
}
};
let mut join_handles = Vec::new();
// Create client listener
let statehandle_cloned = statehandle.clone();
join_handles.push(tokio::spawn(async move {
connection::client_handler(
statehandle_cloned,
(message_send.clone(),message_recv.resubscribe()),
message_ack_send,
db_pool,
endpoint,
close_recv.resubscribe()
).await;
} ));
// Wait for child threads to exit
for handle in join_handles {
match handle.await {
Ok(()) => (),
Err(e) => {
warn!("Problem while cleaning up threads");
warn!("{:?}",e);
}
}
}
//Save state
match statehandle.write().await.write_to_file(&config.statefile) {
Ok(()) => debug!("Successfully wrote state to {:?}",config.statefile),
Err(e) => {
error!("Problem while writing to statefile");
error!("{:?}",e);
}
}
// Should never reach this
// How should I gracefully shutdown? Can I trap ctrlc
ExitCode::SUCCESS
+87
View File
@@ -0,0 +1,87 @@
use serde::{Deserialize,Serialize};
use crate::Coordinate;
use tokio::sync::RwLock;
use std::collections::HashMap;
use std::io;
use std::sync::Arc;
// Is there a way to guarantee that creating a state will create a permission
// at the same time in the event of a crash?
#[derive(Serialize,Deserialize)]
pub struct LoadableState {
room_states: HashMap<Coordinate,StateNode>,
perms: HashMap<Coordinate,PermNode>
}
impl LoadableState {
pub fn new(room_states: HashMap<Coordinate,StateNode>, perms: HashMap<Coordinate,PermNode>) -> Self {
LoadableState{
room_states,
perms
}
}
}
// This whole thing needs to be RWLock protected to be able to add new rooms
pub struct State {
room_states: HashMap<Coordinate,Arc<RwLock<StateNode>>>,
perms: HashMap<Coordinate,Arc<RwLock<PermNode>>>
}
impl State {
pub fn new() -> Self {
State {
room_states: HashMap::new(),
perms: HashMap::new()
}
}
pub fn from_loadable_state(other: LoadableState) -> Self {
unimplemented!()
}
pub fn write_to_file(&mut self, path: &str) -> Result<(),io::Error> {
unimplemented!()
}
pub fn load_from_file(path: &str) -> Result<Self,io::Error> {
unimplemented!()
}
pub async fn to_loadable_state(&self) -> LoadableState {
let mut room_states = HashMap::new();
let mut perms = HashMap::new();
//NOTE: This clone can be really heavy, I wonder if I can just lock the whole
//state to avoid copying it
for (coordinate,state) in self.room_states.iter() {
room_states.insert(coordinate.clone(),state.read().await.clone());
}
for (coordinate,perm) in self.perms.iter() {
perms.insert(coordinate.clone(),perm.read().await.clone());
}
LoadableState::new(room_states,perms)
}
// This traverses to the right leaf node and gets all permissions for an object
pub async fn get_required_permissions(&self, path: fedichat::state::StatePath)
-> HashMap<fedichat::state::StatePermissionKey,fedichat::state::StatePermissionValue>
{
unimplemented!()
}
// Make a wrapper for this that allows the client to request a state value
pub async fn get_state_value(&self, path: fedichat::state::StatePath)
-> fedichat::state::StateValue
{
unimplemented!()
}
}
#[derive(Serialize,Deserialize,Clone)]
pub enum StateNode {
Directory(HashMap<String,StateNode>),
Value(fedichat::state::StateValue)
}
#[derive(Serialize,Deserialize,Clone)]
pub enum PermNode {
Directory(HashMap<String,StateNode>),
Permission(HashMap<fedichat::state::StatePermissionKey,fedichat::state::StatePermissionValue>)
}