Files
coder/cli/cliutil/awscheck.go

115 lines
2.7 KiB
Go

package cliutil
import (
"context"
"encoding/json"
"io"
"net/http"
"net/netip"
"time"
"golang.org/x/xerrors"
)
const AWSIPRangesURL = "https://ip-ranges.amazonaws.com/ip-ranges.json"
type awsIPv4Prefix struct {
Prefix string `json:"ip_prefix"`
Region string `json:"region"`
Service string `json:"service"`
NetworkBorderGroup string `json:"network_border_group"`
}
type awsIPv6Prefix struct {
Prefix string `json:"ipv6_prefix"`
Region string `json:"region"`
Service string `json:"service"`
NetworkBorderGroup string `json:"network_border_group"`
}
type AWSIPRanges struct {
V4 []netip.Prefix
V6 []netip.Prefix
}
type awsIPRangesResponse struct {
SyncToken string `json:"syncToken"`
CreateDate string `json:"createDate"`
IPV4Prefixes []awsIPv4Prefix `json:"prefixes"`
IPV6Prefixes []awsIPv6Prefix `json:"ipv6_prefixes"`
}
func FetchAWSIPRanges(ctx context.Context, url string) (*AWSIPRanges, error) {
client := &http.Client{}
reqCtx, reqCancel := context.WithTimeout(ctx, 5*time.Second)
defer reqCancel()
req, _ := http.NewRequestWithContext(reqCtx, http.MethodGet, url, nil)
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
b, _ := io.ReadAll(resp.Body)
return nil, xerrors.Errorf("unexpected status code %d: %s", resp.StatusCode, b)
}
var body awsIPRangesResponse
err = json.NewDecoder(resp.Body).Decode(&body)
if err != nil {
return nil, xerrors.Errorf("json decode: %w", err)
}
out := &AWSIPRanges{
V4: make([]netip.Prefix, 0, len(body.IPV4Prefixes)),
V6: make([]netip.Prefix, 0, len(body.IPV6Prefixes)),
}
for _, p := range body.IPV4Prefixes {
prefix, err := netip.ParsePrefix(p.Prefix)
if err != nil {
return nil, xerrors.Errorf("parse ip prefix: %w", err)
}
if prefix.Addr().Is6() {
return nil, xerrors.Errorf("ipv4 prefix contains ipv6 address: %s", p.Prefix)
}
out.V4 = append(out.V4, prefix)
}
for _, p := range body.IPV6Prefixes {
prefix, err := netip.ParsePrefix(p.Prefix)
if err != nil {
return nil, xerrors.Errorf("parse ip prefix: %w", err)
}
if prefix.Addr().Is4() {
return nil, xerrors.Errorf("ipv6 prefix contains ipv4 address: %s", p.Prefix)
}
out.V6 = append(out.V6, prefix)
}
return out, nil
}
// CheckIP checks if the given IP address is an AWS IP.
func (r *AWSIPRanges) CheckIP(ip netip.Addr) bool {
if ip.IsLoopback() || ip.IsLinkLocalMulticast() || ip.IsLinkLocalUnicast() || ip.IsPrivate() {
return false
}
if ip.Is4() {
for _, p := range r.V4 {
if p.Contains(ip) {
return true
}
}
} else {
for _, p := range r.V6 {
if p.Contains(ip) {
return true
}
}
}
return false
}