Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement support for Wireguard in PIA #1836

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
scratch.txt
.idea/
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove before merging

1 change: 1 addition & 0 deletions internal/configuration/settings/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ func (p *Provider) validate(vpnType string, storage Storage) (err error) {
providers.Ivpn,
providers.Mullvad,
providers.Nordvpn,
providers.PrivateInternetAccess,
providers.Surfshark,
providers.Windscribe,
}
Expand Down
4 changes: 2 additions & 2 deletions internal/models/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ var (
ErrWireguardPublicKeyEmpty = errors.New("wireguard public key field is empty")
)

func (s *Server) HasMinimumInformation() (err error) {
func (s *Server) HasMinimumInformation(providerWireguardKeyUnknown bool) (err error) {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead check the Hostname field and determine if it's a PIA server from this? So there is no need for the extra argument 🤔

switch {
case s.VPN == "":
return fmt.Errorf("%w", ErrVPNFieldEmpty)
Expand All @@ -54,7 +54,7 @@ func (s *Server) HasMinimumInformation() (err error) {
return fmt.Errorf("%w", ErrNetworkProtocolSet)
case s.VPN == vpn.OpenVPN && !s.TCP && !s.UDP:
return fmt.Errorf("%w", ErrNoNetworkProtocol)
case s.VPN == vpn.Wireguard && s.WgPubKey == "":
case s.VPN == vpn.Wireguard && (s.WgPubKey == "" && !providerWireguardKeyUnknown):
return fmt.Errorf("%w", ErrWireguardPublicKeyEmpty)
default:
return nil
Expand Down
1 change: 1 addition & 0 deletions internal/provider/airvpn/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ type Provider struct {
storage common.Storage
randSource rand.Source
utils.NoPortForwarder
utils.NoWireguardConfigurator
common.Fetcher
}

Expand Down
1 change: 1 addition & 0 deletions internal/provider/custom/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
type Provider struct {
extractor Extractor
utils.NoPortForwarder
utils.NoWireguardConfigurator
common.Fetcher
}

Expand Down
1 change: 1 addition & 0 deletions internal/provider/cyberghost/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ type Provider struct {
storage common.Storage
randSource rand.Source
utils.NoPortForwarder
utils.NoWireguardConfigurator
common.Fetcher
}

Expand Down
1 change: 1 addition & 0 deletions internal/provider/example/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ type Provider struct {
storage common.Storage
randSource rand.Source
utils.NoPortForwarder
utils.NoWireguardConfigurator
common.Fetcher
}

Expand Down
1 change: 1 addition & 0 deletions internal/provider/expressvpn/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ type Provider struct {
storage common.Storage
randSource rand.Source
utils.NoPortForwarder
utils.NoWireguardConfigurator
common.Fetcher
}

Expand Down
1 change: 1 addition & 0 deletions internal/provider/fastestvpn/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ type Provider struct {
storage common.Storage
randSource rand.Source
utils.NoPortForwarder
utils.NoWireguardConfigurator
common.Fetcher
}

Expand Down
1 change: 1 addition & 0 deletions internal/provider/hidemyass/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ type Provider struct {
storage common.Storage
randSource rand.Source
utils.NoPortForwarder
utils.NoWireguardConfigurator
common.Fetcher
}

Expand Down
1 change: 1 addition & 0 deletions internal/provider/ipvanish/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ type Provider struct {
storage common.Storage
randSource rand.Source
utils.NoPortForwarder
utils.NoWireguardConfigurator
common.Fetcher
}

Expand Down
1 change: 1 addition & 0 deletions internal/provider/ivpn/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ type Provider struct {
storage common.Storage
randSource rand.Source
utils.NoPortForwarder
utils.NoWireguardConfigurator
common.Fetcher
}

Expand Down
1 change: 1 addition & 0 deletions internal/provider/mullvad/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ type Provider struct {
storage common.Storage
randSource rand.Source
utils.NoPortForwarder
utils.NoWireguardConfigurator
common.Fetcher
}

Expand Down
1 change: 1 addition & 0 deletions internal/provider/nordvpn/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ type Provider struct {
storage common.Storage
randSource rand.Source
utils.NoPortForwarder
utils.NoWireguardConfigurator
common.Fetcher
}

Expand Down
1 change: 1 addition & 0 deletions internal/provider/perfectprivacy/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ type Provider struct {
storage common.Storage
randSource rand.Source
utils.NoPortForwarder
utils.NoWireguardConfigurator
common.Fetcher
}

Expand Down
1 change: 1 addition & 0 deletions internal/provider/privado/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ type Provider struct {
storage common.Storage
randSource rand.Source
utils.NoPortForwarder
utils.NoWireguardConfigurator
common.Fetcher
}

Expand Down
1 change: 1 addition & 0 deletions internal/provider/privateinternetaccess/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ func (p *Provider) GetConnection(selection settings.ServerSelection, ipv6Support
defaults.OpenVPNTCPPort = 501
defaults.OpenVPNUDPPort = 1197
}
defaults.WireguardPort = 1337

return utils.GetConnection(p.Name(),
p.storage, selection, defaults, ipv6Supported, p.randSource)
Expand Down
90 changes: 5 additions & 85 deletions internal/provider/privateinternetaccess/portforward.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ func (p *Provider) KeepPortForward(ctx context.Context, _ uint16,

func refreshPIAPortForwardData(ctx context.Context, client, privateIPClient *http.Client,
gateway netip.Addr, portForwardPath, authFilePath string) (data piaPortForwardData, err error) {
data.Token, err = fetchToken(ctx, client, authFilePath)
data.Token, err = fetchToken(ctx, client, "client", authFilePath)
if err != nil {
return data, fmt.Errorf("fetching token: %w", err)
}
Expand Down Expand Up @@ -233,57 +233,6 @@ var (
errEmptyToken = errors.New("token received is empty")
)

func fetchToken(ctx context.Context, client *http.Client,
authFilePath string) (token string, err error) {
username, password, err := getOpenvpnCredentials(authFilePath)
if err != nil {
return "", fmt.Errorf("getting username and password: %w", err)
}

errSubstitutions := map[string]string{
url.QueryEscape(username): "<username>",
url.QueryEscape(password): "<password>",
}

form := url.Values{}
form.Add("username", username)
form.Add("password", password)
url := url.URL{
Scheme: "https",
Host: "www.privateinternetaccess.com",
Path: "/api/client/v2/token",
}
request, err := http.NewRequestWithContext(ctx, http.MethodPost, url.String(), strings.NewReader(form.Encode()))
if err != nil {
return "", replaceInErr(err, errSubstitutions)
}

request.Header.Add("Content-Type", "application/x-www-form-urlencoded")

response, err := client.Do(request)
if err != nil {
return "", replaceInErr(err, errSubstitutions)
}
defer response.Body.Close()

if response.StatusCode != http.StatusOK {
return "", makeNOKStatusError(response, errSubstitutions)
}

decoder := json.NewDecoder(response.Body)
var result struct {
Token string `json:"token"`
}
if err := decoder.Decode(&result); err != nil {
return "", fmt.Errorf("decoding response: %w", err)
}

if result.Token == "" {
return "", errEmptyToken
}
return result.Token, nil
}

var (
errAuthFileMalformed = errors.New("authentication file is malformed")
)
Expand Down Expand Up @@ -329,13 +278,13 @@ func fetchPortForwardData(ctx context.Context, client *http.Client, gateway neti
}
request, err := http.NewRequestWithContext(ctx, http.MethodGet, url.String(), nil)
if err != nil {
err = replaceInErr(err, errSubstitutions)
err = ReplaceInErr(err, errSubstitutions)
return 0, "", expiration, fmt.Errorf("obtaining signature payload: %w", err)
}

response, err := client.Do(request)
if err != nil {
err = replaceInErr(err, errSubstitutions)
err = ReplaceInErr(err, errSubstitutions)
return 0, "", expiration, fmt.Errorf("obtaining signature payload: %w", err)
}
defer response.Body.Close()
Expand Down Expand Up @@ -392,12 +341,12 @@ func bindPort(ctx context.Context, client *http.Client, gateway netip.Addr, data

request, err := http.NewRequestWithContext(ctx, http.MethodGet, bindPortURL.String(), nil)
if err != nil {
return replaceInErr(err, errSubstitutions)
return ReplaceInErr(err, errSubstitutions)
}

response, err := client.Do(request)
if err != nil {
return replaceInErr(err, errSubstitutions)
return ReplaceInErr(err, errSubstitutions)
}
defer response.Body.Close()

Expand All @@ -421,33 +370,4 @@ func bindPort(ctx context.Context, client *http.Client, gateway netip.Addr, data
return nil
}

// replaceInErr is used to remove sensitive information from errors.
func replaceInErr(err error, substitutions map[string]string) error {
s := replaceInString(err.Error(), substitutions)
return errors.New(s) //nolint:goerr113
}

// replaceInString is used to remove sensitive information.
func replaceInString(s string, substitutions map[string]string) string {
for old, new := range substitutions {
s = strings.ReplaceAll(s, old, new)
}
return s
}

var ErrHTTPStatusCodeNotOK = errors.New("HTTP status code is not OK")

func makeNOKStatusError(response *http.Response, substitutions map[string]string) (err error) {
url := response.Request.URL.String()
url = replaceInString(url, substitutions)

b, _ := io.ReadAll(response.Body)
shortenMessage := string(b)
shortenMessage = strings.ReplaceAll(shortenMessage, "\n", "")
shortenMessage = strings.ReplaceAll(shortenMessage, " ", " ")
shortenMessage = replaceInString(shortenMessage, substitutions)

return fmt.Errorf("%w: %s: %d %s: response received: %s",
ErrHTTPStatusCodeNotOK, url, response.StatusCode,
response.Status, shortenMessage)
}
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ func Test_replaceInString(t *testing.T) {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
result := replaceInString(testCase.s, testCase.substitutions)
result := ReplaceInString(testCase.s, testCase.substitutions)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please keep this unexported (lowercase first letter), there is no need to export it

assert.Equal(t, testCase.result, result)
})
}
Expand Down
72 changes: 72 additions & 0 deletions internal/provider/privateinternetaccess/tokenutils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package privateinternetaccess
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename file to token.go, avoid utils in filename or directory, it's rather meaningless


import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
"strings"
)

func fetchToken(ctx context.Context, client *http.Client,
tokenType string, authFilePath string) (token string, err error) {
Comment on lines +12 to +13
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since this is used outside openvpn-only, we should inject the username and password and not use the openvpn auth file anymore (especially if it's for Wireguard, that makes it super strange)

username, password, err := getOpenvpnCredentials(authFilePath)
if err != nil {
return "", fmt.Errorf("getting username and password: %w", err)
}

errSubstitutions := map[string]string{
url.QueryEscape(username): "<username>",
url.QueryEscape(password): "<password>",
}

var path string

switch tokenType {
case "client":
path = "/api/client/v2/token"
case "gtoken":
path = "/gtoken/generateToken"
default:
return "", fmt.Errorf("token type %q is not supported", tokenType)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you could even panic here, since that would be a programming error:

Suggested change
return "", fmt.Errorf("token type %q is not supported", tokenType)
panic(fmt.Sprintf("token type %q is not supported", tokenType))

}

form := url.Values{}
form.Add("username", username)
form.Add("password", password)
url := url.URL{
Scheme: "https",
Host: "www.privateinternetaccess.com",
Path: path,
}
request, err := http.NewRequestWithContext(ctx, http.MethodPost, url.String(), strings.NewReader(form.Encode()))
if err != nil {
return "", ReplaceInErr(err, errSubstitutions)
}

request.Header.Add("Content-Type", "application/x-www-form-urlencoded")

response, err := client.Do(request)
if err != nil {
return "", ReplaceInErr(err, errSubstitutions)
}
defer response.Body.Close()

if response.StatusCode != http.StatusOK {
return "", makeNOKStatusError(response, errSubstitutions)
}

decoder := json.NewDecoder(response.Body)
var result struct {
Token string `json:"token"`
}
if err := decoder.Decode(&result); err != nil {
return "", fmt.Errorf("decoding response: %w", err)
}

if result.Token == "" {
return "", errEmptyToken
}
return result.Token, nil
}
1 change: 1 addition & 0 deletions internal/provider/privateinternetaccess/updater/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ type regionData struct {
Servers struct {
UDP []serverData `json:"ovpnudp"`
TCP []serverData `json:"ovpntcp"`
WG []serverData `json:"wg"`
} `json:"servers"`
}

Expand Down
Loading