teableio/teable

View on GitHub
apps/nestjs-backend/src/features/oauth/oauth-server.service.ts

Summary

Maintainability
B
4 hrs
Test Coverage
import {
  BadRequestException,
  HttpException,
  Injectable,
  NotFoundException,
  UnauthorizedException,
} from '@nestjs/common';
import { JwtService } from '@nestjs/jwt';
import { getRandomString, nullsToUndefined } from '@teable/core';
import { PrismaService } from '@teable/db-main-prisma';
import type { DecisionInfoGetVo } from '@teable/openapi';
import type { Response, Request } from 'express';
import { difference, pick } from 'lodash';
import ms from 'ms';
import type {
  IssueGrantCodeFunction,
  IssueExchangeCodeFunction,
  ValidateFunctionArity4,
  ImmediateFunction,
  ExchangeDoneFunction,
  OAuth2,
} from 'oauth2orize';
import oauth2orize, { AuthorizationError } from 'oauth2orize';
import { CacheService } from '../../cache/cache.service';
import { BaseConfig, IBaseConfig } from '../../configs/base.config';
import { IOAuthConfig, OAuthConfig } from '../../configs/oauth.config';
import { second } from '../../utils/second';
import { AccessTokenService } from '../access-token/access-token.service';
import { OAuthTxStore } from './oauth-tx-store';
import type { IAuthorizeClient, IExchangeClient, IOAuth2Server } from './types';

@Injectable()
export class OAuthServerService {
  server: IOAuth2Server;

  constructor(
    private readonly prismaService: PrismaService,
    private readonly cacheService: CacheService,
    private readonly accessTokenService: AccessTokenService,
    private readonly jwtService: JwtService,
    private readonly oauthTxStore: OAuthTxStore,
    @BaseConfig() private readonly baseConfig: IBaseConfig,
    @OAuthConfig() private readonly oauth2Config: IOAuthConfig
  ) {
    this.server = oauth2orize.createServer({
      store: this.oauthTxStore,
    });
    this.server.grant(oauth2orize.grant.code(this.codeGrant));
    this.server.exchange(oauth2orize.exchange.code(this.codeExchange));
    (this.server as unknown as IOAuth2Server<IExchangeClient>).exchange(
      oauth2orize.exchange.refreshToken(this.refreshTokenExchange)
    );
  }

  private async getAuthorizedTime(userId: string, clientId: string) {
    const authorizedTime = await this.prismaService
      .txClient()
      .oAuthAppAuthorized.findUnique({
        where: {
          // eslint-disable-next-line @typescript-eslint/naming-convention
          clientId_userId: {
            clientId,
            userId,
          },
        },
        select: {
          authorizedTime: true,
        },
      })
      .then((data) => data?.authorizedTime);
    // validate authorized time is not expired
    return (
      authorizedTime &&
      new Date(authorizedTime).getTime() + ms(this.oauth2Config.authorizedExpireIn) > Date.now()
    );
  }

  private handleError(error: unknown | undefined) {
    if (error instanceof AuthorizationError) {
      return new HttpException(error.message, Number(error.status));
    }
    return error;
  }

  private authorizeValidate: ValidateFunctionArity4<IAuthorizeClient> = async (
    clientId,
    queryRedirectUri,
    queryScopes,
    done
  ) => {
    try {
      const { redirectUris, scopes } = await this.getOAuthApp(clientId);
      // validate scopes if get scopes from user
      const invalidScopes = difference(queryScopes, scopes);
      if (invalidScopes.length > 0) {
        return done(new BadRequestException('Invalid scopes: ' + invalidScopes.join(',')));
      }

      // valid redirectUri
      if (!redirectUris.length) {
        return done(new BadRequestException('Redirect uri not configured'));
      }
      const redirectUri = queryRedirectUri || redirectUris[0];
      // valid redirectUri
      if (!redirectUris.includes(redirectUri)) {
        return done(new BadRequestException('Redirect uri not found'));
      }
      const clientScopes = queryScopes ?? scopes;
      done(
        null,
        {
          clientId,
          scopes: clientScopes,
          redirectUri,
        },
        redirectUri
      );
    } catch (error) {
      done(error as Error);
    }
  };

  private authorizeImmediate: ImmediateFunction<IAuthorizeClient> = async (
    client,
    user,
    _scope,
    _type,
    _areq,
    done
  ) => {
    const isTrusted = await this.getAuthorizedTime(user.id, client.clientId);
    if (isTrusted) {
      await this.touchAuthorize(client.clientId, user.id);
      return done(null, true, undefined, undefined);
    }
    return done(null, false, undefined, undefined);
  };

