Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions packages/sdk/src/client/modules/attest/attest-client.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import { describe, it, expect, vi } from 'vitest';
import { AttestClient } from './attest-client';
import http from 'node:http';
import { EventEmitter } from 'node:events';

function makeMockHttp(responseData: Buffer, statusCode = 200) {
const mockRes = new EventEmitter() as any;
mockRes.statusCode = statusCode;
const mockReq = new EventEmitter() as any;
const capturedBodies: string[] = [];
mockReq.write = (body: string) => { capturedBodies.push(body); };
mockReq.end = () => {
setTimeout(() => {
mockRes.emit('data', responseData);
mockRes.emit('end');
}, 0);
};
return { mockRes, mockReq, capturedBodies };
}

describe('AttestClient extraData', () => {
it('passes extra_data to TEE server when extraData is provided', async () => {
const { mockRes, mockReq, capturedBodies } = makeMockHttp(Buffer.from('fake-attestation-bytes'));
vi.spyOn(http, 'request').mockImplementation((_opts: any, cb: any) => { cb(mockRes); return mockReq; });
const mockFetch = vi.fn().mockResolvedValue({
ok: true,
json: async () => ({ data: { encryptedToken: 'fake-jwe' }, signature: Buffer.from('fake-sig').toString('base64') }),
});
vi.stubGlobal('fetch', mockFetch);

const client = new AttestClient({ kmsServerURL: 'http://localhost:8080', kmsPublicKey: 'fake-key', audience: 'test' });
const extraData = Buffer.alloc(32, 0xab);

try { await client.attest(extraData); } catch { /* signature will fail */ }

expect(capturedBodies.length).toBeGreaterThan(0);
const teeBody = JSON.parse(capturedBodies[0]);
expect(teeBody.extra_data).toBe(extraData.toString('base64'));

vi.restoreAllMocks();
});

it('omits extra_data when extraData is not provided', async () => {
const { mockRes, mockReq, capturedBodies } = makeMockHttp(Buffer.from('fake-attestation-bytes'));
vi.spyOn(http, 'request').mockImplementation((_opts: any, cb: any) => { cb(mockRes); return mockReq; });
vi.stubGlobal('fetch', vi.fn().mockResolvedValue({
ok: true,
json: async () => ({ data: { encryptedToken: 'fake-jwe' }, signature: Buffer.from('fake-sig').toString('base64') }),
}));

const client = new AttestClient({ kmsServerURL: 'http://localhost:8080', kmsPublicKey: 'fake-key', audience: 'test' });
try { await client.attest(); } catch { /* expected */ }

const teeBody = JSON.parse(capturedBodies[0]);
expect(teeBody.extra_data).toBeUndefined();

vi.restoreAllMocks();
});

it('passes extra_data to KMS when extraData is provided', async () => {
const { mockRes, mockReq } = makeMockHttp(Buffer.from('fake-attestation-bytes'));
vi.spyOn(http, 'request').mockImplementation((_opts: any, cb: any) => { cb(mockRes); return mockReq; });

const kmsCapture: string[] = [];
const mockFetch = vi.fn().mockImplementation(async (_url: string, opts: any) => {
kmsCapture.push(opts.body);
return { ok: true, json: async () => ({ data: { encryptedToken: 'fake-jwe' }, signature: Buffer.from('fake-sig').toString('base64') }) };
});
vi.stubGlobal('fetch', mockFetch);

const client = new AttestClient({ kmsServerURL: 'http://localhost:8080', kmsPublicKey: 'fake-key', audience: 'test' });
const extraData = Buffer.alloc(32, 0xcd);
try { await client.attest(extraData); } catch { /* expected */ }

expect(kmsCapture.length).toBeGreaterThan(0);
const kmsBody = JSON.parse(kmsCapture[0]);
expect(kmsBody.extra_data).toBe(extraData.toString('base64'));

vi.restoreAllMocks();
});

it('throws if extraData exceeds 1MB', async () => {
const client = new AttestClient({ kmsServerURL: 'http://localhost:8080', kmsPublicKey: 'fake-key', audience: 'test' });
const tooLarge = Buffer.alloc(1_048_576 + 1);
await expect(client.attest(tooLarge)).rejects.toThrow('extraData exceeds 1MB limit');
});
});
29 changes: 22 additions & 7 deletions packages/sdk/src/client/modules/attest/attest-client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,13 @@ export class AttestClient {
this.config = config;
}

