Skip to content

Commit

Permalink
[mod] refactor scrape function
Browse files Browse the repository at this point in the history
  • Loading branch information
asciimoo committed Oct 12, 2017
1 parent b014512 commit 5e31f9b
Showing 1 changed file with 46 additions and 43 deletions.
89 changes: 46 additions & 43 deletions colly.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,55 +158,21 @@ func (c *Collector) PostRaw(URL string, requestData []byte) error {
func (c *Collector) scrape(u, method string, depth int, requestData io.Reader, ctx *Context) error {
c.wg.Add(1)
defer c.wg.Done()
if u == "" {
return errors.New("Missing URL")
}
if c.MaxDepth > 0 && c.MaxDepth < depth {
return errors.New("Max depth limit reached")
}
if !c.AllowURLRevisit {
visited := false
for _, u2 := range c.visitedURLs {
if u2 == u {
visited = true
break
}
}
if visited {
return errors.New("URL already visited")
}
if err := c.requestCheck(u, depth); err != nil {
return err
}
parsedURL, err := url.Parse(u)
if err != nil {
return err
}
allowed := false
if c.AllowedDomains == nil || len(c.AllowedDomains) == 0 {
allowed = true
} else {
for _, d := range c.AllowedDomains {
if d == parsedURL.Host {
allowed = true
break
}
}
}
if !allowed {
if !c.isDomainAllowed(parsedURL.Host) {
return errors.New("Forbidden domain")
}
if !c.AllowURLRevisit {
c.lock.Lock()
c.visitedURLs = append(c.visitedURLs, u)
c.lock.Unlock()
}
req, err := http.NewRequest(method, u, requestData)
if err != nil {
return err
}
req.Header.Set("User-Agent", c.UserAgent)
if method == "POST" {
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
}
if ctx == nil {
ctx = NewContext()
}
Expand All @@ -217,8 +183,11 @@ func (c *Collector) scrape(u, method string, depth int, requestData io.Reader, c
Depth: depth,
collector: c,
}
if len(c.requestCallbacks) > 0 {
c.handleOnRequest(request)

c.handleOnRequest(request)

if method == "POST" && req.Header.Get("Content-Type") == "" {
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
}
response, err := c.backend.Cache(req, c.MaxBodySize, c.CacheDir)
// TODO add OnError callback to handle these cases
Expand All @@ -228,15 +197,46 @@ func (c *Collector) scrape(u, method string, depth int, requestData io.Reader, c
response.Ctx = ctx
response.Request = request
response.fixCharset()
if len(c.responseCallbacks) > 0 {
c.handleOnResponse(response)

c.handleOnResponse(response)

c.handleOnHTML(request, response)

return nil
}

func (c *Collector) requestCheck(u string, depth int) error {
if u == "" {
return errors.New("Missing URL")
}
if strings.Index(strings.ToLower(response.Headers.Get("Content-Type")), "html") > -1 {
c.handleOnHTML(request, response)
if c.MaxDepth > 0 && c.MaxDepth < depth {
return errors.New("Max depth limit reached")
}
if !c.AllowURLRevisit {
for _, u2 := range c.visitedURLs {
if u2 == u {
return errors.New("URL already visited")
}
}
c.lock.Lock()
c.visitedURLs = append(c.visitedURLs, u)
c.lock.Unlock()
}
return nil
}

func (c *Collector) isDomainAllowed(domain string) bool {
if c.AllowedDomains == nil || len(c.AllowedDomains) == 0 {
return true
}
for _, d2 := range c.AllowedDomains {
if d2 == domain {
return true
}
}
return false
}

// Wait returns when the collector jobs are finished
func (c *Collector) Wait() {
c.wg.Wait()
Expand Down Expand Up @@ -294,6 +294,9 @@ func (c *Collector) handleOnResponse(r *Response) {
}

func (c *Collector) handleOnHTML(req *Request, resp *Response) {
if strings.Index(strings.ToLower(resp.Headers.Get("Content-Type")), "html") == -1 {
return
}
doc, err := goquery.NewDocumentFromReader(bytes.NewBuffer(resp.Body))
if err != nil {
return
Expand Down

0 comments on commit 5e31f9b

Please sign in to comment.