1use anyhow::{Context, Result};
6use clap::Subcommand;
7use std::path::Path;
8use std::process::Command;
9
10#[derive(Subcommand)]
11pub enum DataCommands {
12 Profile {
14 path: String,
16 },
17 Split {
19 path: String,
21 #[arg(long, default_value = "0.8")]
23 train: f64,
24 #[arg(long, default_value = "0.1")]
26 val: f64,
27 #[arg(long, default_value = "42")]
29 seed: u64,
30 },
31 Cache {
33 source: String,
35 #[arg(long)]
37 to: String,
38 },
39 Shard {
41 path: String,
43 #[arg(long, default_value = "8")]
45 shards: u32,
46 },
47 Benchmark {
49 #[arg(default_value = ".")]
51 path: String,
52 #[arg(long, default_value = "4")]
54 workers: u32,
55 },
56 Serve {
58 path: String,
60 #[arg(long, default_value = "8888")]
62 port: u16,
63 },
64}
65
66pub async fn run(cmd: DataCommands) -> Result<()> {
67 match cmd {
68 DataCommands::Profile { path } => {
69 let p = Path::new(&path);
70 if !p.exists() {
71 anyhow::bail!("path not found: {path}");
72 }
73
74 println!("Dataset Profile: {path}");
75 println!("{}", "=".repeat(60));
76
77 if p.is_file() {
78 let size = p.metadata()?.len();
79 let ext = p.extension().and_then(|e| e.to_str()).unwrap_or("");
80 println!(" Type: file ({ext})");
81 println!(" Size: {}", format_size(size));
82
83 match ext {
84 "parquet" => {
85 let out = Command::new("python3")
86 .args([
87 "-c",
88 &format!(
89 "import pyarrow.parquet as pq; \
90 f=pq.read_metadata('{path}'); \
91 print(f' Rows: {{f.num_rows}}'); \
92 print(f' Columns: {{f.num_columns}}'); \
93 print(f' Row Groups: {{f.num_row_groups}}'); \
94 s=pq.read_schema('{path}'); \
95 for i in range(min(10,len(s))): print(f' {{s.field(i)}}')"
96 ),
97 ])
98 .output();
99 if let Ok(o) = out {
100 print!("{}", String::from_utf8_lossy(&o.stdout));
101 }
102 }
103 "csv" | "tsv" => {
104 let lines = std::io::BufRead::lines(std::io::BufReader::new(
105 std::fs::File::open(p)?,
106 ))
107 .count();
108 println!(" Rows: {lines}");
109 }
110 "json" | "jsonl" => {
111 let lines = std::io::BufRead::lines(std::io::BufReader::new(
112 std::fs::File::open(p)?,
113 ))
114 .count();
115 println!(" Lines: {lines}");
116 }
117 _ => {
118 println!(" (use .parquet/.csv/.json for detailed stats)");
119 }
120 }
121 } else {
122 let mut total_size = 0u64;
124 let mut file_count = 0u64;
125 let mut ext_counts: std::collections::HashMap<String, u64> =
126 std::collections::HashMap::new();
127
128 fn walk(
129 dir: &Path,
130 size: &mut u64,
131 count: &mut u64,
132 exts: &mut std::collections::HashMap<String, u64>,
133 ) {
134 if let Ok(entries) = std::fs::read_dir(dir) {
135 for entry in entries.flatten() {
136 let Some(ft) = entry.file_type().ok() else {
137 continue;
138 };
139 if ft.is_file() {
140 *size += entry.metadata().map(|m| m.len()).unwrap_or(0);
141 *count += 1;
142 let ext = entry
143 .path()
144 .extension()
145 .map(|e| e.to_string_lossy().to_string())
146 .unwrap_or_else(|| "other".into());
147 *exts.entry(ext).or_default() += 1;
148 } else if ft.is_dir() {
149 walk(&entry.path(), size, count, exts);
150 }
151 }
152 }
153 }
154
155 walk(p, &mut total_size, &mut file_count, &mut ext_counts);
156
157 println!(" Type: directory");
158 println!(" Files: {file_count}");
159 println!(" Size: {}", format_size(total_size));
160 println!(" Extensions:");
161 let mut sorted: Vec<_> = ext_counts.iter().collect();
162 sorted.sort_by(|a, b| b.1.cmp(a.1));
163 for (ext, count) in sorted.iter().take(10) {
164 println!(" .{ext}: {count} files");
165 }
166 }
167 }
168
169 DataCommands::Split {
170 path,
171 train,
172 val,
173 seed,
174 } => {
175 let test = 1.0 - train - val;
176 println!("Splitting {path}: train={train} val={val} test={test} seed={seed}");
177
178 let code = format!(
179 "import os, random, shutil; random.seed({seed}); \
180 files=[f for f in os.listdir('{path}') if os.path.isfile(os.path.join('{path}',f))]; \
181 random.shuffle(files); n=len(files); \
182 nt=int(n*{train}); nv=int(n*{val}); \
183 for split,fs in [('train',files[:nt]),('val',files[nt:nt+nv]),('test',files[nt+nv:])]: \
184 d=os.path.join('{path}',split); os.makedirs(d,exist_ok=True); \
185 for f in fs: shutil.move(os.path.join('{path}',f),os.path.join(d,f)); \
186 print(f' {{split}}: {{len(fs)}} files')"
187 );
188 let output = Command::new("python3").args(["-c", &code]).output()?;
189 print!("{}", String::from_utf8_lossy(&output.stdout));
190 if !output.status.success() {
191 print!("{}", String::from_utf8_lossy(&output.stderr));
192 }
193 }
194
195 DataCommands::Cache { source, to } => {
196 println!("Caching {source} → {to}");
197 let status = Command::new("rsync")
198 .args(["-avh", "--progress", &source, &to])
199 .status()
200 .with_context(|| "rsync not found")?;
201 if status.success() {
202 println!("Cache complete.");
203 }
204 }
205
206 DataCommands::Shard { path, shards } => {
207 println!("Sharding {path} into {shards} shards...");
208 let code = format!(
209 "import os, shutil; \
210 files=sorted([f for f in os.listdir('{path}') if os.path.isfile(os.path.join('{path}',f))]); \
211 for i in range({shards}): \
212 d=os.path.join('{path}',f'shard-{{i:04d}}'); os.makedirs(d,exist_ok=True); \
213 for i,f in enumerate(files): \
214 shard=i%{shards}; \
215 shutil.move(os.path.join('{path}',f),os.path.join('{path}',f'shard-{{shard:04d}}',f)); \
216 for i in range({shards}): \
217 d=os.path.join('{path}',f'shard-{{i:04d}}'); \
218 n=len(os.listdir(d)); print(f' shard-{{i:04d}}: {{n}} files')"
219 );
220 let output = Command::new("python3").args(["-c", &code]).output()?;
221 print!("{}", String::from_utf8_lossy(&output.stdout));
222 }
223
224 DataCommands::Benchmark { path, workers } => {
225 println!("DataLoader Benchmark (path: {path}, workers: {workers})");
226 let code = format!(
227 "import torch,time; from torch.utils.data import DataLoader,TensorDataset; \
228 ds=TensorDataset(torch.randn(10000,3,224,224),torch.randint(0,1000,(10000,))); \
229 for w in [0,1,2,4,{workers}]: \
230 dl=DataLoader(ds,batch_size=64,num_workers=w,pin_memory=True); \
231 t0=time.time(); [None for _ in dl]; t1=time.time(); \
232 print(f' workers={{w}}: {{10000/(t1-t0):.0f}} samples/s')"
233 );
234 let output = Command::new("python3").args(["-c", &code]).output()?;
235 print!("{}", String::from_utf8_lossy(&output.stdout));
236 if !output.status.success() {
237 print!("{}", String::from_utf8_lossy(&output.stderr));
238 }
239 }
240
241 DataCommands::Serve { path, port } => {
242 println!("Serving dataset at {path} on port {port}...");
243 println!("URL: http://0.0.0.0:{port}");
244 let status = Command::new("python3")
245 .args(["-m", "http.server", &port.to_string(), "--directory", &path])
246 .status()?;
247 let _ = status;
248 }
249 }
250 Ok(())
251}
252
253fn format_size(bytes: u64) -> String {
254 if bytes < 1024 {
255 format!("{bytes} B")
256 } else if bytes < 1024 * 1024 {
257 format!("{:.1} KB", bytes as f64 / 1024.0)
258 } else if bytes < 1024 * 1024 * 1024 {
259 format!("{:.1} MB", bytes as f64 / (1024.0 * 1024.0))
260 } else {
261 format!("{:.2} GB", bytes as f64 / (1024.0 * 1024.0 * 1024.0))
262 }
263}