fix: Split host and port before storing IP (#2594)

The IP was always nil prior, and this fixes the test to
check for that as well!
This commit is contained in:
Kyle Carberry
2022-06-26 16:22:03 -05:00
committed by GitHub
parent 545a9f3435
commit 4851d932c4
3 changed files with 8 additions and 5 deletions

View File

@ -167,7 +167,8 @@ func ExtractAPIKey(db database.Store, oauth *OAuth2Configs) func(http.Handler) h
// Only update LastUsed once an hour to prevent database spam. // Only update LastUsed once an hour to prevent database spam.
if now.Sub(key.LastUsed) > time.Hour { if now.Sub(key.LastUsed) > time.Hour {
key.LastUsed = now key.LastUsed = now
remoteIP := net.ParseIP(r.RemoteAddr) host, _, _ := net.SplitHostPort(r.RemoteAddr)
remoteIP := net.ParseIP(host)
if remoteIP == nil { if remoteIP == nil {
remoteIP = net.IPv4(0, 0, 0, 0) remoteIP = net.IPv4(0, 0, 0, 0)
} }

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"crypto/sha256" "crypto/sha256"
"fmt" "fmt"
"net"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
@ -413,13 +414,13 @@ func TestAPIKey(t *testing.T) {
rw = httptest.NewRecorder() rw = httptest.NewRecorder()
user = createUser(r.Context(), t, db) user = createUser(r.Context(), t, db)
) )
r.RemoteAddr = "1.1.1.1" r.RemoteAddr = "1.1.1.1:3555"
r.AddCookie(&http.Cookie{ r.AddCookie(&http.Cookie{
Name: httpmw.SessionTokenKey, Name: httpmw.SessionTokenKey,
Value: fmt.Sprintf("%s-%s", id, secret), Value: fmt.Sprintf("%s-%s", id, secret),
}) })
sentAPIKey, err := db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{ _, err := db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{
ID: id, ID: id,
HashedSecret: hashed[:], HashedSecret: hashed[:],
LastUsed: database.Now().AddDate(0, 0, -1), LastUsed: database.Now().AddDate(0, 0, -1),
@ -435,7 +436,7 @@ func TestAPIKey(t *testing.T) {
gotAPIKey, err := db.GetAPIKeyByID(r.Context(), id) gotAPIKey, err := db.GetAPIKeyByID(r.Context(), id)
require.NoError(t, err) require.NoError(t, err)
require.NotEqual(t, sentAPIKey.IPAddress, gotAPIKey.IPAddress) require.Equal(t, net.ParseIP("1.1.1.1"), gotAPIKey.IPAddress.IPNet.IP)
}) })
} }

View File

@ -782,7 +782,8 @@ func (api *API) createAPIKey(rw http.ResponseWriter, r *http.Request, params dat
} }
} }
ip := net.ParseIP(r.RemoteAddr) host, _, _ := net.SplitHostPort(r.RemoteAddr)
ip := net.ParseIP(host)
if ip == nil { if ip == nil {
ip = net.IPv4(0, 0, 0, 0) ip = net.IPv4(0, 0, 0, 0)
} }