feat: implement comprehensive CSRF protection

This commit is contained in:
2025-07-11 18:06:51 +02:00
committed by Kaj Kowalski
parent e7818f5e4f
commit 3e9e75e854
44 changed files with 14964 additions and 6413 deletions

280
lib/csrf.ts Normal file
View File

@ -0,0 +1,280 @@
/**
* CSRF Protection Utilities
*
* This module provides CSRF protection for the application using the csrf library.
* It handles token generation, validation, and provides utilities for both server and client.
*/
import csrf from "csrf";
import { cookies } from "next/headers";
import type { NextRequest } from "next/server";
import { env } from "./env";
const tokens = new csrf();
/**
* CSRF configuration
*/
export const CSRF_CONFIG = {
cookieName: "csrf-token",
headerName: "x-csrf-token",
secret: env.CSRF_SECRET,
cookie: {
httpOnly: true,
secure: env.NODE_ENV === "production",
sameSite: "lax" as const,
maxAge: 60 * 60 * 24, // 24 hours
},
} as const;
/**
* Generate a new CSRF token
*/
export function generateCSRFToken(): string {
const secret = tokens.secretSync();
const token = tokens.create(secret);
return `${secret}:${token}`;
}
/**
* Verify a CSRF token
*/
export function verifyCSRFToken(token: string, secret?: string): boolean {
try {
if (token.includes(":")) {
const [tokenSecret, tokenValue] = token.split(":");
return tokens.verify(tokenSecret, tokenValue);
}
if (secret) {
return tokens.verify(secret, token);
}
return false;
} catch {
return false;
}
}
/**
* Extract CSRF token from request
*/
export function extractCSRFToken(request: NextRequest): string | null {
// Check header first
const headerToken = request.headers.get(CSRF_CONFIG.headerName);
if (headerToken) {
return headerToken;
}
// Check form data for POST requests
if (request.method === "POST") {
try {
const formData = request.formData();
return formData.then((data) => data.get("csrf_token") as string | null);
} catch {
// If formData fails, try JSON body
try {
const body = request.json();
return body.then((data) => data.csrfToken || null);
} catch {
return null;
}
}
}
return null;
}
/**
* Get CSRF token from cookies (server-side)
*/
export async function getCSRFTokenFromCookies(): Promise<string | null> {
try {
const cookieStore = cookies();
const token = cookieStore.get(CSRF_CONFIG.cookieName);
return token?.value || null;
} catch {
return null;
}
}
/**
* Server-side utilities for API routes
*/
export class CSRFProtection {
/**
* Generate and set CSRF token in response
*/
static generateTokenResponse(): {
token: string;
cookie: {
name: string;
value: string;
options: {
httpOnly: boolean;
secure: boolean;
sameSite: "lax";
maxAge: number;
path: string;
};
};
} {
const token = generateCSRFToken();
return {
token,
cookie: {
name: CSRF_CONFIG.cookieName,
value: token,
options: {
...CSRF_CONFIG.cookie,
path: "/",
},
},
};
}
/**
* Validate CSRF token from request
*/
static async validateRequest(request: NextRequest): Promise<{
valid: boolean;
error?: string;
}> {
try {
// Skip CSRF validation for GET, HEAD, OPTIONS
if (["GET", "HEAD", "OPTIONS"].includes(request.method)) {
return { valid: true };
}
// Get token from request
const requestToken = await this.getTokenFromRequest(request);
if (!requestToken) {
return {
valid: false,
error: "CSRF token missing from request",
};
}
// Get stored token from cookies
const cookieToken = request.cookies.get(CSRF_CONFIG.cookieName)?.value;
if (!cookieToken) {
return {
valid: false,
error: "CSRF token missing from cookies",
};
}
// Verify tokens match
if (requestToken !== cookieToken) {
return {
valid: false,
error: "CSRF token mismatch",
};
}
// Verify token is valid
if (!verifyCSRFToken(requestToken)) {
return {
valid: false,
error: "Invalid CSRF token",
};
}
return { valid: true };
} catch (error) {
return {
valid: false,
error: `CSRF validation error: ${error instanceof Error ? error.message : "Unknown error"}`,
};
}
}
/**
* Extract token from request (handles different content types)
*/
private static async getTokenFromRequest(request: NextRequest): Promise<string | null> {
// Check header first
const headerToken = request.headers.get(CSRF_CONFIG.headerName);
if (headerToken) {
return headerToken;
}
// Check form data or JSON body
try {
const contentType = request.headers.get("content-type");
if (contentType?.includes("application/json")) {
const body = await request.clone().json();
return body.csrfToken || body.csrf_token || null;
} else if (contentType?.includes("multipart/form-data") || contentType?.includes("application/x-www-form-urlencoded")) {
const formData = await request.clone().formData();
return formData.get("csrf_token") as string | null;
}
} catch (error) {
// If parsing fails, return null
console.warn("Failed to parse request body for CSRF token:", error);
}
return null;
}
}
/**
* Client-side utilities
*/
export const CSRFClient = {
/**
* Get CSRF token from cookies (client-side)
*/
getToken(): string | null {
if (typeof document === "undefined") return null;
const cookies = document.cookie.split(";");
for (const cookie of cookies) {
const [name, value] = cookie.trim().split("=");
if (name === CSRF_CONFIG.cookieName) {
return decodeURIComponent(value);
}
}
return null;
},
/**
* Add CSRF token to fetch options
*/
addTokenToFetch(options: RequestInit = {}): RequestInit {
const token = this.getToken();
if (!token) return options;
return {
...options,
headers: {
...options.headers,
[CSRF_CONFIG.headerName]: token,
},
};
},
/**
* Add CSRF token to form data
*/
addTokenToFormData(formData: FormData): FormData {
const token = this.getToken();
if (token) {
formData.append("csrf_token", token);
}
return formData;
},
/**
* Add CSRF token to object (for JSON requests)
*/
addTokenToObject<T extends Record<string, unknown>>(obj: T): T & { csrfToken: string } {
const token = this.getToken();
return {
...obj,
csrfToken: token || "",
};
},
};

