Unit tests part 1

This commit is contained in:
Mauro D 2023-01-17 17:55:13 +00:00
parent b1e1067fc9
commit cc722576ea
38 changed files with 1904 additions and 341 deletions

16
Cargo.lock generated
View file

@ -521,9 +521,9 @@ dependencies = [
[[package]]
name = "ed25519"
version = "1.5.2"
version = "1.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e9c280362032ea4203659fc489832d0204ef09f247a0506f170dafcac08c369"
checksum = "91cff35c70bba8a626e3185d8cd48cc11b5437e1a5bcd15b9b5fa3c64b6dfee7"
dependencies = [
"signature",
]
@ -634,7 +634,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "197676987abd2f9cadff84926f410af1c183608d36641465df73ae8211dc65d6"
dependencies = [
"futures-core",
"futures-io",
"futures-task",
"memchr",
"pin-project-lite",
"pin-utils",
"slab",
@ -1091,9 +1093,9 @@ dependencies = [
[[package]]
name = "nom"
version = "7.1.2"
version = "7.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e5507769c4919c998e69e49c839d9dc6e693ede4cc4290d6ad8b41d4f09c548c"
checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a"
dependencies = [
"memchr",
"minimal-lexical",
@ -1354,9 +1356,9 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de"
[[package]]
name = "proc-macro2"
version = "1.0.49"
version = "1.0.50"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "57a8eca9f9c4ffde41714334dee777596264c7825420f521abc92b5b5deb63a5"
checksum = "6ef7d57beacfaf2d8aee5937dab7b7f28de3cb8b1828479bb5de2a7106f2bae2"
dependencies = [
"unicode-ident",
]
@ -1788,8 +1790,10 @@ dependencies = [
"mail-builder 0.2.4",
"mail-parser",
"mail-send",
"num_cpus",
"parking_lot",
"rand 0.8.5",
"rayon",
"regex",
"reqwest",
"rustls",

View file

@ -16,6 +16,7 @@ rustls-pemfile = "1.0"
tokio = { version = "1.23", features = ["full"] }
tokio-rustls = { version = "0.23"}
webpki-roots = { version = "0.22"}
rayon = "1.5"
tracing = "0.1"
tracing-subscriber = "0.3"
parking_lot = "0.12"
@ -25,11 +26,13 @@ blake3 = "1.3"
lru-cache = "0.1.2"
rand = "0.8.5"
x509-parser = "0.14.0"
reqwest = { version = "0.11", default-features = false, features = ["rustls-tls"] }
reqwest = { version = "0.11", default-features = false, features = ["rustls-tls", "blocking"] }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
num_cpus = "1.15.0"
[dev-dependencies]
mail-auth = { path = "/home/vagrant/code/mail-auth", features = ["test"] }
criterion = "0.4.0"
[[bench]]

View file

@ -2,8 +2,7 @@
Stalwart SMTP Server
# TODO
- Analyze reports
- Null MX
- Dashmap cleanup
- RBL
- Sieve
- Spam filter

View file

@ -38,12 +38,12 @@ struct TestEnvelope {
}
impl Envelope for TestEnvelope {
fn local_ip(&self) -> &IpAddr {
&self.local_ip
fn local_ip(&self) -> IpAddr {
self.local_ip
}
fn remote_ip(&self) -> &IpAddr {
&self.remote_ip
fn remote_ip(&self) -> IpAddr {
self.remote_ip
}
fn sender_domain(&self) -> &str {

View file

@ -263,7 +263,11 @@ host = "lmtp"
path = "/var/spool/report"
hash = 16
submitter = "mx.domain.org"
analyze = ["dmarc@*", "abuse@*"]
[report.analysis]
addresses = ["dmarc@*", "abuse@*"]
forward = true
store = "/var/spool/report"
[report.dsn]
from-name = "Mail Delivery Subsystem"

View file

@ -91,7 +91,22 @@ impl Config {
err
)
})?;
let (signer, sealer) = self.parse_signature(id, key.clone(), key)?;
let key_clone = RsaKey::<Sha256>::from_pkcs1_pem(
&String::from_utf8(self.file_contents((
"signature",
id,
"public-key",
))?)
.unwrap_or_default(),
)
.map_err(|err| {
format!(
"Failed to build RSA key for {}: {}",
("signature", id, "public-key",).as_key(),
err
)
})?;
let (signer, sealer) = self.parse_signature(id, key_clone, key)?;
(DkimSigner::RsaSha256(signer), ArcSealer::RsaSha256(sealer))
}
Algorithm::Ed25519Sha256 => {

View file

@ -17,7 +17,7 @@ use std::{
collections::BTreeMap,
net::{Ipv4Addr, Ipv6Addr, SocketAddr},
path::PathBuf,
sync::Arc,
sync::{atomic::AtomicU64, Arc},
time::Duration,
};
@ -239,8 +239,9 @@ pub struct Connect {
pub struct Ehlo {
pub script: IfBlock<Option<Arc<Script>>>,
pub require: IfBlock<bool>,
}
// Capabilities
pub struct Extensions {
pub pipelining: IfBlock<bool>,
pub chunking: IfBlock<bool>,
pub requiretls: IfBlock<bool>,
@ -307,6 +308,7 @@ pub struct SessionConfig {
pub mail: Mail,
pub rcpt: Rcpt,
pub data: Data,
pub extensions: Extensions,
}
pub struct SessionThrottle {
@ -359,7 +361,7 @@ pub struct ReportConfig {
pub path: IfBlock<PathBuf>,
pub hash: IfBlock<u64>,
pub submitter: IfBlock<String>,
pub analyze: Vec<String>,
pub analysis: ReportAnalysis,
pub dkim: Report,
pub spf: Report,
@ -368,6 +370,19 @@ pub struct ReportConfig {
pub tls: AggregateReport,
}
pub struct ReportAnalysis {
pub addresses: Vec<AddressMatch>,
pub forward: bool,
pub store: Option<PathBuf>,
pub report_id: AtomicU64,
}
pub enum AddressMatch {
StartsWith(String),
EndsWith(String),
Equals(String),
}
pub struct Dsn {
pub name: IfBlock<String>,
pub address: IfBlock<String>,

View file

@ -1,7 +1,7 @@
use super::{
utils::{AsKey, ParseValue},
AggregateFrequency, AggregateReport, Config, ConfigContext, EnvelopeKey, IfBlock, Report,
ReportConfig,
AddressMatch, AggregateFrequency, AggregateReport, Config, ConfigContext, EnvelopeKey, IfBlock,
Report, ReportAnalysis, ReportConfig,
};
impl Config {
@ -23,6 +23,10 @@ impl Config {
EnvelopeKey::LocalIp,
EnvelopeKey::RecipientDomain,
];
let mut addresses = Vec::new();
for address in self.properties::<AddressMatch>("report.analysis.addresses") {
addresses.push(address?.1);
}
let default_hostname = self.value_require("server.hostname")?;
Ok(ReportConfig {
@ -45,10 +49,12 @@ impl Config {
hash: self
.parse_if_block("report.hash", ctx, &sender_envelope_keys)?
.unwrap_or_else(|| IfBlock::new(32)),
analyze: self
.values("report.analyze")
.map(|(_, v)| v.to_string())
.collect(),
analysis: ReportAnalysis {
addresses,
forward: self.property("report.analysis.forward")?.unwrap_or(false),
store: self.property("report.analysis.store")?,
report_id: 0.into(),
},
})
}
@ -169,3 +175,24 @@ impl ParseValue for AggregateFrequency {
}
}
}
impl ParseValue for AddressMatch {
fn parse_value(key: impl AsKey, value: &str) -> super::Result<Self> {
if let Some(value) = value.strip_prefix('*').map(|v| v.trim()) {
if !value.is_empty() {
return Ok(AddressMatch::EndsWith(value.to_lowercase()));
}
} else if let Some(value) = value.strip_suffix('*').map(|v| v.trim()) {
if !value.is_empty() {
return Ok(AddressMatch::StartsWith(value.to_lowercase()));
}
} else if value.contains('@') {
return Ok(AddressMatch::Equals(value.trim().to_lowercase()));
}
Err(format!(
"Invalid address match value {:?} for key {:?}.",
value,
key.as_key()
))
}
}

View file

@ -103,6 +103,7 @@ impl Config {
mail: self.parse_session_mail(ctx)?,
rcpt: self.parse_session_rcpt(ctx)?,
data: self.parse_session_data(ctx)?,
extensions: self.parse_extensions(ctx)?,
})
}
@ -120,21 +121,17 @@ impl Config {
})
}
fn parse_session_ehlo(&self, ctx: &ConfigContext) -> super::Result<Ehlo> {
fn parse_extensions(&self, ctx: &ConfigContext) -> super::Result<Extensions> {
let available_keys = [
EnvelopeKey::Listener,
EnvelopeKey::RemoteIp,
EnvelopeKey::LocalIp,
EnvelopeKey::Sender,
EnvelopeKey::SenderDomain,
EnvelopeKey::AuthenticatedAs,
];
Ok(Ehlo {
script: self
.parse_if_block::<Option<String>>("session.ehlo.script", ctx, &available_keys)?
.unwrap_or_default()
.map_if_block(&ctx.scripts, "session.ehlo.script", "script")?,
require: self
.parse_if_block("session.ehlo.require", ctx, &available_keys)?
.unwrap_or_else(|| IfBlock::new(true)),
Ok(Extensions {
pipelining: self
.parse_if_block("session.extensions.pipelining", ctx, &available_keys)?
.unwrap_or_else(|| IfBlock::new(true)),
@ -159,6 +156,24 @@ impl Config {
})
}
fn parse_session_ehlo(&self, ctx: &ConfigContext) -> super::Result<Ehlo> {
let available_keys = [
EnvelopeKey::Listener,
EnvelopeKey::RemoteIp,
EnvelopeKey::LocalIp,
];
Ok(Ehlo {
script: self
.parse_if_block::<Option<String>>("session.ehlo.script", ctx, &available_keys)?
.unwrap_or_default()
.map_if_block(&ctx.scripts, "session.ehlo.script", "script")?,
require: self
.parse_if_block("session.ehlo.require", ctx, &available_keys)?
.unwrap_or_else(|| IfBlock::new(true)),
})
}
fn parse_session_auth(&self, ctx: &ConfigContext) -> super::Result<Auth> {
let available_keys = [
EnvelopeKey::Listener,

View file

@ -324,10 +324,10 @@ impl ParseValue for Duration {
(num, 60 * 60 * 1000)
} else if let Some(num) = duration.strip_suffix('m') {
(num, 60 * 1000)
} else if let Some(num) = duration.strip_suffix('s') {
(num, 1000)
} else if let Some(num) = duration.strip_suffix("ms") {
(num, 1)
} else if let Some(num) = duration.strip_suffix('s') {
(num, 1000)
} else {
(duration.as_str(), 1)
};

View file

@ -45,8 +45,8 @@ impl Conditions {
}
ConditionValue::IpAddrMask(value) => {
let ctx_value = match key {
EnvelopeKey::RemoteIp => *envelope.remote_ip(),
EnvelopeKey::LocalIp => *envelope.local_ip(),
EnvelopeKey::RemoteIp => envelope.remote_ip(),
EnvelopeKey::LocalIp => envelope.local_ip(),
_ => IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)),
};
@ -133,7 +133,7 @@ impl IpAddrMask {
IpAddrMask::V4 { addr, mask } => {
if *mask == u32::MAX {
match remote {
IpAddr::V4(addr) => addr == remote,
IpAddr::V4(remote) => addr == remote,
IpAddr::V6(remote) => {
if let Some(remote) = remote.to_ipv4_mapped() {
addr == &remote
@ -160,7 +160,7 @@ impl IpAddrMask {
if mask == &u128::MAX {
match remote {
IpAddr::V6(remote) => remote == addr,
IpAddr::V4(addr) => &addr.to_ipv6_mapped() == remote,
IpAddr::V4(remote) => &remote.to_ipv6_mapped() == addr,
}
} else {
u128::from_be_bytes(match remote {
@ -198,12 +198,12 @@ mod tests {
}
impl Envelope for TestEnvelope {
fn local_ip(&self) -> &IpAddr {
&self.local_ip
fn local_ip(&self) -> IpAddr {
self.local_ip
}
fn remote_ip(&self) -> &IpAddr {
&self.remote_ip
fn remote_ip(&self) -> IpAddr {
self.remote_ip
}
fn sender_domain(&self) -> &str {

View file

@ -39,8 +39,10 @@ use self::throttle::{
pub mod if_block;
pub mod params;
pub mod throttle;
pub mod worker;
pub struct Core {
pub worker_pool: rayon::ThreadPool,
pub session: SessionCore,
pub queue: QueueCore,
pub resolvers: Resolvers,
@ -128,6 +130,7 @@ pub struct SessionData {
pub priority: i16,
pub delivery_by: u64,
pub notify_by: i64,
pub future_release: u64,
pub valid_until: Instant,
@ -197,6 +200,7 @@ impl SessionData {
messages_sent: 0,
bytes_left: 0,
delivery_by: 0,
notify_by: 0,
future_release: 0,
iprev: None,
spf_ehlo: None,
@ -212,8 +216,8 @@ impl Default for State {
}
pub trait Envelope {
fn local_ip(&self) -> &IpAddr;
fn remote_ip(&self) -> &IpAddr;
fn local_ip(&self) -> IpAddr;
fn remote_ip(&self) -> IpAddr;
fn sender_domain(&self) -> &str;
fn sender(&self) -> &str;
fn rcpt_domain(&self) -> &str;

View file

@ -227,7 +227,7 @@ impl Throttle {
hasher.update(e.mx().as_bytes());
}
if (self.keys & THROTTLE_REMOTE_IP) != 0 {
match &e.local_ip() {
match &e.remote_ip() {
IpAddr::V4(ip) => {
hasher.update(&ip.octets()[..]);
}
@ -237,7 +237,7 @@ impl Throttle {
}
}
if (self.keys & THROTTLE_LOCAL_IP) != 0 {
match &e.remote_ip() {
match &e.local_ip() {
IpAddr::V4(ip) => {
hasher.update(&ip.octets()[..]);
}
@ -325,10 +325,14 @@ impl<T: AsyncRead + AsyncWrite> Session<T> {
}
limiter
});
let rate = t
.rate
.as_ref()
.map(|rate| RateLimiter::new(rate.requests, rate.period.as_secs()));
let rate = t.rate.as_ref().map(|rate| {
let mut r = RateLimiter::new(
rate.requests,
std::cmp::min(rate.period.as_secs(), 1),
);
r.is_allowed();
r
});
e.insert(Limiter { rate, concurrency });
}

29
src/core/worker.rs Normal file
View file

@ -0,0 +1,29 @@
use tokio::sync::oneshot;
use super::Core;
impl Core {
pub async fn spawn_worker<U, V>(&self, f: U) -> Option<V>
where
U: FnOnce() -> V + Send + 'static,
V: Sync + Send + 'static,
{
let (tx, rx) = oneshot::channel();
self.worker_pool.spawn(move || {
tx.send(f()).ok();
});
match rx.await {
Ok(result) => Some(result),
Err(err) => {
tracing::warn!(
context = "worker-pool",
event = "error",
reason = %err,
);
None
}
}
}
}

View file

@ -1,5 +1,6 @@
use std::{
path::PathBuf,
sync::Arc,
time::{Duration, Instant, SystemTime},
};
@ -14,42 +15,39 @@ use tokio::io::{AsyncRead, AsyncWrite};
use crate::{
core::{Session, SessionAddress},
queue::{self, Message, SimpleEnvelope},
reporting::analysis::AnalyzeReport,
};
use super::IsTls;
impl<T: AsyncWrite + AsyncRead + IsTls + Unpin> Session<T> {
pub async fn queue_message(&mut self) -> Result<(), ()> {
pub async fn queue_message(&mut self) -> &'static [u8] {
// Authenticate message
let dc = &self.core.session.config.data;
let ac = &self.core.mail_auth;
let rc = &self.core.report.config;
let raw_message = Arc::new(std::mem::take(&mut self.data.message));
let auth_message =
if let Some(auth_message) = AuthenticatedMessage::parse(&self.data.message) {
auth_message
} else {
tracing::info!(parent: &self.span,
let auth_message = if let Some(auth_message) = AuthenticatedMessage::parse(&raw_message) {
auth_message
} else {
tracing::info!(parent: &self.span,
event = "parse-failed",
context = "data",
size = self.data.message.len());
size = raw_message.len());
self.reset();
return self.write(b"550 5.7.7 Failed to parse message.\r\n").await;
};
return &b"550 5.7.7 Failed to parse message.\r\n"[..];
};
// Loop detection
if auth_message.received_headers_count() > *dc.max_received_headers.eval(self).await {
tracing::info!(parent: &self.span,
event = "loop-detected",
context = "data",
rfc5321_from = self.data.mail_from.as_ref().unwrap().address,
rfc5322_from = auth_message.from(),
return_path = self.data.mail_from.as_ref().unwrap().address,
from = auth_message.from(),
received_headers = auth_message.received_headers_count());
self.reset();
return self
.write(b"450 4.4.6 Too many Received headers. Possible loop detected.\r\n")
.await;
return b"450 4.4.6 Too many Received headers. Possible loop detected.\r\n";
}
// Verify DKIM
@ -74,7 +72,15 @@ impl<T: AsyncWrite + AsyncRead + IsTls + Unpin> Session<T> {
if rejected {
// This violates the advice of Section 6.1 of RFC6376
let message = if dkim_output
tracing::info!(parent: &self.span,
event = "failed",
context = "dkim",
return_path = self.data.mail_from.as_ref().unwrap().address,
from = auth_message.from(),
result = ?dkim_output.iter().map(|d| d.result().to_string()).collect::<Vec<_>>(),
"No passing DKIM signatures found.");
return if dkim_output
.iter()
.any(|d| matches!(d.result(), DkimResult::TempError(_)))
{
@ -82,23 +88,12 @@ impl<T: AsyncWrite + AsyncRead + IsTls + Unpin> Session<T> {
} else {
&b"550 5.7.20 No passing DKIM signatures found.\r\n"[..]
};
tracing::info!(parent: &self.span,
event = "auth-failed",
context = "dkim",
rfc5321_from = self.data.mail_from.as_ref().unwrap().address,
rfc5322_from = auth_message.from(),
result = ?dkim_output.iter().map(|d| d.result().to_string()).collect::<Vec<_>>(),
"No passing DKIM signatures found.");
self.reset();
return self.write(message).await;
} else {
tracing::debug!(parent: &self.span,
event = "verify",
context = "dkim",
rfc5321_from = self.data.mail_from.as_ref().unwrap().address,
rfc5322_from = auth_message.from(),
return_path = self.data.mail_from.as_ref().unwrap().address,
from = auth_message.from(),
result = ?dkim_output.iter().map(|d| d.result().to_string()).collect::<Vec<_>>());
}
dkim_output
@ -115,28 +110,25 @@ impl<T: AsyncWrite + AsyncRead + IsTls + Unpin> Session<T> {
if arc.is_strict()
&& !matches!(arc_output.result(), DkimResult::Pass | DkimResult::None)
{
let message = if matches!(arc_output.result(), DkimResult::TempError(_)) {
tracing::info!(parent: &self.span,
event = "auth-failed",
context = "arc",
return_path = self.data.mail_from.as_ref().unwrap().address,
from = auth_message.from(),
result = %arc_output.result(),
"ARC validation failed.");
return if matches!(arc_output.result(), DkimResult::TempError(_)) {
&b"451 4.7.29 ARC validation failed.\r\n"[..]
} else {
&b"550 5.7.29 ARC validation failed.\r\n"[..]
};
tracing::info!(parent: &self.span,
event = "auth-failed",
context = "arc",
rfc5321_from = self.data.mail_from.as_ref().unwrap().address,
rfc5322_from = auth_message.from(),
result = %arc_output.result(),
"ARC validation failed.");
self.reset();
return self.write(message).await;
} else {
tracing::debug!(parent: &self.span,
event = "verify",
context = "arc",
rfc5321_from = self.data.mail_from.as_ref().unwrap().address,
rfc5322_from = auth_message.from(),
return_path = self.data.mail_from.as_ref().unwrap().address,
from = auth_message.from(),
result = %arc_output.result());
}
arc_output.into()
@ -214,16 +206,16 @@ impl<T: AsyncWrite + AsyncRead + IsTls + Unpin> Session<T> {
tracing::debug!(parent: &self.span,
event = "verify",
context = "dmarc",
rfc5321_from = mail_from.address,
rfc5322_from = auth_message.from(),
return_path = mail_from.address,
from = auth_message.from(),
dkim_result = %dmarc_output.dkim_result(),
spf_result = %dmarc_output.spf_result());
} else {
tracing::info!(parent: &self.span,
event = "auth-failed",
context = "dmarc",
rfc5321_from = mail_from.address,
rfc5322_from = auth_message.from(),
return_path = mail_from.address,
from = auth_message.from(),
dkim_result = %dmarc_output.dkim_result(),
spf_result = %dmarc_output.spf_result());
}
@ -242,14 +234,20 @@ impl<T: AsyncWrite + AsyncRead + IsTls + Unpin> Session<T> {
}
if rejected {
self.reset();
return self
.write(if is_temp_fail {
&b"451 4.7.1 Email temporarily rejected per DMARC policy.\r\n"[..]
} else {
&b"550 5.7.1 Email rejected per DMARC policy.\r\n"[..]
})
.await;
return if is_temp_fail {
&b"451 4.7.1 Email temporarily rejected per DMARC policy.\r\n"[..]
} else {
&b"550 5.7.1 Email rejected per DMARC policy.\r\n"[..]
};
}
}
// Analyze reports
if self.is_report() {
self.core.analyze_report(raw_message.clone());
if !rc.analysis.forward {
self.data.messages_sent += 1;
return b"250 2.0.0 Message queued for delivery.\r\n";
}
}
@ -302,8 +300,8 @@ impl<T: AsyncWrite + AsyncRead + IsTls + Unpin> Session<T> {
tracing::info!(parent: &self.span,
event = "seal-failed",
context = "arc",
rfc5321_from = message.return_path,
rfc5322_from = auth_message.from(),
return_path = message.return_path,
from = auth_message.from(),
"Failed to seal message: {}", err);
}
}
@ -331,7 +329,7 @@ impl<T: AsyncWrite + AsyncRead + IsTls + Unpin> Session<T> {
// DKIM sign
for signer in ac.dkim.sign.eval(self).await.iter() {
match signer.sign_chained(&[headers.as_ref(), &self.data.message]) {
match signer.sign_chained(&[headers.as_ref(), &raw_message]) {
Ok(signature) => {
signature.write_header(&mut headers);
}
@ -339,45 +337,32 @@ impl<T: AsyncWrite + AsyncRead + IsTls + Unpin> Session<T> {
tracing::info!(parent: &self.span,
event = "sign-failed",
context = "dkim",
rfc5321_from = message.return_path,
rfc5322_from = auth_message.from(),
return_path = message.return_path,
from = auth_message.from(),
"Failed to sign message: {}", err);
}
}
}
// Update size
message.size = self.data.message.len() + headers.len();
message.size = raw_message.len() + headers.len();
message.size_headers = auth_message.body_offset() + headers.len();
// Verify queue quota
if self.core.queue.has_quota(&mut message).await {
let id = message.id;
let queue_success = {
let _span = self.span.enter();
self.core
.queue
.queue_message(
message,
vec![headers, std::mem::take(&mut self.data.message)],
&self.span,
)
.queue_message(message, Some(&headers), &raw_message, &self.span)
.await
};
if queue_success {
self.data.messages_sent += 1;
self.write(
format!(
"250 2.0.0 Message queued for delivery with ID {:X}.\r\n",
id
)
.as_bytes(),
)
.await?;
b"250 2.0.0 Message queued for delivery.\r\n"
} else {
self.write(b"451 4.3.5 Unable to accept message at this time.\r\n")
.await?;
b"451 4.3.5 Unable to accept message at this time.\r\n"
}
} else {
tracing::warn!(
@ -387,12 +372,8 @@ impl<T: AsyncWrite + AsyncRead + IsTls + Unpin> Session<T> {
from = message.return_path,
"Queue quota exceeded, rejecting message."
);
self.write(b"452 4.3.1 Mail system full, try again later.\r\n")
.await?;
b"452 4.3.1 Mail system full, try again later.\r\n"
}
self.reset();
Ok(())
}
async fn build_message(
@ -439,26 +420,37 @@ impl<T: AsyncWrite + AsyncRead + IsTls + Unpin> Session<T> {
};
// Set expiration time
let config = &self.core.queue.config;
let expires = Instant::now()
+ if self.data.delivery_by == 0 {
*self.core.queue.config.expire.eval(&envelope).await
*config.expire.eval(&envelope).await
} else {
Duration::from_secs(self.data.delivery_by)
};
// Set delayed notification time
let notify = queue::Schedule::later(
future_release
+ *self
.core
.queue
.config
.notify
.eval(&envelope)
.await
.first()
.unwrap(),
);
let notify_intervals = config.notify.eval(&envelope).await;
let notify_time = future_release
+ match self.data.notify_by.cmp(&0) {
std::cmp::Ordering::Equal => *notify_intervals.first().unwrap(),
std::cmp::Ordering::Greater => std::cmp::min(
*notify_intervals.last().unwrap(),
Duration::from_secs(self.data.notify_by as u64),
),
std::cmp::Ordering::Less => {
let notify_at = -self.data.notify_by as u64;
let last_notify = notify_intervals.last().unwrap().as_secs();
if last_notify > notify_at {
Duration::from_secs(last_notify - notify_at)
} else {
*notify_intervals.first().unwrap()
}
}
};
let mut notify = queue::Schedule::later(notify_time);
if self.data.notify_by != 0 {
notify.inner = (notify_intervals.len() - 1) as u32;
}
message.domains.push(queue::Domain {
retry,

View file

@ -64,7 +64,7 @@ impl<T: AsyncWrite + AsyncRead + IsTls + Unpin> Session<T> {
if !self.stream.is_tls() {
response.capabilities |= EXT_START_TLS;
}
let ec = &self.core.session.config.ehlo;
let ec = &self.core.session.config.extensions;
let rc = &self.core.session.config.rcpt;
let ac = &self.core.session.config.auth;
let dc = &self.core.session.config.data;

View file

@ -1,5 +1,5 @@
use mail_auth::{IprevOutput, IprevResult, SpfOutput, SpfResult};
use smtp_proto::MailFrom;
use smtp_proto::{MailFrom, MAIL_BY_NOTIFY, MAIL_BY_RETURN, MAIL_BY_TRACE, MAIL_REQUIRETLS};
use tokio::io::{AsyncRead, AsyncWrite};
use crate::core::{Session, SessionAddress};
@ -82,6 +82,32 @@ impl<T: AsyncWrite + AsyncRead + Unpin> Session<T> {
}
.into();
// Validate parameters
let config = &self.core.session.config.extensions;
if (from.flags & MAIL_REQUIRETLS) != 0 && !*config.requiretls.eval(self).await {
//todo
}
if (from.flags & (MAIL_BY_NOTIFY | MAIL_BY_RETURN | MAIL_BY_TRACE)) != 0 {
if let Some(duration) = config.deliver_by.eval(self).await {
if (from.flags & MAIL_BY_RETURN) != 0 {
if from.by > 0 {
let deliver_by = from.by as u64;
if deliver_by <= duration.as_secs() {
self.data.delivery_by = deliver_by;
} else {
// err
}
} else {
// err
}
} else {
self.data.notify_by = from.by;
}
} else {
// err
}
}
if self.is_allowed().await {
// Verify SPF
if self.params.spf_mail_from.verify() {

View file

@ -205,7 +205,9 @@ impl<T: AsyncWrite + AsyncRead + IsTls + Unpin> Session<T> {
State::Data(receiver) => {
if self.data.message.len() + bytes.len() < self.params.max_message_size {
if receiver.ingest(&mut iter, &mut self.data.message) {
self.queue_message().await?;
let message = self.queue_message().await;
self.write(message).await?;
self.reset();
state = State::default();
} else {
break 'outer;
@ -218,7 +220,9 @@ impl<T: AsyncWrite + AsyncRead + IsTls + Unpin> Session<T> {
if receiver.ingest(&mut iter, &mut self.data.message) {
if self.can_send_data().await? {
if receiver.is_last {
self.queue_message().await?;
let message = self.queue_message().await;
self.write(message).await?;
self.reset();
} else {
self.write(b"250 2.6.0 Chunk accepted.\r\n").await?;
}
@ -293,6 +297,7 @@ impl<T: AsyncWrite + AsyncRead + Unpin> Session<T> {
self.data.message = Vec::with_capacity(0);
self.data.priority = 0;
self.data.delivery_by = 0;
self.data.notify_by = 0;
self.data.future_release = 0;
}
@ -342,13 +347,13 @@ impl<T: AsyncWrite + AsyncRead + Unpin> Session<T> {
impl<T: AsyncRead + AsyncWrite> Envelope for Session<T> {
#[inline(always)]
fn local_ip(&self) -> &IpAddr {
&self.data.local_ip
fn local_ip(&self) -> IpAddr {
self.data.local_ip
}
#[inline(always)]
fn remote_ip(&self) -> &IpAddr {
&self.data.remote_ip
fn remote_ip(&self) -> IpAddr {
self.data.remote_ip
}
#[inline(always)]

View file

@ -188,13 +188,13 @@ impl Session<TcpStream> {
}
pub async fn handle_conn(
self,
mut self,
tls_acceptor: Option<TlsAcceptor>,
shutdown_rx: watch::Receiver<bool>,
) {
if let Some((session, shutdown_rx)) = self.handle_conn_(shutdown_rx).await {
if let Some(shutdown_rx) = self.handle_conn_(shutdown_rx).await {
if let Some(tls_acceptor) = tls_acceptor {
if let Ok(session) = session.into_tls(tls_acceptor).await {
if let Ok(session) = self.into_tls(tls_acceptor).await {
session.handle_conn(shutdown_rx).await;
}
}
@ -203,16 +203,16 @@ impl Session<TcpStream> {
}
impl Session<TlsStream<TcpStream>> {
pub async fn handle_conn(self, shutdown_rx: watch::Receiver<bool>) {
pub async fn handle_conn(mut self, shutdown_rx: watch::Receiver<bool>) {
self.handle_conn_(shutdown_rx).await;
}
}
impl<T: AsyncRead + AsyncWrite + IsTls + Unpin> Session<T> {
pub async fn handle_conn_(
mut self,
&mut self,
mut shutdown_rx: watch::Receiver<bool>,
) -> Option<(Session<T>, watch::Receiver<bool>)> {
) -> Option<watch::Receiver<bool>> {
let mut buf = vec![0; 8192];
loop {
@ -228,7 +228,7 @@ impl<T: AsyncRead + AsyncWrite + IsTls + Unpin> Session<T> {
match self.ingest(&buf[..bytes_read]).await {
Ok(true) => (),
Ok(false) => {
return (self, shutdown_rx).into();
return (shutdown_rx).into();
}
Err(_) => {
break;

View file

@ -5,5 +5,7 @@ pub mod outbound;
pub mod queue;
pub mod remote;
pub mod reporting;
#[cfg(test)]
pub mod tests;
pub static USER_AGENT: &str = concat!("StalwartSMTP/", env!("CARGO_PKG_VERSION"),);

View file

@ -47,6 +47,16 @@ async fn main() -> std::io::Result<()> {
let (queue_tx, queue_rx) = mpsc::channel(1024);
let (report_tx, report_rx) = mpsc::channel(1024);
let core = Arc::new(Core {
worker_pool: rayon::ThreadPoolBuilder::new()
.num_threads(
config
.property::<usize>("global.thread-pool")
.failed("Failed to parse thread pool size")
.filter(|v| *v > 0)
.unwrap_or_else(num_cpus::get),
)
.build()
.unwrap(),
resolvers: config.build_resolvers().failed("Failed to build resolvers"),
session: SessionCore {
config: session_config,

View file

@ -249,54 +249,71 @@ impl DeliveryAttempt {
// Obtain remote hosts list
let mx_list;
let remote_hosts =
if let Some(next_hop) = queue_config.next_hop.eval(&envelope).await {
vec![RemoteHost::Relay(next_hop)]
} else {
// Lookup MX
mx_list = match core.resolvers.dns.mx_lookup(&domain.domain).await {
Ok(mx) => mx,
Err(err) => {
tracing::info!(
parent: &span,
context = "dns",
event = "mx-lookup-failed",
reason = %err,
);
domain.set_status(err, queue_config.retry.eval(&envelope).await);
continue 'next_domain;
}
};
let remote_hosts = if let Some(next_hop) =
queue_config.next_hop.eval(&envelope).await
{
vec![RemoteHost::Relay(next_hop)]
} else {
// Lookup MX
mx_list = match core.resolvers.dns.mx_lookup(&domain.domain).await {
Ok(mx) => mx,
Err(err) => {
tracing::info!(
parent: &span,
context = "dns",
event = "mx-lookup-failed",
reason = %err,
);
domain.set_status(err, queue_config.retry.eval(&envelope).await);
continue 'next_domain;
}
};
if !mx_list.is_empty() {
// Obtain max number of MX hosts to process
let max_mx = *queue_config.max_mx.eval(&envelope).await;
let mut remote_hosts = Vec::with_capacity(max_mx);
if !mx_list.is_empty() {
// Obtain max number of MX hosts to process
let max_mx = *queue_config.max_mx.eval(&envelope).await;
let mut remote_hosts = Vec::with_capacity(max_mx);
for mx in mx_list.iter() {
if mx.exchanges.len() > 1 {
let mut slice = mx.exchanges.iter().collect::<Vec<_>>();
slice.shuffle(&mut rand::thread_rng());
for remote_host in slice {
remote_hosts.push(RemoteHost::MX(remote_host.as_str()));
if remote_hosts.len() == max_mx {
break;
}
}
} else if let Some(remote_host) = mx.exchanges.first() {
for mx in mx_list.iter() {
if mx.exchanges.len() > 1 {
let mut slice = mx.exchanges.iter().collect::<Vec<_>>();
slice.shuffle(&mut rand::thread_rng());
for remote_host in slice {
remote_hosts.push(RemoteHost::MX(remote_host.as_str()));
if remote_hosts.len() == max_mx {
break;
}
}
} else if let Some(remote_host) = mx.exchanges.first() {
// Check for Null MX
if mx.preference == 0 && remote_host == "." {
tracing::info!(
parent: &span,
context = "dns",
event = "null-mx",
reason = "Domain does not accept messages (mull MX)",
);
domain.set_status(
Status::PermanentFailure(Error::DnsError(
"Domain does not accept messages (null MX)".to_string(),
)),
queue_config.retry.eval(&envelope).await,
);
continue 'next_domain;
}
remote_hosts.push(RemoteHost::MX(remote_host.as_str()));
if remote_hosts.len() == max_mx {
break;
}
}
remote_hosts
} else {
// If an empty list of MXs is returned, the address is treated as if it was
// associated with an implicit MX RR with a preference of 0, pointing to that host.
vec![RemoteHost::MX(domain.domain.as_str())]
}
};
remote_hosts
} else {
// If an empty list of MXs is returned, the address is treated as if it was
// associated with an implicit MX RR with a preference of 0, pointing to that host.
vec![RemoteHost::MX(domain.domain.as_str())]
}
};
// Try delivering message
let max_multihomed = *queue_config.max_multihomed.eval(&envelope).await;

View file

@ -34,11 +34,11 @@ impl QueueCore {
.await;
// Sign message
let message_bytes = attempt
let signature = attempt
.message
.sign(&self.config.dsn.sign, dsn, &attempt.span)
.sign(&self.config.dsn.sign, &dsn, &attempt.span)
.await;
self.queue_message(dsn_message, message_bytes, &attempt.span)
self.queue_message(dsn_message, signature.as_deref(), &dsn, &attempt.span)
.await;
}
} else {
@ -697,7 +697,7 @@ mod test {
message,
in_flight: vec![],
};
let config = QueueConfig::default();
let config = QueueConfig::test();
// Disabled DSN
assert!(attempt.build_dsn(&config).await.is_none());

View file

@ -1,6 +1,6 @@
use std::{
fmt::Display,
net::IpAddr,
net::{IpAddr, Ipv4Addr},
path::PathBuf,
sync::{atomic::AtomicUsize, Arc},
time::{Duration, Instant, SystemTime},
@ -211,12 +211,12 @@ impl<'x> SimpleEnvelope<'x> {
}
impl<'x> Envelope for SimpleEnvelope<'x> {
fn local_ip(&self) -> &std::net::IpAddr {
unreachable!()
fn local_ip(&self) -> IpAddr {
IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0))
}
fn remote_ip(&self) -> &std::net::IpAddr {
unreachable!()
fn remote_ip(&self) -> IpAddr {
IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0))
}
fn sender_domain(&self) -> &str {
@ -265,12 +265,12 @@ pub struct QueueEnvelope<'x> {
}
impl<'x> Envelope for QueueEnvelope<'x> {
fn local_ip(&self) -> &std::net::IpAddr {
&self.local_ip
fn local_ip(&self) -> IpAddr {
self.local_ip
}
fn remote_ip(&self) -> &std::net::IpAddr {
&self.remote_ip
fn remote_ip(&self) -> IpAddr {
self.remote_ip
}
fn sender_domain(&self) -> &str {
@ -311,12 +311,12 @@ impl<'x> Envelope for QueueEnvelope<'x> {
}
impl Envelope for Message {
fn local_ip(&self) -> &IpAddr {
unreachable!()
fn local_ip(&self) -> IpAddr {
IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0))
}
fn remote_ip(&self) -> &IpAddr {
unreachable!()
fn remote_ip(&self) -> IpAddr {
IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0))
}
fn sender_domain(&self) -> &str {
@ -357,12 +357,12 @@ impl Envelope for Message {
}
impl Envelope for &str {
fn local_ip(&self) -> &IpAddr {
unreachable!()
fn local_ip(&self) -> IpAddr {
IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0))
}
fn remote_ip(&self) -> &IpAddr {
unreachable!()
fn remote_ip(&self) -> IpAddr {
IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0))
}
fn sender_domain(&self) -> &str {
@ -517,59 +517,3 @@ impl Display for Status<HostResponse<String>, HostResponse<ErrorDetails>> {
}
}
}
#[cfg(test)]
impl Default for crate::config::QueueConfig {
fn default() -> Self {
use crate::config::{
Dsn, IfBlock, QueueOutboundSourceIp, QueueOutboundTimeout, QueueOutboundTls,
QueueQuotas, QueueThrottle,
};
Self {
path: Default::default(),
hash: Default::default(),
retry: IfBlock::new(vec![Duration::from_secs(10)]),
notify: IfBlock::new(vec![Duration::from_secs(20)]),
expire: IfBlock::new(Duration::from_secs(10)),
hostname: IfBlock::new("mx.example.org".to_string()),
next_hop: Default::default(),
max_mx: IfBlock::new(5),
max_multihomed: IfBlock::new(5),
source_ip: QueueOutboundSourceIp {
ipv4: IfBlock::new(vec![]),
ipv6: IfBlock::new(vec![]),
},
tls: QueueOutboundTls {
dane: IfBlock::new(crate::config::RequireOptional::Optional),
mta_sts: IfBlock::new(crate::config::RequireOptional::Optional),
start: IfBlock::new(crate::config::RequireOptional::Optional),
},
dsn: Dsn {
name: IfBlock::new("Mail Delivery Subsystem".to_string()),
address: IfBlock::new("MAILER-DAEMON@example.org".to_string()),
sign: IfBlock::default(),
},
timeout: QueueOutboundTimeout {
connect: IfBlock::new(Duration::from_secs(1)),
greeting: IfBlock::new(Duration::from_secs(1)),
tls: IfBlock::new(Duration::from_secs(1)),
ehlo: IfBlock::new(Duration::from_secs(1)),
mail: IfBlock::new(Duration::from_secs(1)),
rcpt: IfBlock::new(Duration::from_secs(1)),
data: IfBlock::new(Duration::from_secs(1)),
mta_sts: IfBlock::new(Duration::from_secs(1)),
},
throttle: QueueThrottle {
sender: vec![],
rcpt: vec![],
host: vec![],
},
quota: QueueQuotas {
sender: vec![],
rcpt: vec![],
rcpt_domain: vec![],
},
}
}
}

View file

@ -23,7 +23,8 @@ impl QueueCore {
pub async fn queue_message(
&self,
mut message: Box<Message>,
mut message_bytes: Vec<Vec<u8>>,
raw_headers: Option<&[u8]>,
raw_message: &[u8],
span: &tracing::Span,
) -> bool {
// Generate id
@ -48,7 +49,7 @@ impl QueueCore {
message.path.push(file);
// Serialize metadata
message_bytes.push(message.serialize());
let metadata = message.serialize();
// Save message
let mut file = match fs::File::create(&message.path).await {
@ -65,17 +66,26 @@ impl QueueCore {
return false;
}
};
for bytes in message_bytes {
if let Err(err) = file.write_all(&bytes).await {
tracing::error!(
parent: span,
context = "queue",
event = "error",
"Failed to write to file {}: {}",
message.path.display(),
err
);
return false;
let iter = if let Some(raw_headers) = raw_headers {
[raw_headers, raw_message, &metadata].into_iter()
} else {
[raw_message, &metadata, b""].into_iter()
};
for bytes in iter {
if !bytes.is_empty() {
if let Err(err) = file.write_all(bytes).await {
tracing::error!(
parent: span,
context = "queue",
event = "error",
"Failed to write to file {}: {}",
message.path.display(),
err
);
return false;
}
}
}
if let Err(err) = file.flush().await {

463
src/reporting/analysis.rs Normal file
View file

@ -0,0 +1,463 @@
use std::{
borrow::Cow,
collections::hash_map::Entry,
io::{Cursor, Read},
sync::{atomic::Ordering, Arc},
time::SystemTime,
};
use ahash::AHashMap;
use mail_auth::{
flate2::read::GzDecoder,
report::{tlsrpt::TlsReport, ActionDisposition, DmarcResult, Feedback, Report},
zip,
};
use mail_parser::{DateTime, HeaderValue, Message, MimeHeaders, PartType};
use crate::core::Core;
enum Compression {
None,
Gzip,
Zip,
}
enum Format {
Dmarc,
Tls,
Arf,
}
struct ReportData<'x> {
compression: Compression,
format: Format,
data: &'x [u8],
}
pub trait AnalyzeReport {
fn analyze_report(&self, message: Arc<Vec<u8>>);
}
impl AnalyzeReport for Arc<Core> {
fn analyze_report(&self, message: Arc<Vec<u8>>) {
let core = self.clone();
self.worker_pool.spawn(move || {
let message = if let Some(message) = Message::parse(&message) {
message
} else {
return;
};
let from = match message.from() {
HeaderValue::Address(addr) => addr.address.as_ref().map(|a| a.as_ref()),
HeaderValue::AddressList(addr_list) => addr_list
.last()
.and_then(|a| a.address.as_ref())
.map(|a| a.as_ref()),
_ => None,
}
.unwrap_or("unknown");
let mut reports = Vec::new();
for part in &message.parts {
match &part.body {
PartType::Text(report) => {
if part
.content_type()
.and_then(|ct| ct.subtype())
.map_or(false, |t| t.eq_ignore_ascii_case("xml"))
|| part
.attachment_name()
.and_then(|n| n.rsplit_once('.'))
.map_or(false, |(_, e)| e.eq_ignore_ascii_case("xml"))
{
reports.push(ReportData {
compression: Compression::None,
format: Format::Dmarc,
data: report.as_bytes(),
});
} else if part.is_content_type("message", "feedback-report") {
reports.push(ReportData {
compression: Compression::None,
format: Format::Arf,
data: report.as_bytes(),
});
}
}
PartType::Binary(report) | PartType::InlineBinary(report) => {
if part.is_content_type("message", "feedback-report") {
reports.push(ReportData {
compression: Compression::None,
format: Format::Arf,
data: report.as_ref(),
});
continue;
}
let subtype = part
.content_type()
.and_then(|ct| ct.subtype())
.unwrap_or("");
let attachment_name = part.attachment_name();
let ext = attachment_name
.and_then(|f| f.rsplit_once('.'))
.map_or("", |(_, e)| e);
let tls_parts = subtype.rsplit_once('+');
let compression = match (tls_parts.map(|(_, c)| c).unwrap_or(subtype), ext)
{
("gzip", _) => Compression::Gzip,
("zip", _) => Compression::Zip,
(_, "gz") => Compression::Gzip,
(_, "zip") => Compression::Zip,
_ => Compression::None,
};
let format = match (tls_parts.map(|(c, _)| c).unwrap_or(subtype), ext) {
("xml", _) => Format::Dmarc,
("tlsrpt", _) | (_, "json") => Format::Tls,
_ => {
if attachment_name.map_or(false, |n| n.contains(".xml")) {
Format::Dmarc
} else {
continue;
}
}
};
reports.push(ReportData {
compression,
format,
data: report.as_ref(),
});
}
_ => (),
}
}
for report in reports {
let data = match report.compression {
Compression::None => Cow::Borrowed(report.data),
Compression::Gzip => {
let mut file = GzDecoder::new(report.data);
let mut buf = Vec::new();
if let Err(err) = file.read_to_end(&mut buf) {
tracing::debug!(
context = "report",
from = from,
"Failed to decompress report: {}",
err
);
continue;
}
Cow::Owned(buf)
}
Compression::Zip => {
let mut archive = match zip::ZipArchive::new(Cursor::new(report.data)) {
Ok(archive) => archive,
Err(err) => {
tracing::debug!(
context = "report",
from = from,
"Failed to decompress report: {}",
err
);
continue;
}
};
let mut buf = Vec::with_capacity(0);
for i in 0..archive.len() {
match archive.by_index(i) {
Ok(mut file) => {
buf = Vec::with_capacity(file.compressed_size() as usize);
if let Err(err) = file.read_to_end(&mut buf) {
tracing::debug!(
context = "report",
from = from,
"Failed to decompress report: {}",
err
);
}
break;
}
Err(err) => {
tracing::debug!(
context = "report",
from = from,
"Failed to decompress report: {}",
err
);
}
}
}
Cow::Owned(buf)
}
};
match report.format {
Format::Dmarc => match Report::parse_xml(&data) {
Ok(report) => {
report.log();
}
Err(err) => {
tracing::debug!(
context = "report",
from = from,
"Failed to parse DMARC report: {}",
err
);
continue;
}
},
Format::Tls => match TlsReport::parse_json(&data) {
Ok(report) => {
report.log();
}
Err(err) => {
tracing::debug!(
context = "report",
from = from,
"Failed to parse TLS report: {:?}",
err
);
continue;
}
},
Format::Arf => match Feedback::parse_arf(&data) {
Some(report) => {
report.log();
}
None => {
tracing::debug!(
context = "report",
from = from,
"Failed to parse Auth Failure report"
);
continue;
}
},
}
// Save report
if let Some(report_path) = &core.report.config.analysis.store {
let (report_format, extension) = match report.format {
Format::Dmarc => ("dmarc", "xml"),
Format::Tls => ("tlsrpt", "json"),
Format::Arf => ("arf", "txt"),
};
let c_extension = match report.compression {
Compression::None => "",
Compression::Gzip => ".gz",
Compression::Zip => ".zip",
};
let now = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.map_or(0, |d| d.as_secs());
let id = core
.report
.config
.analysis
.report_id
.fetch_add(1, Ordering::Relaxed);
// Build path
let mut report_path = report_path.clone();
report_path.push(format!(
"{}_{}_{}.{}{}",
report_format, now, id, extension, c_extension
));
if let Err(err) = std::fs::write(&report_path, report.data) {
tracing::warn!(
context = "report",
event = "error",
from = from,
"Failed to write incoming report to {}: {}",
report_path.display(),
err
);
}
}
break;
}
});
}
}
trait LogReport {
fn log(&self);
}
impl LogReport for Report {
fn log(&self) {
let mut dmarc_pass = 0;
let mut dmarc_quarantine = 0;
let mut dmarc_reject = 0;
let mut dmarc_none = 0;
let mut dkim_pass = 0;
let mut dkim_fail = 0;
let mut dkim_none = 0;
let mut spf_pass = 0;
let mut spf_fail = 0;
let mut spf_none = 0;
for record in self.records() {
let count = std::cmp::min(record.count(), 1);
match record.action_disposition() {
ActionDisposition::Pass => {
dmarc_pass += count;
}
ActionDisposition::Quarantine => {
dmarc_quarantine += count;
}
ActionDisposition::Reject => {
dmarc_reject += count;
}
ActionDisposition::None | ActionDisposition::Unspecified => {
dmarc_none += count;
}
}
match record.dmarc_dkim_result() {
DmarcResult::Pass => {
dkim_pass += count;
}
DmarcResult::Fail => {
dkim_fail += count;
}
DmarcResult::Unspecified => {
dkim_none += count;
}
}
match record.dmarc_spf_result() {
DmarcResult::Pass => {
spf_pass += count;
}
DmarcResult::Fail => {
spf_fail += count;
}
DmarcResult::Unspecified => {
spf_none += count;
}
}
}
let range_from = DateTime::from_timestamp(self.date_range_begin() as i64).to_rfc3339();
let range_to = DateTime::from_timestamp(self.date_range_end() as i64).to_rfc3339();
if (dmarc_reject + dmarc_quarantine + dkim_fail + spf_fail) > 0 {
tracing::warn!(
context = "dmarc",
event = "analyze",
range_from = range_from,
range_to = range_to,
domain = self.domain(),
report_email = self.email(),
report_id = self.report_id(),
dmarc_pass = dmarc_pass,
dmarc_quarantine = dmarc_quarantine,
dmarc_reject = dmarc_reject,
dmarc_none = dmarc_none,
dkim_pass = dkim_pass,
dkim_fail = dkim_fail,
dkim_none = dkim_none,
spf_pass = spf_pass,
spf_fail = spf_fail,
spf_none = spf_none,
);
} else {
tracing::info!(
context = "dmarc",
event = "analyze",
range_from = range_from,
range_to = range_to,
domain = self.domain(),
report_email = self.email(),
report_id = self.report_id(),
dmarc_pass = dmarc_pass,
dmarc_quarantine = dmarc_quarantine,
dmarc_reject = dmarc_reject,
dmarc_none = dmarc_none,
dkim_pass = dkim_pass,
dkim_fail = dkim_fail,
dkim_none = dkim_none,
spf_pass = spf_pass,
spf_fail = spf_fail,
spf_none = spf_none,
);
}
}
}
impl LogReport for TlsReport {
fn log(&self) {
for policy in self.policies.iter().take(5) {
let mut details = AHashMap::with_capacity(policy.failure_details.len());
for failure in &policy.failure_details {
let num_failures = std::cmp::min(1, failure.failed_session_count);
match details.entry(failure.result_type) {
Entry::Occupied(mut e) => {
*e.get_mut() += num_failures;
}
Entry::Vacant(e) => {
e.insert(num_failures);
}
}
}
if policy.summary.total_failure > 0 {
tracing::warn!(
context = "tlsrpt",
event = "analyze",
range_from = self.date_range.start_datetime.to_rfc3339(),
range_to = self.date_range.end_datetime.to_rfc3339(),
domain = policy.policy.policy_domain,
report_contact = self.contact_info.as_deref().unwrap_or("unknown"),
report_id = self.report_id,
policy_type = ?policy.policy.policy_type,
total_success = policy.summary.total_success,
total_failures = policy.summary.total_failure,
details = ?details,
);
} else {
tracing::info!(
context = "tlsrpt",
event = "analyze",
range_from = self.date_range.start_datetime.to_rfc3339(),
range_to = self.date_range.end_datetime.to_rfc3339(),
domain = policy.policy.policy_domain,
report_contact = self.contact_info.as_deref().unwrap_or("unknown"),
report_id = self.report_id,
policy_type = ?policy.policy.policy_type,
total_success = policy.summary.total_success,
total_failures = policy.summary.total_failure,
details = ?details,
);
}
}
}
}
impl LogReport for Feedback<'_> {
fn log(&self) {
tracing::warn!(
context = "arf",
event = "analyze",
feedback_type = ?self.feedback_type(),
arrival_date = DateTime::from_timestamp(self.arrival_date().unwrap_or_else(|| {
SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.map_or(0, |d| d.as_secs()) as i64
})).to_rfc3339(),
authentication_results = ?self.authentication_results(),
incidents = self.incidents(),
reported_domain = ?self.reported_domain(),
reported_uri = ?self.reported_uri(),
reporting_mta = self.reporting_mta().unwrap_or_default(),
source_ip = ?self.source_ip(),
user_agent = self.user_agent().unwrap_or_default(),
auth_failure = ?self.auth_failure(),
delivery_result = ?self.delivery_result(),
dkim_domain = self.dkim_domain().unwrap_or_default(),
dkim_identity = self.dkim_identity().unwrap_or_default(),
dkim_selector = self.dkim_selector().unwrap_or_default(),
identity_alignment = ?self.identity_alignment(),
);
}
}

View file

@ -10,8 +10,8 @@ use mail_auth::{
};
use serde::{Deserialize, Serialize};
use tokio::{
fs,
io::{AsyncRead, AsyncWrite},
runtime::Handle,
};
use crate::{
@ -22,7 +22,8 @@ use crate::{
use super::{
scheduler::{
json_append, json_read, json_write, ReportPath, ReportPolicy, ReportType, Scheduler, ToHash,
json_append, json_read_blocking, json_write, ReportPath, ReportPolicy, ReportType,
Scheduler, ToHash,
},
DmarcEvent,
};
@ -274,7 +275,9 @@ pub trait GenerateDmarcReport {
impl GenerateDmarcReport for Arc<Core> {
fn generate_dmarc_report(&self, domain: ReportPolicy<String>, path: ReportPath<PathBuf>) {
let core = self.clone();
tokio::spawn(async move {
let handle = Handle::current();
self.worker_pool.spawn(move || {
let deliver_at = path.created + path.deliver_at.as_secs();
let span = tracing::info_span!(
"dmarc-report",
@ -285,19 +288,18 @@ impl GenerateDmarcReport for Arc<Core> {
);
// Deserialize report
let dmarc = if let Some(dmarc) = json_read::<DmarcFormat>(&path.path, &span).await {
let dmarc = if let Some(dmarc) = json_read_blocking::<DmarcFormat>(&path.path, &span) {
dmarc
} else {
return;
};
// Verify external reporting addresses
let rua = match core
.resolvers
.dns
.verify_dmarc_report_address(&domain.inner, &dmarc.rua)
.await
{
let rua = match handle.block_on(
core.resolvers
.dns
.verify_dmarc_report_address(&domain.inner, &dmarc.rua),
) {
Some(rcpts) => {
if !rcpts.is_empty() {
rcpts
@ -312,7 +314,7 @@ impl GenerateDmarcReport for Arc<Core> {
rua = ?dmarc.rua,
"Unauthorized external reporting addresses"
);
let _ = fs::remove_file(&path.path).await;
let _ = std::fs::remove_file(&path.path);
return;
}
}
@ -324,11 +326,13 @@ impl GenerateDmarcReport for Arc<Core> {
rua = ?dmarc.rua,
"Failed to validate external report addresses",
);
let _ = fs::remove_file(&path.path).await;
let _ = std::fs::remove_file(&path.path);
return;
}
};
let config = &core.report.config.dmarc_aggregate;
// Group duplicates
let mut record_map = AHashMap::with_capacity(dmarc.records.len());
for record in dmarc.records {
@ -343,32 +347,31 @@ impl GenerateDmarcReport for Arc<Core> {
}
// Create report
let config = &core.report.config.dmarc_aggregate;
let mut report = Report::new()
.with_policy_published(dmarc.policy)
.with_date_range_begin(path.created)
.with_date_range_end(deliver_at)
.with_report_id(format!("{}_{}", domain.policy, path.created))
.with_email(config.address.eval(&domain.inner.as_str()).await);
if let Some(org_name) = config.org_name.eval(&domain.inner.as_str()).await {
.with_email(handle.block_on(config.address.eval(&domain.inner.as_str())));
if let Some(org_name) = handle.block_on(config.org_name.eval(&domain.inner.as_str())) {
report = report.with_org_name(org_name);
}
if let Some(contact_info) = config.contact_info.eval(&domain.inner.as_str()).await {
if let Some(contact_info) =
handle.block_on(config.contact_info.eval(&domain.inner.as_str()))
{
report = report.with_extra_contact_info(contact_info);
}
for (record, count) in record_map {
report.add_record(record.with_count(count));
}
let from_addr = config.address.eval(&domain.inner.as_str()).await;
let from_addr = handle.block_on(config.address.eval(&domain.inner.as_str()));
let mut message = Vec::with_capacity(path.size);
let _ = report.write_rfc5322(
core.report
.config
.submitter
.eval(&domain.inner.as_str())
.await,
handle.block_on(core.report.config.submitter.eval(&domain.inner.as_str())),
(
config.name.eval(&domain.inner.as_str()).await.as_str(),
handle
.block_on(config.name.eval(&domain.inner.as_str()))
.as_str(),
from_addr.as_str(),
),
rua.iter().map(|a| a.as_str()),
@ -376,10 +379,9 @@ impl GenerateDmarcReport for Arc<Core> {
);
// Send report
core.send_report(from_addr, rua.iter(), message, &config.sign, &span)
.await;
handle.block_on(core.send_report(from_addr, rua.iter(), message, &config.sign, &span));
if let Err(err) = fs::remove_file(&path.path).await {
if let Err(err) = std::fs::remove_file(&path.path) {
tracing::warn!(
context = "report",
event = "error",

View file

@ -12,13 +12,14 @@ use mail_parser::DateTime;
use tokio::io::{AsyncRead, AsyncWrite};
use crate::{
config::{AggregateFrequency, DkimSigner, IfBlock},
config::{AddressMatch, AggregateFrequency, DkimSigner, IfBlock},
core::{Core, Session},
outbound::{dane::Tlsa, mta_sts::Policy},
queue::{DomainPart, Message},
USER_AGENT,
};
pub mod analysis;
pub mod dkim;
pub mod dmarc;
pub mod scheduler;
@ -71,6 +72,25 @@ impl<T: AsyncWrite + AsyncRead + Unpin> Session<T> {
DeliveryResult::Unspecified
})
}
pub fn is_report(&self) -> bool {
for addr_match in &self.core.report.config.analysis.addresses {
for addr in &self.data.rcpt_to {
match addr_match {
AddressMatch::StartsWith(prefix) if addr.address_lcase.starts_with(prefix) => {
return true
}
AddressMatch::EndsWith(suffix) if addr.address_lcase.ends_with(suffix) => {
return true
}
AddressMatch::Equals(value) if addr.address_lcase.eq(value) => return true,
_ => (),
}
}
}
false
}
}
impl Core {
@ -97,10 +117,12 @@ impl Core {
}
// Sign message
let message_bytes = message.sign(sign_config, report, span).await;
let signature = message.sign(sign_config, &report, span).await;
// Queue message
self.queue.queue_message(message, message_bytes, span).await;
self.queue
.queue_message(message, signature.as_deref(), &report, span)
.await;
}
pub async fn schedule_report(&self, report: impl Into<Event>) {
@ -114,9 +136,9 @@ impl Message {
pub async fn sign(
&mut self,
config: &IfBlock<Vec<Arc<DkimSigner>>>,
bytes: Vec<u8>,
bytes: &[u8],
span: &tracing::Span,
) -> Vec<Vec<u8>> {
) -> Option<Vec<u8>> {
self.size = bytes.len();
self.size_headers = bytes.len();
@ -124,7 +146,7 @@ impl Message {
if !signers.is_empty() {
let mut headers = Vec::with_capacity(64);
for signer in signers.iter() {
match signer.sign(&bytes) {
match signer.sign(bytes) {
Ok(signature) => {
signature.write_header(&mut headers);
}
@ -140,10 +162,10 @@ impl Message {
self.size += headers.len();
self.size_headers += headers.len();
return vec![headers, bytes];
return Some(headers);
}
}
vec![bytes]
None
}
}

View file

@ -454,6 +454,39 @@ pub async fn json_read<T: DeserializeOwned>(path: &PathBuf, span: &tracing::Span
}
}
pub fn json_read_blocking<T: DeserializeOwned>(path: &PathBuf, span: &tracing::Span) -> Option<T> {
match std::fs::read_to_string(path) {
Ok(mut json) => {
json.push_str("]}");
match serde_json::from_str(&json) {
Ok(report) => Some(report),
Err(err) => {
tracing::error!(
parent: span,
context = "deserialize",
event = "error",
"Failed to deserialize report file {}: {}",
path.display(),
err
);
None
}
}
}
Err(err) => {
tracing::error!(
parent: span,
context = "io",
event = "error",
"Failed to read report file {}: {}",
path.display(),
err
);
None
}
}
}
impl Default for Scheduler {
fn default() -> Self {
Self {

View file

@ -13,7 +13,7 @@ use mail_parser::DateTime;
use reqwest::header::CONTENT_TYPE;
use serde::{Deserialize, Serialize};
use std::fmt::Write;
use tokio::fs;
use tokio::runtime::Handle;
use crate::{
config::AggregateFrequency,
@ -25,7 +25,8 @@ use crate::{
use super::{
scheduler::{
json_append, json_read, json_write, ReportPath, ReportPolicy, ReportType, Scheduler, ToHash,
json_append, json_read_blocking, json_write, ReportPath, ReportPolicy, ReportType,
Scheduler, ToHash,
},
TlsEvent,
};
@ -50,7 +51,9 @@ pub trait GenerateTlsReport {
impl GenerateTlsReport for Arc<Core> {
fn generate_tls_report(&self, domain: String, path: ReportPath<Vec<ReportPolicy<PathBuf>>>) {
let core = self.clone();
tokio::spawn(async move {
let handle = Handle::current();
self.worker_pool.spawn(move || {
let deliver_at = path.created + path.deliver_at.as_secs();
let span = tracing::info_span!(
"tls-report",
@ -63,12 +66,16 @@ impl GenerateTlsReport for Arc<Core> {
// Deserialize report
let config = &core.report.config.tls;
let mut report = TlsReport {
organization_name: config.org_name.eval(&domain.as_str()).await.clone(),
organization_name: handle
.block_on(config.org_name.eval(&domain.as_str()))
.clone(),
date_range: DateRange {
start_datetime: DateTime::from_timestamp(path.created as i64),
end_datetime: DateTime::from_timestamp(deliver_at as i64),
},
contact_info: config.contact_info.eval(&domain.as_str()).await.clone(),
contact_info: handle
.block_on(config.contact_info.eval(&domain.as_str()))
.clone(),
report_id: format!(
"{}_{}",
path.created,
@ -78,7 +85,7 @@ impl GenerateTlsReport for Arc<Core> {
};
let mut rua = Vec::new();
for path in &path.path {
if let Some(tls) = json_read::<TlsFormat>(&path.inner, &span).await {
if let Some(tls) = json_read_blocking::<TlsFormat>(&path.inner, &span) {
// Group duplicates
let mut total_success = 0;
let mut total_failure = 0;
@ -149,7 +156,7 @@ impl GenerateTlsReport for Arc<Core> {
for uri in &rua {
match uri {
ReportUri::Http(uri) => {
if let Ok(client) = reqwest::Client::builder()
if let Ok(client) = reqwest::blocking::Client::builder()
.user_agent(USER_AGENT)
.timeout(Duration::from_secs(2 * 60))
.build()
@ -159,7 +166,6 @@ impl GenerateTlsReport for Arc<Core> {
.header(CONTENT_TYPE, "application/tlsrpt+gzip")
.body(json.to_vec())
.send()
.await
{
Ok(response) => {
if response.status().is_success() {
@ -193,13 +199,13 @@ impl GenerateTlsReport for Arc<Core> {
// Deliver report over SMTP
if !rcpts.is_empty() {
let from_addr = config.address.eval(&domain.as_str()).await;
let from_addr = handle.block_on(config.address.eval(&domain.as_str()));
let mut message = Vec::with_capacity(path.size);
let _ = report.write_rfc5322_from_bytes(
&domain,
core.report.config.submitter.eval(&domain.as_str()).await,
handle.block_on(core.report.config.submitter.eval(&domain.as_str())),
(
config.name.eval(&domain.as_str()).await.as_str(),
handle.block_on(config.name.eval(&domain.as_str())).as_str(),
from_addr.as_str(),
),
rcpts.iter().copied(),
@ -208,8 +214,13 @@ impl GenerateTlsReport for Arc<Core> {
);
// Send report
core.send_report(from_addr, rcpts.iter(), message, &config.sign, &span)
.await;
handle.block_on(core.send_report(
from_addr,
rcpts.iter(),
message,
&config.sign,
&span,
));
} else {
tracing::info!(
parent: &span,
@ -217,7 +228,7 @@ impl GenerateTlsReport for Arc<Core> {
"No valid recipients found to deliver report to."
);
}
path.cleanup().await;
path.cleanup_blocking();
});
}
}
@ -378,9 +389,9 @@ impl Scheduler {
}
impl ReportPath<Vec<ReportPolicy<PathBuf>>> {
async fn cleanup(&self) {
fn cleanup_blocking(&self) {
for path in &self.path {
if let Err(err) = fs::remove_file(&path.inner).await {
if let Err(err) = std::fs::remove_file(&path.inner) {
tracing::error!(
context = "report",
report = "tls",

View file

@ -0,0 +1,29 @@
use crate::{
core::{Core, Session},
tests::session::VerifyResponse,
};
#[tokio::test]
async fn basic_commands() {
let mut session = Session::test(Core::test());
// Test NOOP
session.ingest(b"NOOP\r\n").await.unwrap();
session.response().assert_code("250");
// Test RSET
session.ingest(b"RSET\r\n").await.unwrap();
session.response().assert_code("250");
// Test HELP
session.ingest(b"HELP QUIT\r\n").await.unwrap();
session.response().assert_code("250");
// Test LHLO on SMTP channel
session.ingest(b"LHLO domain.org\r\n").await.unwrap();
session.response().assert_code("502");
// Test QUIT
session.ingest(b"QUIT\r\n").await.unwrap_err();
session.response().assert_code("221");
}

80
src/tests/inbound/ehlo.rs Normal file
View file

@ -0,0 +1,80 @@
use std::time::{Duration, Instant};
use mail_auth::{common::parse::TxtRecordParser, spf::Spf, SpfResult};
use crate::{
config::ConfigContext,
core::{Core, Session},
tests::{session::VerifyResponse, ParseTestConfig},
};
#[tokio::test]
async fn ehlo() {
let mut core = Core::test();
core.resolvers.dns.txt_add(
"mx1.foobar.org",
Spf::parse(b"v=spf1 ip4:10.0.0.1 -all").unwrap(),
Instant::now() + Duration::from_secs(5),
);
core.resolvers.dns.txt_add(
"mx2.foobar.org",
Spf::parse(b"v=spf1 ip4:10.0.0.2 -all").unwrap(),
Instant::now() + Duration::from_secs(5),
);
let mut config = &mut core.session.config;
config.data.max_message_size = r"[{if = 'remote-ip', eq = '10.0.0.1', then = 1024},
{else = 2048}]"
.parse_if(&ConfigContext::default());
config.extensions.future_release = r"[{if = 'remote-ip', eq = '10.0.0.1', then = '1h'},
{else = false}]"
.parse_if(&ConfigContext::default());
config.extensions.mt_priority = r"[{if = 'remote-ip', eq = '10.0.0.1', then = 'nsep'},
{else = false}]"
.parse_if(&ConfigContext::default());
core.mail_auth.spf.verify_ehlo = r"[{if = 'remote-ip', eq = '10.0.0.2', then = 'strict'},
{else = 'relaxed'}]"
.parse_if(&ConfigContext::default());
// EHLO capabilities evaluation
let mut session = Session::test(core);
session.data.remote_ip = "10.0.0.1".parse().unwrap();
session.stream.tls = false;
session.eval_session_params().await;
session.ingest(b"EHLO mx1.foobar.org\r\n").await.unwrap();
session
.response()
.assert_code("250")
.assert_contains("SIZE 1024")
.assert_contains("MT-PRIORITY NSEP")
.assert_contains("FUTURERELEASE 3600")
.assert_contains("STARTTLS");
// SPF should be a Pass for 10.0.0.1
assert_eq!(
session.data.spf_ehlo.as_ref().unwrap().result(),
SpfResult::Pass
);
// Test SPF strict mode
session.data.helo_domain = String::new();
session.data.remote_ip = "10.0.0.2".parse().unwrap();
session.stream.tls = true;
session.eval_session_params().await;
session.ingest(b"EHLO mx1.foobar.org\r\n").await.unwrap();
session.response().assert_code("550 5.7.23");
// EHLO capabilities evaluation
session.ingest(b"EHLO mx2.foobar.org\r\n").await.unwrap();
assert_eq!(
session.data.spf_ehlo.as_ref().unwrap().result(),
SpfResult::Pass
);
session
.response()
.assert_code("250")
.assert_contains("SIZE 2048")
.assert_not_contains("MT-PRIORITY")
.assert_not_contains("FUTURERELEASE")
.assert_not_contains("STARTTLS");
}

View file

@ -0,0 +1,61 @@
use std::time::{Duration, Instant};
use tokio::sync::watch;
use crate::{
config::ConfigContext,
core::{Core, Session},
tests::{session::VerifyResponse, ParseTestConfig},
};
#[tokio::test]
async fn limits() {
let mut core = Core::test();
let mut config = &mut core.session.config;
config.transfer_limit = r"[{if = 'remote-ip', eq = '10.0.0.1', then = 10},
{else = 1024}]"
.parse_if(&ConfigContext::default());
config.timeout = r"[{if = 'remote-ip', eq = '10.0.0.2', then = '500ms'},
{else = '30m'}]"
.parse_if(&ConfigContext::default());
config.duration = r"[{if = 'remote-ip', eq = '10.0.0.3', then = '500ms'},
{else = '60m'}]"
.parse_if(&ConfigContext::default());
let (_tx, rx) = watch::channel(true);
// Exceed max line length
let mut session = Session::test(core);
session.data.remote_ip = "10.0.0.1".parse().unwrap();
let mut buf = vec![b'A'; 2049];
session.ingest(&buf).await.unwrap();
session.ingest(b"\r\n").await.unwrap();
session.response().assert_code("554 5.3.4");
// Invalid command
buf.extend_from_slice(b"\r\n");
session.ingest(&buf).await.unwrap();
session.response().assert_code("500 5.5.1");
// Exceed transfer quota
session.eval_session_params().await;
session.write_rx("MAIL FROM:<this_is_a_long@command_over_10_chars.com>\r\n");
session.handle_conn_(rx.clone()).await;
session.response().assert_code("451 4.7.28");
// Loitering
session.data.remote_ip = "10.0.0.3".parse().unwrap();
session.data.valid_until = Instant::now();
session.eval_session_params().await;
tokio::time::sleep(Duration::from_millis(600)).await;
session.write_rx("MAIL FROM:<this_is_a_long@command_over_10_chars.com>\r\n");
session.handle_conn_(rx.clone()).await;
session.response().assert_code("453 4.3.2");
// Timeout
session.data.remote_ip = "10.0.0.2".parse().unwrap();
session.data.valid_until = Instant::now();
session.eval_session_params().await;
session.write_rx("MAIL FROM:<this_is_a_long@command_over_10_chars.com>\r\n");
session.handle_conn_(rx.clone()).await;
session.response().assert_code("221 2.0.0");
}

143
src/tests/inbound/mail.rs Normal file
View file

@ -0,0 +1,143 @@
use std::{
sync::Arc,
time::{Duration, Instant},
};
use mail_auth::{common::parse::TxtRecordParser, spf::Spf, IprevResult, SpfResult};
use crate::{
config::{ConfigContext, IfBlock, VerifyStrategy},
core::{Core, Session},
tests::{session::VerifyResponse, ParseTestConfig},
};
#[tokio::test]
async fn mail() {
let mut core = Core::test();
core.resolvers.dns.txt_add(
"foobar.org",
Spf::parse(b"v=spf1 ip4:10.0.0.1 -all").unwrap(),
Instant::now() + Duration::from_secs(5),
);
core.resolvers.dns.txt_add(
"mx1.foobar.org",
Spf::parse(b"v=spf1 ip4:10.0.0.1 -all").unwrap(),
Instant::now() + Duration::from_secs(5),
);
core.resolvers.dns.ptr_add(
"10.0.0.1".parse().unwrap(),
vec!["mx1.foobar.org.".to_string()],
Instant::now() + Duration::from_secs(5),
);
core.resolvers.dns.ipv4_add(
"mx1.foobar.org.",
vec!["10.0.0.1".parse().unwrap()],
Instant::now() + Duration::from_secs(5),
);
core.resolvers.dns.ptr_add(
"10.0.0.2".parse().unwrap(),
vec!["mx2.foobar.org.".to_string()],
Instant::now() + Duration::from_secs(5),
);
let mut config = &mut core.session.config;
config.ehlo.require = IfBlock::new(true);
core.mail_auth.spf.verify_ehlo = IfBlock::new(VerifyStrategy::Relaxed);
core.mail_auth.spf.verify_mail_from = r"[{if = 'remote-ip', eq = '10.0.0.2', then = 'strict'},
{else = 'relaxed'}]"
.parse_if(&ConfigContext::default());
core.mail_auth.iprev.verify = r"[{if = 'remote-ip', eq = '10.0.0.2', then = 'strict'},
{else = 'relaxed'}]"
.parse_if(&ConfigContext::default());
config.throttle.mail_from = r"[[throttle]]
match = {if = 'remote-ip', eq = '10.0.0.1'}
key = 'sender'
rate = '2/1s'
"
.parse_throttle(&ConfigContext::default());
// Be rude and do not say EHLO
let core = Arc::new(core);
let mut session = Session::test(core.clone());
session.data.remote_ip = "10.0.0.1".parse().unwrap();
session.eval_session_params().await;
session
.ingest(b"MAIL FROM:<bill@foobar.org>\r\n")
.await
.unwrap();
session.response().assert_code("503 5.5.1");
// Both IPREV and SPF should pass
session.ingest(b"EHLO mx1.foobar.org\r\n").await.unwrap();
session.response().assert_code("250");
session
.ingest(b"MAIL FROM:<bill@foobar.org>\r\n")
.await
.unwrap();
session.response().assert_code("250");
assert_eq!(
session.data.spf_ehlo.as_ref().unwrap().result(),
SpfResult::Pass
);
assert_eq!(
session.data.spf_mail_from.as_ref().unwrap().result(),
SpfResult::Pass
);
assert_eq!(
session.data.iprev.as_ref().unwrap().result(),
&IprevResult::Pass
);
// Multiple MAIL FROMs should not be allowed
session
.ingest(b"MAIL FROM:<bill@foobar.org>\r\n")
.await
.unwrap();
session.response().assert_code("503 5.5.1");
// Test rate limit
for n in 0..2 {
session.ingest(b"RSET\r\n").await.unwrap();
session.response().assert_code("250");
session
.ingest(b"MAIL FROM:<bill@foobar.org>\r\n")
.await
.unwrap();
session
.response()
.assert_code(if n == 0 { "250" } else { "451 4.4.5" });
}
// Test strict IPREV
session.data.remote_ip = "10.0.0.2".parse().unwrap();
session.data.iprev = None;
session.eval_session_params().await;
session
.ingest(b"MAIL FROM:<jane@foobar.org>\r\n")
.await
.unwrap();
session.response().assert_code("550 5.7.25");
session.data.iprev = None;
core.resolvers.dns.ipv4_add(
"mx2.foobar.org.",
vec!["10.0.0.2".parse().unwrap()],
Instant::now() + Duration::from_secs(5),
);
// Test strict SPF
session
.ingest(b"MAIL FROM:<jane@foobar.org>\r\n")
.await
.unwrap();
session.response().assert_code("550 5.7.23");
core.resolvers.dns.txt_add(
"foobar.org",
Spf::parse(b"v=spf1 ip4:10.0.0.1 ip4:10.0.0.2 -all").unwrap(),
Instant::now() + Duration::from_secs(5),
);
session
.ingest(b"MAIL FROM:<jane@foobar.org>\r\n")
.await
.unwrap();
session.response().assert_code("250");
}

5
src/tests/inbound/mod.rs Normal file
View file

@ -0,0 +1,5 @@
pub mod basic;
pub mod ehlo;
pub mod limits;
pub mod mail;
pub mod throttle;

View file

@ -0,0 +1,90 @@
use std::time::Duration;
use crate::{
config::ConfigContext,
core::{Core, Session, SessionAddress},
tests::ParseTestConfig,
};
#[tokio::test]
async fn throttle() {
let mut core = Core::test();
let mut config = &mut core.session.config;
config.throttle.connect = r"[[throttle]]
match = {if = 'remote-ip', eq = '10.0.0.1'}
key = 'remote-ip'
concurrency = 2
rate = '3/1s'
"
.parse_throttle(&ConfigContext::default());
config.throttle.mail_from = r"[[throttle]]
key = 'sender'
rate = '2/1s'
"
.parse_throttle(&ConfigContext::default());
config.throttle.rcpt_to = r"[[throttle]]
key = ['remote-ip', 'rcpt']
rate = '2/1s'
"
.parse_throttle(&ConfigContext::default());
// Test connection concurrency limit
let mut session = Session::test(core);
session.data.remote_ip = "10.0.0.1".parse().unwrap();
assert!(
session.is_allowed().await,
"Concurrency limiter too strict."
);
assert!(
session.is_allowed().await,
"Concurrency limiter too strict."
);
assert!(!session.is_allowed().await, "Concurrency limiter failed.");
// Test connection rate limit
session.in_flight.clear(); // Manually reset concurrency limiter
assert!(session.is_allowed().await, "Rate limiter too strict.");
assert!(!session.is_allowed().await, "Rate limiter failed.");
session.in_flight.clear();
tokio::time::sleep(Duration::from_millis(1100)).await;
assert!(
session.is_allowed().await,
"Rate limiter did not restore quota."
);
// Test mail from rate limit
session.data.mail_from = SessionAddress {
address: "sender@test.org".to_string(),
address_lcase: "sender@test.org".to_string(),
domain: "test.org".to_string(),
flags: 0,
dsn_info: None,
}
.into();
assert!(session.is_allowed().await, "Rate limiter too strict.");
assert!(session.is_allowed().await, "Rate limiter too strict.");
assert!(!session.is_allowed().await, "Rate limiter failed.");
session.data.mail_from = SessionAddress {
address: "other-sender@test.org".to_string(),
address_lcase: "other-sender@test.org".to_string(),
domain: "test.org".to_string(),
flags: 0,
dsn_info: None,
}
.into();
assert!(session.is_allowed().await, "Rate limiter failed.");
// Test recipient rate limit
session.data.rcpt_to.push(SessionAddress {
address: "recipient@example.org".to_string(),
address_lcase: "recipient@example.org".to_string(),
domain: "example.org".to_string(),
flags: 0,
dsn_info: None,
});
assert!(session.is_allowed().await, "Rate limiter too strict.");
assert!(session.is_allowed().await, "Rate limiter too strict.");
assert!(!session.is_allowed().await, "Rate limiter failed.");
session.data.remote_ip = "10.0.0.2".parse().unwrap();
assert!(session.is_allowed().await, "Rate limiter too strict.");
}

346
src/tests/mod.rs Normal file
View file

@ -0,0 +1,346 @@
use std::time::Duration;
use dashmap::DashMap;
use mail_auth::{
common::lru::{DnsCache, LruCache},
trust_dns_resolver::config::{ResolverConfig, ResolverOpts},
Resolver,
};
use mail_send::smtp::tls::build_tls_connector;
use smtp_proto::{AUTH_LOGIN, AUTH_PLAIN};
use tokio::sync::mpsc;
use crate::{
config::{
utils::ParseValues, AggregateReport, ArcAuthConfig, Auth, Config, ConfigContext, Connect,
Data, DkimAuthConfig, DmarcAuthConfig, Dsn, Ehlo, EnvelopeKey, Extensions, IfBlock,
IpRevAuthConfig, Mail, MailAuthConfig, QueueConfig, QueueOutboundSourceIp,
QueueOutboundTimeout, QueueOutboundTls, QueueQuotas, QueueThrottle, Rcpt, Report,
ReportAnalysis, ReportConfig, SessionConfig, SessionThrottle, SpfAuthConfig, Throttle,
VerifyStrategy,
},
core::{
throttle::{ConcurrencyLimiter, ThrottleKeyHasherBuilder},
Core, QueueCore, ReportCore, Resolvers, SessionCore, TlsConnectors,
},
outbound::dane::DnssecResolver,
};
pub mod inbound;
pub mod session;
pub trait ParseTestConfig {
fn parse_if<T: Default + ParseValues>(&self, ctx: &ConfigContext) -> IfBlock<T>;
fn parse_throttle(&self, ctx: &ConfigContext) -> Vec<Throttle>;
}
impl ParseTestConfig for &str {
fn parse_if<T: Default + ParseValues>(&self, ctx: &ConfigContext) -> IfBlock<T> {
Config::parse(&format!("test = {}\n", self))
.unwrap()
.parse_if_block(
"test",
ctx,
&[
EnvelopeKey::Recipient,
EnvelopeKey::RecipientDomain,
EnvelopeKey::Sender,
EnvelopeKey::SenderDomain,
EnvelopeKey::Mx,
EnvelopeKey::HeloDomain,
EnvelopeKey::AuthenticatedAs,
EnvelopeKey::Listener,
EnvelopeKey::RemoteIp,
EnvelopeKey::LocalIp,
EnvelopeKey::Priority,
],
)
.unwrap()
.unwrap()
}
fn parse_throttle(&self, ctx: &ConfigContext) -> Vec<Throttle> {
Config::parse(self)
.unwrap()
.parse_throttle(
"throttle",
ctx,
&[
EnvelopeKey::Recipient,
EnvelopeKey::RecipientDomain,
EnvelopeKey::Sender,
EnvelopeKey::SenderDomain,
EnvelopeKey::Mx,
EnvelopeKey::HeloDomain,
EnvelopeKey::AuthenticatedAs,
EnvelopeKey::Listener,
EnvelopeKey::RemoteIp,
EnvelopeKey::LocalIp,
EnvelopeKey::Priority,
],
u16::MAX,
)
.unwrap()
}
}
impl Core {
pub fn test() -> Self {
Core {
worker_pool: rayon::ThreadPoolBuilder::new()
.num_threads(num_cpus::get())
.build()
.unwrap(),
session: SessionCore::test(),
queue: QueueCore::test(),
resolvers: Resolvers {
dns: Resolver::new_system_conf().unwrap(),
dnssec: DnssecResolver::with_capacity(
ResolverConfig::cloudflare(),
ResolverOpts::default(),
)
.unwrap(),
cache: crate::core::DnsCache {
tlsa: LruCache::with_capacity(100),
mta_sts: LruCache::with_capacity(100),
},
},
mail_auth: MailAuthConfig::test(),
report: ReportCore::test(),
}
}
}
impl SessionCore {
pub fn test() -> Self {
SessionCore {
config: SessionConfig::test(),
concurrency: ConcurrencyLimiter::new(100),
throttle: DashMap::with_capacity_and_hasher_and_shard_amount(
10,
ThrottleKeyHasherBuilder::default(),
16,
),
}
}
}
impl SessionConfig {
pub fn test() -> Self {
Self {
timeout: IfBlock::new(Duration::from_secs(10)),
duration: IfBlock::new(Duration::from_secs(10)),
transfer_limit: IfBlock::new(1024 * 1024),
throttle: SessionThrottle {
connect: vec![],
mail_from: vec![],
rcpt_to: vec![],
},
connect: Connect {
script: IfBlock::new(None),
},
ehlo: Ehlo {
script: IfBlock::new(None),
require: IfBlock::new(true),
},
extensions: Extensions {
pipelining: IfBlock::new(true),
chunking: IfBlock::new(true),
requiretls: IfBlock::new(true),
no_soliciting: IfBlock::new("domain.org".to_string().into()),
future_release: IfBlock::new(None),
deliver_by: IfBlock::new(None),
mt_priority: IfBlock::new(None),
},
auth: Auth {
script: IfBlock::new(None),
lookup: IfBlock::new(None),
mechanisms: IfBlock::new(AUTH_PLAIN | AUTH_LOGIN),
errors_max: IfBlock::new(10),
errors_wait: IfBlock::new(Duration::from_secs(1)),
},
mail: Mail {
script: IfBlock::new(None),
},
rcpt: Rcpt {
script: IfBlock::new(None),
relay: IfBlock::new(false),
lookup_domains: IfBlock::new(None),
lookup_addresses: IfBlock::new(None),
lookup_expn: IfBlock::new(None),
lookup_vrfy: IfBlock::new(None),
errors_max: IfBlock::new(3),
errors_wait: IfBlock::new(Duration::from_secs(1)),
max_recipients: IfBlock::new(3),
},
data: Data {
script: IfBlock::new(None),
max_messages: IfBlock::new(10),
max_message_size: IfBlock::new(1024 * 1024),
max_received_headers: IfBlock::new(10),
add_received: IfBlock::new(true),
add_received_spf: IfBlock::new(true),
add_return_path: IfBlock::new(true),
add_auth_results: IfBlock::new(true),
add_message_id: IfBlock::new(true),
add_date: IfBlock::new(true),
},
}
}
}
impl QueueCore {
pub fn test() -> Self {
Self {
config: QueueConfig::test(),
throttle: DashMap::with_capacity_and_hasher_and_shard_amount(
10,
ThrottleKeyHasherBuilder::default(),
16,
),
quota: DashMap::with_capacity_and_hasher_and_shard_amount(
10,
ThrottleKeyHasherBuilder::default(),
16,
),
tx: mpsc::channel(1024).0,
id_seq: 0.into(),
connectors: TlsConnectors {
pki_verify: build_tls_connector(false),
dummy_verify: build_tls_connector(true),
},
}
}
}
impl QueueConfig {
pub fn test() -> Self {
Self {
path: Default::default(),
hash: IfBlock::new(10),
retry: IfBlock::new(vec![Duration::from_secs(10)]),
notify: IfBlock::new(vec![Duration::from_secs(20)]),
expire: IfBlock::new(Duration::from_secs(10)),
hostname: IfBlock::new("mx.example.org".to_string()),
next_hop: Default::default(),
max_mx: IfBlock::new(5),
max_multihomed: IfBlock::new(5),
source_ip: QueueOutboundSourceIp {
ipv4: IfBlock::new(vec![]),
ipv6: IfBlock::new(vec![]),
},
tls: QueueOutboundTls {
dane: IfBlock::new(crate::config::RequireOptional::Optional),
mta_sts: IfBlock::new(crate::config::RequireOptional::Optional),
start: IfBlock::new(crate::config::RequireOptional::Optional),
},
dsn: Dsn {
name: IfBlock::new("Mail Delivery Subsystem".to_string()),
address: IfBlock::new("MAILER-DAEMON@example.org".to_string()),
sign: IfBlock::default(),
},
timeout: QueueOutboundTimeout {
connect: IfBlock::new(Duration::from_secs(1)),
greeting: IfBlock::new(Duration::from_secs(1)),
tls: IfBlock::new(Duration::from_secs(1)),
ehlo: IfBlock::new(Duration::from_secs(1)),
mail: IfBlock::new(Duration::from_secs(1)),
rcpt: IfBlock::new(Duration::from_secs(1)),
data: IfBlock::new(Duration::from_secs(1)),
mta_sts: IfBlock::new(Duration::from_secs(1)),
},
throttle: QueueThrottle {
sender: vec![],
rcpt: vec![],
host: vec![],
},
quota: QueueQuotas {
sender: vec![],
rcpt: vec![],
rcpt_domain: vec![],
},
}
}
}
impl MailAuthConfig {
pub fn test() -> Self {
Self {
dkim: DkimAuthConfig {
verify: IfBlock::new(VerifyStrategy::Relaxed),
sign: IfBlock::default(),
},
arc: ArcAuthConfig {
verify: IfBlock::new(VerifyStrategy::Relaxed),
seal: IfBlock::default(),
},
spf: SpfAuthConfig {
verify_ehlo: IfBlock::new(VerifyStrategy::Relaxed),
verify_mail_from: IfBlock::new(VerifyStrategy::Relaxed),
},
dmarc: DmarcAuthConfig {
verify: IfBlock::new(VerifyStrategy::Relaxed),
},
iprev: IpRevAuthConfig {
verify: IfBlock::new(VerifyStrategy::Relaxed),
},
}
}
}
impl ReportCore {
pub fn test() -> Self {
Self {
config: ReportConfig::test(),
tx: mpsc::channel(1024).0,
}
}
}
impl ReportConfig {
pub fn test() -> Self {
Self {
path: Default::default(),
hash: IfBlock::new(10),
submitter: IfBlock::new("example.org".to_string()),
analysis: ReportAnalysis {
addresses: vec![],
forward: true,
store: None,
report_id: 0.into(),
},
dkim: Report::test(),
spf: Report::test(),
dmarc: Report::test(),
dmarc_aggregate: AggregateReport::test(),
tls: AggregateReport::test(),
}
}
}
impl Report {
pub fn test() -> Self {
Self {
name: IfBlock::default(),
address: IfBlock::default(),
subject: IfBlock::default(),
sign: IfBlock::default(),
send: IfBlock::default(),
}
}
}
impl AggregateReport {
pub fn test() -> Self {
Self {
name: IfBlock::default(),
address: IfBlock::default(),
subject: IfBlock::default(),
org_name: IfBlock::default(),
contact_info: IfBlock::default(),
send: IfBlock::default(),
sign: IfBlock::default(),
max_size: IfBlock::default(),
}
}
}

153
src/tests/session.rs Normal file
View file

@ -0,0 +1,153 @@
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite};
use crate::{
config::ServerProtocol,
core::{Core, ServerInstance, Session, SessionData, SessionParameters, State},
inbound::IsTls,
};
pub struct DummyIo {
pub tx_buf: Vec<u8>,
pub rx_buf: Vec<u8>,
pub tls: bool,
}
impl AsyncRead for DummyIo {
fn poll_read(
mut self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
if !self.rx_buf.is_empty() {
buf.put_slice(&self.rx_buf);
self.rx_buf.clear();
std::task::Poll::Ready(Ok(()))
} else {
std::task::Poll::Pending
}
}
}
impl AsyncWrite for DummyIo {
fn poll_write(
mut self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<Result<usize, std::io::Error>> {
self.tx_buf.extend_from_slice(buf);
std::task::Poll::Ready(Ok(buf.len()))
}
fn poll_flush(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
std::task::Poll::Ready(Ok(()))
}
fn poll_shutdown(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
std::task::Poll::Ready(Ok(()))
}
}
impl IsTls for DummyIo {
fn is_tls(&self) -> bool {
self.tls
}
fn write_tls_header(&self, _headers: &mut Vec<u8>) {}
}
impl Unpin for DummyIo {}
impl Session<DummyIo> {
pub fn test(core: impl Into<Arc<Core>>) -> Self {
Self {
state: State::default(),
instance: Arc::new(ServerInstance::test()),
core: core.into(),
span: tracing::info_span!("test"),
stream: DummyIo {
rx_buf: vec![],
tx_buf: vec![],
tls: false,
},
data: SessionData::new("127.0.0.1".parse().unwrap(), "127.0.0.1".parse().unwrap()),
params: SessionParameters::default(),
in_flight: vec![],
}
}
pub fn response(&mut self) -> Vec<String> {
if !self.stream.tx_buf.is_empty() {
let response = std::str::from_utf8(&self.stream.tx_buf)
.unwrap()
.split("\r\n")
.filter_map(|r| {
if !r.is_empty() {
r.to_string().into()
} else {
None
}
})
.collect::<Vec<_>>();
self.stream.tx_buf.clear();
response
} else {
panic!("There was no response.");
}
}
pub fn write_rx(&mut self, data: &str) {
self.stream.rx_buf.extend_from_slice(data.as_bytes());
}
}
pub trait VerifyResponse {
fn assert_code(self, expected_code: &str) -> Self;
fn assert_contains(self, expected_text: &str) -> Self;
fn assert_not_contains(self, expected_text: &str) -> Self;
}
impl VerifyResponse for Vec<String> {
fn assert_code(self, expected_code: &str) -> Self {
if self.last().expect("response").starts_with(expected_code) {
self
} else {
panic!("Expected {:?} but got {:?}.", expected_code, self);
}
}
fn assert_contains(self, expected_text: &str) -> Self {
if self.iter().any(|line| line.contains(expected_text)) {
self
} else {
panic!("Expected {:?} but got {:?}.", expected_text, self);
}
}
fn assert_not_contains(self, expected_text: &str) -> Self {
if !self.iter().any(|line| line.contains(expected_text)) {
self
} else {
panic!("Not expecting {:?} but got it {:?}.", expected_text, self);
}
}
}
impl ServerInstance {
pub fn test() -> Self {
Self {
id: "smtp".to_string(),
listener_id: 1,
protocol: ServerProtocol::Smtp,
hostname: "mx.example.org".to_string(),
greeting: b"220 mx.example.org at your service.\r\n".to_vec(),
}
}
}