Skip to content
281 changes: 272 additions & 9 deletions Cargo.lock

Large diffs are not rendered by default.

24 changes: 23 additions & 1 deletion rkvm-client/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,14 @@ version = "0.6.1"
authors = ["Jan Trefil <8711792+htrefil@users.noreply.github.com>"]
edition = "2021"

[[bin]]
name = "rkvm-service"
path = "src/windows-service.rs"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
async-trait = "0.1.89"
tokio = { version = "1.0.1", features = ["macros", "time", "fs", "net", "signal", "rt-multi-thread", "sync"] }
rkvm-input = { path = "../rkvm-input" }
rkvm-net = { path = "../rkvm-net" }
Expand All @@ -19,7 +24,24 @@ thiserror = "1.0.40"
tokio-rustls = "0.24.0"
rustls-pemfile = "1.0.2"
tracing = "0.1.37"
tracing-subscriber = { version = "0.3.17", features = ["env-filter"] }
tracing-subscriber = { version = "0.3.17", features = ["env-filter", "local-time"] }

[target.'cfg(windows)'.dependencies]
bincode = "1.3.3"
windows = { version = "0.62", features = [
"Win32_System_RemoteDesktop",
"Win32_Foundation",
"Win32_Security",
"Win32_Security_Authorization",
"Win32_System_Threading",
"Win32_System_Diagnostics",
"Win32_System_Diagnostics_ToolHelp",
"Win32_System_Environment",
"Win32_System_StationsAndDesktops",
] }
windows-core = "0.62"
windows-sys = { version = "0.61", features = [ "Win32" ]}
windows-service = "0.8"

[package.metadata.rpm]
package = "rkvm-client"
Expand Down
167 changes: 73 additions & 94 deletions rkvm-client/src/client.rs
Original file line number Diff line number Diff line change
@@ -1,37 +1,68 @@
use rkvm_input::writer::Writer;
use crate::config::Config;
use crate::stream::{RkvmStream, RkvmWriter};

use rkvm_input::writer::{DeviceWriter};
use rkvm_net::auth::{AuthChallenge, AuthStatus};
use rkvm_net::message::Message;
use rkvm_net::version::Version;
use rkvm_net::{Pong, Update};
use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::io;
use rkvm_net::Update;
use std::fs::OpenOptions;
use std::io::{self, stdout, BufWriter};
use std::path::Path;
use std::time::Instant;
use thiserror::Error;
use tokio::io::{AsyncWriteExt, BufStream};
use tokio::fs;
use tokio::io::{AsyncRead, AsyncWriteExt, BufStream};
use tokio::net::TcpStream;
use tokio::time;
use tokio_rustls::rustls::ServerName;
use tokio_rustls::rustls::{self, ServerName};
use tokio_rustls::TlsConnector;
use tracing_subscriber::{fmt, Registry,EnvFilter};
use tracing_subscriber::fmt::time::LocalTime;
use tracing_subscriber::prelude::*;
#[cfg(target_os="windows")]
use windows::core;