async attest(): Promise<string> {
async attest(extraData?: Buffer): Promise<string> {
// go-tpm-tools hashes extraData (SHA-256/SHA-512) before binding it into the
// hardware nonce, so callers can pass arbitrary data up to 1MB.
if (extraData && extraData.length > 1_048_576) {
throw new Error(`extraData exceeds 1MB limit (${extraData.length} bytes)`);
}

const { publicKey, privateKey } = generateKeyPairSync('rsa', {
modulusLength: 4096,
publicKeyEncoding: { type: 'spki', format: 'pem' } as const,
Expand All @@ -35,8 +41,8 @@ export class AttestClient {
.digest();

const socketPath = this.config.socketPath ?? DEFAULT_SOCKET_PATH;
const attestationBytes = await this.getAttestation(socketPath, challengeHash);
const attestResponse = await this.postAttest(attestationBytes, publicKey);
const attestationBytes = await this.getAttestation(socketPath, challengeHash, extraData);
const attestResponse = await this.postAttest(attestationBytes, publicKey, extraData);

this.verifySignature(JSON.stringify(attestResponse.data), attestResponse.signature);

Expand Down Expand Up @@ -79,9 +85,13 @@ export class AttestClient {
}
}

private getAttestation(socketPath: string, challenge: Buffer): Promise<Buffer> {
private getAttestation(socketPath: string, challenge: Buffer, extraData?: Buffer): Promise<Buffer> {
return new Promise((resolve, reject) => {
const body = JSON.stringify({ challenge: challenge.toString('base64') });
const requestBody: Record<string, string> = { challenge: challenge.toString('base64') };
if (extraData && extraData.length > 0) {
requestBody.extra_data = extraData.toString('base64');
}
const body = JSON.stringify(requestBody);

const req = http.request(
{
Expand Down Expand Up @@ -115,14 +125,19 @@ export class AttestClient {
private async postAttest(
attestationBytes: Buffer,
rsaPublicKey: string,
extraData?: Buffer,
): Promise<{ data: { encryptedToken: string }; signature: string }> {
const url = `${this.config.kmsServerURL}/auth/attest`;
const body = JSON.stringify({
const requestBody: Record<string, unknown> = {
version: 3,
attestation: attestationBytes.toString('base64'),
rsaKey: rsaPublicKey,
audience: this.config.audience,
});
};
if (extraData && extraData.length > 0) {
requestBody.extra_data = extraData.toString('base64');
}
const body = JSON.stringify(requestBody);

const response = await fetch(url, {
method: 'POST',
Expand Down
44 changes: 44 additions & 0 deletions packages/sdk/src/client/modules/attest/jwt-provider.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,50 @@ describe('JwtProvider', () => {
expect(client.attest).toHaveBeenCalledTimes(2);
});

it('skips cache and fetches fresh token when extraData is provided', async () => {
const token1 = makeJwt(futureExp);
const token2 = makeJwt(futureExp + 1);
let callCount = 0;
const client = mockAttestClient(async () => callCount++ === 0 ? token1 : token2);
const provider = new JwtProvider(client);

const first = await provider.getToken();
expect(first).toBe(token1);

const extraData = Buffer.from('some-action-hash');
const second = await provider.getToken(extraData);
expect(second).toBe(token2);
expect(client.attest).toHaveBeenCalledTimes(2);
expect(client.attest).toHaveBeenLastCalledWith(extraData);
});

it('deduplicates concurrent extraData requests with same key', async () => {
let resolveAttest: (token: string) => void;
const client = mockAttestClient(
() => new Promise<string>((resolve) => { resolveAttest = resolve; }),
);
const provider = new JwtProvider(client);
const extraData = Buffer.from('same-action');

const p1 = provider.getToken(extraData);
const p2 = provider.getToken(extraData);
const token = makeJwt(futureExp);
resolveAttest!(token);

const results = await Promise.all([p1, p2]);
expect(results).toEqual([token, token]);
expect(client.attest).toHaveBeenCalledTimes(1);
});

it('does not deduplicate extraData requests with different keys', async () => {
const client = mockAttestClient(async () => makeJwt(futureExp));
const provider = new JwtProvider(client);

await provider.getToken(Buffer.from('action-a'));
await provider.getToken(Buffer.from('action-b'));
expect(client.attest).toHaveBeenCalledTimes(2);
});

it('clears pending promise on error so concurrent waiters also fail', async () => {
let rejectAttest: (err: Error) => void;
const client = mockAttestClient(
Expand Down
17 changes: 16 additions & 1 deletion packages/sdk/src/client/modules/attest/jwt-provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,28 @@ export class JwtProvider {
private cachedToken?: string;
private expiresAt?: number;
private pending?: Promise<string>;
private pendingExtraData = new Map<string, Promise<string>>();

constructor(attestClient: AttestClient, bufferSeconds: number = 30) {
this.attestClient = attestClient;
this.bufferSeconds = bufferSeconds;
}

async getToken(): Promise<string> {
async getToken(extraData?: Buffer): Promise<string> {
// When extraData is provided, bypass long-lived cache but deduplicate
// concurrent requests for the same extraData to avoid thundering herd
// on TEE hardware calls.
if (extraData && extraData.length > 0) {
const key = extraData.toString('hex');
const existing = this.pendingExtraData.get(key);
if (existing) return existing;
const promise = this.attestClient.attest(extraData).finally(() => {
this.pendingExtraData.delete(key);
});
this.pendingExtraData.set(key, promise);
return promise;
}

if (this.cachedToken && !this.isExpiringSoon()) {
return this.cachedToken;
}
Expand Down