mirror of
https://github.com/coder/coder.git
synced 2025-07-03 16:13:58 +00:00
chore: add support for one-way websockets to backend (#16853)
Closes https://github.com/coder/coder/issues/16775 ## Changes made - Added `OneWayWebSocket` function that establishes WebSocket connections that don't allow client-to-server communication - Added tests for the new function - Updated API endpoints to make new WS-based endpoints, and mark previous SSE-based endpoints as deprecated - Updated existing SSE handlers to use the same core logic as the new WS handlers ## Notes - Frontend changes handled via #16855
This commit is contained in:
@ -1,14 +1,18 @@
|
||||
package httpapi_test
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@ -16,6 +20,7 @@ import (
|
||||
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestInternalServerError(t *testing.T) {
|
||||
@ -155,3 +160,436 @@ func TestWebsocketCloseMsg(t *testing.T) {
|
||||
assert.Equal(t, len(trunc), 123)
|
||||
})
|
||||
}
|
||||
|
||||
// Our WebSocket library accepts any arbitrary ResponseWriter at the type level,
|
||||
// but the writer must also implement http.Hijacker for long-lived connections.
|
||||
type mockOneWaySocketWriter struct {
|
||||
serverRecorder *httptest.ResponseRecorder
|
||||
serverConn net.Conn
|
||||
clientConn net.Conn
|
||||
serverReadWriter *bufio.ReadWriter
|
||||
testContext *testing.T
|
||||
}
|
||||
|
||||
func (m mockOneWaySocketWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
return m.serverConn, m.serverReadWriter, nil
|
||||
}
|
||||
|
||||
func (m mockOneWaySocketWriter) Flush() {
|
||||
err := m.serverReadWriter.Flush()
|
||||
require.NoError(m.testContext, err)
|
||||
}
|
||||
|
||||
func (m mockOneWaySocketWriter) Header() http.Header {
|
||||
return m.serverRecorder.Header()
|
||||
}
|
||||
|
||||
func (m mockOneWaySocketWriter) Write(b []byte) (int, error) {
|
||||
return m.serverReadWriter.Write(b)
|
||||
}
|
||||
|
||||
func (m mockOneWaySocketWriter) WriteHeader(code int) {
|
||||
m.serverRecorder.WriteHeader(code)
|
||||
}
|
||||
|
||||
type mockEventSenderWrite func(b []byte) (int, error)
|
||||
|
||||
func (w mockEventSenderWrite) Write(b []byte) (int, error) {
|
||||
return w(b)
|
||||
}
|
||||
|
||||
func TestOneWayWebSocketEventSender(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
newBaseRequest := func(ctx context.Context) *http.Request {
|
||||
url := "ws://www.fake-website.com/logs"
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
h := req.Header
|
||||
h.Add("Connection", "Upgrade")
|
||||
h.Add("Upgrade", "websocket")
|
||||
h.Add("Sec-WebSocket-Version", "13")
|
||||
h.Add("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==") // Just need any string
|
||||
|
||||
return req
|
||||
}
|
||||
|
||||
newOneWayWriter := func(t *testing.T) mockOneWaySocketWriter {
|
||||
mockServer, mockClient := net.Pipe()
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
var write mockEventSenderWrite = func(b []byte) (int, error) {
|
||||
serverCount, err := mockServer.Write(b)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
recorderCount, err := recorder.Write(b)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return min(serverCount, recorderCount), nil
|
||||
}
|
||||
|
||||
return mockOneWaySocketWriter{
|
||||
testContext: t,
|
||||
serverConn: mockServer,
|
||||
clientConn: mockClient,
|
||||
serverRecorder: recorder,
|
||||
serverReadWriter: bufio.NewReadWriter(
|
||||
bufio.NewReader(mockServer),
|
||||
bufio.NewWriter(write),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("Produces error if the socket connection could not be established", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
incorrectProtocols := []struct {
|
||||
major int
|
||||
minor int
|
||||
proto string
|
||||
}{
|
||||
{0, 9, "HTTP/0.9"},
|
||||
{1, 0, "HTTP/1.0"},
|
||||
}
|
||||
for _, p := range incorrectProtocols {
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
req := newBaseRequest(ctx)
|
||||
req.ProtoMajor = p.major
|
||||
req.ProtoMinor = p.minor
|
||||
req.Proto = p.proto
|
||||
|
||||
writer := newOneWayWriter(t)
|
||||
_, _, err := httpapi.OneWayWebSocketEventSender(writer, req)
|
||||
require.ErrorContains(t, err, p.proto)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Returned callback can publish new event to WebSocket connection", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
req := newBaseRequest(ctx)
|
||||
writer := newOneWayWriter(t)
|
||||
send, _, err := httpapi.OneWayWebSocketEventSender(writer, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
serverPayload := codersdk.ServerSentEvent{
|
||||
Type: codersdk.ServerSentEventTypeData,
|
||||
Data: "Blah",
|
||||
}
|
||||
err = send(serverPayload)
|
||||
require.NoError(t, err)
|
||||
|
||||
// The client connection will receive a little bit of additional data on
|
||||
// top of the main payload. Have to make sure check has tolerance for
|
||||
// extra data being present
|
||||
serverBytes, err := json.Marshal(serverPayload)
|
||||
require.NoError(t, err)
|
||||
clientBytes, err := io.ReadAll(writer.clientConn)
|
||||
require.NoError(t, err)
|
||||
require.True(t, bytes.Contains(clientBytes, serverBytes))
|
||||
})
|
||||
|
||||
t.Run("Signals to outside consumer when socket has been closed", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithCancel(testutil.Context(t, testutil.WaitShort))
|
||||
req := newBaseRequest(ctx)
|
||||
writer := newOneWayWriter(t)
|
||||
_, done, err := httpapi.OneWayWebSocketEventSender(writer, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
successC := make(chan bool)
|
||||
ticker := time.NewTicker(testutil.WaitShort)
|
||||
go func() {
|
||||
select {
|
||||
case <-done:
|
||||
successC <- true
|
||||
case <-ticker.C:
|
||||
successC <- false
|
||||
}
|
||||
}()
|
||||
|
||||
cancel()
|
||||
require.True(t, <-successC)
|
||||
})
|
||||
|
||||
t.Run("Socket will immediately close if client sends any message", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
req := newBaseRequest(ctx)
|
||||
writer := newOneWayWriter(t)
|
||||
_, done, err := httpapi.OneWayWebSocketEventSender(writer, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
successC := make(chan bool)
|
||||
ticker := time.NewTicker(testutil.WaitShort)
|
||||
go func() {
|
||||
select {
|
||||
case <-done:
|
||||
successC <- true
|
||||
case <-ticker.C:
|
||||
successC <- false
|
||||
}
|
||||
}()
|
||||
|
||||
type JunkClientEvent struct {
|
||||
Value string
|
||||
}
|
||||
b, err := json.Marshal(JunkClientEvent{"Hi :)"})
|
||||
require.NoError(t, err)
|
||||
_, err = writer.clientConn.Write(b)
|
||||
require.NoError(t, err)
|
||||
require.True(t, <-successC)
|
||||
})
|
||||
|
||||
t.Run("Renders the socket inert if the request context cancels", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithCancel(testutil.Context(t, testutil.WaitShort))
|
||||
req := newBaseRequest(ctx)
|
||||
writer := newOneWayWriter(t)
|
||||
send, done, err := httpapi.OneWayWebSocketEventSender(writer, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
successC := make(chan bool)
|
||||
ticker := time.NewTicker(testutil.WaitShort)
|
||||
go func() {
|
||||
select {
|
||||
case <-done:
|
||||
successC <- true
|
||||
case <-ticker.C:
|
||||
successC <- false
|
||||
}
|
||||
}()
|
||||
|
||||
cancel()
|
||||
require.True(t, <-successC)
|
||||
err = send(codersdk.ServerSentEvent{
|
||||
Type: codersdk.ServerSentEventTypeData,
|
||||
Data: "Didn't realize you were closed - sorry! I'll try coming back tomorrow.",
|
||||
})
|
||||
require.Equal(t, err, ctx.Err())
|
||||
_, open := <-done
|
||||
require.False(t, open)
|
||||
_, err = writer.serverConn.Write([]byte{})
|
||||
require.Equal(t, err, io.ErrClosedPipe)
|
||||
_, err = writer.clientConn.Read([]byte{})
|
||||
require.Equal(t, err, io.EOF)
|
||||
})
|
||||
|
||||
t.Run("Sends a heartbeat to the socket on a fixed internal of time to keep connections alive", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Need add at least three heartbeats for something to be reliably
|
||||
// counted as an interval, but also need some wiggle room
|
||||
heartbeatCount := 3
|
||||
hbDuration := time.Duration(heartbeatCount) * httpapi.HeartbeatInterval
|
||||
timeout := hbDuration + (5 * time.Second)
|
||||
|
||||
ctx := testutil.Context(t, timeout)
|
||||
req := newBaseRequest(ctx)
|
||||
writer := newOneWayWriter(t)
|
||||
_, _, err := httpapi.OneWayWebSocketEventSender(writer, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
type Result struct {
|
||||
Err error
|
||||
Success bool
|
||||
}
|
||||
resultC := make(chan Result)
|
||||
go func() {
|
||||
err := writer.
|
||||
clientConn.
|
||||
SetReadDeadline(time.Now().Add(timeout))
|
||||
if err != nil {
|
||||
resultC <- Result{err, false}
|
||||
return
|
||||
}
|
||||
for range heartbeatCount {
|
||||
pingBuffer := make([]byte, 1)
|
||||
pingSize, err := writer.clientConn.Read(pingBuffer)
|
||||
if err != nil || pingSize != 1 {
|
||||
resultC <- Result{err, false}
|
||||
return
|
||||
}
|
||||
}
|
||||
resultC <- Result{nil, true}
|
||||
}()
|
||||
|
||||
result := <-resultC
|
||||
require.NoError(t, result.Err)
|
||||
require.True(t, result.Success)
|
||||
})
|
||||
}
|
||||
|
||||
// ServerSentEventSender accepts any arbitrary ResponseWriter at the type level,
|
||||
// but the writer must also implement http.Flusher for long-lived connections
|
||||
type mockServerSentWriter struct {
|
||||
serverRecorder *httptest.ResponseRecorder
|
||||
serverConn net.Conn
|
||||
clientConn net.Conn
|
||||
buffer *bytes.Buffer
|
||||
testContext *testing.T
|
||||
}
|
||||
|
||||
func (m mockServerSentWriter) Flush() {
|
||||
b := m.buffer.Bytes()
|
||||
_, err := m.serverConn.Write(b)
|
||||
require.NoError(m.testContext, err)
|
||||
m.buffer.Reset()
|
||||
|
||||
// Must close server connection to indicate EOF for any reads from the
|
||||
// client connection; otherwise reads block forever. This is a testing
|
||||
// limitation compared to the one-way websockets, since we have no way to
|
||||
// frame the data and auto-indicate EOF for each message
|
||||
err = m.serverConn.Close()
|
||||
require.NoError(m.testContext, err)
|
||||
}
|
||||
|
||||
func (m mockServerSentWriter) Header() http.Header {
|
||||
return m.serverRecorder.Header()
|
||||
}
|
||||
|
||||
func (m mockServerSentWriter) Write(b []byte) (int, error) {
|
||||
return m.buffer.Write(b)
|
||||
}
|
||||
|
||||
func (m mockServerSentWriter) WriteHeader(code int) {
|
||||
m.serverRecorder.WriteHeader(code)
|
||||
}
|
||||
|
||||
func TestServerSentEventSender(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
newBaseRequest := func(ctx context.Context) *http.Request {
|
||||
url := "ws://www.fake-website.com/logs"
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
require.NoError(t, err)
|
||||
return req
|
||||
}
|
||||
|
||||
newServerSentWriter := func(t *testing.T) mockServerSentWriter {
|
||||
mockServer, mockClient := net.Pipe()
|
||||
return mockServerSentWriter{
|
||||
testContext: t,
|
||||
serverRecorder: httptest.NewRecorder(),
|
||||
clientConn: mockClient,
|
||||
serverConn: mockServer,
|
||||
buffer: &bytes.Buffer{},
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("Mutates response headers to support SSE connections", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
req := newBaseRequest(ctx)
|
||||
writer := newServerSentWriter(t)
|
||||
_, _, err := httpapi.ServerSentEventSender(writer, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
h := writer.Header()
|
||||
require.Equal(t, h.Get("Content-Type"), "text/event-stream")
|
||||
require.Equal(t, h.Get("Cache-Control"), "no-cache")
|
||||
require.Equal(t, h.Get("Connection"), "keep-alive")
|
||||
require.Equal(t, h.Get("X-Accel-Buffering"), "no")
|
||||
})
|
||||
|
||||
t.Run("Returned callback can publish new event to SSE connection", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
req := newBaseRequest(ctx)
|
||||
writer := newServerSentWriter(t)
|
||||
send, _, err := httpapi.ServerSentEventSender(writer, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
serverPayload := codersdk.ServerSentEvent{
|
||||
Type: codersdk.ServerSentEventTypeData,
|
||||
Data: "Blah",
|
||||
}
|
||||
err = send(serverPayload)
|
||||
require.NoError(t, err)
|
||||
|
||||
clientBytes, err := io.ReadAll(writer.clientConn)
|
||||
require.NoError(t, err)
|
||||
require.Equal(
|
||||
t,
|
||||
string(clientBytes),
|
||||
"event: data\ndata: \"Blah\"\n\n",
|
||||
)
|
||||
})
|
||||
|
||||
t.Run("Signals to outside consumer when connection has been closed", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithCancel(testutil.Context(t, testutil.WaitShort))
|
||||
req := newBaseRequest(ctx)
|
||||
writer := newServerSentWriter(t)
|
||||
_, done, err := httpapi.ServerSentEventSender(writer, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
successC := make(chan bool)
|
||||
ticker := time.NewTicker(testutil.WaitShort)
|
||||
go func() {
|
||||
select {
|
||||
case <-done:
|
||||
successC <- true
|
||||
case <-ticker.C:
|
||||
successC <- false
|
||||
}
|
||||
}()
|
||||
|
||||
cancel()
|
||||
require.True(t, <-successC)
|
||||
})
|
||||
|
||||
t.Run("Sends a heartbeat to the client on a fixed internal of time to keep connections alive", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Need add at least three heartbeats for something to be reliably
|
||||
// counted as an interval, but also need some wiggle room
|
||||
heartbeatCount := 3
|
||||
hbDuration := time.Duration(heartbeatCount) * httpapi.HeartbeatInterval
|
||||
timeout := hbDuration + (5 * time.Second)
|
||||
|
||||
ctx := testutil.Context(t, timeout)
|
||||
req := newBaseRequest(ctx)
|
||||
writer := newServerSentWriter(t)
|
||||
_, _, err := httpapi.ServerSentEventSender(writer, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
type Result struct {
|
||||
Err error
|
||||
Success bool
|
||||
}
|
||||
resultC := make(chan Result)
|
||||
go func() {
|
||||
err := writer.
|
||||
clientConn.
|
||||
SetReadDeadline(time.Now().Add(timeout))
|
||||
if err != nil {
|
||||
resultC <- Result{err, false}
|
||||
return
|
||||
}
|
||||
for range heartbeatCount {
|
||||
pingBuffer := make([]byte, 1)
|
||||
pingSize, err := writer.clientConn.Read(pingBuffer)
|
||||
if err != nil || pingSize != 1 {
|
||||
resultC <- Result{err, false}
|
||||
return
|
||||
}
|
||||
}
|
||||
resultC <- Result{nil, true}
|
||||
}()
|
||||
|
||||
result := <-resultC
|
||||
require.NoError(t, result.Err)
|
||||
require.True(t, result.Success)
|
||||
})
|
||||
}
|
||||
|
Reference in New Issue
Block a user