1use anyhow::Result;
13use std::process::Command;
14use tracing::{debug, info, warn};
15
16const DEFAULT_INTERFACES: &[&str] = &["eth0", "ib0", "enp", "ens"];
18
19const NCCL_PORT_START: u16 = 29500;
21const NCCL_PORT_END: u16 = 30000;
22
23const NCCL_DSCP: u8 = 46; pub fn detect_nccl_interface() -> Option<String> {
28 if let Ok(iface) = std::env::var("NCCL_SOCKET_IFNAME") {
30 return Some(iface);
31 }
32
33 if std::path::Path::new("/sys/class/infiniband").exists() {
35 if let Ok(entries) = std::fs::read_dir("/sys/class/infiniband") {
36 if let Some(entry) = entries.flatten().next() {
37 let name = entry.file_name().to_string_lossy().to_string();
38 return Some(name);
39 }
40 }
41 }
42
43 let output = Command::new("ip")
45 .args(["-o", "link", "show", "up"])
46 .output()
47 .ok()?;
48
49 let stdout = String::from_utf8_lossy(&output.stdout);
50 for line in stdout.lines() {
51 let parts: Vec<&str> = line.split_whitespace().collect();
52 if parts.len() >= 2 {
53 let iface = parts[1].trim_end_matches(':');
54 if iface != "lo" {
55 return Some(iface.to_string());
56 }
57 }
58 }
59
60 None
61}
62
63pub fn enable_nccl_priority(interface: &str) -> Result<()> {
66 info!(interface, "enabling NCCL traffic priority");
67
68 let _ = Command::new("tc")
70 .args(["qdisc", "del", "dev", interface, "root"])
71 .output(); let status = Command::new("tc")
74 .args([
75 "qdisc", "add", "dev", interface, "root", "handle", "1:", "prio", "bands", "3",
76 "priomap", "1", "1", "1", "1", "1", "1", "1", "1", "1", "1", "1", "1", "1", "1", "1",
77 "1",
78 ])
79 .output();
80
81 match status {
82 Ok(o) if o.status.success() => {
83 debug!("prio qdisc added");
84 }
85 Ok(o) => {
86 let err = String::from_utf8_lossy(&o.stderr);
87 warn!(error = %err, "failed to add prio qdisc (requires root)");
88 return Ok(());
89 }
90 Err(e) => {
91 warn!(error = %e, "tc not found");
92 return Ok(());
93 }
94 }
95
96 for port in [NCCL_PORT_START, 29501, 29502, 29503, 29504] {
98 let _ = Command::new("tc")
99 .args([
100 "filter",
101 "add",
102 "dev",
103 interface,
104 "parent",
105 "1:0",
106 "protocol",
107 "ip",
108 "prio",
109 "1",
110 "u32",
111 "match",
112 "ip",
113 "dport",
114 &port.to_string(),
115 "0xffff",
116 "flowid",
117 "1:1",
118 ])
119 .output();
120 }
121
122 let _ = Command::new("tc")
124 .args([
125 "filter",
126 "add",
127 "dev",
128 interface,
129 "parent",
130 "1:0",
131 "protocol",
132 "ip",
133 "prio",
134 "1",
135 "u32",
136 "match",
137 "ip",
138 "tos",
139 &format!("0x{:02x}", NCCL_DSCP << 2),
140 "0xfc",
141 "flowid",
142 "1:1",
143 ])
144 .output();
145
146 info!(
147 interface,
148 ports = format!("{NCCL_PORT_START}-{NCCL_PORT_END}"),
149 "NCCL traffic priority enabled"
150 );
151
152 Ok(())
153}
154
155pub fn disable_nccl_priority(interface: &str) -> Result<()> {
157 let _ = Command::new("tc")
158 .args(["qdisc", "del", "dev", interface, "root"])
159 .output();
160 info!(interface, "NCCL traffic priority disabled");
161 Ok(())
162}
163
164pub fn mark_nccl_dscp() -> Result<()> {
166 for port in NCCL_PORT_START..NCCL_PORT_END {
168 let _ = Command::new("iptables")
169 .args([
170 "-t",
171 "mangle",
172 "-A",
173 "OUTPUT",
174 "-p",
175 "tcp",
176 "--dport",
177 &port.to_string(),
178 "-j",
179 "DSCP",
180 "--set-dscp",
181 &NCCL_DSCP.to_string(),
182 ])
183 .output();
184 }
185 info!("NCCL packets marked with DSCP EF ({})", NCCL_DSCP);
186 Ok(())
187}