fix: add authorization check and use html tokenizer instead of parser

This commit is contained in:
yinebebt
2024-11-29 15:08:40 +03:00
parent 3c1489a037
commit 5156f90c44
3 changed files with 104 additions and 93 deletions

View File

@ -2,6 +2,7 @@ package main
import (
"encoding/json"
"errors"
"golang.org/x/net/html"
"io"
"net"
@ -18,39 +19,42 @@ type linkPreview struct {
}
var client = &http.Client{
Transport: &http.Transport{},
Timeout: time.Second * 2,
Timeout: time.Second * 2,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
if err := validateURL(req.URL.String()); err != nil {
return err
}
return nil
},
}
// previewLink handles the HTTP request, fetches the URL, and returns the link preview.
func previewLink(w http.ResponseWriter, r *http.Request) {
// check authorization
uid, challenge, err := authHttpRequest(r)
if err != nil {
http.Error(w, "invalid auth secret", http.StatusBadRequest)
return
}
if challenge != nil {
http.Error(w, "login challenge not done", http.StatusMultipleChoices)
return
}
if uid.IsZero() {
http.Error(w, "user not authenticated", http.StatusUnauthorized)
return
}
u := r.URL.Query().Get("url")
if u == "" {
http.Error(w, "Missing 'url' query parameter", http.StatusBadRequest)
return
}
parsedURL, err := url.Parse(u)
if err != nil {
if err := validateURL(u); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" {
http.Error(w, "invalid schema", http.StatusBadRequest)
return
}
ips, err := net.LookupIP(parsedURL.Hostname())
if err != nil {
http.Error(w, "invalid host", http.StatusBadRequest)
return
}
for _, ip := range ips {
if ip.IsLoopback() || ip.IsPrivate() {
http.Error(w, "non routable IP address", http.StatusBadRequest)
return
}
}
req, err := http.NewRequest(http.MethodGet, u, nil)
if err != nil {
@ -70,79 +74,92 @@ func previewLink(w http.ResponseWriter, r *http.Request) {
return
}
body := io.LimitReader(resp.Body, 2*1024) // 2KB limit
doc, err := html.Parse(body)
if err != nil {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte("{}"))
return
}
linkPreview := extractMetadata(doc)
body := http.MaxBytesReader(nil, resp.Body, 2*1024) // 2KB limit
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(linkPreview); err != nil {
if err := json.NewEncoder(w).Encode(extractMetadata(body)); err != nil {
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
}
}
func extractMetadata(n *html.Node) linkPreview {
func extractMetadata(body io.Reader) linkPreview {
var preview linkPreview
var gotTitle, gotDesc, gotImg, inTitleTag bool
var traverse func(*html.Node)
traverse = func(n *html.Node) {
if n.Type == html.ElementNode && strings.ToLower(n.Data) == "meta" {
var name, property, content string
for _, attr := range n.Attr {
switch attr.Key {
case "name":
name = attr.Val
case "property":
property = attr.Val
case "content":
content = attr.Val
tokenizer := html.NewTokenizer(body)
for {
switch tokenizer.Next() {
case html.ErrorToken:
return preview
case html.StartTagToken, html.SelfClosingTagToken:
token := tokenizer.Token()
data := strings.ToLower(token.Data)
if data == "meta" {
var name, property, content string
for _, attr := range token.Attr {
switch strings.ToLower(attr.Key) {
case "name":
name = attr.Val
case "property":
property = attr.Val
case "content":
content = attr.Val
}
}
}
if strings.HasPrefix(property, "og:") && content != "" {
if property == "og:title" {
preview.Title = content
} else if property == "og:description" {
if strings.HasPrefix(property, "og:") && content != "" {
switch property {
case "og:title":
preview.Title = content
gotTitle = true
case "og:description":
preview.Description = content
gotDesc = true
case "og:image":
preview.ImageURL = content
gotImg = true
}
} else if name == "description" && preview.Description == "" {
preview.Description = content
} else if property == "og:image" {
preview.ImageURL = content
gotDesc = true
}
} else if name == "description" && preview.Description == "" {
preview.Description = content
} else if data == "title" {
inTitleTag = true
}
case html.TextToken:
if !gotTitle && inTitleTag {
preview.Title = strings.TrimSpace(tokenizer.Token().Data)
gotTitle = true
inTitleTag = false
}
}
for child := n.FirstChild; child != nil; child = child.NextSibling {
traverse(child)
if gotTitle && gotDesc && gotImg {
break
}
}
traverse(n)
if preview.Title == "" {
preview.Title = extractTitle(n)
}
return preview
}
func extractTitle(n *html.Node) string {
if n.Type == html.ElementNode && n.Data == "title" {
if n.FirstChild != nil {
return n.FirstChild.Data
func validateURL(u string) error {
parsedURL, err := url.Parse(u)
if err != nil {
return err
}
if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" {
return &url.Error{Op: "validate", Err: errors.New("invalid scheme")}
}
ips, err := net.LookupIP(parsedURL.Hostname())
if err != nil {
return &url.Error{Op: "validate", Err: errors.New("invalid host")}
}
for _, ip := range ips {
if ip.IsLoopback() || ip.IsPrivate() {
return &url.Error{Op: "validate", Err: errors.New("non routable IP address")}
}
}
for child := n.FirstChild; child != nil; child = child.NextSibling {
title := extractTitle(child)
if title != "" {
return title
}
}
return ""
return nil
}

View File

@ -240,10 +240,6 @@ type mediaConfig struct {
Handlers map[string]json.RawMessage `json:"handlers"`
}
type LinkPreviewConfig struct {
Enabled bool `json:"enabled"`
}
// Contentx of the configuration file
type configType struct {
// HTTP(S) address:port to listen on for websocket and long polling clients. Either a
@ -296,17 +292,17 @@ type configType struct {
DefaultCountryCode string `json:"default_country_code"`
// Configs for subsystems
Cluster json.RawMessage `json:"cluster_config"`
Plugin json.RawMessage `json:"plugins"`
Store json.RawMessage `json:"store_config"`
Push json.RawMessage `json:"push"`
TLS json.RawMessage `json:"tls"`
Auth map[string]json.RawMessage `json:"auth_config"`
Validator map[string]*validatorConfig `json:"acc_validation"`
AccountGC *accountGcConfig `json:"acc_gc_config"`
Media *mediaConfig `json:"media"`
WebRTC json.RawMessage `json:"webrtc"`
LinkPreview *LinkPreviewConfig `json:"link_preview"`
Cluster json.RawMessage `json:"cluster_config"`
Plugin json.RawMessage `json:"plugins"`
Store json.RawMessage `json:"store_config"`
Push json.RawMessage `json:"push"`
TLS json.RawMessage `json:"tls"`
Auth map[string]json.RawMessage `json:"auth_config"`
Validator map[string]*validatorConfig `json:"acc_validation"`
AccountGC *accountGcConfig `json:"acc_gc_config"`
Media *mediaConfig `json:"media"`
WebRTC json.RawMessage `json:"webrtc"`
LinkPreviewEnabled bool `json:"link_preview_enabled"`
}
func main() {
@ -739,7 +735,7 @@ func main() {
mux.HandleFunc("/", serve404)
}
if config.LinkPreview != nil && config.LinkPreview.Enabled {
if config.LinkPreviewEnabled {
mux.HandleFunc(config.ApiPath+"v0/preview-link", previewLink)
}

View File

@ -679,7 +679,5 @@
"service_addr": "tcp://localhost:40051"
}
],
"link_preview": {
"enabled":true
}
"link_preview_enabled":true
}