  async authorize(req: Request, res: Response) {
    return new Promise<void>((resolve, reject) => {
      // eslint-disable-next-line @typescript-eslint/no-explicit-any
      (this.server as any).authorization(this.authorizeValidate, this.authorizeImmediate)(
        // eslint-disable-next-line @typescript-eslint/no-explicit-any
        req as any,
        res,
        (error: unknown) => {
          if (error) {
            return reject(this.handleError(error));
          }
          res.redirect(
            // eslint-disable-next-line @typescript-eslint/no-explicit-any
            `${this.baseConfig.publicOrigin}/oauth/decision?transaction_id=${(req as any).oauth2.transactionID}`
          );
          resolve();
        }
      );
    });
  }

  async token(req: Request, res: Response) {
    return new Promise<void>((resolve, reject) => {
      // eslint-disable-next-line @typescript-eslint/no-explicit-any
      this.server.token()(req as any, res, (error) => {
        if (error) {
          return reject(this.handleError(error));
        }
        resolve();
      });
    });
  }

  private decisionComplete = async (_req: unknown, oauth2: OAuth2, cb: (err?: unknown) => void) => {
    // complete the transaction
    await this.touchAuthorize(oauth2.req.clientID, oauth2.user.id)
      .then(() => cb())
      .catch(cb);
  };

  private touchAuthorize = async (clientId: string, userId: string) => {
    // update authorized time
    console.log('touchAuthorize', clientId, userId);
    await this.prismaService.oAuthAppAuthorized.upsert({
      where: {
        // eslint-disable-next-line @typescript-eslint/naming-convention
        clientId_userId: {
          clientId: clientId,
          userId: userId,
        },
      },
      create: {
        clientId: clientId,
        userId: userId,
        authorizedTime: new Date().toISOString(),
      },
      update: {
        authorizedTime: new Date().toISOString(),
      },
    });
  };

  async decision(req: Request, res: Response) {
    return new Promise<void>((resolve, reject) => {
      // this.decision() return an array of middleware
      // eslint-disable-next-line @typescript-eslint/no-explicit-any
      const fns: Array<ReturnType<IOAuth2Server['decision']>> = (this.server as any).decision(
        undefined,
        undefined,
        this.decisionComplete
      );
      // transactionLoader loads oauth data into req.oauth2
      const transactionLoader = fns[0];
      const decisionFn = fns[1];
      // eslint-disable-next-line @typescript-eslint/no-explicit-any
      transactionLoader(req as any, res, (error) => {
        if (error) {
          return reject(this.handleError(error));
        }
        // eslint-disable-next-line @typescript-eslint/no-explicit-any
        decisionFn(req as any, res, async (error) => {
          if (error) {
            return reject(this.handleError(error));
          }
          resolve();
        });
      });
    });
  }

  private async getOAuthApp(clientId: string) {
    const data = await this.prismaService
      .txClient()
      .oAuthApp.findUniqueOrThrow({
        where: {
          clientId,
        },
      })
      .catch((error) => {
        throw new UnauthorizedException(error.message);
      });
    return nullsToUndefined({
      ...data,
      redirectUris: data.redirectUris ? (JSON.parse(data.redirectUris) as string[]) : [],
      scopes: data.scopes ? (JSON.parse(data.scopes) as string[]) : [],
    });
  }

  private codeGrant: IssueGrantCodeFunction = async (client, _redirectUri, user, ares, done) => {
    const { clientId } = await this.getOAuthApp(client.clientId);
    const code = getRandomString(16);
    // save code
    await this.cacheService.set(
      `oauth:code:${code}`,
      {
        clientId,
        redirectUri: client.redirectUri,
        scopes: client.scopes,
        user: pick(user, ['id', 'email', 'name']),
      },
      this.oauth2Config.codeExpireIn
    );
    done(null, code);
  };

  private generateAccessToken({
    userId,
    scopes,
    clientId,
    clientName,
  }: {
    userId: string;
    scopes: string[];
    clientId: string;
    clientName: string;
  }) {
    return this.accessTokenService.createAccessToken({
      clientId,
      name: `oauth:${clientName}`,
      scopes,
      userId,
      // 10 minutes
      expiredTime: new Date(Date.now() + ms(this.oauth2Config.accessTokenExpireIn)).toISOString(),
    });
  }

  private getRefreshToken(client: IExchangeClient, accessTokenId: string, sign: string) {
    const { clientId, clientSecret } = client;
    return this.jwtService.signAsync(
      {
        clientId,
        secret: clientSecret,
        accessTokenId,
        sign: sign,
      },
      { expiresIn: this.oauth2Config.refreshTokenExpireIn }
    );
  }

