Compare commits

...

5 Commits

Author SHA1 Message Date
0cb2685161 Update gateway.go 2025-01-18 04:19:04 +01:00
de13869fd5 Update gateway.go 2025-01-18 04:18:23 +01:00
8520eba449 feat(infisical-gateway): CLI support 2025-01-18 03:37:02 +01:00
efb8e5d070 Update main.go 2025-01-17 22:08:19 +01:00
a22743e7e2 feat: infisical-gateway 2025-01-17 22:08:19 +01:00
10 changed files with 825 additions and 0 deletions

View File

@ -9,6 +9,7 @@ require (
github.com/denisbrodbeck/machineid v1.0.1
github.com/fatih/semgroup v1.2.0
github.com/gitleaks/go-gitdiff v0.8.0
github.com/google/uuid v1.6.0
github.com/h2non/filetype v1.1.3
github.com/infisical/go-sdk v0.4.7
github.com/mattn/go-isatty v0.0.20

View File

@ -544,3 +544,81 @@ func CallUpdateRawSecretsV3(httpClient *resty.Client, request UpdateRawSecretByN
return nil
}
func CallListGatewaysV1(httpClient *resty.Client) (ListGatewaysV1Response, error) {
var listGatewaysResponse ListGatewaysV1Response
response, err := httpClient.
R().
SetResult(&listGatewaysResponse).
SetHeader("User-Agent", USER_AGENT).
Get(fmt.Sprintf("%v/v1/gateways", config.INFISICAL_URL))
if err != nil {
return ListGatewaysV1Response{}, fmt.Errorf("CallListGatewaysV1: Unable to complete api request [err=%w]", err)
}
if response.IsError() {
return ListGatewaysV1Response{}, fmt.Errorf("CallListGatewaysV1: Unsuccessful response [%v %v] [status-code=%v] [response=%v]", response.Request.Method, response.Request.URL, response.StatusCode(), response.String())
}
return listGatewaysResponse, nil
}
func CallGetGatewayV1(httpClient *resty.Client, request GetGatewayV1Request) (GetGatewayV1Response, error) {
var getGatewayResponse GetGatewayV1Response
response, err := httpClient.
R().
SetResult(&getGatewayResponse).
SetHeader("User-Agent", USER_AGENT).
Get(fmt.Sprintf("%v/v1/gateways/%v", config.INFISICAL_URL, request.ID))
if err != nil {
return GetGatewayV1Response{}, fmt.Errorf("CallGetGatewayV1: Unable to complete api request [err=%w]", err)
}
if response.IsError() {
return GetGatewayV1Response{}, fmt.Errorf("CallGetGatewayV1: Unsuccessful response [%v %v] [status-code=%v] [response=%v]", response.Request.Method, response.Request.URL, response.StatusCode(), response.String())
}
return getGatewayResponse, nil
}
func CallCreateGatewayV1(httpClient *resty.Client, request CreateGatewayV1Request) (CreateGatewayV1Response, error) {
var createGatewayResponse CreateGatewayV1Response
response, err := httpClient.
R().
SetResult(&createGatewayResponse).
SetHeader("User-Agent", USER_AGENT).
SetBody(request).
Post(fmt.Sprintf("%v/v1/gateways", config.INFISICAL_URL))
if err != nil {
return CreateGatewayV1Response{}, fmt.Errorf("CallCreateGatewayV1: Unable to complete api request [err=%w]", err)
}
if response.IsError() {
return CreateGatewayV1Response{}, fmt.Errorf("CallCreateGatewayV1: Unsuccessful response [%v %v] [status-code=%v] [response=%v]", response.Request.Method, response.Request.URL, response.StatusCode(), response.String())
}
return createGatewayResponse, nil
}
func CallUpdateGatewayV1(httpClient *resty.Client, request UpdateGatewayV1Request) (UpdateGatewayV1Response, error) {
var updateGatewayResponse UpdateGatewayV1Response
response, err := httpClient.
R().
SetResult(&updateGatewayResponse).
SetHeader("User-Agent", USER_AGENT).
SetBody(request).
Patch(fmt.Sprintf("%v/v1/gateways/%v", config.INFISICAL_URL, request.ID))
if err != nil {
return UpdateGatewayV1Response{}, fmt.Errorf("CallUpdateGatewayV1: Unable to complete api request [err=%w]", err)
}
if response.IsError() {
return UpdateGatewayV1Response{}, fmt.Errorf("CallUpdateGatewayV1: Unsuccessful response [%v %v] [status-code=%v] [response=%v]", response.Request.Method, response.Request.URL, response.StatusCode(), response.String())
}
return updateGatewayResponse, nil
}

View File

@ -629,3 +629,42 @@ type GetRawSecretV3ByNameResponse struct {
} `json:"secret"`
ETag string
}
type Gateway struct {
Name string `json:"name"`
ID string `json:"id"`
Hostname string `json:"host"`
LastPingAt time.Time `json:"lastPingAt"`
OrgId string `json:"orgId"`
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
}
type ListGatewaysV1Response struct {
Gateways []Gateway `json:"gateways"`
}
type GetGatewayV1Response struct {
Gateway Gateway `json:"gateway"`
}
type GetGatewayV1Request struct {
ID string
}
type CreateGatewayV1Response struct {
Gateway Gateway `json:"gateway"`
}
type CreateGatewayV1Request struct {
Name string `json:"name"`
}
type UpdateGatewayV1Request struct {
ID string
Name string `json:"name"`
}
type UpdateGatewayV1Response struct {
Gateway Gateway `json:"gateway"`
}

