mirror of
https://github.com/coder/coder.git
synced 2025-03-14 10:09:57 +00:00
353 lines
9.7 KiB
Go
353 lines
9.7 KiB
Go
package provisionersdk
|
|
|
|
import (
|
|
"archive/tar"
|
|
"bytes"
|
|
"context"
|
|
"fmt"
|
|
"hash/crc32"
|
|
"io"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/spf13/afero"
|
|
"golang.org/x/xerrors"
|
|
|
|
"cdr.dev/slog"
|
|
|
|
"github.com/coder/coder/v2/provisionersdk/proto"
|
|
)
|
|
|
|
const (
|
|
// ReadmeFile is the location we look for to extract documentation from template versions.
|
|
ReadmeFile = "README.md"
|
|
|
|
sessionDirPrefix = "Session"
|
|
staleSessionRetention = 7 * 24 * time.Hour
|
|
)
|
|
|
|
// protoServer is a wrapper that translates the dRPC protocol into a Session with method calls into the Server.
|
|
type protoServer struct {
|
|
server Server
|
|
opts ServeOptions
|
|
}
|
|
|
|
func (p *protoServer) Session(stream proto.DRPCProvisioner_SessionStream) error {
|
|
sessID := uuid.New().String()
|
|
s := &Session{
|
|
Logger: p.opts.Logger.With(slog.F("session_id", sessID)),
|
|
stream: stream,
|
|
server: p.server,
|
|
}
|
|
|
|
err := CleanStaleSessions(s.Context(), p.opts.WorkDirectory, afero.NewOsFs(), time.Now(), s.Logger)
|
|
if err != nil {
|
|
return xerrors.Errorf("unable to clean stale sessions %q: %w", s.WorkDirectory, err)
|
|
}
|
|
|
|
s.WorkDirectory = filepath.Join(p.opts.WorkDirectory, SessionDir(sessID))
|
|
err = os.MkdirAll(s.WorkDirectory, 0o700)
|
|
if err != nil {
|
|
return xerrors.Errorf("create work directory %q: %w", s.WorkDirectory, err)
|
|
}
|
|
defer func() {
|
|
var err error
|
|
// Cleanup the work directory after execution.
|
|
for attempt := 0; attempt < 5; attempt++ {
|
|
err = os.RemoveAll(s.WorkDirectory)
|
|
if err != nil {
|
|
// On Windows, open files cannot be removed.
|
|
// When the provisioner daemon is shutting down,
|
|
// it may take a few milliseconds for processes to exit.
|
|
// See: https://github.com/golang/go/issues/50510
|
|
s.Logger.Debug(s.Context(), "failed to clean work directory; trying again", slog.Error(err))
|
|
time.Sleep(250 * time.Millisecond)
|
|
continue
|
|
}
|
|
s.Logger.Debug(s.Context(), "cleaned up work directory")
|
|
return
|
|
}
|
|
s.Logger.Error(s.Context(), "failed to clean up work directory after multiple attempts",
|
|
slog.F("path", s.WorkDirectory), slog.Error(err))
|
|
}()
|
|
req, err := stream.Recv()
|
|
if err != nil {
|
|
return xerrors.Errorf("receive config: %w", err)
|
|
}
|
|
config := req.GetConfig()
|
|
if config == nil {
|
|
return xerrors.New("first request must be Config")
|
|
}
|
|
s.Config = config
|
|
if s.Config.ProvisionerLogLevel != "" {
|
|
s.logLevel = proto.LogLevel_value[strings.ToUpper(s.Config.ProvisionerLogLevel)]
|
|
}
|
|
|
|
err = s.extractArchive()
|
|
if err != nil {
|
|
return xerrors.Errorf("extract archive: %w", err)
|
|
}
|
|
return s.handleRequests()
|
|
}
|
|
|
|
func (s *Session) requestReader(done <-chan struct{}) <-chan *proto.Request {
|
|
ch := make(chan *proto.Request)
|
|
go func() {
|
|
defer close(ch)
|
|
for {
|
|
req, err := s.stream.Recv()
|
|
if err != nil {
|
|
s.Logger.Info(s.Context(), "recv done on Session", slog.Error(err))
|
|
return
|
|
}
|
|
select {
|
|
case ch <- req:
|
|
continue
|
|
case <-done:
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
return ch
|
|
}
|
|
|
|
func (s *Session) handleRequests() error {
|
|
done := make(chan struct{})
|
|
defer close(done)
|
|
requests := s.requestReader(done)
|
|
planned := false
|
|
for req := range requests {
|
|
if req.GetCancel() != nil {
|
|
s.Logger.Warn(s.Context(), "ignoring cancel before request or after complete")
|
|
continue
|
|
}
|
|
resp := &proto.Response{}
|
|
if parse := req.GetParse(); parse != nil {
|
|
r := &request[*proto.ParseRequest, *proto.ParseComplete]{
|
|
req: parse,
|
|
session: s,
|
|
serverFn: s.server.Parse,
|
|
cancels: requests,
|
|
}
|
|
complete, err := r.do()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
// Handle README centrally, so that individual provisioners don't need to mess with it.
|
|
readme, err := os.ReadFile(filepath.Join(s.WorkDirectory, ReadmeFile))
|
|
if err == nil {
|
|
complete.Readme = readme
|
|
} else {
|
|
s.Logger.Debug(s.Context(), "failed to parse readme (missing ok)", slog.Error(err))
|
|
}
|
|
resp.Type = &proto.Response_Parse{Parse: complete}
|
|
}
|
|
if plan := req.GetPlan(); plan != nil {
|
|
r := &request[*proto.PlanRequest, *proto.PlanComplete]{
|
|
req: plan,
|
|
session: s,
|
|
serverFn: s.server.Plan,
|
|
cancels: requests,
|
|
}
|
|
complete, err := r.do()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
resp.Type = &proto.Response_Plan{Plan: complete}
|
|
if complete.Error == "" {
|
|
planned = true
|
|
}
|
|
}
|
|
if apply := req.GetApply(); apply != nil {
|
|
if !planned {
|
|
return xerrors.New("cannot apply before successful plan")
|
|
}
|
|
r := &request[*proto.ApplyRequest, *proto.ApplyComplete]{
|
|
req: apply,
|
|
session: s,
|
|
serverFn: s.server.Apply,
|
|
cancels: requests,
|
|
}
|
|
complete, err := r.do()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
resp.Type = &proto.Response_Apply{Apply: complete}
|
|
}
|
|
err := s.stream.Send(resp)
|
|
if err != nil {
|
|
return xerrors.Errorf("send response: %w", err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
type Session struct {
|
|
Logger slog.Logger
|
|
WorkDirectory string
|
|
Config *proto.Config
|
|
|
|
server Server
|
|
stream proto.DRPCProvisioner_SessionStream
|
|
logLevel int32
|
|
}
|
|
|
|
func (s *Session) Context() context.Context {
|
|
return s.stream.Context()
|
|
}
|
|
|
|
func (s *Session) extractArchive() error {
|
|
ctx := s.Context()
|
|
|
|
s.Logger.Info(ctx, "unpacking template source archive",
|
|
slog.F("size_bytes", len(s.Config.TemplateSourceArchive)),
|
|
)
|
|
|
|
reader := tar.NewReader(bytes.NewBuffer(s.Config.TemplateSourceArchive))
|
|
// for safety, nil out the reference on Config, since the reader now owns it.
|
|
s.Config.TemplateSourceArchive = nil
|
|
for {
|
|
header, err := reader.Next()
|
|
if err != nil {
|
|
if xerrors.Is(err, io.EOF) {
|
|
break
|
|
}
|
|
return xerrors.Errorf("read template source archive: %w", err)
|
|
}
|
|
s.Logger.Debug(context.Background(), "read archive entry",
|
|
slog.F("name", header.Name),
|
|
slog.F("mod_time", header.ModTime),
|
|
slog.F("size", header.Size))
|
|
|
|
// Security: don't untar absolute or relative paths, as this can allow a malicious tar to overwrite
|
|
// files outside the workdir.
|
|
if !filepath.IsLocal(header.Name) {
|
|
return xerrors.Errorf("refusing to extract to non-local path")
|
|
}
|
|
// nolint: gosec
|
|
headerPath := filepath.Join(s.WorkDirectory, header.Name)
|
|
if !strings.HasPrefix(headerPath, filepath.Clean(s.WorkDirectory)) {
|
|
return xerrors.New("tar attempts to target relative upper directory")
|
|
}
|
|
mode := header.FileInfo().Mode()
|
|
if mode == 0 {
|
|
mode = 0o600
|
|
}
|
|
|
|
// Always check for context cancellation before reading the next header.
|
|
// This is mainly important for unit tests, since a canceled context means
|
|
// the underlying directory is going to be deleted. There still exists
|
|
// the small race condition that the context is canceled after this, and
|
|
// before the disk write.
|
|
if ctx.Err() != nil {
|
|
return xerrors.Errorf("context canceled: %w", ctx.Err())
|
|
}
|
|
switch header.Typeflag {
|
|
case tar.TypeDir:
|
|
err = os.MkdirAll(headerPath, mode)
|
|
if err != nil {
|
|
return xerrors.Errorf("mkdir %q: %w", headerPath, err)
|
|
}
|
|
s.Logger.Debug(context.Background(), "extracted directory",
|
|
slog.F("path", headerPath),
|
|
slog.F("mode", fmt.Sprintf("%O", mode)))
|
|
case tar.TypeReg:
|
|
file, err := os.OpenFile(headerPath, os.O_CREATE|os.O_RDWR, mode)
|
|
if err != nil {
|
|
return xerrors.Errorf("create file %q (mode %s): %w", headerPath, mode, err)
|
|
}
|
|
|
|
hash := crc32.NewIEEE()
|
|
hashReader := io.TeeReader(reader, hash)
|
|
// Max file size of 10MiB.
|
|
size, err := io.CopyN(file, hashReader, 10<<20)
|
|
if xerrors.Is(err, io.EOF) {
|
|
err = nil
|
|
}
|
|
if err != nil {
|
|
_ = file.Close()
|
|
return xerrors.Errorf("copy file %q: %w", headerPath, err)
|
|
}
|
|
err = file.Close()
|
|
if err != nil {
|
|
return xerrors.Errorf("close file %q: %s", headerPath, err)
|
|
}
|
|
s.Logger.Debug(context.Background(), "extracted file",
|
|
slog.F("size_bytes", size),
|
|
slog.F("path", headerPath),
|
|
slog.F("mode", mode),
|
|
slog.F("checksum", fmt.Sprintf("%x", hash.Sum(nil))))
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *Session) ProvisionLog(level proto.LogLevel, output string) {
|
|
if int32(level) < s.logLevel {
|
|
return
|
|
}
|
|
|
|
err := s.stream.Send(&proto.Response{Type: &proto.Response_Log{Log: &proto.Log{
|
|
Level: level,
|
|
Output: output,
|
|
}}})
|
|
if err != nil {
|
|
s.Logger.Error(s.Context(), "failed to transmit log",
|
|
slog.F("level", level), slog.F("output", output))
|
|
}
|
|
}
|
|
|
|
type pRequest interface {
|
|
*proto.ParseRequest | *proto.PlanRequest | *proto.ApplyRequest
|
|
}
|
|
|
|
type pComplete interface {
|
|
*proto.ParseComplete | *proto.PlanComplete | *proto.ApplyComplete
|
|
}
|
|
|
|
// request processes a single request call to the Server and returns its complete result, while also processing cancel
|
|
// requests from the daemon. Provisioner implementations read from canceledOrComplete to be asynchronously informed
|
|
// of cancel.
|
|
type request[R pRequest, C pComplete] struct {
|
|
req R
|
|
session *Session
|
|
cancels <-chan *proto.Request
|
|
serverFn func(*Session, R, <-chan struct{}) C
|
|
}
|
|
|
|
func (r *request[R, C]) do() (C, error) {
|
|
canceledOrComplete := make(chan struct{})
|
|
result := make(chan C)
|
|
go func() {
|
|
c := r.serverFn(r.session, r.req, canceledOrComplete)
|
|
result <- c
|
|
}()
|
|
select {
|
|
case req := <-r.cancels:
|
|
close(canceledOrComplete)
|
|
// wait for server to complete the request, even though we have canceled,
|
|
// so that we can't start a new request, and so that if the job was close
|
|
// to completion and the cancel was ignored, we return to complete.
|
|
c := <-result
|
|
// verify we got a cancel instead of another request or closed channel --- which is an error!
|
|
if req.GetCancel() != nil {
|
|
return c, nil
|
|
}
|
|
if req == nil {
|
|
return c, xerrors.New("got nil while old request still processing")
|
|
}
|
|
return c, xerrors.Errorf("got new request %T while old request still processing", req.Type)
|
|
case c := <-result:
|
|
close(canceledOrComplete)
|
|
return c, nil
|
|
}
|
|
}
|
|
|
|
// SessionDir returns the directory name with mandatory prefix.
|
|
func SessionDir(sessID string) string {
|
|
return sessionDirPrefix + sessID
|
|
}
|