  private getRefreshTokenExpireTime() {
    return new Date(Date.now() + ms(this.oauth2Config.refreshTokenExpireIn)).toISOString();
  }

  private codeExchange: IssueExchangeCodeFunction = async (client, code, redirectUri, done) => {
    await this.prismaService
      .$tx(async () => {
        // Verify the code
        const codeState = await this.cacheService.get(`oauth:code:${code}`);
        if (!codeState) {
          return done(new UnauthorizedException('Invalid code'));
        }
        await this.cacheService.del(`oauth:code:${code}`);

        if (codeState.clientId !== client.clientId) {
          return done(new UnauthorizedException('Invalid client'));
        }
        if (redirectUri && codeState.redirectUri !== redirectUri) {
          return done(new UnauthorizedException('Invalid redirectUri'));
        }

        // save access token
        const accessToken = await this.generateAccessToken({
          userId: codeState.user.id,
          scopes: codeState.scopes,
          clientId: client.clientId,
          clientName: client.name,
        });

        // save oauth access token
        const refreshTokenSign = getRandomString(16);
        const refreshToken = await this.getRefreshToken(client, accessToken.id, refreshTokenSign);
        await this.prismaService.txClient().oAuthAppToken.create({
          data: {
            refreshTokenSign,
            appSecretId: client.secretId,
            createdBy: codeState.user.id,
            expiredTime: this.getRefreshTokenExpireTime(),
          },
        });
        // Issue a token
        done(null, accessToken.token, refreshToken, {
          scopes: codeState.scopes,
          expires_in: second(this.oauth2Config.accessTokenExpireIn),
          refresh_expires_in: second(this.oauth2Config.refreshTokenExpireIn),
        });
      })
      .catch((error) => done(error));
  };

  private refreshTokenExchange: (
    client: IExchangeClient,
    refreshToken: string,
    issued: ExchangeDoneFunction
  ) => void = (client, refreshToken: string, done) => {
    return this.prismaService
      .$tx(async () => {
        const { clientSecret, name, secretId } = client;
        const { clientId, secret, accessTokenId, sign } = await this.jwtService.verifyAsync<{
          clientId: string;
          secret: string;
          accessTokenId: string;
          sign: string;
        }>(refreshToken);

        if (client.clientId !== clientId) {
          return done(new UnauthorizedException('Invalid client'));
        }
        if (clientSecret !== secret) {
          return done(new UnauthorizedException('Invalid secret'));
        }

        const oldAccessToken = await this.prismaService.txClient().accessToken.findUnique({
          where: { id: accessTokenId },
        });

        if (!oldAccessToken) {
          return done(new UnauthorizedException('Invalid access token'));
        }
        const scopes = oldAccessToken.scopes ? JSON.parse(oldAccessToken.scopes) : [];
        const accessToken = await this.generateAccessToken({
          userId: oldAccessToken.userId,
          scopes,
          clientId,
          clientName: name,
        });

        // validate refresh_token and refresh refresh_token
        const oauthAppToken = await this.prismaService
          .txClient()
          .oAuthAppToken.update({
            where: { refreshTokenSign: sign, appSecretId: secretId },
            data: {
              refreshTokenSign: getRandomString(16),
              expiredTime: this.getRefreshTokenExpireTime(),
            },
            select: {
              refreshTokenSign: true,
            },
          })
          .catch(() => {
            throw new UnauthorizedException('Invalid refresh token');
          });

        const newRefreshToken = await this.getRefreshToken(
          client,
          accessToken.id,
          oauthAppToken.refreshTokenSign
        );
        // Issue a token
        done(null, accessToken.token, newRefreshToken, {
          scopes,
          expires_in: second(this.oauth2Config.accessTokenExpireIn),
          refresh_expires_in: second(this.oauth2Config.refreshTokenExpireIn),
        });
      })
      .catch((error) => done(error));
  };

  async getDecisionInfo(req: Request, transactionId: string) {
    req.body['transaction_id'] = transactionId;
    return new Promise<DecisionInfoGetVo>((resolve, reject) => {
      this.oauthTxStore.load(req, async (err, txn) => {
        if (err) {
          reject(err);
        } else {
          const clientId = txn!.req.clientID;
          const oauthApp = await this.getOAuthApp(clientId);
          if (!oauthApp) {
            return reject(new NotFoundException('Client not found'));
          }
          resolve({
            name: oauthApp.name,
            description: oauthApp.description ?? undefined,
            homepage: oauthApp.homepage,
            logo: oauthApp.logo ?? undefined,
            scopes: txn!.req.scope,
          });
        }
      });
    });
  }
}