diff --git a/article_scraper/src/constants.rs b/article_scraper/src/constants.rs index 31bc213..79823f8 100644 --- a/article_scraper/src/constants.rs +++ b/article_scraper/src/constants.rs @@ -3,6 +3,7 @@ use std::collections::HashSet; use once_cell::sync::Lazy; use regex::{Regex, RegexBuilder}; +pub const UNKNOWN_CONTENT_SIZE_LIMIT: usize = 5 * 1024 * 1024; pub const DEFAULT_CHAR_THRESHOLD: usize = 500; pub static IS_IMAGE: Lazy = Lazy::new(|| { RegexBuilder::new(r#"\.(jpg|jpeg|png|webp)"#) diff --git a/article_scraper/src/images/mod.rs b/article_scraper/src/images/mod.rs index 56554c6..5a2d133 100644 --- a/article_scraper/src/images/mod.rs +++ b/article_scraper/src/images/mod.rs @@ -2,6 +2,7 @@ pub use self::error::ImageDownloadError; use self::image_data::ImageDataBase64; use self::pair::Pair; use self::request::ImageRequest; +use crate::constants; use crate::util::Util; use base64::Engine; use futures::StreamExt; @@ -36,13 +37,19 @@ impl ImageDownloader { ) -> Result, ImageDownloadError> { let response = client.get(url).send().await?; - let content_type = Util::get_content_type(&response)?; - let content_length = Util::get_content_length(&response).unwrap_or(0); + let content_type = Util::get_content_type(&response); + let content_length = Util::get_content_length(&response); - if !content_type.contains("image") { + if let (Err(_), Ok(content_length)) = (&content_type, &content_length) { + if *content_length > constants::UNKNOWN_CONTENT_SIZE_LIMIT { + return Err(ImageDownloadError::ContentType); + } + } else if !content_type?.contains("image") { return Err(ImageDownloadError::ContentType); } + let content_length = content_length.unwrap_or(0); + let mut stream = response.bytes_stream(); let mut downloaded_bytes = 0;