View File

@ -79,6 +79,9 @@ export const env = {
NEXTAUTH_SECRET: parseEnvValue(process.env.NEXTAUTH_SECRET) || "",
NODE_ENV: parseEnvValue(process.env.NODE_ENV) || "development",
// CSRF Protection
CSRF_SECRET: parseEnvValue(process.env.CSRF_SECRET) || parseEnvValue(process.env.NEXTAUTH_SECRET) || "fallback-csrf-secret",
// OpenAI
OPENAI_API_KEY: parseEnvValue(process.env.OPENAI_API_KEY) || "",
OPENAI_MOCK_MODE: parseEnvValue(process.env.OPENAI_MOCK_MODE) === "true",

191
lib/hooks/useCSRF.ts Normal file
View File

@ -0,0 +1,191 @@
/**
* CSRF React Hooks
*
* Client-side hooks for managing CSRF tokens in React components.
*/
"use client";
import { useCallback, useEffect, useState } from "react";
import { CSRFClient } from "../csrf";
/**
* Hook for managing CSRF tokens
*/
export function useCSRF() {
const [token, setToken] = useState<string | null>(null);
const [loading, setLoading] = useState(true);
const [error, setError] = useState<string | null>(null);
/**
* Fetch a new CSRF token from the server
*/
const fetchToken = useCallback(async () => {
try {
setLoading(true);
setError(null);
const response = await fetch("/api/csrf-token", {
method: "GET",
credentials: "include",
});
if (!response.ok) {
throw new Error(`Failed to fetch CSRF token: ${response.status}`);
}
const data = await response.json();
if (data.success && data.token) {
setToken(data.token);
} else {
throw new Error("Invalid response from CSRF endpoint");
}
} catch (err) {
const errorMessage = err instanceof Error ? err.message : "Failed to fetch CSRF token";
setError(errorMessage);
console.error("CSRF token fetch error:", errorMessage);
} finally {
setLoading(false);
}
}, []);
/**
* Get token from cookies or fetch new one
*/
const getToken = useCallback(async (): Promise<string | null> => {
// Try to get existing token from cookies
const existingToken = CSRFClient.getToken();
if (existingToken) {
setToken(existingToken);
setLoading(false);
return existingToken;
}
// If no token exists, fetch a new one
await fetchToken();
return CSRFClient.getToken();
}, [fetchToken]);
/**
* Initialize token on mount
*/
useEffect(() => {
getToken();
}, [getToken]);
return {
token,
loading,
error,
fetchToken,
getToken,
refreshToken: fetchToken,
};
}
/**
* Hook for adding CSRF protection to fetch requests
*/
export function useCSRFFetch() {
const { token, getToken } = useCSRF();
/**
* Enhanced fetch with automatic CSRF token inclusion
*/
const csrfFetch = useCallback(
async (url: string, options: RequestInit = {}): Promise<Response> => {
// Ensure we have a token for state-changing requests
const method = options.method || "GET";
if (["POST", "PUT", "DELETE", "PATCH"].includes(method.toUpperCase())) {
const currentToken = token || (await getToken());
if (currentToken) {
options = CSRFClient.addTokenToFetch(options);
}
}
return fetch(url, {
...options,
credentials: "include", // Ensure cookies are sent
});
},
[token, getToken]
);
return {
csrfFetch,
token,
addTokenToFetch: CSRFClient.addTokenToFetch,
addTokenToFormData: CSRFClient.addTokenToFormData,
addTokenToObject: CSRFClient.addTokenToObject,
};
}
/**
* Hook for form submissions with CSRF protection
*/
export function useCSRFForm() {
const { token, getToken } = useCSRF();
/**
* Submit form with CSRF protection
*/
const submitForm = useCallback(
async (
url: string,
formData: FormData,
options: RequestInit = {}
): Promise<Response> => {
// Ensure we have a token
const currentToken = token || (await getToken());
if (currentToken) {
CSRFClient.addTokenToFormData(formData);
}
return fetch(url, {
method: "POST",
body: formData,
credentials: "include",
...options,
});
},
[token, getToken]
);
/**
* Submit JSON data with CSRF protection
*/
const submitJSON = useCallback(
async (
url: string,
data: Record<string, unknown>,
options: RequestInit = {}
): Promise<Response> => {
// Ensure we have a token
const currentToken = token || (await getToken());
if (currentToken) {
data = CSRFClient.addTokenToObject(data);
}
return fetch(url, {
method: "POST",
headers: {
"Content-Type": "application/json",
...options.headers,
},
body: JSON.stringify(data),
credentials: "include",
...options,
});
},
[token, getToken]
);
return {
token,
submitForm,
submitJSON,
addTokenToFormData: CSRFClient.addTokenToFormData,
addTokenToObject: CSRFClient.addTokenToObject,
};
}

View File

@ -9,6 +9,7 @@ import { httpBatchLink } from "@trpc/client";
import { createTRPCNext } from "@trpc/next";
import superjson from "superjson";
import type { AppRouter } from "@/server/routers/_app";
import { CSRFClient } from "./csrf";
function getBaseUrl() {
if (typeof window !== "undefined") {
@ -54,10 +55,25 @@ export const trpc = createTRPCNext<AppRouter>({
* @link https://trpc.io/docs/v10/header
*/
headers() {
return {
// Include credentials for authentication
const headers: Record<string, string> = {};
// Add CSRF token for state-changing operations
const csrfToken = CSRFClient.getToken();
if (csrfToken) {
headers["x-csrf-token"] = csrfToken;
}
return headers;
},
/**
* Custom fetch implementation to include credentials
*/
fetch(url, options) {
return fetch(url, {
...options,
credentials: "include",
};
});
},
}),
],

View File

@ -15,6 +15,7 @@ import type { z } from "zod";
import { authOptions } from "./auth";
import { prisma } from "./prisma";
import { validateInput } from "./validation";
import { CSRFProtection } from "./csrf";
/**
* Create context for tRPC requests
@ -151,6 +152,38 @@ export const companyProcedure = publicProcedure.use(enforceCompanyAccess);
export const adminProcedure = publicProcedure.use(enforceAdminAccess);
export const validatedProcedure = createValidatedProcedure;
/**
* CSRF protection middleware for state-changing operations
*/
const enforceCSRFProtection = t.middleware(async ({ ctx, next }) => {
// Extract request from context
const request = ctx.req as Request;
// Skip CSRF validation for GET requests
if (request.method === "GET") {
return next({ ctx });
}
// Convert to NextRequest for validation
const nextRequest = new Request(request.url, {
method: request.method,
headers: request.headers,
body: request.body,
}) as any;
// Validate CSRF token
const validation = await CSRFProtection.validateRequest(nextRequest);
if (!validation.valid) {
throw new TRPCError({
code: "FORBIDDEN",
message: validation.error || "CSRF validation failed",
});
}
return next({ ctx });
});
/**
* Rate limiting middleware for sensitive operations
*/
@ -161,3 +194,11 @@ export const rateLimitedProcedure = publicProcedure.use(
return next({ ctx });
}
);
/**
* CSRF-protected procedures for state-changing operations
*/
export const csrfProtectedProcedure = publicProcedure.use(enforceCSRFProtection);
export const csrfProtectedAuthProcedure = csrfProtectedProcedure.use(enforceUserIsAuthed);
export const csrfProtectedCompanyProcedure = csrfProtectedProcedure.use(enforceCompanyAccess);
export const csrfProtectedAdminProcedure = csrfProtectedProcedure.use(enforceAdminAccess);