vault_audit_tools/utils/
reader.rs1use anyhow::{Context, Result};
29use flate2::read::GzDecoder;
30use std::fs::File;
31use std::io::Read;
32use std::path::Path;
33
34pub fn open_file(path: impl AsRef<Path>) -> Result<Box<dyn Read + Send>> {
60 let path = path.as_ref();
61 let file =
62 File::open(path).with_context(|| format!("Failed to open file: {}", path.display()))?;
63
64 let extension = path.extension().and_then(|e| e.to_str()).unwrap_or("");
65
66 match extension {
67 "gz" => {
68 let decoder = GzDecoder::new(file);
69 Ok(Box::new(decoder))
70 }
71 "zst" => {
72 let decoder = zstd::Decoder::new(file).with_context(|| {
73 format!("Failed to create zstd decoder for: {}", path.display())
74 })?;
75 Ok(Box::new(decoder))
76 }
77 _ => Ok(Box::new(file)),
78 }
79}
80
81#[cfg(test)]
82mod tests {
83 use super::*;
84 use std::io::{BufRead, BufReader, Write};
85 use tempfile::TempDir;
86
87 #[test]
88 fn test_plain_file() {
89 let dir = TempDir::new().unwrap();
90 let path = dir.path().join("test.txt");
91 {
92 let mut file = std::fs::File::create(&path).unwrap();
93 writeln!(file, "test line 1").unwrap();
94 writeln!(file, "test line 2").unwrap();
95 }
96
97 let reader = open_file(&path).unwrap();
98 let buf_reader = BufReader::new(reader);
99 let lines: Vec<String> = buf_reader.lines().collect::<Result<_, _>>().unwrap();
100
101 assert_eq!(lines.len(), 2);
102 assert_eq!(lines[0], "test line 1");
103 assert_eq!(lines[1], "test line 2");
104 }
105
106 #[test]
107 fn test_gzip_file() {
108 use flate2::write::GzEncoder;
109 use flate2::Compression;
110
111 let dir = TempDir::new().unwrap();
112 let path = dir.path().join("test.gz");
113 {
114 let file = std::fs::File::create(&path).unwrap();
115 let mut encoder = GzEncoder::new(file, Compression::default());
116 writeln!(encoder, "compressed line 1").unwrap();
117 writeln!(encoder, "compressed line 2").unwrap();
118 encoder.finish().unwrap();
119 }
120
121 let reader = open_file(&path).unwrap();
122 let buf_reader = BufReader::new(reader);
123 let lines: Vec<String> = buf_reader.lines().collect::<Result<_, _>>().unwrap();
124
125 assert_eq!(lines.len(), 2);
126 assert_eq!(lines[0], "compressed line 1");
127 assert_eq!(lines[1], "compressed line 2");
128 }
129
130 #[test]
131 fn test_zstd_file() {
132 let dir = TempDir::new().unwrap();
133 let path = dir.path().join("test.zst");
134 {
135 let file = std::fs::File::create(&path).unwrap();
136 let mut encoder = zstd::Encoder::new(file, 3).unwrap();
137 writeln!(encoder, "zstd line 1").unwrap();
138 writeln!(encoder, "zstd line 2").unwrap();
139 encoder.finish().unwrap();
140 }
141
142 let reader = open_file(&path).unwrap();
143 let buf_reader = BufReader::new(reader);
144 let lines: Vec<String> = buf_reader.lines().collect::<Result<_, _>>().unwrap();
145
146 assert_eq!(lines.len(), 2);
147 assert_eq!(lines[0], "zstd line 1");
148 assert_eq!(lines[1], "zstd line 2");
149 }
150}