1use crate::{Error, Result};
3use governor::{
4 clock::DefaultClock, state::direct::NotKeyed, state::InMemoryState, Quota, RateLimiter,
5};
6use std::num::NonZeroU32;
7use std::path::PathBuf;
8use std::sync::atomic::{AtomicBool, Ordering};
9use std::sync::Arc;
10use tokio::fs;
11use tokio::sync::Semaphore;
12
13const EPIC_CDN_BASE: &str = "https://epicgames-download1.akamaized.net/Builds";
14const MAX_CONCURRENT_DOWNLOADS: usize = 4;
15
16pub struct CdnDownloader {
17 client: reqwest::Client,
18 cache_dir: PathBuf,
19 semaphore: Arc<Semaphore>,
20 rate_limiter: Option<Arc<RateLimiter<NotKeyed, InMemoryState, DefaultClock>>>,
21}
22
23#[derive(Debug, Clone)]
24pub struct DownloadStats {
25 pub bytes_downloaded: u64,
26 pub download_speed: f64,
27 pub chunks_completed: usize,
28 pub chunks_total: usize,
29}
30
31impl CdnDownloader {
32 pub fn new(cache_dir: PathBuf) -> Result<Self> {
33 Self::with_bandwidth_limit(cache_dir, None)
34 }
35
36 pub fn with_bandwidth_limit(
37 cache_dir: PathBuf,
38 bandwidth_limit_mbps: Option<u32>,
39 ) -> Result<Self> {
40 let client = reqwest::Client::builder()
41 .user_agent("epik/0.1.0")
42 .timeout(std::time::Duration::from_secs(30))
43 .build()
44 .map_err(|e| Error::Other(format!("HTTP client error: {}", e)))?;
45
46 let rate_limiter = bandwidth_limit_mbps.and_then(|limit_mbps| {
48 if limit_mbps > 0 {
49 let bytes_per_second = (limit_mbps as u64) * 1024 * 1024;
51 let bytes_per_100ms = bytes_per_second / 10;
53
54 NonZeroU32::new(bytes_per_100ms as u32).map(|quota_value| {
55 let quota = Quota::per_second(quota_value);
56 Arc::new(RateLimiter::direct(quota))
57 })
58 } else {
59 None
60 }
61 });
62
63 if let Some(ref limiter) = rate_limiter {
64 let _ = limiter;
65 log::info!(
66 "Bandwidth limiting enabled: {} MB/s",
67 bandwidth_limit_mbps.unwrap()
68 );
69 }
70
71 Ok(Self {
72 client,
73 cache_dir,
74 semaphore: Arc::new(Semaphore::new(MAX_CONCURRENT_DOWNLOADS)),
75 rate_limiter,
76 })
77 }
78
79 pub async fn download_chunks_parallel(
80 &self,
81 chunks: Vec<(String, String)>,
82 progress_callback: impl Fn(DownloadStats) + Send + Sync + 'static,
83 cancel_flag: Option<Arc<AtomicBool>>,
84 ) -> Result<Vec<Vec<u8>>> {
85 let total_chunks = chunks.len();
86 let mut handles = Vec::new();
87 let progress_callback = Arc::new(progress_callback);
88
89 let start_time = std::time::Instant::now();
90 let bytes_downloaded = Arc::new(tokio::sync::Mutex::new(0u64));
91 let chunks_completed = Arc::new(tokio::sync::Mutex::new(0usize));
92
93 for (guid, hash) in chunks {
94 if let Some(flag) = cancel_flag.as_ref() {
95 if flag.load(Ordering::Relaxed) {
96 return Err(Error::Other("Download cancelled".to_string()));
97 }
98 }
99
100 let permit = self.semaphore.clone().acquire_owned().await.unwrap();
101 let downloader = self.clone();
102 let progress_cb = progress_callback.clone();
103 let bytes_dl = bytes_downloaded.clone();
104 let chunks_done = chunks_completed.clone();
105 let cancel_clone = cancel_flag.clone();
106
107 let handle = tokio::spawn(async move {
108 if let Some(flag) = cancel_clone.as_ref() {
109 if flag.load(Ordering::Relaxed) {
110 drop(permit);
111 return Err(Error::Other("Download cancelled".to_string()));
112 }
113 }
114
115 let result = downloader
116 .download_chunk(&guid, &hash, cancel_clone.clone())
117 .await;
118
119 if let Ok(ref data) = result {
120 let mut bytes = bytes_dl.lock().await;
121 *bytes += data.len() as u64;
122
123 let mut completed = chunks_done.lock().await;
124 *completed += 1;
125
126 let elapsed = start_time.elapsed().as_secs_f64();
127 let speed = if elapsed > 0.0 {
128 *bytes as f64 / elapsed
129 } else {
130 0.0
131 };
132
133 progress_cb(DownloadStats {
134 bytes_downloaded: *bytes,
135 download_speed: speed,
136 chunks_completed: *completed,
137 chunks_total: total_chunks,
138 });
139 }
140
141 drop(permit);
142 result
143 });
144
145 handles.push(handle);
146 }
147
148 let mut results = Vec::new();
149 let mut handle_iter = handles.into_iter();
150
151 while let Some(handle) = handle_iter.next() {
152 let data = match handle.await {
153 Ok(result) => match result {
154 Ok(bytes) => bytes,
155 Err(e) => {
156 for h in handle_iter {
157 h.abort();
158 }
159 return Err(e);
160 }
161 },
162 Err(e) => {
163 for h in handle_iter {
164 h.abort();
165 }
166 return Err(Error::Other(format!("Join error: {}", e)));
167 }
168 };
169
170 results.push(data);
171 }
172
173 Ok(results)
174 }
175
176 async fn download_chunk(
177 &self,
178 chunk_guid: &str,
179 chunk_hash: &str,
180 cancel_flag: Option<Arc<AtomicBool>>,
181 ) -> Result<Vec<u8>> {
182 let cache_path = self.cache_dir.join(format!("{}.chunk", chunk_guid));
183 if cache_path.exists() {
184 return Ok(fs::read(&cache_path).await?);
185 }
186
187 let url = format!(
188 "{}/{}/{}/{}.chunk",
189 EPIC_CDN_BASE,
190 &chunk_hash[0..2],
191 chunk_hash,
192 chunk_guid
193 );
194
195 let data = self.download_with_retry(&url, 3, cancel_flag).await?;
196 self.verify_chunk_hash(&data, chunk_hash)?;
197
198 if let Some(parent) = cache_path.parent() {
199 fs::create_dir_all(parent).await?;
200 }
201 fs::write(&cache_path, &data).await?;
202
203 Ok(data)
204 }
205
206 async fn download_with_retry(
207 &self,
208 url: &str,
209 max_retries: usize,
210 cancel_flag: Option<Arc<AtomicBool>>,
211 ) -> Result<Vec<u8>> {
212 'retry: for attempt in 0..max_retries {
213 if let Some(flag) = cancel_flag.as_ref() {
214 if flag.load(Ordering::Relaxed) {
215 return Err(Error::Other("Download cancelled".to_string()));
216 }
217 }
218
219 if attempt > 0 {
220 tokio::time::sleep(std::time::Duration::from_secs(2u64.pow(attempt as u32))).await;
221 }
222
223 match self.client.get(url).send().await {
224 Ok(response) => {
225 let mut response = match response.error_for_status() {
226 Ok(resp) => resp,
227 Err(_) => continue,
228 };
229
230 let mut data = Vec::new();
231 loop {
232 let chunk = match response.chunk().await {
233 Ok(Some(bytes)) => bytes,
234 Ok(None) => break,
235 Err(_) => continue 'retry,
236 };
237
238 if let Some(flag) = cancel_flag.as_ref() {
239 if flag.load(Ordering::Relaxed) {
240 return Err(Error::Other("Download cancelled".to_string()));
241 }
242 }
243
244 if let Some(ref limiter) = self.rate_limiter {
245 if let Some(permits) = NonZeroU32::new(chunk.len() as u32) {
246 let _ = limiter.until_n_ready(permits).await;
247 }
248 }
249
250 data.extend_from_slice(&chunk);
251 }
252
253 return Ok(data);
254 }
255 Err(_) => continue,
256 }
257 }
258
259 Err(Error::Other("Download failed".to_string()))
260 }
261
262 fn verify_chunk_hash(&self, data: &[u8], expected_hash: &str) -> Result<()> {
263 use sha2::{Digest, Sha256};
264
265 let hash_hex = hex::encode(Sha256::digest(data));
266 if hash_hex == expected_hash {
267 Ok(())
268 } else {
269 Err(Error::Other("Hash mismatch".to_string()))
270 }
271 }
272}
273
274impl Clone for CdnDownloader {
275 fn clone(&self) -> Self {
276 Self {
277 client: self.client.clone(),
278 cache_dir: self.cache_dir.clone(),
279 semaphore: self.semaphore.clone(),
280 rate_limiter: self.rate_limiter.clone(),
281 }
282 }
283}