Skip to content

Commit

Permalink
Close PeerConnection on DTLS CloseNotify
Browse files Browse the repository at this point in the history
Resolves #1767
Resolves pion/dtls#151
  • Loading branch information
Sean-Der committed Sep 4, 2023
1 parent ea23dec commit f76dc63
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 7 deletions.
9 changes: 6 additions & 3 deletions dtlstransport.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ type DTLSTransport struct {
state DTLSTransportState
srtpProtectionProfile srtp.ProtectionProfile

onStateChangeHandler func(DTLSTransportState)
onStateChangeHandler func(DTLSTransportState)
internalOnCloseHandler func()

conn *dtls.Conn

Expand Down Expand Up @@ -320,8 +321,6 @@ func (t *DTLSTransport) Start(remoteParameters DTLSParameters) error {
}, nil
}

var dtlsConn *dtls.Conn
dtlsEndpoint := t.iceTransport.newEndpoint(mux.MatchDTLS)
role, dtlsConfig, err := prepareTransport()
if err != nil {
return err
Expand All @@ -344,6 +343,10 @@ func (t *DTLSTransport) Start(remoteParameters DTLSParameters) error {
dtlsConfig.RootCAs = t.api.settingEngine.dtls.rootCAs
dtlsConfig.KeyLogWriter = t.api.settingEngine.dtls.keyLogWriter

var dtlsConn *dtls.Conn
dtlsEndpoint := t.iceTransport.newEndpoint(mux.MatchDTLS)
dtlsEndpoint.SetOnClose(t.internalOnCloseHandler)

// Connect as DTLS Client/Server, function is blocking and we
// must not hold the DTLSTransport lock
if role == DTLSRoleClient {
Expand Down
18 changes: 14 additions & 4 deletions internal/mux/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,18 @@ import (

// Endpoint implements net.Conn. It is used to read muxed packets.
type Endpoint struct {
mux *Mux
buffer *packetio.Buffer
mux *Mux
buffer *packetio.Buffer
onClose func()
}

// Close unregisters the endpoint from the Mux
func (e *Endpoint) Close() (err error) {
err = e.close()
if err != nil {
if e.onClose != nil {
e.onClose()
}

Check warning on line 27 in internal/mux/endpoint.go

View check run for this annotation

Codecov / codecov/patch

internal/mux/endpoint.go#L25-L27

Added lines #L25 - L27 were not covered by tests

if err = e.close(); err != nil {

Check warning on line 29 in internal/mux/endpoint.go

View check run for this annotation

Codecov / codecov/patch

internal/mux/endpoint.go#L29

Added line #L29 was not covered by tests
return err
}

Expand Down Expand Up @@ -76,3 +80,9 @@ func (e *Endpoint) SetReadDeadline(time.Time) error {
func (e *Endpoint) SetWriteDeadline(time.Time) error {
return nil
}

// SetOnClose is a user set callback that
// will be executed when `Close` is called
func (e *Endpoint) SetOnClose(onClose func()) {
e.onClose = onClose

Check warning on line 87 in internal/mux/endpoint.go

View check run for this annotation

Codecov / codecov/patch

internal/mux/endpoint.go#L86-L87

Added lines #L86 - L87 were not covered by tests
}
10 changes: 10 additions & 0 deletions peerconnection.go
Original file line number Diff line number Diff line change
Expand Up @@ -2261,6 +2261,16 @@ func (pc *PeerConnection) startTransports(iceRole ICERole, dtlsRole DTLSRole, re
return
}

pc.dtlsTransport.internalOnCloseHandler = func() {
pc.log.Info("Closing PeerConnection from DTLS CloseNotify")

go func() {
if pcClosErr := pc.Close(); pcClosErr != nil {
pc.log.Warnf("Failed to close PeerConnection from DTLS CloseNotify: %s", pcClosErr)
}
}()
}

// Start the dtls transport
err = pc.dtlsTransport.Start(DTLSParameters{
Role: dtlsRole,
Expand Down
38 changes: 38 additions & 0 deletions peerconnection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -754,3 +754,41 @@ func TestTransportChain(t *testing.T) {

closePairNow(t, offer, answer)
}

// Assert that the PeerConnection closes via DTLS (and not ICE)
func TestDTLSClose(t *testing.T) {
lim := test.TimeOut(time.Second * 10)
defer lim.Stop()

report := test.CheckRoutines(t)
defer report()

pcOffer, pcAnswer, err := newPair()
assert.NoError(t, err)

_, err = pcOffer.AddTransceiverFromKind(RTPCodecTypeVideo)
assert.NoError(t, err)

peerConnectionsConnected := untilConnectionState(PeerConnectionStateConnected, pcOffer, pcAnswer)

offer, err := pcOffer.CreateOffer(nil)
assert.NoError(t, err)

offerGatheringComplete := GatheringCompletePromise(pcOffer)
assert.NoError(t, pcOffer.SetLocalDescription(offer))
<-offerGatheringComplete

assert.NoError(t, pcAnswer.SetRemoteDescription(*pcOffer.LocalDescription()))

answer, err := pcAnswer.CreateAnswer(nil)
assert.NoError(t, err)

answerGatheringComplete := GatheringCompletePromise(pcAnswer)
assert.NoError(t, pcAnswer.SetLocalDescription(answer))
<-answerGatheringComplete

assert.NoError(t, pcOffer.SetRemoteDescription(*pcAnswer.LocalDescription()))

peerConnectionsConnected.Wait()
assert.NoError(t, pcOffer.Close())
}

0 comments on commit f76dc63

Please sign in to comment.