#[derive(Error, Debug)]
pub enum Error {
#[error("Network error: {0}")]
Network(io::Error),
#[error("Input error: {0}")]
Input(io::Error),
#[error("Io error: {0}")]
Io(#[from] io::Error),
#[error(transparent)]
Rustls(#[from] rustls::Error),
#[error("Toml error: {0}")]
Toml(#[from] toml::de::Error),
#[cfg(target_os="windows")]
#[error("Windows API error: {0}")]
Windows(#[from] core::Error),
#[error("Incompatible server version (got {server}, expected {client})")]
Version { server: Version, client: Version },
#[error("Invalid password")]
Auth,
}

pub async fn run(
hostname: &ServerName,
port: u16,
connector: TlsConnector,
password: &str,
) -> Result<(), Error> {
pub fn init_tracing<P: AsRef<Path>>(log_level: &String, log_file: &Option<P>) {
let filter = EnvFilter::new(log_level);
if let Some(path) = log_file {
let file = OpenOptions::new().create(true).append(true).open(path).unwrap();
let fmt_layer = fmt::layer().with_ansi(false).with_timer(LocalTime::rfc_3339()).with_writer(move || BufWriter::new(file.try_clone().unwrap()));
let registry = Registry::default().with(filter).with(fmt_layer);
tracing::subscriber::set_global_default(registry).unwrap();
} else {
let fmt_layer = fmt::layer().with_writer(stdout).without_time();
let registry = Registry::default().with(filter).with(fmt_layer);
tracing::subscriber::set_global_default(registry).unwrap();
}
}

pub async fn init_config<P: AsRef<Path> + ?Sized> (path: &P) -> Result<Config,Error> {
let config = fs::read_to_string(path).await?;
let config = toml::from_str::<Config>(&config)?;
return Ok(config);
}

pub async fn init_stream(hostname: &ServerName, port: u16, connector: &TlsConnector, password: &str) -> Result<RkvmStream,Error> {
// Intentionally don't impose any timeout for TCP connect.
let stream = match hostname {
ServerName::DnsName(name) => TcpStream::connect(&(name.as_ref(), port)).await,
Expand Down Expand Up @@ -98,61 +129,32 @@ pub async fn run(
}

tracing::info!("Authenticated successfully");
Ok(RkvmStream::Tcp(stream))
}

let mut start = Instant::now();
pub async fn run<R,W,H>(reader: &mut R, writer: &mut W, mut handler: H) -> Result<(), Error>
where
R: AsyncRead + Send + Unpin,
W: RkvmWriter + Send,
H: DeviceWriter {

let mut interval = time::interval(rkvm_net::PING_INTERVAL + rkvm_net::READ_TIMEOUT);
let mut writers = HashMap::new();
let mut start = Instant::now();

// Interval ticks immediately after creation.
interval.tick().await;
let timeout_duration = rkvm_net::PING_INTERVAL + rkvm_net::READ_TIMEOUT;

loop {
let update = tokio::select! {
update = Update::decode(&mut stream) => update.map_err(Error::Network)?,
_ = interval.tick() => return Err(Error::Network(io::Error::new(io::ErrorKind::TimedOut, "Ping timed out"))),
};
let update = match time::timeout(timeout_duration, Update::decode(reader)).await {
Err(_) => Err(Error::Network(io::Error::new(io::ErrorKind::TimedOut, "Ping timeout"))),
Ok(res) => res.map_err(Error::Network)
}?;

match update {
Update::CreateDevice {
id,
name,
vendor,
product,
version,
rel,
abs,
keys,
delay,
period,
} => {
let entry = writers.entry(id);
if let Entry::Occupied(_) = entry {
return Err(Error::Network(io::Error::new(
io::ErrorKind::InvalidData,
"Server created the same device twice",
)));
}

let writer = async {
Writer::builder()?
.name(&name)
.vendor(vendor)
.product(product)
.version(version)
.rel(rel)?
.abs(abs)?
.key(keys)?
.delay(delay)?
.period(period)?
.build()
.await
}
.await
.map_err(Error::Input)?;

entry.or_insert(writer);
let duration = start.elapsed();
tracing::debug!(duration = ?duration, "received {:?}", update);
start = Instant::now();

match update {
Update::CreateDevice { id,name,vendor,product,version,rel,abs,keys,delay,period,} => {
handler.create_device(id, &name, vendor, product, version, rel, abs, keys, delay, period).await?;
tracing::info!(
id = %id,
name = ?name,
Expand All @@ -163,46 +165,23 @@ pub async fn run(
);
}
Update::DestroyDevice { id } => {
if writers.remove(&id).is_none() {
return Err(Error::Network(io::Error::new(
io::ErrorKind::InvalidData,
"Server destroyed a nonexistent device",
)));
}

handler.destroy_device(id).await?;
tracing::info!(id = %id, "Destroyed device");
}
Update::Event { id, event } => {
let writer = writers.get_mut(&id).ok_or_else(|| {
Error::Network(io::Error::new(
io::ErrorKind::InvalidData,
"Server sent an event to a nonexistent device",
))
})?;

writer.write(&event).await.map_err(Error::Input)?;

handler.event(id, event).await?;
tracing::trace!(id = %id, "Wrote an event to device");
}
Update::Ping => {
let duration = start.elapsed();
tracing::debug!(duration = ?duration, "Received ping");

start = Instant::now();
interval.reset();

rkvm_net::timeout(rkvm_net::WRITE_TIMEOUT, async {
Pong.encode(&mut stream).await?;
stream.flush().await?;

Ok(())
})
.await
.map_err(Error::Network)?;

writer.send(Update::Pong).await?;
let duration = start.elapsed();
tracing::debug!(duration = ?duration, "Sent pong");
}
Update::Stop => {
tracing::info!("Stoping..");
return Ok(());
}
_ => {}
}
}
}
71 changes: 34 additions & 37 deletions rkvm-client/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,63 +1,59 @@
mod client;
mod client;
mod config;
mod tls;
mod stream;
#[cfg(target_os="windows")]
mod windows;


use clap::Parser;
use config::Config;
use client::{Error, init_tracing};
use std::path::PathBuf;
use std::process::ExitCode;
use tokio::{fs, signal};
use tracing::subscriber;
use tracing_subscriber::filter::{EnvFilter, LevelFilter};
use tracing_subscriber::fmt;
use tracing_subscriber::prelude::*;
use tokio::io::split;
use tokio::signal;

#[cfg(target_os="windows")]
use windows::{init_stream, init_writers};


#[derive(Parser)]
#[structopt(name = "rkvm-client", about = "The rkvm client application")]
struct Args {
pub struct Args {
#[clap(help = "Path to configuration file")]
config_path: PathBuf,
#[clap(long, default_value = "info", help = "log filter")]
log_level: String,
#[clap(long, help = "output file for the logs")]
log_file: Option<PathBuf>,
}

#[cfg(not(target_os="windows"))]
async fn process_args(args: &Args) -> Result<RkvmStream,Error> {
let config = client::init_config(&args.config_path).await?;
let connector = tls::configure(&config.certificate).await?;
client::init_stream(&config.server.hostname, config.server.port, &connector, &config.password).await
}

#[tokio::main]
async fn main() -> ExitCode {
let filter = EnvFilter::builder()
.with_default_directive(LevelFilter::INFO.into())
.from_env_lossy();

let registry = tracing_subscriber::registry()
.with(filter)
.with(fmt::layer().without_time());

subscriber::set_global_default(registry).unwrap();

let args = Args::parse();
let config = match fs::read_to_string(&args.config_path).await {
Ok(config) => config,
Err(err) => {
tracing::error!("Error reading config: {}", err);
return ExitCode::FAILURE;
}
};
init_tracing(&args.log_level, &args.log_file);

let config = match toml::from_str::<Config>(&config) {
Ok(config) => config,
Err(err) => {
tracing::error!("Error parsing config: {}", err);
tracing::info!("Client starting...");
let stream = match init_stream(&args).await {
Ok(stream) => stream,
Err(e) => {
tracing::error!("Failed to open stream {}", e);
return ExitCode::FAILURE;
}
};

let connector = match tls::configure(&config.certificate).await {
Ok(connector) => connector,
Err(err) => {
tracing::error!("Error configuring TLS: {}", err);
return ExitCode::FAILURE;
}
};
let writers = init_writers();

let (mut r, mut w) = split(stream);
tokio::select! {
result = client::run(&config.server.hostname, config.server.port, connector, &config.password) => {
result = client::run(&mut r, &mut w, writers) => {
if let Err(err) = result {
tracing::error!("Error: {}", err);
return ExitCode::FAILURE;
Expand All @@ -76,3 +72,4 @@ async fn main() -> ExitCode {

ExitCode::SUCCESS
}

Loading