460
cli/packages/cmd/gateway.go Normal file
View File

@ -0,0 +1,460 @@
/*
Copyright (c) 2023 Infisical Inc.
*/
package cmd
import (
"context"
"encoding/json"
"fmt"
"io"
"log"
"net"
"net/http"
"os"
"strings"
"sync"
"time"
"github.com/Infisical/infisical-merge/packages/api"
"github.com/Infisical/infisical-merge/packages/util"
"github.com/go-resty/resty/v2"
"github.com/google/uuid"
infisicalSdk "github.com/infisical/go-sdk"
"github.com/spf13/cobra"
)
func getRealIP(r *http.Request) string {
// Order of headers to check for real IP
headersOrder := []string{
"cf-connecting-ip", // Cloudflare
"Cf-Pseudo-IPv4", // Cloudflare
"x-client-ip", // Most common
"x-envoy-external-address", // for envoy
"x-forwarded-for", // Mostly used by proxies
"fastly-client-ip",
"true-client-ip", // Akamai and Cloudflare
"x-real-ip", // Nginx
"x-cluser-client-ip", // Rackspace LB
"forwarded-for",
"x-forwarded",
"forwarded",
"x-appengine-user-ip", // GCP App Engine
}
// Check each header in order
for _, header := range headersOrder {
if ip := r.Header.Get(header); ip != "" {
// If IP contains comma, take the first IP (client IP)
if strings.Contains(ip, ",") {
return strings.TrimSpace(strings.Split(ip, ",")[0])
}
return ip
}
}
// If no headers found, get IP from RemoteAddr
ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
// If RemoteAddr doesn't have a port, return as is, split host and port
ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
return r.RemoteAddr
}
return ip
}
return ip
}
type TunnelRequest struct {
Protocol string `json:"protocol"`
Target string `json:"target"`
}
type TunnelResponse struct {
TunnelID string `json:"tunnelId"`
TunnelPort int `json:"tunnelPort"`
}
type Tunnel struct {
allowedIp string
ID string
Protocol string
Target string
LocalPort int
Created time.Time
listener net.Listener
}
type TunnelManager struct {
infisicalClient infisicalSdk.InfisicalClientInterface
tunnels map[string]*Tunnel
portRange portRange
mu sync.RWMutex
logger *log.Logger
}
type portRange struct {
start int
end int
}
func NewTunnelManager(startPort, endPort int, infisicalClient infisicalSdk.InfisicalClientInterface) *TunnelManager {
return &TunnelManager{
infisicalClient: infisicalClient,
tunnels: make(map[string]*Tunnel),
portRange: portRange{
start: startPort,
end: endPort,
},
logger: log.New(log.Writer(), "[TUNNEL] ", log.LstdFlags),
}
}
func (tm *TunnelManager) findAvailablePort() (int, error) {
for port := tm.portRange.start; port <= tm.portRange.end; port++ {
listener, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
if err == nil {
listener.Close()
return port, nil
}
}
return 0, fmt.Errorf("no available ports in range %d-%d", tm.portRange.start, tm.portRange.end)
}
func sanitizeHost(host string) string {
// If host contains @, take everything after the last @
if idx := strings.LastIndex(host, "@"); idx != -1 {
return host[idx+1:]
}
return host
}
func (tm *TunnelManager) createTunnel(req TunnelRequest, creatorIpAddress string) (*TunnelResponse, error) {
tm.mu.Lock()
defer tm.mu.Unlock()
port, err := tm.findAvailablePort()
if err != nil {
return nil, err
}
listener, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
if err != nil {
return nil, fmt.Errorf("failed to create listener: %v", err)
}
tunnelID := uuid.New().String()
tunnel := &Tunnel{
allowedIp: creatorIpAddress,
ID: tunnelID,
Protocol: req.Protocol,
Target: req.Target,
LocalPort: port,
Created: time.Now(),
listener: listener,
}
tm.tunnels[tunnelID] = tunnel
go tm.handleTunnelConnections(tunnel)
return &TunnelResponse{
TunnelID: tunnelID,
TunnelPort: port,
}, nil
}
func (tm *TunnelManager) handleTunnelConnections(tunnel *Tunnel) {
defer func() {
tm.mu.Lock()
delete(tm.tunnels, tunnel.ID)
tm.mu.Unlock()
tunnel.listener.Close()
}()
for {
clientConn, err := tunnel.listener.Accept()
if err != nil {
tm.logger.Printf("Error accepting connection: %v", err)
return
}
go tm.handleConnection(tunnel, clientConn)
}
}
func (tm *TunnelManager) handleConnection(tunnel *Tunnel, clientConn net.Conn) {
defer clientConn.Close()
clientIP, _, err := net.SplitHostPort(clientConn.RemoteAddr().String())
if err != nil {
tm.logger.Printf("Failed to get client IP: %v", err)
return
}
if clientIP != tunnel.allowedIp {
tm.logger.Printf("Unauthorized connection from %s", clientIP)
return
}
targetHost := sanitizeHost(tunnel.Target)
targetConn, err := net.Dial("tcp", targetHost)
if err != nil {
tm.logger.Printf("Failed to connect to target %s: %v", targetHost, err)
return
}
defer targetConn.Close()
tm.logger.Printf("New connection on tunnel %s: %s -> %s",
tunnel.ID, clientConn.RemoteAddr(), targetHost)
// Bidirectional copy, target -> client and client -> target
errCh := make(chan error, 2)
go func() {
_, err := io.Copy(targetConn, clientConn)
errCh <- err
}()
go func() {
_, err := io.Copy(clientConn, targetConn)
errCh <- err
}()
// Wait for either end to close
err = <-errCh
if err != nil && err != io.EOF {
tm.logger.Printf("Connection error: %v", err)
}
}
func (tm *TunnelManager) handleTunnelCreate(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req TunnelRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
if req.Target == "" {
http.Error(w, "Missing target host or port", http.StatusBadRequest)
return
}
resp, err := tm.createTunnel(req, getRealIP(r))
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}
func (tm *TunnelManager) handleTunnelList(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
tm.mu.RLock()
tunnels := make([]Tunnel, 0, len(tm.tunnels))
for _, t := range tm.tunnels {
tunnels = append(tunnels, *t)
}
tm.mu.RUnlock()
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(tunnels)
}
var gatewayCmd = &cobra.Command{
Use: "gateway",
Short: "Used to manage your Infisical Gateway",
DisableFlagsInUseLine: true,
Example: "infisical gateway",
Args: cobra.NoArgs,
}
func startGateway(cmd *cobra.Command, args []string) {
infisicalConfig, err := util.GetConfigFile()
if err != nil {
util.HandleError(fmt.Errorf("startGateway: unable to get config file because [err=%s]", err))
}
loginMethod, err := cmd.Flags().GetString("method")
if err != nil {
util.HandleError(err)
}
gatewayName, err := cmd.Flags().GetString("name")
if err != nil {
util.HandleError(err)
}
if gatewayName == "" {
util.PrintErrorMessageAndExit("Gateway name is required to start the gateway. Use the --name flag to specify the gateway name.")
}
domain, err := cmd.Flags().GetString("domain")
if err != nil {
util.HandleError(err)
}
authMethodValid, strategy := util.IsAuthMethodValid(loginMethod, false)
if !authMethodValid {
util.PrintErrorMessageAndExit(fmt.Sprintf("Invalid login method: %s", loginMethod))
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel() // Cancel the context when the client is no longer needed
infisicalClient := infisicalSdk.NewInfisicalClient(ctx, infisicalSdk.Config{
SiteUrl: domain,
})
authStrategies := map[util.AuthStrategyType]func(cmd *cobra.Command, infisicalClient infisicalSdk.InfisicalClientInterface) (credential infisicalSdk.MachineIdentityCredential, e error){
util.AuthStrategy.UNIVERSAL_AUTH: handleUniversalAuthLogin,
util.AuthStrategy.KUBERNETES_AUTH: handleKubernetesAuthLogin,
util.AuthStrategy.AZURE_AUTH: handleAzureAuthLogin,
util.AuthStrategy.GCP_ID_TOKEN_AUTH: handleGcpIdTokenAuthLogin,
util.AuthStrategy.GCP_IAM_AUTH: handleGcpIamAuthLogin,
util.AuthStrategy.AWS_IAM_AUTH: handleAwsIamAuthLogin,
util.AuthStrategy.OIDC_AUTH: handleOidcAuthLogin,
}
_, err = authStrategies[strategy](cmd, infisicalClient)
if err != nil {
util.HandleError(err)
}
tm := NewTunnelManager(10000, 20000, infisicalClient)
setupComplete := false
var wg sync.WaitGroup
authMiddleware := func(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if !setupComplete {
http.Error(w, "Unauthorized: Gateway setup not complete", http.StatusUnauthorized)
return
}
token := r.Header.Get("Authorization")
if token == "" {
http.Error(w, "Unauthorized: No token provided", http.StatusUnauthorized)
return
}
token = strings.TrimPrefix(token, "Bearer ")
httpClient := resty.New()
httpClient.SetAuthScheme("Bearer")
httpClient.SetAuthToken(token)
_, err := api.CallListGatewaysV1(httpClient)
if err != nil {
http.Error(w, "Unauthorized: Invalid token", http.StatusUnauthorized)
return
}
next(w, r)
}
}
http.HandleFunc("/tunnel", authMiddleware(tm.handleTunnelCreate))
http.HandleFunc("/tunnels", authMiddleware(tm.handleTunnelList))
http.Handle("/health", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
wg.Add(1)
go func() {
defer wg.Done()
serverPort := 8022
tm.logger.Printf("Starting gateway server on port %d", serverPort)
if err := http.ListenAndServe(fmt.Sprintf(":%d", serverPort), nil); err != nil {
tm.logger.Fatalf("Failed to start server: %v", err)
os.Exit(1)
}
}()
accessToken := infisicalClient.Auth().GetAccessToken()
httpClient := resty.New().SetAuthToken(accessToken)
if infisicalConfig.Gateway.ID == "" {
createdGateway, err := api.CallCreateGatewayV1(httpClient, api.CreateGatewayV1Request{
Name: gatewayName,
})
if err != nil {
util.HandleError(err)
}
infisicalConfig.Gateway.ID = createdGateway.Gateway.ID
err = util.WriteConfigFile(&infisicalConfig)
if err != nil {
util.HandleError(err)
}
} else {
res, err := api.CallGetGatewayV1(httpClient, api.GetGatewayV1Request{
ID: infisicalConfig.Gateway.ID,
})
if err != nil {
util.HandleError(err)
}
if res.Gateway.Name != gatewayName {
tm.logger.Printf("Gateway name has been changed from %s to %s\nUpdating..\n\n", res.Gateway.Name, gatewayName)
_, err := api.CallUpdateGatewayV1(httpClient, api.UpdateGatewayV1Request{
ID: infisicalConfig.Gateway.ID,
Name: gatewayName,
})
if err != nil {
util.HandleError(err)
}
}
}
tm.logger.Printf("Gateway started successfully on port %d", 8022)
setupComplete = true
wg.Wait()
}
var gatewayStartCmd = &cobra.Command{
Example: `gateway start`,
Short: "Starts the Infisical Gateway",
Use: "start",
DisableFlagsInUseLine: true,
Args: cobra.NoArgs,
Run: startGateway,
}
func init() {
rootCmd.AddCommand(gatewayCmd)
gatewayCmd.AddCommand(gatewayStartCmd)
gatewayStartCmd.Flags().String("name", "", "name of the gateway")
gatewayStartCmd.Flags().String("method", "", "login method")
gatewayStartCmd.Flags().String("client-id", "", "client id for universal auth")
gatewayStartCmd.Flags().String("client-secret", "", "client secret for universal auth")
gatewayStartCmd.Flags().String("machine-identity-id", "", "machine identity id for kubernetes, azure, gcp-id-token, gcp-iam, and aws-iam auth methods")
gatewayStartCmd.Flags().String("service-account-token-path", "", "service account token path for kubernetes auth")
gatewayStartCmd.Flags().String("service-account-key-file-path", "", "service account key file path for GCP IAM auth")
gatewayStartCmd.Flags().String("oidc-jwt", "", "JWT for OIDC authentication")
}

View File

@ -17,6 +17,9 @@ type ConfigFile struct {
VaultBackendType string `json:"vaultBackendType,omitempty"`
VaultBackendPassphrase string `json:"vaultBackendPassphrase,omitempty"`
Domains []string `json:"domains,omitempty"`
Gateway struct {
ID string `json:"id"`
} `json:"gateway"`
}
type LoggedInUser struct {

View File

@ -198,6 +198,7 @@ func GetWorkspaceConfigByPath(path string) (workspaceConfig models.WorkspaceConf
// Get the infisical config file and if it doesn't exist, return empty config model, otherwise raise error
func GetConfigFile() (models.ConfigFile, error) {
fullConfigFilePath, _, err := GetFullConfigFilePath()
if err != nil {
return models.ConfigFile{}, err
}

View File

@ -6,6 +6,7 @@ const (
INFISICAL_DEFAULT_US_URL = "https://app.infisical.com"
INFISICAL_DEFAULT_EU_URL = "https://eu.infisical.com"
INFISICAL_WORKSPACE_CONFIG_FILE_NAME = ".infisical.json"
INFISICAL_GATEWAY_CONFIG_FILE_NAME = ".infisical-gateway.json"
INFISICAL_TOKEN_NAME = "INFISICAL_TOKEN"
INFISICAL_UNIVERSAL_AUTH_ACCESS_TOKEN_NAME = "INFISICAL_UNIVERSAL_AUTH_ACCESS_TOKEN"
INFISICAL_VAULT_FILE_PASSPHRASE_ENV_NAME = "INFISICAL_VAULT_FILE_PASSPHRASE" // This works because we've forked the keyring package and added support for this env variable. This explains why you won't find any occurrences of it in the CLI codebase.

5
gateway/go.mod Normal file
View File

@ -0,0 +1,5 @@
module github.com/infisical/gateway
go 1.23.3
require github.com/google/uuid v1.6.0

2
gateway/go.sum Normal file
View File

@ -0,0 +1,2 @@
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=

235
gateway/main.go Normal file
View File

@ -0,0 +1,235 @@
package main
import (
"encoding/json"
"fmt"
"io"
"log"
"net"
"net/http"
"strings"
"sync"
"time"
"github.com/google/uuid"
)
// TunnelRequest represents the handshake request
type TunnelRequest struct {
GatewayToken string `json:"gatewayToken"`
Protocol string `json:"protocol"`
TargetHost string `json:"targetHost"`
TargetPort int `json:"targetPort"`
}
// TunnelResponse is returned after successful handshake
type TunnelResponse struct {
TunnelID string `json:"tunnelId"`
TunnelPort int `json:"tunnelPort"`
}
// Tunnel represents an active tunnel
type Tunnel struct {
ID string
Protocol string
TargetHost string
TargetPort int
LocalPort int
Created time.Time
listener net.Listener
}
// TunnelManager handles tunnel lifecycle
type TunnelManager struct {
tunnels map[string]*Tunnel
portRange portRange
mu sync.RWMutex
logger *log.Logger
}
type portRange struct {
start int
end int
}
func NewTunnelManager(startPort, endPort int) *TunnelManager {
return &TunnelManager{
tunnels: make(map[string]*Tunnel),
portRange: portRange{
start: startPort,
end: endPort,
},
logger: log.New(log.Writer(), "[TUNNEL] ", log.LstdFlags),
}
}
func (tm *TunnelManager) findAvailablePort() (int, error) {
for port := tm.portRange.start; port <= tm.portRange.end; port++ {
listener, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
if err == nil {
listener.Close()
return port, nil
}
}
return 0, fmt.Errorf("no available ports in range %d-%d", tm.portRange.start, tm.portRange.end)
}
func sanitizeHost(host string) string {
// If host contains @, take everything after the last @
if idx := strings.LastIndex(host, "@"); idx != -1 {
return host[idx+1:]
}
return host
}
func (tm *TunnelManager) createTunnel(req TunnelRequest) (*TunnelResponse, error) {
tm.mu.Lock()
defer tm.mu.Unlock()
// Find available port
port, err := tm.findAvailablePort()
if err != nil {
return nil, err
}
// Create listener for tunnel
listener, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
if err != nil {
return nil, fmt.Errorf("failed to create listener: %v", err)
}
tunnelID := uuid.New().String()
tunnel := &Tunnel{
ID: tunnelID,
Protocol: req.Protocol,
TargetHost: req.TargetHost,
TargetPort: req.TargetPort,
LocalPort: port,
Created: time.Now(),
listener: listener,
}
tm.tunnels[tunnelID] = tunnel
// Start handling connections
go tm.handleTunnelConnections(tunnel)
return &TunnelResponse{
TunnelID: tunnelID,
TunnelPort: port,
}, nil
}
func (tm *TunnelManager) handleTunnelConnections(tunnel *Tunnel) {
defer func() {
tm.mu.Lock()
delete(tm.tunnels, tunnel.ID)
tm.mu.Unlock()
tunnel.listener.Close()
}()
for {
clientConn, err := tunnel.listener.Accept()
if err != nil {
tm.logger.Printf("Error accepting connection: %v", err)
return
}
go tm.handleConnection(tunnel, clientConn)
}
}
func (tm *TunnelManager) handleConnection(tunnel *Tunnel, clientConn net.Conn) {
defer clientConn.Close()
// Sanitize the host before dialing
targetHost := sanitizeHost(tunnel.TargetHost)
targetConn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", targetHost, tunnel.TargetPort))
if err != nil {
tm.logger.Printf("Failed to connect to target %s:%d: %v", targetHost, tunnel.TargetPort, err)
return
}
defer targetConn.Close()
tm.logger.Printf("New connection on tunnel %s: %s -> %s:%d",
tunnel.ID, clientConn.RemoteAddr(), targetHost, tunnel.TargetPort)
// Bidirectional copy
errCh := make(chan error, 2)
go func() {
_, err := io.Copy(targetConn, clientConn)
errCh <- err
}()
go func() {
_, err := io.Copy(clientConn, targetConn)
errCh <- err
}()
// Wait for either end to close
err = <-errCh
if err != nil && err != io.EOF {
tm.logger.Printf("Connection error: %v", err)
}
}
// HTTP handlers
func (tm *TunnelManager) handleTunnelCreate(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req TunnelRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
// Validate request
if req.TargetHost == "" || req.TargetPort == 0 {
http.Error(w, "Missing target host or port", http.StatusBadRequest)
return
}
resp, err := tm.createTunnel(req)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}
func (tm *TunnelManager) handleTunnelList(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
tm.mu.RLock()
tunnels := make([]Tunnel, 0, len(tm.tunnels))
for _, t := range tm.tunnels {
tunnels = append(tunnels, *t)
}
tm.mu.RUnlock()
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(tunnels)
}
func main() {
tm := NewTunnelManager(10000, 20000) // Port range for tunnels
// HTTP endpoints
http.HandleFunc("/tunnel", tm.handleTunnelCreate)
http.HandleFunc("/tunnels", tm.handleTunnelList)
// Start HTTP server
serverPort := 8022
tm.logger.Printf("Starting gateway server on port %d", serverPort)
if err := http.ListenAndServe(fmt.Sprintf(":%d", serverPort), nil); err != nil {
tm.logger.Fatalf("Failed to start server: %v", err)
}
}