zernel/commands/
data.rs

1// Copyright (C) 2026 Dyber, Inc. — Proprietary
2
3//! zernel data — Dataset management
4
5use anyhow::{Context, Result};
6use clap::Subcommand;
7use std::path::Path;
8use std::process::Command;
9
10#[derive(Subcommand)]
11pub enum DataCommands {
12    /// Profile a dataset (stats, types, size)
13    Profile {
14        /// Path to dataset file or directory
15        path: String,
16    },
17    /// Split dataset into train/val/test
18    Split {
19        /// Path to dataset directory
20        path: String,
21        /// Training fraction
22        #[arg(long, default_value = "0.8")]
23        train: f64,
24        /// Validation fraction
25        #[arg(long, default_value = "0.1")]
26        val: f64,
27        /// Random seed
28        #[arg(long, default_value = "42")]
29        seed: u64,
30    },
31    /// Cache dataset to fast storage
32    Cache {
33        /// Source path
34        source: String,
35        /// Destination (fast storage)
36        #[arg(long)]
37        to: String,
38    },
39    /// Shard dataset for distributed training
40    Shard {
41        /// Path to dataset
42        path: String,
43        /// Number of shards
44        #[arg(long, default_value = "8")]
45        shards: u32,
46    },
47    /// Benchmark DataLoader throughput
48    Benchmark {
49        /// Path to dataset
50        #[arg(default_value = ".")]
51        path: String,
52        /// Number of workers
53        #[arg(long, default_value = "4")]
54        workers: u32,
55    },
56    /// Serve dataset over network for multi-node training
57    Serve {
58        /// Path to dataset
59        path: String,
60        /// Port to serve on
61        #[arg(long, default_value = "8888")]
62        port: u16,
63    },
64}
65
66pub async fn run(cmd: DataCommands) -> Result<()> {
67    match cmd {
68        DataCommands::Profile { path } => {
69            let p = Path::new(&path);
70            if !p.exists() {
71                anyhow::bail!("path not found: {path}");
72            }
73
74            println!("Dataset Profile: {path}");
75            println!("{}", "=".repeat(60));
76
77            if p.is_file() {
78                let size = p.metadata()?.len();
79                let ext = p.extension().and_then(|e| e.to_str()).unwrap_or("");
80                println!("  Type:     file ({ext})");
81                println!("  Size:     {}", format_size(size));
82
83                match ext {
84                    "parquet" => {
85                        let out = Command::new("python3")
86                            .args([
87                                "-c",
88                                &format!(
89                                    "import pyarrow.parquet as pq; \
90                                 f=pq.read_metadata('{path}'); \
91                                 print(f'  Rows:     {{f.num_rows}}'); \
92                                 print(f'  Columns:  {{f.num_columns}}'); \
93                                 print(f'  Row Groups: {{f.num_row_groups}}'); \
94                                 s=pq.read_schema('{path}'); \
95                                 for i in range(min(10,len(s))): print(f'    {{s.field(i)}}')"
96                                ),
97                            ])
98                            .output();
99                        if let Ok(o) = out {
100                            print!("{}", String::from_utf8_lossy(&o.stdout));
101                        }
102                    }
103                    "csv" | "tsv" => {
104                        let lines = std::io::BufRead::lines(std::io::BufReader::new(
105                            std::fs::File::open(p)?,
106                        ))
107                        .count();
108                        println!("  Rows:     {lines}");
109                    }
110                    "json" | "jsonl" => {
111                        let lines = std::io::BufRead::lines(std::io::BufReader::new(
112                            std::fs::File::open(p)?,
113                        ))
114                        .count();
115                        println!("  Lines:    {lines}");
116                    }
117                    _ => {
118                        println!("  (use .parquet/.csv/.json for detailed stats)");
119                    }
120                }
121            } else {
122                // Directory — count files and total size
123                let mut total_size = 0u64;
124                let mut file_count = 0u64;
125                let mut ext_counts: std::collections::HashMap<String, u64> =
126                    std::collections::HashMap::new();
127
128                fn walk(
129                    dir: &Path,
130                    size: &mut u64,
131                    count: &mut u64,
132                    exts: &mut std::collections::HashMap<String, u64>,
133                ) {
134                    if let Ok(entries) = std::fs::read_dir(dir) {
135                        for entry in entries.flatten() {
136                            let Some(ft) = entry.file_type().ok() else {
137                                continue;
138                            };
139                            if ft.is_file() {
140                                *size += entry.metadata().map(|m| m.len()).unwrap_or(0);
141                                *count += 1;
142                                let ext = entry
143                                    .path()
144                                    .extension()
145                                    .map(|e| e.to_string_lossy().to_string())
146                                    .unwrap_or_else(|| "other".into());
147                                *exts.entry(ext).or_default() += 1;
148                            } else if ft.is_dir() {
149                                walk(&entry.path(), size, count, exts);
150                            }
151                        }
152                    }
153                }
154
155                walk(p, &mut total_size, &mut file_count, &mut ext_counts);
156
157                println!("  Type:     directory");
158                println!("  Files:    {file_count}");
159                println!("  Size:     {}", format_size(total_size));
160                println!("  Extensions:");
161                let mut sorted: Vec<_> = ext_counts.iter().collect();
162                sorted.sort_by(|a, b| b.1.cmp(a.1));
163                for (ext, count) in sorted.iter().take(10) {
164                    println!("    .{ext}: {count} files");
165                }
166            }
167        }
168
169        DataCommands::Split {
170            path,
171            train,
172            val,
173            seed,
174        } => {
175            let test = 1.0 - train - val;
176            println!("Splitting {path}: train={train} val={val} test={test} seed={seed}");
177
178            let code = format!(
179                "import os, random, shutil; random.seed({seed}); \
180                 files=[f for f in os.listdir('{path}') if os.path.isfile(os.path.join('{path}',f))]; \
181                 random.shuffle(files); n=len(files); \
182                 nt=int(n*{train}); nv=int(n*{val}); \
183                 for split,fs in [('train',files[:nt]),('val',files[nt:nt+nv]),('test',files[nt+nv:])]: \
184                     d=os.path.join('{path}',split); os.makedirs(d,exist_ok=True); \
185                     for f in fs: shutil.move(os.path.join('{path}',f),os.path.join(d,f)); \
186                     print(f'  {{split}}: {{len(fs)}} files')"
187            );
188            let output = Command::new("python3").args(["-c", &code]).output()?;
189            print!("{}", String::from_utf8_lossy(&output.stdout));
190            if !output.status.success() {
191                print!("{}", String::from_utf8_lossy(&output.stderr));
192            }
193        }
194
195        DataCommands::Cache { source, to } => {
196            println!("Caching {source} → {to}");
197            let status = Command::new("rsync")
198                .args(["-avh", "--progress", &source, &to])
199                .status()
200                .with_context(|| "rsync not found")?;
201            if status.success() {
202                println!("Cache complete.");
203            }
204        }
205
206        DataCommands::Shard { path, shards } => {
207            println!("Sharding {path} into {shards} shards...");
208            let code = format!(
209                "import os, shutil; \
210                 files=sorted([f for f in os.listdir('{path}') if os.path.isfile(os.path.join('{path}',f))]); \
211                 for i in range({shards}): \
212                     d=os.path.join('{path}',f'shard-{{i:04d}}'); os.makedirs(d,exist_ok=True); \
213                 for i,f in enumerate(files): \
214                     shard=i%{shards}; \
215                     shutil.move(os.path.join('{path}',f),os.path.join('{path}',f'shard-{{shard:04d}}',f)); \
216                 for i in range({shards}): \
217                     d=os.path.join('{path}',f'shard-{{i:04d}}'); \
218                     n=len(os.listdir(d)); print(f'  shard-{{i:04d}}: {{n}} files')"
219            );
220            let output = Command::new("python3").args(["-c", &code]).output()?;
221            print!("{}", String::from_utf8_lossy(&output.stdout));
222        }
223
224        DataCommands::Benchmark { path, workers } => {
225            println!("DataLoader Benchmark (path: {path}, workers: {workers})");
226            let code = format!(
227                "import torch,time; from torch.utils.data import DataLoader,TensorDataset; \
228                 ds=TensorDataset(torch.randn(10000,3,224,224),torch.randint(0,1000,(10000,))); \
229                 for w in [0,1,2,4,{workers}]: \
230                     dl=DataLoader(ds,batch_size=64,num_workers=w,pin_memory=True); \
231                     t0=time.time(); [None for _ in dl]; t1=time.time(); \
232                     print(f'  workers={{w}}: {{10000/(t1-t0):.0f}} samples/s')"
233            );
234            let output = Command::new("python3").args(["-c", &code]).output()?;
235            print!("{}", String::from_utf8_lossy(&output.stdout));
236            if !output.status.success() {
237                print!("{}", String::from_utf8_lossy(&output.stderr));
238            }
239        }
240
241        DataCommands::Serve { path, port } => {
242            println!("Serving dataset at {path} on port {port}...");
243            println!("URL: http://0.0.0.0:{port}");
244            let status = Command::new("python3")
245                .args(["-m", "http.server", &port.to_string(), "--directory", &path])
246                .status()?;
247            let _ = status;
248        }
249    }
250    Ok(())
251}
252
253fn format_size(bytes: u64) -> String {
254    if bytes < 1024 {
255        format!("{bytes} B")
256    } else if bytes < 1024 * 1024 {
257        format!("{:.1} KB", bytes as f64 / 1024.0)
258    } else if bytes < 1024 * 1024 * 1024 {
259        format!("{:.1} MB", bytes as f64 / (1024.0 * 1024.0))
260    } else {
261        format!("{:.2} GB", bytes as f64 / (1024.0 * 1024.0 * 1024.0))
262    }
263}