Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 90 additions & 41 deletions kms/capi/capi.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ const (
HashArg = "sha1"
StoreLocationArg = "store-location" // 'machine', 'user', etc
StoreNameArg = "store" // 'MY', 'CA', 'ROOT', etc
FriendlyNameArg = "friendly-name"
DescriptionArg = "description"
IntermediateStoreLocationArg = "intermediate-store-location"
IntermediateStoreNameArg = "intermediate-store"
KeyIDArg = "key-id"
Expand Down Expand Up @@ -91,6 +93,8 @@ type uriAttributes struct {
subjectCN string
serialNumber *big.Int
issuerName string
friendlyName string
description string
keySpec string
skipFindCertificateKey bool
pin string
Expand Down Expand Up @@ -132,6 +136,8 @@ func parseURI(rawuri string) (*uriAttributes, error) {
subjectCN: u.Get(SubjectCNArg),
serialNumber: serialNumber,
issuerName: u.Get(IssuerNameArg),
friendlyName: u.Get(FriendlyNameArg),
description: u.Get(DescriptionArg),
keySpec: u.Get(KeySpec),
skipFindCertificateKey: u.GetBool(SkipFindCertificateKey),
pin: u.Pin(),
Expand Down Expand Up @@ -392,11 +398,17 @@ func (k *CAPIKMS) getCertContext(u *uriAttributes) (*windows.CertContext, error)
0,
0,
certStoreLocation,
uintptr(unsafe.Pointer(wide(u.storeName))))
uintptr(unsafe.Pointer(wide(u.storeName))),
)
if err != nil {
return nil, fmt.Errorf("CertOpenStore for the %q store %q returned: %w", u.storeLocation, u.storeName, err)
}

// if issuer + any of the other fields in the list below is provided, then attempt a second certificate lookup when
// lookup by KeyID fails (not found). This fix an issue when looking up device certificates, as in that case the KeyID is
// derived from a randomly generate string each time agent runs, thus not being able to find certificates installed from
// a previous run.
canLookupByIssuer := u.issuerName != "" && (u.serialNumber != nil || u.subjectCN != "" || u.friendlyName != "" || u.description != "")
var handle *windows.CertContext

switch {
Expand All @@ -421,44 +433,9 @@ func (k *CAPIKMS) getCertContext(u *uriAttributes) (*windows.CertContext, error)
}
case len(u.keyID) > 0:
if handle, err = findCertificateBySubjectKeyID(st, u.keyID); err != nil {
return nil, err
}
case u.issuerName != "" && (u.serialNumber != nil || u.subjectCN != ""):
var prevCert *windows.CertContext
for {
handle, err = findCertificateInStore(st,
encodingX509ASN|encodingPKCS7,
0,
findIssuerStr,
uintptr(unsafe.Pointer(wide(u.issuerName))), prevCert)
if err != nil {
return nil, fmt.Errorf("findCertificateInStore failed: %w", err)
}

if handle == nil {
return nil, apiv1.NotFoundError{Message: fmt.Sprintf("certificate with %s=%q not found", IssuerNameArg, u.issuerName)}
}

x509Cert, err := certContextToX509(handle)
if err != nil {
return nil, fmt.Errorf("could not unmarshal certificate to DER: %w", err)
if !errors.Is(err, apiv1.NotFoundError{}) || !canLookupByIssuer {
return nil, err
}

switch {
case u.serialNumber != nil:
// TODO: Replace this search with a CERT_ID + CERT_ISSUER_SERIAL_NUMBER search instead
// https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/ns-wincrypt-cert_id
// https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/ns-wincrypt-cert_issuer_serial_number
if x509Cert.SerialNumber.Cmp(u.serialNumber) == 0 {
return handle, nil
}
case len(u.subjectCN) > 0:
if x509Cert.Subject.CommonName == u.subjectCN {
return handle, nil
}
}

prevCert = handle
}
case u.containerName != "":
key, err := k.GetPublicKey(&apiv1.GetPublicKeyRequest{
Expand All @@ -474,13 +451,75 @@ func (k *CAPIKMS) getCertContext(u *uriAttributes) (*windows.CertContext, error)
return nil, fmt.Errorf("error generating SubjectKeyID: %w", err)
}
if handle, err = findCertificateBySubjectKeyID(st, keyID); err != nil {
return nil, err
if !errors.Is(err, apiv1.NotFoundError{}) || !canLookupByIssuer {
return nil, err
}
}
default:
}

if handle != nil {
return handle, err
}

if !canLookupByIssuer {
return nil, fmt.Errorf("%q, %q, or %q and one of %q or %q is required to find a certificate", HashArg, KeyIDArg, IssuerNameArg, SerialNumberArg, SubjectCNArg)
}

return handle, err
// lookup certificate by issuer + another field (serial, CN, friendlyName, description)
var prevCert *windows.CertContext
for {
handle, err = findCertificateInStore(st,
encodingX509ASN|encodingPKCS7,
0,
findIssuerStr,
uintptr(unsafe.Pointer(wide(u.issuerName))), prevCert)
if err != nil {
return nil, fmt.Errorf("findCertificateInStore failed: %w", err)
}

if handle == nil {
return nil, apiv1.NotFoundError{Message: fmt.Sprintf("certificate with %s=%q not found", IssuerNameArg, u.issuerName)}
}

x509Cert, err := certContextToX509(handle)
if err != nil {
return nil, fmt.Errorf("could not unmarshal certificate to DER: %w", err)
}

switch {
case u.serialNumber != nil:
// TODO: Replace this search with a CERT_ID + CERT_ISSUER_SERIAL_NUMBER search instead
// https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/ns-wincrypt-cert_id
// https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/ns-wincrypt-cert_issuer_serial_number
if x509Cert.SerialNumber.Cmp(u.serialNumber) == 0 {
return handle, nil
}
case len(u.subjectCN) > 0:
if x509Cert.Subject.CommonName == u.subjectCN {
return handle, nil
}
case len(u.friendlyName) > 0:
val, err := cryptFindCertificateFriendlyName(handle)
if err != nil {
return nil, fmt.Errorf("cryptFindCertificateFriendlyName failed: %w", err)
}

if val == u.friendlyName {
return handle, nil
}
case len(u.description) > 0:
val, err := cryptFindCertificateDescription(handle)
if err != nil {
return nil, fmt.Errorf("cryptFindCertificateDescription failed: %w", err)
}

if val == u.description {
return handle, nil
}
}

prevCert = handle
}
}

// CreateSigner returns a crypto.Signer that will sign using the key passed in via the URI.
Expand Down Expand Up @@ -818,6 +857,14 @@ func (k *CAPIKMS) StoreCertificate(req *apiv1.StoreCertificateRequest) error {
cryptFindCertificateKeyProvInfo(certContext)
}

if u.friendlyName != "" {
cryptSetCertificateFriendlyName(certContext, u.friendlyName)
}

if u.description != "" {
cryptSetCertificateDescription(certContext, u.description)
}

st, err := windows.CertOpenStore(
certStoreProvSystem,
0,
Expand Down Expand Up @@ -853,6 +900,8 @@ func (k *CAPIKMS) StoreCertificateChain(req *apiv1.StoreCertificateChainRequest)
HashArg: []string{fp},
StoreLocationArg: []string{u.storeLocation},
StoreNameArg: []string{u.storeName},
FriendlyNameArg: []string{u.friendlyName},
DescriptionArg: []string{u.description},
SkipFindCertificateKey: []string{strconv.FormatBool(u.skipFindCertificateKey)},
}).String(),
Certificate: leaf,
Expand Down
101 changes: 101 additions & 0 deletions kms/capi/ncrypt_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,11 @@ const (
compareShift = 16 // CERT_COMPARE_SHIFT
compareSHA1Hash = 1 // CERT_COMPARE_SHA1_HASH
compareCertID = 16 // CERT_COMPARE_CERT_ID
compareProp = 5 // CERT_COMPARE_CERT_ID
findIssuerStr = compareNameStrW<<compareShift | infoIssuerFlag // CERT_FIND_ISSUER_STR_W
findIssuerName = compareName<<compareShift | infoIssuerFlag // CERT_FIND_ISSUER_NAME
findHash = compareSHA1Hash << compareShift // CERT_FIND_HASH
findProperty = compareProp << compareShift // CERT_FIND_PROPERTY
findCertID = compareCertID << compareShift // CERT_FIND_CERT_ID

signatureKeyUsage = 0x80 // CERT_DIGITAL_SIGNATURE_KEY_USAGE
Expand All @@ -84,6 +86,8 @@ const (
CERT_ID_SHA1_HASH = uint32(3)

CERT_KEY_PROV_INFO_PROP_ID = uint32(2)
CERT_FRIENDLY_NAME_PROP_ID = uint32(11)
CERT_DESCRIPTION_PROP_ID = uint32(13)

CERT_NAME_STR_COMMA_FLAG = uint32(0x04000000)
CERT_SIMPLE_NAME_STR = uint32(1)
Expand Down Expand Up @@ -153,6 +157,7 @@ var (
procCertFindCertificateInStore = crypt32.MustFindProc("CertFindCertificateInStore")
procCryptFindCertificateKeyProvInfo = crypt32.MustFindProc("CryptFindCertificateKeyProvInfo")
procCertGetCertificateContextProperty = crypt32.MustFindProc("CertGetCertificateContextProperty")
procCertSetCertificateContextProperty = crypt32.MustFindProc("CertSetCertificateContextProperty")
procCertStrToName = crypt32.MustFindProc("CertStrToNameW")
)

Expand Down Expand Up @@ -633,6 +638,102 @@ func cryptFindCertificateKeyContainerName(certContext *windows.CertContext) (str
return "", nil
}

func certSetCertificateContextProperty(certContext *windows.CertContext, propID uint32, pvData uintptr) error {
r0, _, err := procCertSetCertificateContextProperty.Call(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is r0? I think this could use a more descriptive name.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's actually r1, this invokes a system call on windows (whose procedure is found in a DLL), the system call returns r1, r2, error, in general, where r1 represents the return value status from the procedure stored in a register (e.g. on Linux the %rax value), the semantic value of this depends on the procedure invoked, r2 is usually not used but kept for compatibility with platforms that return status on more registers, and the 3rd value return is actually an error, on windows the error is always non nil so we must check r0(r1).
I think we can leave this low level stuff as is, or perhaps considering refactor this in another pr.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, that all makes sense. Can you add a comment to these funcs explaining it? With the information you provided above, it's clear, but without it, it's not clear what's happening.

uintptr(unsafe.Pointer(certContext)),
uintptr(propID),
0,
pvData,
)

if r0 == 0 {
return err
}
return nil
Comment on lines +649 to +652

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is err assumed to be non-nil here?

If err != nil, and r0 is != 0, is it ok to drop the error?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This invokes CertSetCertificateContextProperty which returns a bool, on success it returns true.
But since this is system call invocation the bool is cast to int, false is zero, so in that case we actually return the error.
We handle this as syscall on windows always return a non nil error even if the function succeeds

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to the above, I think comments helps a lot here.

}

func cryptSetCertificateFriendlyName(certContext *windows.CertContext, val string) error {
data := CRYPTOAPI_BLOB{
len: uint32(len(val)+1) * 2,
data: uintptr(unsafe.Pointer(wide(val))),
}

return certSetCertificateContextProperty(certContext, CERT_FRIENDLY_NAME_PROP_ID, uintptr(unsafe.Pointer(&data)))
}

func cryptSetCertificateDescription(certContext *windows.CertContext, val string) error {
data := CRYPTOAPI_BLOB{
len: uint32(len(val)+1) * 2,
data: uintptr(unsafe.Pointer(wide(val))),
}

return certSetCertificateContextProperty(certContext, CERT_DESCRIPTION_PROP_ID, uintptr(unsafe.Pointer(&data)))
}

func certGetCertificateContextProperty(certContext *windows.CertContext, propID uint32, pvData *byte, pcbData *uint32) error {
r0, _, err := procCertGetCertificateContextProperty.Call(
uintptr(unsafe.Pointer(certContext)),
uintptr(propID),
uintptr(unsafe.Pointer(pvData)),
uintptr(unsafe.Pointer(pcbData)),
)
if r0 == 0 {
return err
}
return nil
}

func cryptFindCertificateFriendlyName(certContext *windows.CertContext) (string, error) {
var size uint32

err := certGetCertificateContextProperty(certContext, CERT_FRIENDLY_NAME_PROP_ID, nil, &size)
if err != nil {
if errno, ok := err.(windows.Errno); ok && uint32(errno) == CRYPT_E_NOT_FOUND {
return "", nil
}

return "", err
}

if size == 0 {
return "", nil
}

buf := make([]byte, size)
err = certGetCertificateContextProperty(certContext, CERT_FRIENDLY_NAME_PROP_ID, &buf[0], &size)
if err != nil {
return "", err
}

uc := bytes.ReplaceAll(buf, []byte{0x00}, []byte(""))
return string(uc), nil
}

func cryptFindCertificateDescription(certContext *windows.CertContext) (string, error) {
var size uint32

err := certGetCertificateContextProperty(certContext, CERT_DESCRIPTION_PROP_ID, nil, &size)
if err != nil {
if errno, ok := err.(windows.Errno); ok && uint32(errno) == CRYPT_E_NOT_FOUND {
return "", nil
}

return "", err
}
if size == 0 {
return "", nil
}

buf := make([]byte, size)
err = certGetCertificateContextProperty(certContext, CERT_DESCRIPTION_PROP_ID, &buf[0], &size)
if err != nil {
return "", err
}

uc := bytes.ReplaceAll(buf, []byte{0x00}, []byte(""))
return string(uc), nil
}

func certStrToName(x500Str string) ([]byte, error) {
var size uint32

Expand Down
19 changes: 13 additions & 6 deletions kms/tpmkms/tpmkms.go
Original file line number Diff line number Diff line change
Expand Up @@ -870,6 +870,9 @@ func (k *TPMKMS) loadCertificateChainFromWindowsCertificateStore(req *apiv1.Load
"store": []string{store},
"intermediate-store-location": []string{intermediateCAStoreLocation},
"intermediate-store": []string{intermediateCAStore},
"issuer": []string{o.issuer},
"friendly-name": []string{o.friendlyName},
"description": []string{o.description},
}).String(),
})
}
Expand Down Expand Up @@ -967,6 +970,8 @@ func (k *TPMKMS) storeCertificateChainToWindowsCertificateStore(req *apiv1.Store
Name: uri.New("capi", url.Values{
"store-location": []string{location},
"store": []string{store},
"friendly-name": []string{o.friendlyName},
"description": []string{o.description},
"skip-find-certificate-key": []string{skipFindCertificateKey},
"intermediate-store-location": []string{intermediateCAStoreLocation},
"intermediate-store": []string{intermediateCAStore},
Expand Down Expand Up @@ -1544,9 +1549,11 @@ type deletingCertificateChainManager interface {
DeleteCertificate(req *apiv1.DeleteCertificateRequest) error
}

var _ apiv1.KeyManager = (*TPMKMS)(nil)
var _ apiv1.Attester = (*TPMKMS)(nil)
var _ apiv1.CertificateManager = (*TPMKMS)(nil)
var _ apiv1.CertificateChainManager = (*TPMKMS)(nil)
var _ deletingCertificateChainManager = (*TPMKMS)(nil)
var _ apiv1.AttestationClient = (*attestationClient)(nil)
var (
_ apiv1.KeyManager = (*TPMKMS)(nil)
_ apiv1.Attester = (*TPMKMS)(nil)
_ apiv1.CertificateManager = (*TPMKMS)(nil)
_ apiv1.CertificateChainManager = (*TPMKMS)(nil)
_ deletingCertificateChainManager = (*TPMKMS)(nil)
_ apiv1.AttestationClient = (*attestationClient)(nil)
)
6 changes: 6 additions & 0 deletions kms/tpmkms/uri.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ type objectProperties struct {
path string
storeLocation string
store string
friendlyName string
description string
intermediateStoreLocation string
intermediateStore string
skipFindCertificateKey bool
Expand Down Expand Up @@ -59,8 +61,12 @@ func parseNameURI(nameURI string) (o objectProperties, err error) {

// store location and store options are used on Windows to override
// which store(s) are used for storing and loading (intermediate) certificates
// friendly-name and description are used on Windows to populate additional certificate
// context properties to aid in retrieval
o.storeLocation = u.Get("store-location")
o.store = u.Get("store")
o.friendlyName = u.Get("friendly-name")
o.description = u.Get("description")
o.intermediateStoreLocation = u.Get("intermediate-store-location")
o.intermediateStore = u.Get("intermediate-store")
o.skipFindCertificateKey = u.GetBool("skip-find-certificate-key")
Expand Down
Loading