diff --git a/devices/src/virtio/snd/cras_backend/async_funcs.rs b/devices/src/virtio/snd/cras_backend/async_funcs.rs index db9c26baac..a5fd6b7573 100644 --- a/devices/src/virtio/snd/cras_backend/async_funcs.rs +++ b/devices/src/virtio/snd/cras_backend/async_funcs.rs @@ -11,6 +11,7 @@ use cros_async::{sync::Condvar, sync::Mutex as AsyncMutex, EventAsync, Executor} use data_model::{DataInit, Le32}; use vm_memory::GuestMemory; +use crate::virtio::cras_backend::Parameters; use crate::virtio::snd::common::*; use crate::virtio::snd::constants::*; use crate::virtio::snd::layout::*; @@ -27,6 +28,7 @@ async fn process_pcm_ctrl( rx_queue: &Rc>, interrupt: &Rc, streams: &Rc>>>>, + params: &Parameters, cmd_code: u32, writer: &mut Writer, stream_id: usize, @@ -49,7 +51,7 @@ async fn process_pcm_ctrl( let result = match cmd_code { VIRTIO_SND_R_PCM_PREPARE => { stream - .prepare(ex, mem.clone(), tx_queue, rx_queue, interrupt) + .prepare(ex, mem.clone(), tx_queue, rx_queue, interrupt, params) .await } VIRTIO_SND_R_PCM_START => stream.start().await, @@ -329,6 +331,7 @@ pub async fn handle_ctrl_queue( interrupt: &Rc, tx_queue: &Rc>, rx_queue: &Rc>, + params: &Parameters, ) -> Result<(), Error> { loop { let desc_chain = queue @@ -562,6 +565,7 @@ pub async fn handle_ctrl_queue( rx_queue, interrupt, streams, + params, code, &mut writer, stream_id, diff --git a/devices/src/virtio/snd/cras_backend/mod.rs b/devices/src/virtio/snd/cras_backend/mod.rs index 10e824031c..9d772d454b 100644 --- a/devices/src/virtio/snd/cras_backend/mod.rs +++ b/devices/src/virtio/snd/cras_backend/mod.rs @@ -6,6 +6,7 @@ use std::io; use std::rc::Rc; +use std::str::{FromStr, ParseBoolError}; use std::thread; use audio_streams::{SampleFormat, StreamSource}; @@ -15,7 +16,7 @@ use cros_async::{select4, AsyncError, EventAsync, Executor, SelectResult}; use data_model::DataInit; use futures::channel::mpsc; use futures::{pin_mut, Future, TryFutureExt}; -use libcras::{BoxError, CrasClient, CrasClientType}; +use libcras::{BoxError, CrasClient, CrasClientType, CrasSocketType}; use thiserror::Error as ThisError; use vm_memory::GuestMemory; @@ -83,6 +84,67 @@ pub enum Error { /// Writing to a buffer in the guest failed. #[error("failed to write to buffer: {0}")] WriteBuffer(io::Error), + /// Failed to parse parameters. + #[error("Invalid cras snd parameter: {0}")] + UnknownParameter(String), + /// Unknown cras snd parameter value. + #[error("Invalid cras snd parameter value ({0}): {1}")] + InvalidParameterValue(String, String), + /// Failed to parse bool value. + #[error("Invalid bool value: {0}")] + InvalidBoolValue(ParseBoolError), +} + +/// Holds the parameters for a cras sound device +#[derive(Debug, Clone)] +pub struct Parameters { + pub capture: bool, + pub client_type: CrasClientType, + pub socket_type: CrasSocketType, +} + +impl Default for Parameters { + fn default() -> Self { + Parameters { + capture: true, + client_type: CrasClientType::CRAS_CLIENT_TYPE_CROSVM, + socket_type: CrasSocketType::Unified, + } + } +} + +impl FromStr for Parameters { + type Err = Error; + fn from_str(s: &str) -> std::result::Result { + let mut params: Parameters = Default::default(); + let opts = s + .split(',') + .map(|frag| frag.split('=')) + .map(|mut kv| (kv.next().unwrap_or(""), kv.next().unwrap_or(""))); + + for (k, v) in opts { + match k { + "capture" => { + params.capture = v.parse::().map_err(Error::InvalidBoolValue)?; + } + "client_type" => { + params.client_type = v.parse().map_err(|e: libcras::CrasSysError| { + Error::InvalidParameterValue(v.to_string(), e.to_string()) + })?; + } + "socket_type" => { + params.socket_type = v.parse().map_err(|e: libcras::Error| { + Error::InvalidParameterValue(v.to_string(), e.to_string()) + })?; + } + _ => { + return Err(Error::UnknownParameter(k.to_string())); + } + } + } + + Ok(params) + } } pub enum DirectionalStream { @@ -158,6 +220,7 @@ impl<'a> StreamInfo<'a> { tx_queue: &Rc>, rx_queue: &Rc>, interrupt: &Rc, + params: &Parameters, ) -> Result<(), Error> { if self.state != VIRTIO_SND_R_PCM_SET_PARAMS && self.state != VIRTIO_SND_R_PCM_PREPARE @@ -176,13 +239,11 @@ impl<'a> StreamInfo<'a> { return Err(Error::OperationNotSupported); } if self.client.is_none() { - // TODO(woodychow): once we're running in vm_concierge, we need an --enable-capture - // option for - // false: CrasClient::new() - // true: CrasClient::with_type(CrasSocketType::Unified) - // to use different socket. - let mut client = CrasClient::new().map_err(Error::Libcras).unwrap(); - client.set_client_type(CrasClientType::CRAS_CLIENT_TYPE_CROSVM); + let mut client = CrasClient::with_type(params.socket_type).map_err(Error::Libcras)?; + if params.capture { + client.enable_cras_capture(); + } + client.set_client_type(params.client_type); self.client = Some(client); } // (*) @@ -211,7 +272,6 @@ impl<'a> StreamInfo<'a> { tx_queue.clone(), ), VIRTIO_SND_D_INPUT => { - self.client.as_mut().unwrap().enable_cras_capture(); ( DirectionalStream::Input( self.client @@ -322,10 +382,11 @@ pub struct VirtioSndCras { queue_sizes: Box<[u16]>, worker_threads: Vec>, kill_evt: Option, + params: Parameters, } impl VirtioSndCras { - pub fn new(base_features: u64) -> Result { + pub fn new(base_features: u64, params: Parameters) -> Result { let cfg = virtio_snd_config { jacks: 0.into(), streams: 2.into(), @@ -341,6 +402,7 @@ impl VirtioSndCras { queue_sizes: vec![QUEUE_SIZE; NUM_QUEUES].into_boxed_slice(), worker_threads: Vec::new(), kill_evt: None, + params, }) } } @@ -477,6 +539,8 @@ impl VirtioDevice for VirtioSndCras { }); // } + let params = self.params.clone(); + let worker_result = thread::Builder::new() .name("virtio_snd w".to_string()) .spawn(move || { @@ -492,7 +556,7 @@ impl VirtioDevice for VirtioSndCras { }; if let Err(err_string) = run_worker( - interrupt, queues, guest_mem, streams, snd_data, queue_evts, kill_evt, + interrupt, queues, guest_mem, streams, snd_data, queue_evts, kill_evt, params, ) { error!("{}", err_string); } @@ -531,6 +595,7 @@ fn run_worker( snd_data: SndData, queue_evts: Vec, kill_evt: Event, + params: Parameters, ) -> Result<(), String> { let ex = Executor::new().expect("Failed to create an executor"); @@ -561,6 +626,7 @@ fn run_worker( &interrupt, &tx_queue, &rx_queue, + ¶ms, ); pin_mut!(f_ctrl); @@ -604,3 +670,57 @@ fn run_worker( Ok(()) } +#[cfg(test)] +mod tests { + use super::*; + #[test] + fn parameters_fromstr() { + fn check_success( + s: &str, + capture: bool, + client_type: CrasClientType, + socket_type: CrasSocketType, + ) { + let params = s.parse::().expect("parse should have succeded"); + assert_eq!(params.capture, capture); + assert_eq!(params.client_type, client_type); + assert_eq!(params.socket_type, socket_type); + } + fn check_failure(s: &str) { + s.parse::() + .expect_err("parse should have failed"); + } + + check_success( + "capture=false", + false, + CrasClientType::CRAS_CLIENT_TYPE_CROSVM, + CrasSocketType::Unified, + ); + check_success( + "capture=true,client_type=crosvm", + true, + CrasClientType::CRAS_CLIENT_TYPE_CROSVM, + CrasSocketType::Unified, + ); + check_success( + "capture=true,client_type=arcvm", + true, + CrasClientType::CRAS_CLIENT_TYPE_ARCVM, + CrasSocketType::Unified, + ); + check_failure("capture=true,client_type=none"); + check_success( + "socket_type=legacy", + true, + CrasClientType::CRAS_CLIENT_TYPE_CROSVM, + CrasSocketType::Legacy, + ); + check_success( + "socket_type=unified", + true, + CrasClientType::CRAS_CLIENT_TYPE_CROSVM, + CrasSocketType::Unified, + ); + } +} diff --git a/src/crosvm.rs b/src/crosvm.rs index 08fb81d514..10639dac54 100644 --- a/src/crosvm.rs +++ b/src/crosvm.rs @@ -22,6 +22,8 @@ use std::str::FromStr; use arch::{Pstore, VcpuAffinity}; use devices::serial_device::{SerialHardware, SerialParameters}; +#[cfg(feature = "audio_cras")] +use devices::virtio::cras_backend::Parameters as CrasSndParameters; use devices::virtio::fs::passthrough; #[cfg(feature = "gpu")] use devices::virtio::gpu::GpuParameters; @@ -210,7 +212,7 @@ pub struct Config { pub cpu_clusters: Vec>, pub cpu_capacity: BTreeMap, // CPU index -> capacity #[cfg(feature = "audio_cras")] - pub cras_snd: bool, + pub cras_snd: Option, pub delay_rt: bool, pub no_smt: bool, pub memory: Option, @@ -297,7 +299,7 @@ impl Default for Config { cpu_clusters: Vec::new(), cpu_capacity: BTreeMap::new(), #[cfg(feature = "audio_cras")] - cras_snd: false, + cras_snd: None, delay_rt: false, no_smt: false, memory: None, diff --git a/src/linux.rs b/src/linux.rs index cd10673c53..edc6d2e5f5 100644 --- a/src/linux.rs +++ b/src/linux.rs @@ -33,6 +33,8 @@ use devices::serial_device::{SerialHardware, SerialParameters}; use devices::vfio::{VfioCommonSetup, VfioCommonTrait}; #[cfg(feature = "gpu")] use devices::virtio::gpu::{DEFAULT_DISPLAY_HEIGHT, DEFAULT_DISPLAY_WIDTH}; +#[cfg(feature = "audio_cras")] +use devices::virtio::snd::cras_backend::Parameters as CrasSndParameters; use devices::virtio::vhost::user::vmm::{ Block as VhostUserBlock, Console as VhostUserConsole, Fs as VhostUserFs, Mac80211Hwsim as VhostUserMac80211Hwsim, Net as VhostUserNet, Wl as VhostUserWl, @@ -332,10 +334,12 @@ fn create_rng_device(cfg: &Config) -> DeviceResult { } #[cfg(feature = "audio_cras")] -fn create_cras_snd_device(cfg: &Config) -> DeviceResult { - let dev = - virtio::snd::cras_backend::VirtioSndCras::new(virtio::base_features(cfg.protected_vm)) - .map_err(Error::CrasSoundDeviceNew)?; +fn create_cras_snd_device(cfg: &Config, cras_snd: CrasSndParameters) -> DeviceResult { + let dev = virtio::snd::cras_backend::VirtioSndCras::new( + virtio::base_features(cfg.protected_vm), + cras_snd, + ) + .map_err(Error::CrasSoundDeviceNew)?; let jail = match simple_jail(&cfg, "cras_snd_device")? { Some(mut jail) => { @@ -1268,8 +1272,8 @@ fn create_virtio_devices( #[cfg(feature = "audio_cras")] { - if cfg.cras_snd { - devs.push(create_cras_snd_device(cfg)?); + if let Some(cras_snd) = &cfg.cras_snd { + devs.push(create_cras_snd_device(cfg, cras_snd.clone())?); } } diff --git a/src/main.rs b/src/main.rs index 5c71afc6a6..05fac8f0f7 100644 --- a/src/main.rs +++ b/src/main.rs @@ -26,6 +26,8 @@ use crosvm::{ VhostUserFsOption, VhostUserOption, VhostUserWlOption, DISK_ID_LEN, }; use devices::serial_device::{SerialHardware, SerialParameters, SerialType}; +#[cfg(feature = "audio_cras")] +use devices::virtio::snd::cras_backend::Error as CrasSndError; use devices::virtio::vhost::user::device::{ run_block_device, run_console_device, run_net_device, run_wl_device, }; @@ -978,7 +980,12 @@ fn set_argument(cfg: &mut Config, name: &str, value: Option<&str>) -> argument:: } #[cfg(feature = "audio_cras")] "cras-snd" => { - cfg.cras_snd = true; + cfg.cras_snd = Some( + value + .unwrap() + .parse() + .map_err(|e: CrasSndError| argument::Error::Syntax(e.to_string()))?, + ); } "no-smt" => { cfg.no_smt = true; @@ -2060,7 +2067,13 @@ fn run_vm(args: std::env::Args) -> std::result::Result<(), ()> { Argument::value("cpu-cluster", "CPUSET", "Group the given CPUs into a cluster (default: no clusters)"), Argument::value("cpu-capacity", "CPU=CAP[,CPU=CAP[,...]]", "Set the relative capacity of the given CPU (default: no capacity)"), #[cfg(feature = "audio_cras")] - Argument::flag("cras-snd", "Enable virtio-snd device with CRAS backend"), + Argument::value("cras-snd", + "[capture=true,client=crosvm,socket=unified]", + "Comma separated key=value pairs for setting up cras snd devices. + Possible key values: + capture - Enable audio capture. + client_type - Set specific client type for cras backend. + socket_type - Set specific socket type for cras backend (legacy/unified)"), Argument::flag("no-smt", "Don't use SMT in the guest"), Argument::value("rt-cpus", "CPUSET", "Comma-separated list of CPUs or CPU ranges to run VCPUs on. (e.g. 0,1-3,5) (default: none)"), Argument::flag("delay-rt", "Don't set VCPUs real-time until make-rt command is run"),