chunking away at it
parent
0a244be523
commit
f58f77bf1f
pkg/mastotypes
|
@ -21,6 +21,7 @@ package db
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/go-fed/activity/pub"
|
||||
|
@ -145,6 +146,10 @@ type DB interface {
|
|||
// C) something went wrong in the db
|
||||
IsEmailAvailable(email string) error
|
||||
|
||||
// NewSignup creates a new user in the database with the given parameters, with an *unconfirmed* email address.
|
||||
// By the time this function is called, it should be assumed that all the parameters have passed validation!
|
||||
NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string) (*model.User, error)
|
||||
|
||||
/*
|
||||
USEFUL CONVERSION FUNCTIONS
|
||||
*/
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
package model
|
||||
|
||||
import (
|
||||
"crypto/rsa"
|
||||
"net/url"
|
||||
"time"
|
||||
)
|
||||
|
@ -82,6 +83,8 @@ type Account struct {
|
|||
SubscriptionExpiresAt time.Time `pg:"type:timestamp"`
|
||||
// Does this account identify itself as a bot?
|
||||
Bot bool
|
||||
// What reason was given for signing up when this account was created?
|
||||
Reason string
|
||||
|
||||
/*
|
||||
PRIVACY SETTINGS
|
||||
|
@ -123,9 +126,9 @@ type Account struct {
|
|||
|
||||
Secret string
|
||||
// Privatekey for validating activitypub requests, will obviously only be defined for local accounts
|
||||
PrivateKey string
|
||||
PrivateKey *rsa.PrivateKey
|
||||
// Publickey for encoding activitypub requests, will be defined for both local and remote accounts
|
||||
PublicKey string
|
||||
PublicKey *rsa.PublicKey
|
||||
|
||||
/*
|
||||
ADMIN FIELDS
|
||||
|
|
|
@ -35,13 +35,13 @@ type DomainBlock struct {
|
|||
// Account ID of the creator of this block
|
||||
CreatedByAccountID string `pg:",notnull"`
|
||||
// TODO: define this
|
||||
Severity int
|
||||
Severity int
|
||||
// Reject media from this domain?
|
||||
RejectMedia bool
|
||||
RejectMedia bool
|
||||
// Reject reports from this domain?
|
||||
RejectReports bool
|
||||
RejectReports bool
|
||||
// Private comment on this block, viewable to admins
|
||||
PrivateComment string
|
||||
PrivateComment string
|
||||
// Public comment on this block, viewable (optionally) by everyone
|
||||
PublicComment string
|
||||
PublicComment string
|
||||
}
|
||||
|
|
|
@ -20,8 +20,11 @@ package db
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/mail"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
@ -35,6 +38,7 @@ import (
|
|||
"github.com/gotosocial/gotosocial/internal/db/model"
|
||||
"github.com/gotosocial/gotosocial/pkg/mastotypes"
|
||||
"github.com/sirupsen/logrus"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
// postgresService satisfies the DB interface
|
||||
|
@ -305,7 +309,6 @@ func (ps *postgresService) GetAccountByUserID(userID string, account *model.Acco
|
|||
return err
|
||||
}
|
||||
if err := ps.conn.Model(account).Where("id = ?", user.AccountID).Select(); err != nil {
|
||||
fmt.Println(account)
|
||||
if err == pg.ErrNoRows {
|
||||
return ErrNoEntries{}
|
||||
}
|
||||
|
@ -400,7 +403,7 @@ func (ps *postgresService) IsEmailAvailable(email string) error {
|
|||
// fail because we got an unexpected error
|
||||
return fmt.Errorf("db error: %s", err)
|
||||
}
|
||||
|
||||
|
||||
// check if this email is associated with an account already
|
||||
if err := ps.conn.Model(&model.Account{}).Where("email = ?", email).WhereOr("unconfirmed_email = ?", email).Select(); err == nil {
|
||||
// fail because we found something
|
||||
|
@ -412,6 +415,43 @@ func (ps *postgresService) IsEmailAvailable(email string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (ps *postgresService) NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string) (*model.User, error) {
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
ps.log.Errorf("error creating new rsa key: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
a := &model.Account{
|
||||
Username: username,
|
||||
DisplayName: username,
|
||||
Reason: reason,
|
||||
PrivateKey: key,
|
||||
PublicKey: &key.PublicKey,
|
||||
ActorType: "Person",
|
||||
}
|
||||
if _, err = ps.conn.Model(a).Insert(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pw, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error hashing password: %s", err)
|
||||
}
|
||||
u := &model.User{
|
||||
AccountID: a.ID,
|
||||
EncryptedPassword: string(pw),
|
||||
SignUpIP: signUpIP,
|
||||
Locale: locale,
|
||||
UnconfirmedEmail: email,
|
||||
}
|
||||
if _, err = ps.conn.Model(u).Insert(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return u, nil
|
||||
}
|
||||
|
||||
/*
|
||||
CONVERSION FUNCTIONS
|
||||
*/
|
||||
|
@ -433,7 +473,6 @@ func (ps *postgresService) AccountToMastoSensitive(a *model.Account) (*mastotype
|
|||
}
|
||||
fields = append(fields, mField)
|
||||
}
|
||||
fmt.Printf("fields: %+v", fields)
|
||||
|
||||
// count followers
|
||||
followers := []model.Follow{}
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
/*
|
||||
GoToSocial
|
||||
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package db
|
||||
|
||||
// TODO: write tests for postgres
|
|
@ -1,3 +1,21 @@
|
|||
/*
|
||||
GoToSocial
|
||||
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package db
|
||||
|
||||
import (
|
|
@ -0,0 +1,21 @@
|
|||
/*
|
||||
GoToSocial
|
||||
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package db
|
||||
|
||||
// TODO: write tests for pgfed
|
|
@ -19,6 +19,8 @@
|
|||
package account
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
@ -26,9 +28,10 @@ import (
|
|||
"github.com/gotosocial/gotosocial/internal/db"
|
||||
"github.com/gotosocial/gotosocial/internal/db/model"
|
||||
"github.com/gotosocial/gotosocial/internal/module"
|
||||
"github.com/gotosocial/gotosocial/internal/module/oauth"
|
||||
"github.com/gotosocial/gotosocial/internal/oauth"
|
||||
"github.com/gotosocial/gotosocial/internal/router"
|
||||
"github.com/gotosocial/gotosocial/pkg/mastotypes"
|
||||
"github.com/gotosocial/oauth2/v4"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
|
@ -39,9 +42,10 @@ const (
|
|||
)
|
||||
|
||||
type accountModule struct {
|
||||
config *config.Config
|
||||
db db.DB
|
||||
log *logrus.Logger
|
||||
config *config.Config
|
||||
db db.DB
|
||||
oauthServer oauth.Server
|
||||
log *logrus.Logger
|
||||
}
|
||||
|
||||
// New returns a new account module
|
||||
|
@ -60,15 +64,15 @@ func (m *accountModule) Route(r router.Router) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// accountCreatePOSTHandler handles create account requests, validates them,
|
||||
// and puts them in the database if they're valid.
|
||||
// It should be served as a POST at /api/v1/accounts
|
||||
func (m *accountModule) accountCreatePOSTHandler(c *gin.Context) {
|
||||
l := m.log.WithField("func", "AccountCreatePOSTHandler")
|
||||
// TODO: check whether a valid app token has been presented!!
|
||||
// See: https://docs.joinmastodon.org/methods/accounts/
|
||||
|
||||
l.Trace("checking if registration is open")
|
||||
if !m.config.AccountsConfig.OpenRegistration {
|
||||
l.Debug("account registration is closed, returning error to client")
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "account registration is closed"})
|
||||
l := m.log.WithField("func", "accountCreatePOSTHandler")
|
||||
authed, err := oauth.GetAuthed(c)
|
||||
if err != nil {
|
||||
l.Debugf("couldn't auth: %s", err)
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -81,15 +85,34 @@ func (m *accountModule) accountCreatePOSTHandler(c *gin.Context) {
|
|||
}
|
||||
|
||||
l.Tracef("validating form %+v", form)
|
||||
if err := validateCreateAccount(form, m.config.AccountsConfig.ReasonRequired, m.db); err != nil {
|
||||
if err := validateCreateAccount(form, m.config.AccountsConfig, m.db); err != nil {
|
||||
l.Debugf("error validating form: %s", err)
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
clientIP := c.ClientIP()
|
||||
l.Tracef("attempting to parse client ip address %s", clientIP)
|
||||
signUpIP := net.ParseIP(clientIP)
|
||||
if signUpIP == nil {
|
||||
l.Debugf("error validating sign up ip address %s", clientIP)
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "ip address could not be parsed from request"})
|
||||
return
|
||||
}
|
||||
|
||||
ti, err := m.accountCreate(form, signUpIP, authed.Token, authed.Application)
|
||||
if err != nil {
|
||||
l.Errorf("internal server error while creating new account: %s", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, ti)
|
||||
}
|
||||
|
||||
// accountVerifyGETHandler serves a user's account details to them IF they reached this
|
||||
// handler while in possession of a valid token, according to the oauth middleware.
|
||||
// It should be served as a GET at /api/v1/accounts/verify_credentials
|
||||
func (m *accountModule) accountVerifyGETHandler(c *gin.Context) {
|
||||
l := m.log.WithField("func", "AccountVerifyGETHandler")
|
||||
|
||||
|
@ -120,3 +143,39 @@ func (m *accountModule) accountVerifyGETHandler(c *gin.Context) {
|
|||
l.Tracef("conversion successful, returning OK and mastosensitive account %+v", acctSensitive)
|
||||
c.JSON(http.StatusOK, acctSensitive)
|
||||
}
|
||||
|
||||
/*
|
||||
HELPER FUNCTIONS
|
||||
*/
|
||||
|
||||
// accountCreate does the dirty work of making an account and user in the database.
|
||||
// It then returns a token to the caller, for use with the new account, as per the
|
||||
// spec here: https://docs.joinmastodon.org/methods/accounts/
|
||||
func (m *accountModule) accountCreate(form *mastotypes.AccountCreateRequest, signUpIP net.IP, token oauth2.TokenInfo, app *model.Application) (*mastotypes.Token, error) {
|
||||
l := m.log.WithField("func", "accountCreate")
|
||||
|
||||
// don't store a reason if we don't require one
|
||||
reason := form.Reason
|
||||
if !m.config.AccountsConfig.ReasonRequired {
|
||||
reason = ""
|
||||
}
|
||||
|
||||
l.Trace("creating new username and account")
|
||||
user, err := m.db.NewSignup(form.Username, reason, m.config.AccountsConfig.RequireApproval, form.Email, form.Password, signUpIP, form.Locale)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating new signup in the database: %s", err)
|
||||
}
|
||||
|
||||
l.Tracef("generating a token for user %s with account %s and application %s", user.ID, user.AccountID, app.ID)
|
||||
ti, err := m.oauthServer.GenerateUserAccessToken(token, app.ClientSecret, user.ID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating new access token for user %s: %s", user.ID, err)
|
||||
}
|
||||
|
||||
return &mastotypes.Token{
|
||||
AccessToken: ti.GetCode(),
|
||||
TokenType: "Bearer",
|
||||
Scope: ti.GetScope(),
|
||||
CreatedAt: ti.GetCodeCreateAt().Unix(),
|
||||
}, nil
|
||||
}
|
||||
|
|
|
@ -20,34 +20,33 @@ package account
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/gotosocial/gotosocial/internal/config"
|
||||
"github.com/gotosocial/gotosocial/internal/db"
|
||||
"github.com/gotosocial/gotosocial/internal/db/model"
|
||||
"github.com/gotosocial/gotosocial/internal/module/oauth"
|
||||
"github.com/gotosocial/gotosocial/internal/router"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
type AccountTestSuite struct {
|
||||
suite.Suite
|
||||
db db.DB
|
||||
log *logrus.Logger
|
||||
testAccountLocal *model.Account
|
||||
testAccountRemote *model.Account
|
||||
testUser *model.User
|
||||
config *config.Config
|
||||
db db.DB
|
||||
accountModule *accountModule
|
||||
}
|
||||
|
||||
// SetupSuite sets some variables on the suite that we can use as consts (more or less) throughout
|
||||
func (suite *AccountTestSuite) SetupSuite() {
|
||||
log := logrus.New()
|
||||
log.SetLevel(logrus.TraceLevel)
|
||||
suite.log = log
|
||||
|
||||
c := config.Empty()
|
||||
c.DBConfig = &config.DBConfig{
|
||||
Type: "postgres",
|
||||
|
@ -58,118 +57,126 @@ func (suite *AccountTestSuite) SetupSuite() {
|
|||
Database: "postgres",
|
||||
ApplicationName: "gotosocial",
|
||||
}
|
||||
suite.config = c
|
||||
|
||||
encryptedPassword, err := bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost)
|
||||
database, err := db.New(context.Background(), c, log)
|
||||
if err != nil {
|
||||
logrus.Panicf("error encrypting user pass: %s", err)
|
||||
suite.FailNow(err.Error())
|
||||
}
|
||||
suite.db = database
|
||||
|
||||
suite.accountModule = &accountModule{
|
||||
config: c,
|
||||
db: database,
|
||||
log: log,
|
||||
}
|
||||
|
||||
localAvatar, err := url.Parse("https://localhost:8080/media/aaaaaaaaa.png")
|
||||
if err != nil {
|
||||
logrus.Panicf("error parsing localavatar url: %s", err)
|
||||
}
|
||||
localHeader, err := url.Parse("https://localhost:8080/media/ffffffffff.png")
|
||||
if err != nil {
|
||||
logrus.Panicf("error parsing localheader url: %s", err)
|
||||
}
|
||||
// encryptedPassword, err := bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost)
|
||||
// if err != nil {
|
||||
// logrus.Panicf("error encrypting user pass: %s", err)
|
||||
// }
|
||||
|
||||
acctID := uuid.NewString()
|
||||
suite.testAccountLocal = &model.Account{
|
||||
ID: acctID,
|
||||
Username: "local_account_of_some_kind",
|
||||
AvatarRemoteURL: localAvatar,
|
||||
HeaderRemoteURL: localHeader,
|
||||
DisplayName: "michael caine",
|
||||
Fields: []model.Field{
|
||||
{
|
||||
Name: "come and ave a go",
|
||||
Value: "if you think you're hard enough",
|
||||
},
|
||||
{
|
||||
Name: "website",
|
||||
Value: "https://imdb.com",
|
||||
VerifiedAt: time.Now(),
|
||||
},
|
||||
},
|
||||
Note: "My name is Michael Caine and i'm a local user.",
|
||||
Discoverable: true,
|
||||
}
|
||||
// localAvatar, err := url.Parse("https://localhost:8080/media/aaaaaaaaa.png")
|
||||
// if err != nil {
|
||||
// logrus.Panicf("error parsing localavatar url: %s", err)
|
||||
// }
|
||||
// localHeader, err := url.Parse("https://localhost:8080/media/ffffffffff.png")
|
||||
// if err != nil {
|
||||
// logrus.Panicf("error parsing localheader url: %s", err)
|
||||
// }
|
||||
|
||||
avatarURL, err := url.Parse("http://example.org/accounts/avatars/000/207/122/original/089-1098-09.png")
|
||||
if err != nil {
|
||||
logrus.Panicf("error parsing avatarURL: %s", err)
|
||||
}
|
||||
// acctID := uuid.NewString()
|
||||
// suite.testAccountLocal = &model.Account{
|
||||
// ID: acctID,
|
||||
// Username: "local_account_of_some_kind",
|
||||
// AvatarRemoteURL: localAvatar,
|
||||
// HeaderRemoteURL: localHeader,
|
||||
// DisplayName: "michael caine",
|
||||
// Fields: []model.Field{
|
||||
// {
|
||||
// Name: "come and ave a go",
|
||||
// Value: "if you think you're hard enough",
|
||||
// },
|
||||
// {
|
||||
// Name: "website",
|
||||
// Value: "https://imdb.com",
|
||||
// VerifiedAt: time.Now(),
|
||||
// },
|
||||
// },
|
||||
// Note: "My name is Michael Caine and i'm a local user.",
|
||||
// Discoverable: true,
|
||||
// }
|
||||
|
||||
headerURL, err := url.Parse("http://example.org/accounts/headers/000/207/122/original/111111111111.png")
|
||||
if err != nil {
|
||||
logrus.Panicf("error parsing avatarURL: %s", err)
|
||||
}
|
||||
suite.testAccountRemote = &model.Account{
|
||||
ID: uuid.NewString(),
|
||||
Username: "neato_bombeato",
|
||||
Domain: "example.org",
|
||||
// avatarURL, err := url.Parse("http://example.org/accounts/avatars/000/207/122/original/089-1098-09.png")
|
||||
// if err != nil {
|
||||
// logrus.Panicf("error parsing avatarURL: %s", err)
|
||||
// }
|
||||
|
||||
AvatarFileName: "avatar.png",
|
||||
AvatarContentType: "image/png",
|
||||
AvatarFileSize: 1024,
|
||||
AvatarUpdatedAt: time.Now(),
|
||||
AvatarRemoteURL: avatarURL,
|
||||
// headerURL, err := url.Parse("http://example.org/accounts/headers/000/207/122/original/111111111111.png")
|
||||
// if err != nil {
|
||||
// logrus.Panicf("error parsing avatarURL: %s", err)
|
||||
// }
|
||||
// suite.testAccountRemote = &model.Account{
|
||||
// ID: uuid.NewString(),
|
||||
// Username: "neato_bombeato",
|
||||
// Domain: "example.org",
|
||||
|
||||
HeaderFileName: "avatar.png",
|
||||
HeaderContentType: "image/png",
|
||||
HeaderFileSize: 1024,
|
||||
HeaderUpdatedAt: time.Now(),
|
||||
HeaderRemoteURL: headerURL,
|
||||
// AvatarFileName: "avatar.png",
|
||||
// AvatarContentType: "image/png",
|
||||
// AvatarFileSize: 1024,
|
||||
// AvatarUpdatedAt: time.Now(),
|
||||
// AvatarRemoteURL: avatarURL,
|
||||
|
||||
DisplayName: "one cool dude 420",
|
||||
Fields: []model.Field{
|
||||
{
|
||||
Name: "pronouns",
|
||||
Value: "he/they",
|
||||
},
|
||||
{
|
||||
Name: "website",
|
||||
Value: "https://imcool.edu",
|
||||
VerifiedAt: time.Now(),
|
||||
},
|
||||
},
|
||||
Note: "<p>I'm cool as heck!</p>",
|
||||
Discoverable: true,
|
||||
URI: "https://example.org/users/neato_bombeato",
|
||||
URL: "https://example.org/@neato_bombeato",
|
||||
LastWebfingeredAt: time.Now(),
|
||||
InboxURL: "https://example.org/users/neato_bombeato/inbox",
|
||||
OutboxURL: "https://example.org/users/neato_bombeato/outbox",
|
||||
SharedInboxURL: "https://example.org/inbox",
|
||||
FollowersURL: "https://example.org/users/neato_bombeato/followers",
|
||||
FeaturedCollectionURL: "https://example.org/users/neato_bombeato/collections/featured",
|
||||
}
|
||||
suite.testUser = &model.User{
|
||||
ID: uuid.NewString(),
|
||||
EncryptedPassword: string(encryptedPassword),
|
||||
Email: "user@example.org",
|
||||
AccountID: acctID,
|
||||
// HeaderFileName: "avatar.png",
|
||||
// HeaderContentType: "image/png",
|
||||
// HeaderFileSize: 1024,
|
||||
// HeaderUpdatedAt: time.Now(),
|
||||
// HeaderRemoteURL: headerURL,
|
||||
|
||||
// DisplayName: "one cool dude 420",
|
||||
// Fields: []model.Field{
|
||||
// {
|
||||
// Name: "pronouns",
|
||||
// Value: "he/they",
|
||||
// },
|
||||
// {
|
||||
// Name: "website",
|
||||
// Value: "https://imcool.edu",
|
||||
// VerifiedAt: time.Now(),
|
||||
// },
|
||||
// },
|
||||
// Note: "<p>I'm cool as heck!</p>",
|
||||
// Discoverable: true,
|
||||
// URI: "https://example.org/users/neato_bombeato",
|
||||
// URL: "https://example.org/@neato_bombeato",
|
||||
// LastWebfingeredAt: time.Now(),
|
||||
// InboxURL: "https://example.org/users/neato_bombeato/inbox",
|
||||
// OutboxURL: "https://example.org/users/neato_bombeato/outbox",
|
||||
// SharedInboxURL: "https://example.org/inbox",
|
||||
// FollowersURL: "https://example.org/users/neato_bombeato/followers",
|
||||
// FeaturedCollectionURL: "https://example.org/users/neato_bombeato/collections/featured",
|
||||
// }
|
||||
// suite.testUser = &model.User{
|
||||
// ID: uuid.NewString(),
|
||||
// EncryptedPassword: string(encryptedPassword),
|
||||
// Email: "user@example.org",
|
||||
// AccountID: acctID,
|
||||
// }
|
||||
}
|
||||
|
||||
func (suite *AccountTestSuite) TearDownSuite() {
|
||||
if err := suite.db.Stop(context.Background()); err != nil {
|
||||
logrus.Panicf("error closing db connection: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// SetupTest creates a postgres connection and creates the oauth_clients table before each test
|
||||
// SetupTest creates a db connection and creates necessary tables before each test
|
||||
func (suite *AccountTestSuite) SetupTest() {
|
||||
|
||||
log := logrus.New()
|
||||
log.SetLevel(logrus.TraceLevel)
|
||||
db, err := db.New(context.Background(), suite.config, log)
|
||||
if err != nil {
|
||||
logrus.Panicf("error creating database connection: %s", err)
|
||||
}
|
||||
|
||||
suite.db = db
|
||||
|
||||
models := []interface{}{
|
||||
&model.User{},
|
||||
&model.Account{},
|
||||
&model.Follow{},
|
||||
&model.Status{},
|
||||
&model.Application{},
|
||||
}
|
||||
|
||||
for _, m := range models {
|
||||
|
@ -177,70 +184,31 @@ func (suite *AccountTestSuite) SetupTest() {
|
|||
logrus.Panicf("db connection error: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := suite.db.Put(suite.testAccountLocal); err != nil {
|
||||
logrus.Panicf("could not insert test account into db: %s", err)
|
||||
}
|
||||
if err := suite.db.Put(suite.testUser); err != nil {
|
||||
logrus.Panicf("could not insert test user into db: %s", err)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// TearDownTest drops the oauth_clients table and closes the pg connection after each test
|
||||
// TearDownTest drops tables to make sure there's no data in the db
|
||||
func (suite *AccountTestSuite) TearDownTest() {
|
||||
models := []interface{}{
|
||||
&model.User{},
|
||||
&model.Account{},
|
||||
&model.Follow{},
|
||||
&model.Status{},
|
||||
&model.Application{},
|
||||
}
|
||||
for _, m := range models {
|
||||
if err := suite.db.DropTable(m); err != nil {
|
||||
logrus.Panicf("error dropping table: %s", err)
|
||||
}
|
||||
}
|
||||
if err := suite.db.Stop(context.Background()); err != nil {
|
||||
logrus.Panicf("error closing db connection: %s", err)
|
||||
}
|
||||
suite.db = nil
|
||||
}
|
||||
|
||||
func (suite *AccountTestSuite) TestAPIInitialize() {
|
||||
log := logrus.New()
|
||||
log.SetLevel(logrus.TraceLevel)
|
||||
|
||||
r, err := router.New(suite.config, log)
|
||||
if err != nil {
|
||||
suite.FailNow(fmt.Sprintf("error creating router: %s", err))
|
||||
}
|
||||
|
||||
r.AttachMiddleware(func(c *gin.Context) {
|
||||
account := &model.Account{}
|
||||
if err := suite.db.GetAccountByUserID(suite.testUser.ID, account); err != nil || account == nil {
|
||||
suite.T().Log(err)
|
||||
suite.FailNowf("no account found for user %s, continuing with unauthenticated request: %+v", "", suite.testUser.ID, account)
|
||||
fmt.Println(account)
|
||||
return
|
||||
}
|
||||
|
||||
c.Set(oauth.SessionAuthorizedAccount, account)
|
||||
c.Set(oauth.SessionAuthorizedUser, suite.testUser.ID)
|
||||
})
|
||||
|
||||
acct := New(suite.config, suite.db, log)
|
||||
if err := acct.Route(r); err != nil {
|
||||
suite.FailNow(fmt.Sprintf("error mapping routes onto router: %s", err))
|
||||
}
|
||||
|
||||
r.Start()
|
||||
defer func() {
|
||||
if err := r.Stop(context.Background()); err != nil {
|
||||
panic(fmt.Errorf("error stopping router: %s", err))
|
||||
}
|
||||
}()
|
||||
time.Sleep(10 * time.Second)
|
||||
|
||||
func (suite *AccountTestSuite) TestAccountCreatePOSTHandler() {
|
||||
// TODO: figure out how to test this properly
|
||||
recorder := httptest.NewRecorder()
|
||||
recorder.Header().Set("X-Forwarded-For", "127.0.0.1")
|
||||
ctx, _ := gin.CreateTestContext(recorder)
|
||||
// ctx.Set()
|
||||
suite.accountModule.accountCreatePOSTHandler(ctx)
|
||||
}
|
||||
|
||||
func TestAccountTestSuite(t *testing.T) {
|
||||
|
|
|
@ -21,12 +21,17 @@ package account
|
|||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/gotosocial/gotosocial/internal/config"
|
||||
"github.com/gotosocial/gotosocial/internal/db"
|
||||
"github.com/gotosocial/gotosocial/internal/util"
|
||||
"github.com/gotosocial/gotosocial/pkg/mastotypes"
|
||||
)
|
||||
|
||||
func validateCreateAccount(form *mastotypes.AccountCreateRequest, reasonRequired bool, database db.DB) error {
|
||||
func validateCreateAccount(form *mastotypes.AccountCreateRequest, c *config.AccountsConfig, database db.DB) error {
|
||||
if !c.OpenRegistration {
|
||||
return errors.New("registration is not open for this server")
|
||||
}
|
||||
|
||||
if err := util.ValidateSignUpUsername(form.Username); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -47,7 +52,7 @@ func validateCreateAccount(form *mastotypes.AccountCreateRequest, reasonRequired
|
|||
return err
|
||||
}
|
||||
|
||||
if err := util.ValidateSignUpReason(form.Reason, reasonRequired); err != nil {
|
||||
if err := util.ValidateSignUpReason(form.Reason, c.ReasonRequired); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,140 @@
|
|||
/*
|
||||
GoToSocial
|
||||
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package app
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/gotosocial/gotosocial/internal/db"
|
||||
"github.com/gotosocial/gotosocial/internal/db/model"
|
||||
"github.com/gotosocial/gotosocial/internal/module"
|
||||
"github.com/gotosocial/gotosocial/internal/oauth"
|
||||
"github.com/gotosocial/gotosocial/internal/router"
|
||||
"github.com/gotosocial/gotosocial/pkg/mastotypes"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const appsPath = "/api/v1/apps"
|
||||
|
||||
type appModule struct {
|
||||
server oauth.Server
|
||||
db db.DB
|
||||
log *logrus.Logger
|
||||
}
|
||||
|
||||
// New returns a new auth module
|
||||
func New(srv oauth.Server, db db.DB, log *logrus.Logger) module.ClientAPIModule {
|
||||
return &appModule{
|
||||
server: srv,
|
||||
db: db,
|
||||
log: log,
|
||||
}
|
||||
}
|
||||
|
||||
// Route satisfies the RESTAPIModule interface
|
||||
func (m *appModule) Route(s router.Router) error {
|
||||
s.AttachHandler(http.MethodPost, appsPath, m.appsPOSTHandler)
|
||||
return nil
|
||||
}
|
||||
|
||||
// appsPOSTHandler should be served at https://example.org/api/v1/apps
|
||||
// It is equivalent to: https://docs.joinmastodon.org/methods/apps/
|
||||
func (m *appModule) appsPOSTHandler(c *gin.Context) {
|
||||
l := m.log.WithField("func", "AppsPOSTHandler")
|
||||
l.Trace("entering AppsPOSTHandler")
|
||||
|
||||
form := &mastotypes.ApplicationPOSTRequest{}
|
||||
if err := c.ShouldBind(form); err != nil {
|
||||
c.JSON(http.StatusUnprocessableEntity, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// permitted length for most fields
|
||||
permittedLength := 64
|
||||
// redirect can be a bit bigger because we probably need to encode data in the redirect uri
|
||||
permittedRedirect := 256
|
||||
|
||||
// check lengths of fields before proceeding so the user can't spam huge entries into the database
|
||||
if len(form.ClientName) > permittedLength {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("client_name must be less than %d bytes", permittedLength)})
|
||||
return
|
||||
}
|
||||
if len(form.Website) > permittedLength {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("website must be less than %d bytes", permittedLength)})
|
||||
return
|
||||
}
|
||||
if len(form.RedirectURIs) > permittedRedirect {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("redirect_uris must be less than %d bytes", permittedRedirect)})
|
||||
return
|
||||
}
|
||||
if len(form.Scopes) > permittedLength {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("scopes must be less than %d bytes", permittedLength)})
|
||||
return
|
||||
}
|
||||
|
||||
// set default 'read' for scopes if it's not set, this follows the default of the mastodon api https://docs.joinmastodon.org/methods/apps/
|
||||
var scopes string
|
||||
if form.Scopes == "" {
|
||||
scopes = "read"
|
||||
} else {
|
||||
scopes = form.Scopes
|
||||
}
|
||||
|
||||
// generate new IDs for this application and its associated client
|
||||
clientID := uuid.NewString()
|
||||
clientSecret := uuid.NewString()
|
||||
vapidKey := uuid.NewString()
|
||||
|
||||
// generate the application to put in the database
|
||||
app := &model.Application{
|
||||
Name: form.ClientName,
|
||||
Website: form.Website,
|
||||
RedirectURI: form.RedirectURIs,
|
||||
ClientID: clientID,
|
||||
ClientSecret: clientSecret,
|
||||
Scopes: scopes,
|
||||
VapidKey: vapidKey,
|
||||
}
|
||||
|
||||
// chuck it in the db
|
||||
if err := m.db.Put(app); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// now we need to model an oauth client from the application that the oauth library can use
|
||||
oc := &oauth.Client{
|
||||
ID: clientID,
|
||||
Secret: clientSecret,
|
||||
Domain: form.RedirectURIs,
|
||||
UserID: "", // This client isn't yet associated with a specific user, it's just an app client right now
|
||||
}
|
||||
|
||||
// chuck it in the db
|
||||
if err := m.db.Put(oc); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// done, return the new app information per the spec here: https://docs.joinmastodon.org/methods/apps/
|
||||
c.JSON(http.StatusOK, app.ToMasto())
|
||||
}
|
|
@ -0,0 +1,21 @@
|
|||
/*
|
||||
GoToSocial
|
||||
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package app
|
||||
|
||||
// TODO: write tests
|
|
@ -1,4 +1,4 @@
|
|||
# oauth
|
||||
# auth
|
||||
|
||||
This package provides uses the [GoToSocial oauth2](https://github.com/gotosocial/oauth2) module (forked from [go-oauth2](https://github.com/go-oauth2/oauth2)) to provide [oauth2](https://www.oauth.com/) functionality to the GoToSocial client API.
|
||||
|
|
@ -16,57 +16,42 @@
|
|||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
// Package oauth is a module that provides oauth functionality to a router.
|
||||
// Package auth is a module that provides oauth functionality to a router.
|
||||
// It adds the following paths:
|
||||
// /api/v1/apps
|
||||
// /auth/sign_in
|
||||
// /oauth/token
|
||||
// /oauth/authorize
|
||||
// It also includes the oauthTokenMiddleware, which can be attached to a router to authenticate every request by Bearer token.
|
||||
package oauth
|
||||
package auth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/gotosocial/gotosocial/internal/db"
|
||||
"github.com/gotosocial/gotosocial/internal/db/model"
|
||||
"github.com/gotosocial/gotosocial/internal/module"
|
||||
"github.com/gotosocial/gotosocial/internal/oauth"
|
||||
"github.com/gotosocial/gotosocial/internal/router"
|
||||
"github.com/gotosocial/gotosocial/pkg/mastotypes"
|
||||
"github.com/gotosocial/oauth2/v4"
|
||||
"github.com/gotosocial/oauth2/v4/errors"
|
||||
"github.com/gotosocial/oauth2/v4/manage"
|
||||
"github.com/gotosocial/oauth2/v4/server"
|
||||
"github.com/sirupsen/logrus"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
const (
|
||||
appsPath = "/api/v1/apps"
|
||||
authSignInPath = "/auth/sign_in"
|
||||
oauthTokenPath = "/oauth/token"
|
||||
oauthAuthorizePath = "/oauth/authorize"
|
||||
// SessionAuthorizedUser is the key set in the gin context for the id of
|
||||
// a User who has successfully passed Bearer token authorization.
|
||||
// The interface returned from grabbing this key should be parsed as a string.
|
||||
SessionAuthorizedUser = "authorized_user"
|
||||
// SessionAuthorizedAccount is the key set in the gin context for the Account
|
||||
// of a User who has successfully passed Bearer token authorization.
|
||||
// The interface returned from grabbing this key should be parsed as a *gtsmodel.Account
|
||||
SessionAuthorizedAccount = "authorized_account"
|
||||
)
|
||||
|
||||
// oauthModule is an oauth2 oauthModule that satisfies the ClientAPIModule interface
|
||||
type oauthModule struct {
|
||||
oauthManager *manage.Manager
|
||||
oauthServer *server.Server
|
||||
db db.DB
|
||||
log *logrus.Logger
|
||||
type authModule struct {
|
||||
server oauth.Server
|
||||
db db.DB
|
||||
log *logrus.Logger
|
||||
}
|
||||
|
||||
type login struct {
|
||||
|
@ -74,52 +59,17 @@ type login struct {
|
|||
Password string `form:"password"`
|
||||
}
|
||||
|
||||
// New returns a new oauth module
|
||||
func New(ts oauth2.TokenStore, cs oauth2.ClientStore, db db.DB, log *logrus.Logger) module.ClientAPIModule {
|
||||
manager := manage.NewDefaultManager()
|
||||
manager.MapTokenStorage(ts)
|
||||
manager.MapClientStorage(cs)
|
||||
manager.SetAuthorizeCodeTokenCfg(manage.DefaultAuthorizeCodeTokenCfg)
|
||||
sc := &server.Config{
|
||||
TokenType: "Bearer",
|
||||
// Must follow the spec.
|
||||
AllowGetAccessRequest: false,
|
||||
// Support only the non-implicit flow.
|
||||
AllowedResponseTypes: []oauth2.ResponseType{oauth2.Code},
|
||||
// Allow:
|
||||
// - Authorization Code (for first & third parties)
|
||||
AllowedGrantTypes: []oauth2.GrantType{
|
||||
oauth2.AuthorizationCode,
|
||||
},
|
||||
AllowedCodeChallengeMethods: []oauth2.CodeChallengeMethod{oauth2.CodeChallengePlain},
|
||||
// New returns a new auth module
|
||||
func New(srv oauth.Server, db db.DB, log *logrus.Logger) module.ClientAPIModule {
|
||||
return &authModule{
|
||||
server: srv,
|
||||
db: db,
|
||||
log: log,
|
||||
}
|
||||
|
||||
srv := server.NewServer(sc, manager)
|
||||
srv.SetInternalErrorHandler(func(err error) *errors.Response {
|
||||
log.Errorf("internal oauth error: %s", err)
|
||||
return nil
|
||||
})
|
||||
|
||||
srv.SetResponseErrorHandler(func(re *errors.Response) {
|
||||
log.Errorf("internal response error: %s", re.Error)
|
||||
})
|
||||
|
||||
m := &oauthModule{
|
||||
oauthManager: manager,
|
||||
oauthServer: srv,
|
||||
db: db,
|
||||
log: log,
|
||||
}
|
||||
|
||||
m.oauthServer.SetUserAuthorizationHandler(m.userAuthorizationHandler)
|
||||
m.oauthServer.SetClientInfoHandler(server.ClientFormHandler)
|
||||
return m
|
||||
}
|
||||
|
||||
// Route satisfies the RESTAPIModule interface
|
||||
func (m *oauthModule) Route(s router.Router) error {
|
||||
s.AttachHandler(http.MethodPost, appsPath, m.appsPOSTHandler)
|
||||
|
||||
func (m *authModule) Route(s router.Router) error {
|
||||
s.AttachHandler(http.MethodGet, authSignInPath, m.signInGETHandler)
|
||||
s.AttachHandler(http.MethodPost, authSignInPath, m.signInPOSTHandler)
|
||||
|
||||
|
@ -129,7 +79,6 @@ func (m *oauthModule) Route(s router.Router) error {
|
|||
s.AttachHandler(http.MethodPost, oauthAuthorizePath, m.authorizePOSTHandler)
|
||||
|
||||
s.AttachMiddleware(m.oauthTokenMiddleware)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -137,93 +86,10 @@ func (m *oauthModule) Route(s router.Router) error {
|
|||
MAIN HANDLERS -- serve these through a server/router
|
||||
*/
|
||||
|
||||
// appsPOSTHandler should be served at https://example.org/api/v1/apps
|
||||
// It is equivalent to: https://docs.joinmastodon.org/methods/apps/
|
||||
func (m *oauthModule) appsPOSTHandler(c *gin.Context) {
|
||||
l := m.log.WithField("func", "AppsPOSTHandler")
|
||||
l.Trace("entering AppsPOSTHandler")
|
||||
|
||||
form := &mastotypes.ApplicationPOSTRequest{}
|
||||
if err := c.ShouldBind(form); err != nil {
|
||||
c.JSON(http.StatusUnprocessableEntity, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// permitted length for most fields
|
||||
permittedLength := 64
|
||||
// redirect can be a bit bigger because we probably need to encode data in the redirect uri
|
||||
permittedRedirect := 256
|
||||
|
||||
// check lengths of fields before proceeding so the user can't spam huge entries into the database
|
||||
if len(form.ClientName) > permittedLength {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("client_name must be less than %d bytes", permittedLength)})
|
||||
return
|
||||
}
|
||||
if len(form.Website) > permittedLength {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("website must be less than %d bytes", permittedLength)})
|
||||
return
|
||||
}
|
||||
if len(form.RedirectURIs) > permittedRedirect {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("redirect_uris must be less than %d bytes", permittedRedirect)})
|
||||
return
|
||||
}
|
||||
if len(form.Scopes) > permittedLength {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("scopes must be less than %d bytes", permittedLength)})
|
||||
return
|
||||
}
|
||||
|
||||
// set default 'read' for scopes if it's not set, this follows the default of the mastodon api https://docs.joinmastodon.org/methods/apps/
|
||||
var scopes string
|
||||
if form.Scopes == "" {
|
||||
scopes = "read"
|
||||
} else {
|
||||
scopes = form.Scopes
|
||||
}
|
||||
|
||||
// generate new IDs for this application and its associated client
|
||||
clientID := uuid.NewString()
|
||||
clientSecret := uuid.NewString()
|
||||
vapidKey := uuid.NewString()
|
||||
|
||||
// generate the application to put in the database
|
||||
app := &model.Application{
|
||||
Name: form.ClientName,
|
||||
Website: form.Website,
|
||||
RedirectURI: form.RedirectURIs,
|
||||
ClientID: clientID,
|
||||
ClientSecret: clientSecret,
|
||||
Scopes: scopes,
|
||||
VapidKey: vapidKey,
|
||||
}
|
||||
|
||||
// chuck it in the db
|
||||
if err := m.db.Put(app); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// now we need to model an oauth client from the application that the oauth library can use
|
||||
oc := &oauthClient{
|
||||
ID: clientID,
|
||||
Secret: clientSecret,
|
||||
Domain: form.RedirectURIs,
|
||||
UserID: "", // This client isn't yet associated with a specific user, it's just an app client right now
|
||||
}
|
||||
|
||||
// chuck it in the db
|
||||
if err := m.db.Put(oc); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// done, return the new app information per the spec here: https://docs.joinmastodon.org/methods/apps/
|
||||
c.JSON(http.StatusOK, app.ToMasto())
|
||||
}
|
||||
|
||||
// signInGETHandler should be served at https://example.org/auth/sign_in.
|
||||
// The idea is to present a sign in page to the user, where they can enter their username and password.
|
||||
// The form will then POST to the sign in page, which will be handled by SignInPOSTHandler
|
||||
func (m *oauthModule) signInGETHandler(c *gin.Context) {
|
||||
func (m *authModule) signInGETHandler(c *gin.Context) {
|
||||
m.log.WithField("func", "SignInGETHandler").Trace("serving sign in html")
|
||||
c.HTML(http.StatusOK, "sign-in.tmpl", gin.H{})
|
||||
}
|
||||
|
@ -231,7 +97,7 @@ func (m *oauthModule) signInGETHandler(c *gin.Context) {
|
|||
// signInPOSTHandler should be served at https://example.org/auth/sign_in.
|
||||
// The idea is to present a sign in page to the user, where they can enter their username and password.
|
||||
// The handler will then redirect to the auth handler served at /auth
|
||||
func (m *oauthModule) signInPOSTHandler(c *gin.Context) {
|
||||
func (m *authModule) signInPOSTHandler(c *gin.Context) {
|
||||
l := m.log.WithField("func", "SignInPOSTHandler")
|
||||
s := sessions.Default(c)
|
||||
form := &login{}
|
||||
|
@ -260,10 +126,10 @@ func (m *oauthModule) signInPOSTHandler(c *gin.Context) {
|
|||
// tokenPOSTHandler should be served as a POST at https://example.org/oauth/token
|
||||
// The idea here is to serve an oauth access token to a user, which can be used for authorizing against non-public APIs.
|
||||
// See https://docs.joinmastodon.org/methods/apps/oauth/#obtain-a-token
|
||||
func (m *oauthModule) tokenPOSTHandler(c *gin.Context) {
|
||||
func (m *authModule) tokenPOSTHandler(c *gin.Context) {
|
||||
l := m.log.WithField("func", "TokenPOSTHandler")
|
||||
l.Trace("entered TokenPOSTHandler")
|
||||
if err := m.oauthServer.HandleTokenRequest(c.Writer, c.Request); err != nil {
|
||||
if err := m.server.HandleTokenRequest(c.Writer, c.Request); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
}
|
||||
}
|
||||
|
@ -271,7 +137,7 @@ func (m *oauthModule) tokenPOSTHandler(c *gin.Context) {
|
|||
// authorizeGETHandler should be served as GET at https://example.org/oauth/authorize
|
||||
// The idea here is to present an oauth authorize page to the user, with a button
|
||||
// that they have to click to accept. See here: https://docs.joinmastodon.org/methods/apps/oauth/#authorize-a-user
|
||||
func (m *oauthModule) authorizeGETHandler(c *gin.Context) {
|
||||
func (m *authModule) authorizeGETHandler(c *gin.Context) {
|
||||
l := m.log.WithField("func", "AuthorizeGETHandler")
|
||||
s := sessions.Default(c)
|
||||
|
||||
|
@ -349,7 +215,7 @@ func (m *oauthModule) authorizeGETHandler(c *gin.Context) {
|
|||
// At this point we assume that the user has A) logged in and B) accepted that the app should act for them,
|
||||
// so we should proceed with the authentication flow and generate an oauth token for them if we can.
|
||||
// See here: https://docs.joinmastodon.org/methods/apps/oauth/#authorize-a-user
|
||||
func (m *oauthModule) authorizePOSTHandler(c *gin.Context) {
|
||||
func (m *authModule) authorizePOSTHandler(c *gin.Context) {
|
||||
l := m.log.WithField("func", "AuthorizePOSTHandler")
|
||||
s := sessions.Default(c)
|
||||
|
||||
|
@ -404,7 +270,7 @@ func (m *oauthModule) authorizePOSTHandler(c *gin.Context) {
|
|||
l.Tracef("values on request set to %+v", c.Request.Form)
|
||||
|
||||
// and proceed with authorization using the oauth2 library
|
||||
if err := m.oauthServer.HandleAuthorizeRequest(c.Writer, c.Request); err != nil {
|
||||
if err := m.server.HandleAuthorizeRequest(c.Writer, c.Request); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
}
|
||||
}
|
||||
|
@ -418,25 +284,50 @@ func (m *oauthModule) authorizePOSTHandler(c *gin.Context) {
|
|||
// the request. Then, it will look up the account for that user, and set that in the request too.
|
||||
// If user or account can't be found, then the handler won't *fail*, in case the server wants to allow
|
||||
// public requests that don't have a Bearer token set (eg., for public instance information and so on).
|
||||
func (m *oauthModule) oauthTokenMiddleware(c *gin.Context) {
|
||||
func (m *authModule) oauthTokenMiddleware(c *gin.Context) {
|
||||
l := m.log.WithField("func", "ValidatePassword")
|
||||
l.Trace("entering OauthTokenMiddleware")
|
||||
|
||||
ti, err := m.oauthServer.ValidationBearerToken(c.Request)
|
||||
ti, err := m.server.ValidationBearerToken(c.Request)
|
||||
if err != nil {
|
||||
l.Trace("no valid token presented: continuing with unauthenticated request")
|
||||
return
|
||||
}
|
||||
l.Tracef("authenticated user %s with bearer token, scope is %s", ti.GetUserID(), ti.GetScope())
|
||||
c.Set(oauth.SessionAuthorizedToken, ti)
|
||||
l.Tracef("set gin context %s to %+v", oauth.SessionAuthorizedToken, ti)
|
||||
|
||||
acct := &model.Account{}
|
||||
if err := m.db.GetAccountByUserID(ti.GetUserID(), acct); err != nil || acct == nil {
|
||||
l.Tracef("no account found for user %s, continuing with unauthenticated request", ti.GetUserID())
|
||||
return
|
||||
// check for user-level token
|
||||
if uid := ti.GetUserID(); uid != "" {
|
||||
l.Tracef("authenticated user %s with bearer token, scope is %s", uid, ti.GetScope())
|
||||
|
||||
// fetch user's and account for this user id
|
||||
user := &model.User{}
|
||||
if err := m.db.GetByID(uid, user); err != nil || user == nil {
|
||||
l.Warnf("no user found for validated uid %s", uid)
|
||||
return
|
||||
}
|
||||
c.Set(oauth.SessionAuthorizedUser, user)
|
||||
l.Tracef("set gin context %s to %+v", oauth.SessionAuthorizedUser, user)
|
||||
|
||||
acct := &model.Account{}
|
||||
if err := m.db.GetByID(user.AccountID, acct); err != nil || acct == nil {
|
||||
l.Warnf("no account found for validated user %s", uid)
|
||||
return
|
||||
}
|
||||
c.Set(oauth.SessionAuthorizedAccount, acct)
|
||||
l.Tracef("set gin context %s to %+v", oauth.SessionAuthorizedAccount, acct)
|
||||
}
|
||||
|
||||
c.Set(SessionAuthorizedAccount, acct)
|
||||
c.Set(SessionAuthorizedUser, ti.GetUserID())
|
||||
// check for application token
|
||||
if cid := ti.GetClientID(); cid != "" {
|
||||
l.Tracef("authenticated client %s with bearer token, scope is %s", cid, ti.GetScope())
|
||||
app := &model.Application{}
|
||||
if err := m.db.GetWhere("client_id", cid, app); err != nil {
|
||||
l.Tracef("no app found for client %s", cid)
|
||||
}
|
||||
c.Set(oauth.SessionAuthorizedApplication, app)
|
||||
l.Tracef("set gin context %s to %+v", oauth.SessionAuthorizedApplication, app)
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
|
@ -447,7 +338,7 @@ func (m *oauthModule) oauthTokenMiddleware(c *gin.Context) {
|
|||
// The goal is to authenticate the password against the one for that email
|
||||
// address stored in the database. If OK, we return the userid (a uuid) for that user,
|
||||
// so that it can be used in further Oauth flows to generate a token/retreieve an oauth client from the db.
|
||||
func (m *oauthModule) validatePassword(email string, password string) (userid string, err error) {
|
||||
func (m *authModule) validatePassword(email string, password string) (userid string, err error) {
|
||||
l := m.log.WithField("func", "ValidatePassword")
|
||||
|
||||
// make sure an email/password was provided and bail if not
|
||||
|
@ -487,18 +378,6 @@ func incorrectPassword() (string, error) {
|
|||
return "", errors.New("password/email combination was incorrect")
|
||||
}
|
||||
|
||||
// userAuthorizationHandler gets the user's ID from the 'userid' field of the request form,
|
||||
// or redirects to the /auth/sign_in page, if this key is not present.
|
||||
func (m *oauthModule) userAuthorizationHandler(w http.ResponseWriter, r *http.Request) (userID string, err error) {
|
||||
l := m.log.WithField("func", "UserAuthorizationHandler")
|
||||
userID = r.FormValue("userid")
|
||||
if userID == "" {
|
||||
return "", errors.New("userid was empty, redirecting to sign in page")
|
||||
}
|
||||
l.Tracef("returning userID %s", userID)
|
||||
return userID, err
|
||||
}
|
||||
|
||||
// parseAuthForm parses the OAuthAuthorize form in the gin context, and stores
|
||||
// the values in the form into the session.
|
||||
func parseAuthForm(c *gin.Context, l *logrus.Entry) error {
|
|
@ -16,38 +16,38 @@
|
|||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package oauth
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gotosocial/gotosocial/internal/config"
|
||||
"github.com/gotosocial/gotosocial/internal/db"
|
||||
"github.com/gotosocial/gotosocial/internal/db/model"
|
||||
"github.com/gotosocial/gotosocial/internal/oauth"
|
||||
"github.com/gotosocial/gotosocial/internal/router"
|
||||
"github.com/gotosocial/oauth2/v4"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
type OauthTestSuite struct {
|
||||
type AuthTestSuite struct {
|
||||
suite.Suite
|
||||
tokenStore oauth2.TokenStore
|
||||
clientStore oauth2.ClientStore
|
||||
oauthServer oauth.Server
|
||||
db db.DB
|
||||
testAccount *model.Account
|
||||
testApplication *model.Application
|
||||
testUser *model.User
|
||||
testClient *oauthClient
|
||||
testClient *oauth.Client
|
||||
config *config.Config
|
||||
}
|
||||
|
||||
// SetupSuite sets some variables on the suite that we can use as consts (more or less) throughout
|
||||
func (suite *OauthTestSuite) SetupSuite() {
|
||||
func (suite *AuthTestSuite) SetupSuite() {
|
||||
c := config.Empty()
|
||||
// we're running on localhost without https so set the protocol to http
|
||||
c.Protocol = "http"
|
||||
|
@ -84,7 +84,7 @@ func (suite *OauthTestSuite) SetupSuite() {
|
|||
Email: "user@example.org",
|
||||
AccountID: acctID,
|
||||
}
|
||||
suite.testClient = &oauthClient{
|
||||
suite.testClient = &oauth.Client{
|
||||
ID: "a-known-client-id",
|
||||
Secret: "some-secret",
|
||||
Domain: fmt.Sprintf("%s://%s", c.Protocol, c.Host),
|
||||
|
@ -101,7 +101,7 @@ func (suite *OauthTestSuite) SetupSuite() {
|
|||
}
|
||||
|
||||
// SetupTest creates a postgres connection and creates the oauth_clients table before each test
|
||||
func (suite *OauthTestSuite) SetupTest() {
|
||||
func (suite *AuthTestSuite) SetupTest() {
|
||||
|
||||
log := logrus.New()
|
||||
log.SetLevel(logrus.TraceLevel)
|
||||
|
@ -113,8 +113,8 @@ func (suite *OauthTestSuite) SetupTest() {
|
|||
suite.db = db
|
||||
|
||||
models := []interface{}{
|
||||
&oauthClient{},
|
||||
&oauthToken{},
|
||||
&oauth.Client{},
|
||||
&oauth.Token{},
|
||||
&model.User{},
|
||||
&model.Account{},
|
||||
&model.Application{},
|
||||
|
@ -126,8 +126,7 @@ func (suite *OauthTestSuite) SetupTest() {
|
|||
}
|
||||
}
|
||||
|
||||
suite.tokenStore = newTokenStore(context.Background(), suite.db, logrus.New())
|
||||
suite.clientStore = newClientStore(suite.db)
|
||||
suite.oauthServer = oauth.New(suite.db, log)
|
||||
|
||||
if err := suite.db.Put(suite.testAccount); err != nil {
|
||||
logrus.Panicf("could not insert test account into db: %s", err)
|
||||
|
@ -145,10 +144,10 @@ func (suite *OauthTestSuite) SetupTest() {
|
|||
}
|
||||
|
||||
// TearDownTest drops the oauth_clients table and closes the pg connection after each test
|
||||
func (suite *OauthTestSuite) TearDownTest() {
|
||||
func (suite *AuthTestSuite) TearDownTest() {
|
||||
models := []interface{}{
|
||||
&oauthClient{},
|
||||
&oauthToken{},
|
||||
&oauth.Client{},
|
||||
&oauth.Token{},
|
||||
&model.User{},
|
||||
&model.Account{},
|
||||
&model.Application{},
|
||||
|
@ -164,7 +163,7 @@ func (suite *OauthTestSuite) TearDownTest() {
|
|||
suite.db = nil
|
||||
}
|
||||
|
||||
func (suite *OauthTestSuite) TestAPIInitialize() {
|
||||
func (suite *AuthTestSuite) TestAPIInitialize() {
|
||||
log := logrus.New()
|
||||
log.SetLevel(logrus.TraceLevel)
|
||||
|
||||
|
@ -173,17 +172,18 @@ func (suite *OauthTestSuite) TestAPIInitialize() {
|
|||
suite.FailNow(fmt.Sprintf("error mapping routes onto router: %s", err))
|
||||
}
|
||||
|
||||
api := New(suite.tokenStore, suite.clientStore, suite.db, log)
|
||||
api := New(suite.oauthServer, suite.db, log)
|
||||
if err := api.Route(r); err != nil {
|
||||
suite.FailNow(fmt.Sprintf("error mapping routes onto router: %s", err))
|
||||
}
|
||||
|
||||
r.Start()
|
||||
time.Sleep(60 * time.Second)
|
||||
if err := r.Stop(context.Background()); err != nil {
|
||||
suite.FailNow(fmt.Sprintf("error stopping router: %s", err))
|
||||
}
|
||||
}
|
||||
|
||||
func TestOauthTestSuite(t *testing.T) {
|
||||
suite.Run(t, new(OauthTestSuite))
|
||||
func TestAuthTestSuite(t *testing.T) {
|
||||
suite.Run(t, new(AuthTestSuite))
|
||||
}
|
|
@ -38,7 +38,7 @@ func newClientStore(db db.DB) oauth2.ClientStore {
|
|||
}
|
||||
|
||||
func (cs *clientStore) GetByID(ctx context.Context, clientID string) (oauth2.ClientInfo, error) {
|
||||
poc := &oauthClient{
|
||||
poc := &Client{
|
||||
ID: clientID,
|
||||
}
|
||||
if err := cs.db.GetByID(clientID, poc); err != nil {
|
||||
|
@ -48,7 +48,7 @@ func (cs *clientStore) GetByID(ctx context.Context, clientID string) (oauth2.Cli
|
|||
}
|
||||
|
||||
func (cs *clientStore) Set(ctx context.Context, id string, cli oauth2.ClientInfo) error {
|
||||
poc := &oauthClient{
|
||||
poc := &Client{
|
||||
ID: cli.GetID(),
|
||||
Secret: cli.GetSecret(),
|
||||
Domain: cli.GetDomain(),
|
||||
|
@ -58,13 +58,13 @@ func (cs *clientStore) Set(ctx context.Context, id string, cli oauth2.ClientInfo
|
|||
}
|
||||
|
||||
func (cs *clientStore) Delete(ctx context.Context, id string) error {
|
||||
poc := &oauthClient{
|
||||
poc := &Client{
|
||||
ID: id,
|
||||
}
|
||||
return cs.db.DeleteByID(id, poc)
|
||||
}
|
||||
|
||||
type oauthClient struct {
|
||||
type Client struct {
|
||||
ID string
|
||||
Secret string
|
||||
Domain string
|
|
@ -69,7 +69,7 @@ func (suite *PgClientStoreTestSuite) SetupTest() {
|
|||
suite.db = db
|
||||
|
||||
models := []interface{}{
|
||||
&oauthClient{},
|
||||
&Client{},
|
||||
}
|
||||
|
||||
for _, m := range models {
|
||||
|
@ -82,7 +82,7 @@ func (suite *PgClientStoreTestSuite) SetupTest() {
|
|||
// TearDownTest drops the oauth_clients table and closes the pg connection after each test
|
||||
func (suite *PgClientStoreTestSuite) TearDownTest() {
|
||||
models := []interface{}{
|
||||
&oauthClient{},
|
||||
&Client{},
|
||||
}
|
||||
for _, m := range models {
|
||||
if err := suite.db.DropTable(m); err != nil {
|
|
@ -0,0 +1,212 @@
|
|||
/*
|
||||
GoToSocial
|
||||
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package oauth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gotosocial/gotosocial/internal/db"
|
||||
"github.com/gotosocial/gotosocial/internal/db/model"
|
||||
"github.com/gotosocial/oauth2/v4"
|
||||
"github.com/gotosocial/oauth2/v4/errors"
|
||||
"github.com/gotosocial/oauth2/v4/manage"
|
||||
"github.com/gotosocial/oauth2/v4/server"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
SessionAuthorizedToken = "authorized_token"
|
||||
// SessionAuthorizedUser is the key set in the gin context for the id of
|
||||
// a User who has successfully passed Bearer token authorization.
|
||||
// The interface returned from grabbing this key should be parsed as a *gtsmodel.User
|
||||
SessionAuthorizedUser = "authorized_user"
|
||||
// SessionAuthorizedAccount is the key set in the gin context for the Account
|
||||
// of a User who has successfully passed Bearer token authorization.
|
||||
// The interface returned from grabbing this key should be parsed as a *gtsmodel.Account
|
||||
SessionAuthorizedAccount = "authorized_account"
|
||||
// SessionAuthorizedAccount is the key set in the gin context for the Application
|
||||
// of a Client who has successfully passed Bearer token authorization.
|
||||
// The interface returned from grabbing this key should be parsed as a *gtsmodel.Application
|
||||
SessionAuthorizedApplication = "authorized_app"
|
||||
)
|
||||
|
||||
// Server wraps some oauth2 server functions in an interface, exposing only what is needed
|
||||
type Server interface {
|
||||
HandleTokenRequest(w http.ResponseWriter, r *http.Request) error
|
||||
HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) error
|
||||
ValidationBearerToken(r *http.Request) (oauth2.TokenInfo, error)
|
||||
GenerateUserAccessToken(ti oauth2.TokenInfo, clientSecret string, userID string) (accessToken oauth2.TokenInfo, err error)
|
||||
}
|
||||
|
||||
// s fulfils the Server interface using the underlying oauth2 server
|
||||
type s struct {
|
||||
server *server.Server
|
||||
log *logrus.Logger
|
||||
}
|
||||
|
||||
type Authed struct {
|
||||
Token oauth2.TokenInfo
|
||||
Application *model.Application
|
||||
User *model.User
|
||||
Account *model.Account
|
||||
}
|
||||
|
||||
// GetAuthed is a convenience function for returning an Authed struct from a gin context.
|
||||
// In essence, it tries to extract a token, application, user, and account from the context,
|
||||
// and then sets them on a struct for convenience.
|
||||
//
|
||||
// If any are not present in the context, they will be set to nil on the returned Authed struct.
|
||||
//
|
||||
// If *ALL* are not present, then nil and an error will be returned.
|
||||
//
|
||||
// If something goes wrong during parsing, then nil and an error will be returned (consider this not authed).
|
||||
func GetAuthed(c *gin.Context) (*Authed, error) {
|
||||
ctx := c.Copy()
|
||||
a := &Authed{}
|
||||
var i interface{}
|
||||
var ok bool
|
||||
|
||||
i, ok = ctx.Get(SessionAuthorizedToken)
|
||||
if ok {
|
||||
parsed, ok := i.(oauth2.TokenInfo)
|
||||
if !ok {
|
||||
return nil, errors.New("could not parse token from session context")
|
||||
}
|
||||
a.Token = parsed
|
||||
}
|
||||
|
||||
i, ok = ctx.Get(SessionAuthorizedApplication)
|
||||
if ok {
|
||||
parsed, ok := i.(*model.Application)
|
||||
if !ok {
|
||||
return nil, errors.New("could not parse application from session context")
|
||||
}
|
||||
a.Application = parsed
|
||||
}
|
||||
|
||||
i, ok = ctx.Get(SessionAuthorizedUser)
|
||||
if ok {
|
||||
parsed, ok := i.(*model.User)
|
||||
if !ok {
|
||||
return nil, errors.New("could not parse user from session context")
|
||||
}
|
||||
a.User = parsed
|
||||
}
|
||||
|
||||
i, ok = ctx.Get(SessionAuthorizedAccount)
|
||||
if ok {
|
||||
parsed, ok := i.(*model.Account)
|
||||
if !ok {
|
||||
return nil, errors.New("could not parse account from session context")
|
||||
}
|
||||
a.Account = parsed
|
||||
}
|
||||
|
||||
if a.Token == nil && a.Application == nil && a.User == nil && a.Account == nil {
|
||||
return nil, errors.New("not authorized")
|
||||
}
|
||||
|
||||
return a, nil
|
||||
}
|
||||
|
||||
// HandleTokenRequest wraps the oauth2 library's HandleTokenRequest function
|
||||
func (s *s) HandleTokenRequest(w http.ResponseWriter, r *http.Request) error {
|
||||
return s.server.HandleTokenRequest(w, r)
|
||||
}
|
||||
|
||||
// HandleAuthorizeRequest wraps the oauth2 library's HandleAuthorizeRequest function
|
||||
func (s *s) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) error {
|
||||
return s.server.HandleAuthorizeRequest(w, r)
|
||||
}
|
||||
|
||||
// ValidationBearerToken wraps the oauth2 library's ValidationBearerToken function
|
||||
func (s *s) ValidationBearerToken(r *http.Request) (oauth2.TokenInfo, error) {
|
||||
return s.server.ValidationBearerToken(r)
|
||||
}
|
||||
|
||||
// GenerateUserAccessToken shortcuts the normal oauth flow to create an user-level
|
||||
// bearer token *without* requiring that user to log in. This is useful when we
|
||||
// need to create a token for new users who haven't validated their email or logged in yet.
|
||||
//
|
||||
// The ti parameter refers to an existing Application token that was used to make the upstream
|
||||
// request. This token needs to be validated and exist in database in order to create a new token.
|
||||
func (s *s) GenerateUserAccessToken(ti oauth2.TokenInfo, clientSecret string, userID string) (accessToken oauth2.TokenInfo, err error) {
|
||||
|
||||
tgr := &oauth2.TokenGenerateRequest{
|
||||
ClientID: ti.GetClientID(),
|
||||
ClientSecret: clientSecret,
|
||||
UserID: userID,
|
||||
RedirectURI: ti.GetRedirectURI(),
|
||||
Scope: ti.GetScope(),
|
||||
Code: ti.GetCode(),
|
||||
CodeChallenge: ti.GetCodeChallenge(),
|
||||
CodeChallengeMethod: ti.GetCodeChallengeMethod(),
|
||||
}
|
||||
|
||||
return s.server.Manager.GenerateAccessToken(context.Background(), oauth2.AuthorizationCode, tgr)
|
||||
}
|
||||
|
||||
func New(database db.DB, log *logrus.Logger) Server {
|
||||
ts := newTokenStore(context.Background(), database, log)
|
||||
cs := newClientStore(database)
|
||||
|
||||
manager := manage.NewDefaultManager()
|
||||
manager.MapTokenStorage(ts)
|
||||
manager.MapClientStorage(cs)
|
||||
manager.SetAuthorizeCodeTokenCfg(manage.DefaultAuthorizeCodeTokenCfg)
|
||||
sc := &server.Config{
|
||||
TokenType: "Bearer",
|
||||
// Must follow the spec.
|
||||
AllowGetAccessRequest: false,
|
||||
// Support only the non-implicit flow.
|
||||
AllowedResponseTypes: []oauth2.ResponseType{oauth2.Code},
|
||||
// Allow:
|
||||
// - Authorization Code (for first & third parties)
|
||||
// - Client Credentials (for applications)
|
||||
AllowedGrantTypes: []oauth2.GrantType{
|
||||
oauth2.AuthorizationCode,
|
||||
oauth2.ClientCredentials,
|
||||
},
|
||||
AllowedCodeChallengeMethods: []oauth2.CodeChallengeMethod{oauth2.CodeChallengePlain},
|
||||
}
|
||||
|
||||
srv := server.NewServer(sc, manager)
|
||||
srv.SetInternalErrorHandler(func(err error) *errors.Response {
|
||||
log.Errorf("internal oauth error: %s", err)
|
||||
return nil
|
||||
})
|
||||
|
||||
srv.SetResponseErrorHandler(func(re *errors.Response) {
|
||||
log.Errorf("internal response error: %s", re.Error)
|
||||
})
|
||||
|
||||
srv.SetUserAuthorizationHandler(func(w http.ResponseWriter, r *http.Request) (string, error) {
|
||||
userID := r.FormValue("userid")
|
||||
if userID == "" {
|
||||
return "", errors.New("userid was empty")
|
||||
}
|
||||
return userID, nil
|
||||
})
|
||||
srv.SetClientInfoHandler(server.ClientFormHandler)
|
||||
return &s{
|
||||
server: srv,
|
||||
}
|
||||
}
|
|
@ -0,0 +1,21 @@
|
|||
/*
|
||||
GoToSocial
|
||||
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package oauth
|
||||
|
||||
// TODO: write tests
|
|
@ -70,7 +70,7 @@ func newTokenStore(ctx context.Context, db db.DB, log *logrus.Logger) oauth2.Tok
|
|||
func (pts *tokenStore) sweep() error {
|
||||
// select *all* tokens from the db
|
||||
// todo: if this becomes expensive (ie., there are fucking LOADS of tokens) then figure out a better way.
|
||||
tokens := new([]*oauthToken)
|
||||
tokens := new([]*Token)
|
||||
if err := pts.db.GetAll(tokens); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -106,22 +106,22 @@ func (pts *tokenStore) Create(ctx context.Context, info oauth2.TokenInfo) error
|
|||
|
||||
// RemoveByCode deletes a token from the DB based on the Code field
|
||||
func (pts *tokenStore) RemoveByCode(ctx context.Context, code string) error {
|
||||
return pts.db.DeleteWhere("code", code, &oauthToken{})
|
||||
return pts.db.DeleteWhere("code", code, &Token{})
|
||||
}
|
||||
|
||||
// RemoveByAccess deletes a token from the DB based on the Access field
|
||||
func (pts *tokenStore) RemoveByAccess(ctx context.Context, access string) error {
|
||||
return pts.db.DeleteWhere("access", access, &oauthToken{})
|
||||
return pts.db.DeleteWhere("access", access, &Token{})
|
||||
}
|
||||
|
||||
// RemoveByRefresh deletes a token from the DB based on the Refresh field
|
||||
func (pts *tokenStore) RemoveByRefresh(ctx context.Context, refresh string) error {
|
||||
return pts.db.DeleteWhere("refresh", refresh, &oauthToken{})
|
||||
return pts.db.DeleteWhere("refresh", refresh, &Token{})
|
||||
}
|
||||
|
||||
// GetByCode selects a token from the DB based on the Code field
|
||||
func (pts *tokenStore) GetByCode(ctx context.Context, code string) (oauth2.TokenInfo, error) {
|
||||
pgt := &oauthToken{
|
||||
pgt := &Token{
|
||||
Code: code,
|
||||
}
|
||||
if err := pts.db.GetWhere("code", code, pgt); err != nil {
|
||||
|
@ -132,7 +132,7 @@ func (pts *tokenStore) GetByCode(ctx context.Context, code string) (oauth2.Token
|
|||
|
||||
// GetByAccess selects a token from the DB based on the Access field
|
||||
func (pts *tokenStore) GetByAccess(ctx context.Context, access string) (oauth2.TokenInfo, error) {
|
||||
pgt := &oauthToken{
|
||||
pgt := &Token{
|
||||
Access: access,
|
||||
}
|
||||
if err := pts.db.GetWhere("access", access, pgt); err != nil {
|
||||
|
@ -143,7 +143,7 @@ func (pts *tokenStore) GetByAccess(ctx context.Context, access string) (oauth2.T
|
|||
|
||||
// GetByRefresh selects a token from the DB based on the Refresh field
|
||||
func (pts *tokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2.TokenInfo, error) {
|
||||
pgt := &oauthToken{
|
||||
pgt := &Token{
|
||||
Refresh: refresh,
|
||||
}
|
||||
if err := pts.db.GetWhere("refresh", refresh, pgt); err != nil {
|
||||
|
@ -156,7 +156,7 @@ func (pts *tokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2
|
|||
The following models are basically helpers for the postgres token store implementation, they should only be used internally.
|
||||
*/
|
||||
|
||||
// oauthToken is a translation of the gotosocial token with the ExpiresIn fields replaced with ExpiresAt.
|
||||
// Token is a translation of the gotosocial token with the ExpiresIn fields replaced with ExpiresAt.
|
||||
//
|
||||
// Explanation for this: gotosocial assumes an in-memory or file database of some kind, where a time-to-live parameter (TTL) can be defined,
|
||||
// and tokens with expired TTLs are automatically removed. Since Postgres doesn't have that feature, it's easier to set an expiry time and
|
||||
|
@ -164,9 +164,9 @@ func (pts *tokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2
|
|||
//
|
||||
// Note that this struct does *not* satisfy the token interface shown here: https://github.com/gotosocial/oauth2/blob/master/model.go#L22
|
||||
// and implemented here: https://github.com/gotosocial/oauth2/blob/master/models/token.go.
|
||||
// As such, manual translation is always required between oauthToken and the gotosocial *model.Token. The helper functions oauthTokenToPGToken
|
||||
// As such, manual translation is always required between Token and the gotosocial *model.Token. The helper functions oauthTokenToPGToken
|
||||
// and pgTokenToOauthToken can be used for that.
|
||||
type oauthToken struct {
|
||||
type Token struct {
|
||||
ID string `pg:"type:uuid,default:gen_random_uuid(),pk,notnull"`
|
||||
ClientID string
|
||||
UserID string
|
||||
|
@ -186,7 +186,7 @@ type oauthToken struct {
|
|||
}
|
||||
|
||||
// oauthTokenToPGToken is a lil util function that takes a gotosocial token and gives back a token for inserting into postgres
|
||||
func oauthTokenToPGToken(tkn *models.Token) *oauthToken {
|
||||
func oauthTokenToPGToken(tkn *models.Token) *Token {
|
||||
now := time.Now()
|
||||
|
||||
// For the following, we want to make sure we're not adding a time.Now() to an *empty* ExpiresIn, otherwise that's
|
||||
|
@ -208,7 +208,7 @@ func oauthTokenToPGToken(tkn *models.Token) *oauthToken {
|
|||
rea = now.Add(tkn.RefreshExpiresIn)
|
||||
}
|
||||
|
||||
return &oauthToken{
|
||||
return &Token{
|
||||
ClientID: tkn.ClientID,
|
||||
UserID: tkn.UserID,
|
||||
RedirectURI: tkn.RedirectURI,
|
||||
|
@ -228,7 +228,7 @@ func oauthTokenToPGToken(tkn *models.Token) *oauthToken {
|
|||
}
|
||||
|
||||
// pgTokenToOauthToken is a lil util function that takes a postgres token and gives back a gotosocial token
|
||||
func pgTokenToOauthToken(pgt *oauthToken) *models.Token {
|
||||
func pgTokenToOauthToken(pgt *Token) *models.Token {
|
||||
now := time.Now()
|
||||
|
||||
return &models.Token{
|
|
@ -0,0 +1,21 @@
|
|||
/*
|
||||
GoToSocial
|
||||
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package oauth
|
||||
|
||||
// TODO: write tests
|
|
@ -36,7 +36,7 @@ import (
|
|||
// Router provides the REST interface for gotosocial, using gin.
|
||||
type Router interface {
|
||||
// Attach a gin handler to the router with the given method and path
|
||||
AttachHandler(method string, path string, handler gin.HandlerFunc)
|
||||
AttachHandler(method string, path string, f gin.HandlerFunc)
|
||||
// Attach a gin middleware to the router that will be used globally
|
||||
AttachMiddleware(handler gin.HandlerFunc)
|
||||
// Start the router
|
||||
|
@ -59,6 +59,8 @@ func (r *router) Start() {
|
|||
r.logger.Fatalf("listen: %s", err)
|
||||
}
|
||||
}()
|
||||
// c := &gin.Context{}
|
||||
// c.Get()
|
||||
}
|
||||
|
||||
// Stop shuts down the router nicely
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
/*
|
||||
GoToSocial
|
||||
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package mastotypes
|
||||
|
||||
// Token represents an OAuth token used for authenticating with the API and performing actions.. See https://docs.joinmastodon.org/entities/token/
|
||||
type Token struct {
|
||||
// An OAuth token to be used for authorization.
|
||||
AccessToken string `json:"access_token"`
|
||||
// The OAuth token type. Mastodon uses Bearer tokens.
|
||||
TokenType string `json:"token_type"`
|
||||
// The OAuth scopes granted by this token, space-separated.
|
||||
Scope string `json:"scope"`
|
||||
// When the token was generated. (UNIX timestamp seconds)
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
}
|
Loading…
Reference in New Issue