diff --git a/server/linkpreview.go b/server/linkpreview.go index dced553d..cf87cfd9 100644 --- a/server/linkpreview.go +++ b/server/linkpreview.go @@ -2,6 +2,7 @@ package main import ( "encoding/json" + "errors" "golang.org/x/net/html" "io" "net" @@ -18,39 +19,42 @@ type linkPreview struct { } var client = &http.Client{ - Transport: &http.Transport{}, - Timeout: time.Second * 2, + Timeout: time.Second * 2, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + if err := validateURL(req.URL.String()); err != nil { + return err + } + return nil + }, } // previewLink handles the HTTP request, fetches the URL, and returns the link preview. func previewLink(w http.ResponseWriter, r *http.Request) { + // check authorization + uid, challenge, err := authHttpRequest(r) + if err != nil { + http.Error(w, "invalid auth secret", http.StatusBadRequest) + return + } + if challenge != nil { + http.Error(w, "login challenge not done", http.StatusMultipleChoices) + return + } + if uid.IsZero() { + http.Error(w, "user not authenticated", http.StatusUnauthorized) + return + } + u := r.URL.Query().Get("url") if u == "" { http.Error(w, "Missing 'url' query parameter", http.StatusBadRequest) return } - parsedURL, err := url.Parse(u) - if err != nil { + if err := validateURL(u); err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } - if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" { - http.Error(w, "invalid schema", http.StatusBadRequest) - return - } - - ips, err := net.LookupIP(parsedURL.Hostname()) - if err != nil { - http.Error(w, "invalid host", http.StatusBadRequest) - return - } - for _, ip := range ips { - if ip.IsLoopback() || ip.IsPrivate() { - http.Error(w, "non routable IP address", http.StatusBadRequest) - return - } - } req, err := http.NewRequest(http.MethodGet, u, nil) if err != nil { @@ -70,79 +74,92 @@ func previewLink(w http.ResponseWriter, r *http.Request) { return } - body := io.LimitReader(resp.Body, 2*1024) // 2KB limit - doc, err := html.Parse(body) - if err != nil { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - w.Write([]byte("{}")) - return - } - - linkPreview := extractMetadata(doc) + body := http.MaxBytesReader(nil, resp.Body, 2*1024) // 2KB limit w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(linkPreview); err != nil { + if err := json.NewEncoder(w).Encode(extractMetadata(body)); err != nil { http.Error(w, "Failed to encode response", http.StatusInternalServerError) } } -func extractMetadata(n *html.Node) linkPreview { +func extractMetadata(body io.Reader) linkPreview { var preview linkPreview + var gotTitle, gotDesc, gotImg, inTitleTag bool - var traverse func(*html.Node) - traverse = func(n *html.Node) { - if n.Type == html.ElementNode && strings.ToLower(n.Data) == "meta" { - var name, property, content string - for _, attr := range n.Attr { - switch attr.Key { - case "name": - name = attr.Val - case "property": - property = attr.Val - case "content": - content = attr.Val + tokenizer := html.NewTokenizer(body) + for { + switch tokenizer.Next() { + case html.ErrorToken: + return preview + + case html.StartTagToken, html.SelfClosingTagToken: + token := tokenizer.Token() + data := strings.ToLower(token.Data) + if data == "meta" { + var name, property, content string + for _, attr := range token.Attr { + switch strings.ToLower(attr.Key) { + case "name": + name = attr.Val + case "property": + property = attr.Val + case "content": + content = attr.Val + } } - } - if strings.HasPrefix(property, "og:") && content != "" { - if property == "og:title" { - preview.Title = content - } else if property == "og:description" { + if strings.HasPrefix(property, "og:") && content != "" { + switch property { + case "og:title": + preview.Title = content + gotTitle = true + case "og:description": + preview.Description = content + gotDesc = true + case "og:image": + preview.ImageURL = content + gotImg = true + } + } else if name == "description" && preview.Description == "" { preview.Description = content - } else if property == "og:image" { - preview.ImageURL = content + gotDesc = true } - } else if name == "description" && preview.Description == "" { - preview.Description = content + } else if data == "title" { + inTitleTag = true + } + + case html.TextToken: + if !gotTitle && inTitleTag { + preview.Title = strings.TrimSpace(tokenizer.Token().Data) + gotTitle = true + inTitleTag = false } } - - for child := n.FirstChild; child != nil; child = child.NextSibling { - traverse(child) + if gotTitle && gotDesc && gotImg { + break } } - - traverse(n) - - if preview.Title == "" { - preview.Title = extractTitle(n) - } - return preview } -func extractTitle(n *html.Node) string { - if n.Type == html.ElementNode && n.Data == "title" { - if n.FirstChild != nil { - return n.FirstChild.Data +func validateURL(u string) error { + parsedURL, err := url.Parse(u) + if err != nil { + return err + } + if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" { + return &url.Error{Op: "validate", Err: errors.New("invalid scheme")} + } + + ips, err := net.LookupIP(parsedURL.Hostname()) + if err != nil { + return &url.Error{Op: "validate", Err: errors.New("invalid host")} + } + for _, ip := range ips { + if ip.IsLoopback() || ip.IsPrivate() { + return &url.Error{Op: "validate", Err: errors.New("non routable IP address")} } } - for child := n.FirstChild; child != nil; child = child.NextSibling { - title := extractTitle(child) - if title != "" { - return title - } - } - return "" + + return nil } diff --git a/server/main.go b/server/main.go index 7c52e16c..2182b08c 100644 --- a/server/main.go +++ b/server/main.go @@ -240,10 +240,6 @@ type mediaConfig struct { Handlers map[string]json.RawMessage `json:"handlers"` } -type LinkPreviewConfig struct { - Enabled bool `json:"enabled"` -} - // Contentx of the configuration file type configType struct { // HTTP(S) address:port to listen on for websocket and long polling clients. Either a @@ -296,17 +292,17 @@ type configType struct { DefaultCountryCode string `json:"default_country_code"` // Configs for subsystems - Cluster json.RawMessage `json:"cluster_config"` - Plugin json.RawMessage `json:"plugins"` - Store json.RawMessage `json:"store_config"` - Push json.RawMessage `json:"push"` - TLS json.RawMessage `json:"tls"` - Auth map[string]json.RawMessage `json:"auth_config"` - Validator map[string]*validatorConfig `json:"acc_validation"` - AccountGC *accountGcConfig `json:"acc_gc_config"` - Media *mediaConfig `json:"media"` - WebRTC json.RawMessage `json:"webrtc"` - LinkPreview *LinkPreviewConfig `json:"link_preview"` + Cluster json.RawMessage `json:"cluster_config"` + Plugin json.RawMessage `json:"plugins"` + Store json.RawMessage `json:"store_config"` + Push json.RawMessage `json:"push"` + TLS json.RawMessage `json:"tls"` + Auth map[string]json.RawMessage `json:"auth_config"` + Validator map[string]*validatorConfig `json:"acc_validation"` + AccountGC *accountGcConfig `json:"acc_gc_config"` + Media *mediaConfig `json:"media"` + WebRTC json.RawMessage `json:"webrtc"` + LinkPreviewEnabled bool `json:"link_preview_enabled"` } func main() { @@ -739,7 +735,7 @@ func main() { mux.HandleFunc("/", serve404) } - if config.LinkPreview != nil && config.LinkPreview.Enabled { + if config.LinkPreviewEnabled { mux.HandleFunc(config.ApiPath+"v0/preview-link", previewLink) } diff --git a/server/tinode.conf b/server/tinode.conf index 36e84e8d..d799b15d 100644 --- a/server/tinode.conf +++ b/server/tinode.conf @@ -679,7 +679,5 @@ "service_addr": "tcp://localhost:40051" } ], - "link_preview": { - "enabled":true - } + "link_preview_enabled":true }