-
Notifications
You must be signed in to change notification settings - Fork 0
/
scraper.py
95 lines (81 loc) · 2.72 KB
/
scraper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import time
import urllib
import os
import math
import threading
import shutil
import random
import requests
import selenium
from selenium import webdriver
from PIL import Image
def scraper(searchterm, foldername):
driver = webdriver.Chrome("../chromedriver_win32/chromedriver.exe")
driver.get("https://www.google.co.in/search?q=" + searchterm.replace(' ', '%20') + "&source=lnms&tbm=isch")
for _ in range(500):
driver.execute_script("window.scrollBy(0,10000)")
if _%10 == 0:
try:
if driver.find_element_by_xpath("//input[@value='Show more results']"):
driver.find_element_by_xpath("//input[@value='Show more results']").click()
time.sleep(2)
except:
pass
try:
if driver.find_element_by_xpath("//span[contains(text(), 'See more')]"):
driver.find_element_by_xpath("//span[contains(text(), 'See more')]").click()
except:
pass
time.sleep(3)
driver.execute_script("console.log(urls=Array.from(document.querySelectorAll('.rg_i')))")
urls = driver.execute_script("return[urls.map(el=> el.hasAttribute('data-src')?el.getAttribute('data-src'):el.getAttribute('data-iurl'))]")
driver.close()
if not os.path.exists(foldername):
os.mkdir(foldername)
i = 1
for url in urls[0]:
if url and i <= maximum:
img = Image.open(requests.get(str(url), stream = True).raw)
try:
img.save(f"{foldername}/{foldername}_{str(i)}.jpg")
i+=1
except:
pass
if i - 1 >= minimum:
print(f"Found {i - 1} images for class {foldername}...")
else:
print(f"Found only {i - 1} images for class {foldername}, discarding class...")
shutil.rmtree(foldername)
n_classes = int(input("Enter number of classes: "))
maximum = int(input("Enter maximum number of images for each class: "))
minimum = int(input("Enter minimum number of images for each class: "))
split = float(input("Enter train-test split: "))
threads = []
for _ in range(n_classes):
f = input("Enter class name: ")
s = input("Enter Google Search term: ")
thread = threading.Thread(target = scraper, args = (s, f))
threads.append(thread)
os.mkdir("data")
os.chdir("data")
for thread in threads:
thread.start()
for thread in threads:
thread.join()
os.chdir("..")
found_classes = [c for c in os.listdir("data")]
os.mkdir("data/train")
os.mkdir("data/test")
for c in found_classes:
os.mkdir(f"data/train/{c}")
os.mkdir(f"data/test/{c}")
all_images = [image for image in os.listdir(f"data/{c}")]
random.shuffle(all_images)
split_at = math.floor(len(all_images) * split)
training_images = all_images[:split_at]
testing_images = all_images[split_at:]
for image in training_images:
shutil.move(f"data/{c}/{image}", f"data/train/{c}")
for image in testing_images:
shutil.move(f"data/{c}/{image}", f"data/test/{c}")
shutil.rmtree((f"data/{c}"))