diff --git a/cmd/fscrypt/errors.go b/cmd/fscrypt/errors.go index c6986739..f11ff124 100644 --- a/cmd/fscrypt/errors.go +++ b/cmd/fscrypt/errors.go @@ -46,7 +46,6 @@ var ( ErrCanceled = errors.New("operation canceled") ErrNoDesctructiveOps = errors.New("operation would be destructive") ErrMaxPassphrase = util.SystemError("max passphrase length exceeded") - ErrPAMPassphrase = errors.New("incorrect login passphrase") ErrInvalidSource = errors.New("invalid source type") ErrPassphraseMismatch = errors.New("entered passphrases do not match") ErrSpecifyProtector = errors.New("multiple protectors available") @@ -59,6 +58,7 @@ var ( ErrBadOwners = errors.New("you do not own this directory") ErrNotEmptyDir = errors.New("not an empty directory") ErrNotPassphrase = errors.New("protector does not use a passphrase") + ErrUnknownUser = errors.New("unknown user") ) var loadHelpText = fmt.Sprintf("You may need to mount a linked filesystem. Run with %s for more information.", shortDisplay(verboseFlag)) diff --git a/cmd/fscrypt/keys.go b/cmd/fscrypt/keys.go index 820ddec4..65360a9c 100644 --- a/cmd/fscrypt/keys.go +++ b/cmd/fscrypt/keys.go @@ -125,7 +125,7 @@ func makeKeyFunc(supportRetry, shouldConfirm bool, prefix string) actions.KeyFun switch info.Source() { case metadata.SourceType_pam_passphrase: prompt := fmt.Sprintf("Enter %slogin passphrase for %s: ", - prefix, getUsername(info.UID())) + prefix, formatUsername(info.UID())) key, err := getPassphraseKey(prompt) if err != nil { return nil, err @@ -134,15 +134,16 @@ func makeKeyFunc(supportRetry, shouldConfirm bool, prefix string) actions.KeyFun // To confirm, check that the passphrase is the user's // login passphrase. if shouldConfirm { - username := getUsername(info.UID()) - ok, err := pam.IsUserLoginToken(username, key) + username, err := usernameFromID(info.UID()) if err != nil { key.Wipe() return nil, err } - if !ok { + + err = pam.IsUserLoginToken(username, key, quietFlag.Value) + if err != nil { key.Wipe() - return nil, ErrPAMPassphrase + return nil, err } } return key, nil diff --git a/cmd/fscrypt/prompt.go b/cmd/fscrypt/prompt.go index fdbef814..52f8c478 100644 --- a/cmd/fscrypt/prompt.go +++ b/cmd/fscrypt/prompt.go @@ -27,6 +27,8 @@ import ( "strconv" "strings" + "github.com/pkg/errors" + "github.com/google/fscrypt/actions" "github.com/google/fscrypt/metadata" "github.com/google/fscrypt/util" @@ -106,21 +108,31 @@ func askConfirmation(question string, defaultChoice bool, warning string) error return nil } -// getUsername returns the username for the provided UID. If the UID does not -// correspond to a user or the username is blank, "UID=" is returned. -func getUsername(uid int64) string { +// usernameFromID returns the username for the provided UID. If the UID does not +// correspond to a user or the username is blank, an error is returned. +func usernameFromID(uid int64) (string, error) { u, err := user.LookupId(strconv.Itoa(int(uid))) if err != nil || u.Username == "" { - return fmt.Sprintf("UID=%d", uid) + return "", errors.Wrapf(ErrUnknownUser, "uid %d", uid) + } + return u.Username, nil +} + +// formatUsername either returns the username for the provided UID, or a string +// containing the error for unknown UIDs. +func formatUsername(uid int64) string { + username, err := usernameFromID(uid) + if err != nil { + return fmt.Sprintf("[%v]", err) } - return u.Username + return username } // formatInfo gives a string description of metadata.ProtectorData. func formatInfo(data actions.ProtectorInfo) string { switch data.Source() { case metadata.SourceType_pam_passphrase: - return "login protector for " + getUsername(data.UID()) + return "login protector for " + formatUsername(data.UID()) case metadata.SourceType_custom_passphrase: return fmt.Sprintf("custom protector %q", data.Name()) case metadata.SourceType_raw_key: diff --git a/crypto/key.go b/crypto/key.go index cffe2b4d..e440ca1a 100644 --- a/crypto/key.go +++ b/crypto/key.go @@ -20,6 +20,12 @@ package crypto +/* +#include +#include +*/ +import "C" + import ( "bytes" "crypto/subtle" @@ -148,13 +154,6 @@ func (key *Key) Len() int { return len(key.data) } -// UnsafeData exposes the underlying protected slice. This is unsafe because the -// data can be paged to disk if the buffer is copied, or the slice may be -// wiped while being used. -func (key *Key) UnsafeData() []byte { - return key.data -} - // Equals compares the contents of two keys, returning true if they have the same // key data. This function runs in constant time. func (key *Key) Equals(key2 *Key) bool { @@ -178,6 +177,30 @@ func (key *Key) resize(requestedSize int) (*Key, error) { return resizedKey, nil } +// UnsafeToCString makes a copy of the string's data into a null-terminated C +// string allocated by C. Note that this method is unsafe as this C copy has no +// locking or wiping functionality. The key shouldn't contain any `\0` bytes. +func (key *Key) UnsafeToCString() unsafe.Pointer { + // Memory for the key must be moved into a C string allocated by C. + size := C.size_t(key.Len()) + data := C.calloc(size+1, 1) + C.memcpy(data, util.Ptr(key.data), size) + return data +} + +// NewKeyFromCString creates of a copy of some C string's data in a key. Note +// that the original C string is not modified at all, so steps must be taken to +// ensure that this original copy is secured. +func NewKeyFromCString(str unsafe.Pointer) (*Key, error) { + size := C.strlen((*C.char)(str)) + key, err := newBlankKey(int(size)) + if err != nil { + return nil, err + } + C.memcpy(util.Ptr(key.data), str, size) + return key, nil +} + // NewKeyFromReader constructs a key of abritary length by reading from reader // until hitting EOF. func NewKeyFromReader(reader io.Reader) (*Key, error) { diff --git a/pam/constants.go b/pam/constants.go new file mode 100644 index 00000000..5c57e063 --- /dev/null +++ b/pam/constants.go @@ -0,0 +1,110 @@ +/* + * constants.go - PAM flags and item types from github.com/msteinert/pam + * + * Modifications Copyright 2017 Google Inc. + * Modifications Author: Joe Richey (joerichey@google.com) + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +/* + * Copyright 2011, krockot + * Copyright 2015, Michael Steinert + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * * Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * * Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +package pam + +/* +#cgo LDFLAGS: -lpam + +#include +*/ +import "C" + +// Item is a an PAM information type. +type Item int + +// PAM Item types. +const ( + // Service is the name which identifies the PAM stack. + Service Item = C.PAM_SERVICE + // User identifies the username identity used by a service. + User = C.PAM_USER + // Tty is the terminal name. + Tty = C.PAM_TTY + // Rhost is the requesting host name. + Rhost = C.PAM_RHOST + // Authtok is the currently active authentication token. + Authtok = C.PAM_AUTHTOK + // Oldauthtok is the old authentication token. + Oldauthtok = C.PAM_OLDAUTHTOK + // Ruser is the requesting user name. + Ruser = C.PAM_RUSER + // UserPrompt is the string use to prompt for a username. + UserPrompt = C.PAM_USER_PROMPT +) + +// Flag is used as input to various PAM functions. Flags can be combined with a +// bitwise or. Refer to the official PAM documentation for which flags are +// accepted by which functions. +type Flag int + +// PAM Flag types. +const ( + // Silent indicates that no messages should be emitted. + Silent Flag = C.PAM_SILENT + // DisallowNullAuthtok indicates that authorization should fail + // if the user does not have a registered authentication token. + DisallowNullAuthtok = C.PAM_DISALLOW_NULL_AUTHTOK + // EstablishCred indicates that credentials should be established + // for the user. + EstablishCred = C.PAM_ESTABLISH_CRED + // DeleteCred inidicates that credentials should be deleted. + DeleteCred = C.PAM_DELETE_CRED + // ReinitializeCred indicates that credentials should be fully + // reinitialized. + ReinitializeCred = C.PAM_REINITIALIZE_CRED + // RefreshCred indicates that the lifetime of existing credentials + // should be extended. + RefreshCred = C.PAM_REFRESH_CRED + // ChangeExpiredAuthtok indicates that the authentication token + // should be changed if it has expired. + ChangeExpiredAuthtok = C.PAM_CHANGE_EXPIRED_AUTHTOK + // PrelimCheck indicates that the modules are being probed as to their + // ready status for altering the user's authentication token. + PrelimCheck = C.PAM_PRELIM_CHECK + // UpdateAuthtok informs the module that this is the call it should + // change the authorization tokens. + UpdateAuthtok = C.PAM_UPDATE_AUTHTOK +) diff --git a/pam/login.go b/pam/login.go index 2d792233..e89ee019 100644 --- a/pam/login.go +++ b/pam/login.go @@ -23,17 +23,12 @@ // See http://www.linux-pam.org/Linux-PAM-html/ for more information. package pam -/* -#cgo LDFLAGS: -lpam -#include -#include "pam.h" -*/ import "C" import ( + "fmt" "log" "sync" - "unsafe" "github.com/pkg/errors" @@ -41,82 +36,78 @@ import ( "github.com/google/fscrypt/util" ) +// Pam error values +var ( + ErrPAMPassphrase = errors.New("incorrect login passphrase") +) + // Global state is needed for the PAM callback, so we guard this function with a -// lock. tokenToCheck is only ever non-nil when loginLock is held. +// lock. tokenToCheck is only ever non-nil when tokenLock is held. var ( - ErrPamInternal = util.SystemError("internal pam error") - loginLock sync.Mutex - tokenToCheck *crypto.Key + tokenLock sync.Mutex + tokenToCheck *crypto.Key ) -// unexpectedMessage logs an error encountered in the PAM callback. -//export unexpectedMessage -func unexpectedMessage(msg *C.char) { - log.Printf("pam encountered unexpected %q", C.GoString(msg)) +// userInput is run when the the callback needs some input from the user. We +// prompt the user for information and return their answer. A return value of +// nil indicates an error occurred. +//export userInput +func userInput(prompt *C.char) *C.char { + fmt.Print(C.GoString(prompt)) + input, err := util.ReadLine() + if err != nil { + log.Printf("getting input for PAM: %s", err) + return nil + } + return C.CString(input) } -// pamInput is run when the PAM module needs some input from the user. The -// message parameter is the prompt that would be displayed to the user. -//export pamInput -func pamInput(msg *C.char) *C.char { - log.Printf("requesting secret data with %q", C.GoString(msg)) - - // Memory for the key must be moved into a C string allocated by C. - cLen := C.size_t(tokenToCheck.Len()) - cData := C.malloc(cLen + 1) +// passphraseInput is run when the callback needs a passphrase from the user. We +// pass along the tokenToCheck without prompting. A return value of nil +// indicates an error occurred. +//export passphraseInput +func passphraseInput(prompt *C.char) *C.char { + log.Printf("getting secret data for PAM: %q", C.GoString(prompt)) + if tokenToCheck == nil { + log.Print("secret data requested multiple times") + return nil + } - // View the cData as a go slice - goData := (*[1 << 30]byte)(cData) - copy(goData[:cLen], tokenToCheck.UnsafeData()) - goData[cLen] = 0 // Null terminator - return (*C.char)(cData) + // Subsequent calls to passphrase input should fail + input := (*C.char)(tokenToCheck.UnsafeToCString()) + tokenToCheck = nil + return input } -// IsUserLoginToken returns true if the presented token is the user's login key, -// false if it is not their login key, and an error if this cannot be -// determined. Note that unless the currently running process is root, this -// check will only work for the user running this process. -func IsUserLoginToken(username string, token *crypto.Key) (_ bool, err error) { +// IsUserLoginToken returns nil if the presented token is the user's login key, +// and returns an error otherwise. Note that unless we are currently running as +// root, this check will only work for the user running this process. +func IsUserLoginToken(username string, token *crypto.Key, quiet bool) error { log.Printf("Checking login token for %s", username) + // We require global state for the function. This function never takes // ownership of the token, so it is not responsible for wiping it. - loginLock.Lock() + tokenLock.Lock() tokenToCheck = token defer func() { tokenToCheck = nil - loginLock.Unlock() + tokenLock.Unlock() }() - cUsername := C.CString(username) - defer C.free(unsafe.Pointer(cUsername)) - - var conv C.struct_pam_conv - var handle *C.struct_pam_handle - C.pam_init(&conv) - - // Start the pam transaction with the desired conversation and handle. - returnCode := C.pam_start(C.fscrypt_service, cUsername, &conv, &handle) - if returnCode != C.PAM_SUCCESS { - return false, errors.Wrapf(ErrPamInternal, "pam_start() = %d", returnCode) + transaction, err := Start("fscrypt", username) + if err != nil { + return err } + defer transaction.End() - defer func() { - // End the PAM transaction, setting the error if appropriate. - returnCode = C.pam_end(handle, returnCode) - if returnCode != C.PAM_SUCCESS && err == nil { - err = errors.Wrapf(ErrPamInternal, "pam_end() = %d", returnCode) - } - }() + // Ask PAM to authenticate the token. + authenticated, err := transaction.Authenticate(quiet) + if err != nil { + return err + } - // Ask PAM to authenticate the token. We either get an answer or an error - returnCode = C.pam_authenticate(handle, 0) - switch returnCode { - case C.PAM_SUCCESS: - return true, nil - case C.PAM_AUTH_ERR: - return false, nil - default: - // PAM didn't give us an answer to the authentication question - return false, errors.Wrapf(ErrPamInternal, "pam_authenticate() = %d", returnCode) + if !authenticated { + return ErrPAMPassphrase } + return nil } diff --git a/pam/pam.c b/pam/pam.c index ce640e81..e32770fa 100644 --- a/pam/pam.c +++ b/pam/pam.c @@ -21,13 +21,15 @@ #include #include +#include -#include "_cgo_export.h" // for pamInput callback +#include +#include // mlock/munlock -const char* fscrypt_service = "fscrypt"; +#include "_cgo_export.h" // for input callbacks -static int pam_conv(int num_msg, const struct pam_message** msg, - struct pam_response** resp, void* appdata_ptr) { +static int conversation(int num_msg, const struct pam_message** msg, + struct pam_response** resp, void* appdata_ptr) { if (num_msg <= 0 || num_msg > PAM_MAX_NUM_MSG) { return PAM_CONV_ERR; } @@ -49,16 +51,14 @@ static int pam_conv(int num_msg, const struct pam_message** msg, // we just print the error messages or text info to standard output. switch (msg[i]->msg_style) { case PAM_PROMPT_ECHO_OFF: - callback_resp = pamInput(callback_msg); + callback_resp = passphraseInput(callback_msg); break; case PAM_PROMPT_ECHO_ON: - // We should never have a request for non-secret data - unexpectedMessage(callback_msg); - callback_resp = NULL; + callback_resp = userInput(callback_msg); break; case PAM_ERROR_MSG: case PAM_TEXT_INFO: - printf("%s\n", callback_msg); + fprintf(stderr, "%s\n", callback_msg); continue; } @@ -69,12 +69,41 @@ static int pam_conv(int num_msg, const struct pam_message** msg, free((*resp)[i].resp); } free(*resp); + *resp = NULL; return PAM_CONV_ERR; } (*resp)[i].resp = callback_resp; } + return PAM_SUCCESS; } -void pam_init(struct pam_conv* conv) { conv->conv = pam_conv; } +const struct pam_conv conv = {conversation, NULL}; + +void freeData(pam_handle_t* pamh, void* data, int error_status) { free(data); } + +void freeArray(pam_handle_t* pamh, void** array, int error_status) { + int i; + for (i = 0; array[i]; ++i) { + free(array[i]); + } + free(array); +} + +void* copyIntoSecret(void* data) { + size_t size = strlen(data) + 1; // include null terminator + void* copy = malloc(size); + mlock(copy, size); + memcpy(copy, data, size); + return copy; +} + +void freeSecret(pam_handle_t* pamh, char* data, int error_status) { + size_t size = strlen(data) + 1; // Include null terminator + // Use volitile function pointer to actually clear the memory. + static void* (*const volatile memset_sec)(void*, int, size_t) = &memset; + memset_sec(data, 0, size); + munlock(data, size); + free(data); +} \ No newline at end of file diff --git a/pam/pam.go b/pam/pam.go new file mode 100644 index 00000000..010d4d27 --- /dev/null +++ b/pam/pam.go @@ -0,0 +1,190 @@ +/* + * pam.go - Utility functions for interfacing with the PAM libraries. + * + * Copyright 2017 Google Inc. + * Author: Joe Richey (joerichey@google.com) + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package pam + +/* +#cgo LDFLAGS: -lpam +#include "pam.h" + +#include +#include +#include +*/ +import "C" +import ( + "errors" + "fmt" + "unsafe" + + "github.com/google/fscrypt/util" +) + +// Handle wraps the C pam_handle_t type. This is used from within modules. +type Handle struct { + handle *C.pam_handle_t + status C.int +} + +// NewHandle creates a Handle from a raw pointer. +func NewHandle(pamh unsafe.Pointer) *Handle { + return &Handle{ + handle: (*C.pam_handle_t)(pamh), + status: C.PAM_SUCCESS, + } +} + +func (h *Handle) setData(name string, data unsafe.Pointer, cleanup C.CleanupFunc) error { + cName := C.CString(name) + defer C.free(unsafe.Pointer(cName)) + h.status = C.pam_set_data(h.handle, cName, data, cleanup) + return h.err() +} + +func (h *Handle) getData(name string) (unsafe.Pointer, error) { + var data unsafe.Pointer + cName := C.CString(name) + defer C.free(unsafe.Pointer(cName)) + h.status = C.pam_get_data(h.handle, cName, &data) + return data, h.err() +} + +func (h *Handle) SetSecret(name string, secret unsafe.Pointer) error { + return h.setData(name, C.copyIntoSecret(secret), C.CleanupFunc(C.freeSecret)) +} + +func (h *Handle) GetSecret(name string) (unsafe.Pointer, error) { + return h.getData(name) +} + +func (h *Handle) ClearSecret(name string) error { + return h.setData(name, unsafe.Pointer(C.CString("")), C.CleanupFunc(C.freeData)) +} + +func (h *Handle) SetString(name string, s string) error { + return h.setData(name, unsafe.Pointer(C.CString(s)), C.CleanupFunc(C.freeData)) +} + +func (h *Handle) GetString(name string) (string, error) { + data, err := h.getData(name) + if err != nil { + return "", err + } + return C.GoString((*C.char)(data)), nil +} + +func (h *Handle) SetSlice(name string, slice []string) error { + sliceLength := uintptr(len(slice)) + memorySize := (sliceLength + 1) * unsafe.Sizeof(uintptr(0)) + data := C.malloc(C.size_t(memorySize)) + + cSlice := util.PointerSlice(data) + for i, str := range slice { + cSlice[i] = unsafe.Pointer(C.CString(str)) + } + cSlice[sliceLength] = nil + + return h.setData(name, data, C.CleanupFunc(C.freeArray)) +} + +func (h *Handle) GetSlice(name string) ([]string, error) { + data, err := h.getData(name) + if err != nil { + return nil, err + } + + var slice []string + for _, cString := range util.PointerSlice(data) { + if cString == nil { + return slice, nil + } + slice = append(slice, C.GoString((*C.char)(cString))) + } + panic("We will never get here") +} + +// GetItem retrieves a PAM information item. This a pointer directory to the +// data, so it shouldn't be modified. +func (h *Handle) GetItem(i Item) (unsafe.Pointer, error) { + var data unsafe.Pointer + h.status = C.pam_get_item(h.handle, C.int(i), &data) + return data, h.err() +} + +// GetUID retrieves the UID of the corresponding PAM_USER. +func (h *Handle) GetUID() (int64, error) { + var pamUsername *C.char + h.status = C.pam_get_user(h.handle, &pamUsername, nil) + if err := h.err(); err != nil { + return 0, err + } + + pwd := C.getpwnam(pamUsername) + if pwd == nil { + return 0, fmt.Errorf("unknown user %q", C.GoString(pamUsername)) + } + return int64(pwd.pw_uid), nil +} + +func (h *Handle) err() error { + if h.status == C.PAM_SUCCESS { + return nil + } + s := C.GoString(C.pam_strerror(h.handle, C.int(h.status))) + return errors.New(s) +} + +// Transaction represents a wrapped pam_handle_t type created with pam_start +// form an application. +type Transaction Handle + +// Start initializes a pam Transaction. End() should be called after the +// Transaction is no longer needed. +func Start(service, username string) (*Transaction, error) { + cService := C.CString(service) + defer C.free(unsafe.Pointer(cService)) + cUsername := C.CString(username) + defer C.free(unsafe.Pointer(cUsername)) + + t := &Transaction{ + handle: nil, + status: C.PAM_SUCCESS, + } + t.status = C.pam_start(cService, cUsername, &C.conv, &t.handle) + return t, (*Handle)(t).err() +} + +// End finalizes a pam Transaction with pam_end(). +func (t *Transaction) End() { + C.pam_end(t.handle, t.status) +} + +// Authenticate returns a boolean indicating if the user authenticated correctly +// or not. If the authentication check did not complete, an error is returned. +func (t *Transaction) Authenticate(quiet bool) (bool, error) { + var flags C.int = C.PAM_DISALLOW_NULL_AUTHTOK + if quiet { + flags |= C.PAM_SILENT + } + t.status = C.pam_authenticate(t.handle, flags) + if t.status == C.PAM_AUTH_ERR { + return false, nil + } + return true, (*Handle)(t).err() +} diff --git a/pam/pam.h b/pam/pam.h index 83ef2a9d..9f3cdb2c 100644 --- a/pam/pam.h +++ b/pam/pam.h @@ -22,10 +22,23 @@ #include -// fscrypt_service is the display name of the service requesting the passphrase. -const char* fscrypt_service; +// Conversation that will call back into Go code when appropriate. +const struct pam_conv conv; -// pam_init initializes the pam_conv structure for use with our Go callbacks. -void pam_init(struct pam_conv* conv); +// CleaupFuncs are used to cleanup specific PAM data. +typedef void (*CleanupFunc)(pam_handle_t *pamh, void *data, int error_status); -#endif +// CleaupFunc that calls free() on data. +void freeData(pam_handle_t *pamh, void *data, int error_status); + +// CleaupFunc that frees each item in a null terminated array of pointers and +// then frees the array itself. +void freeArray(pam_handle_t *pamh, void **array, int error_status); + +// Creates a copy of a C string, which resides in an locked buffer. +void *copyIntoSecret(void *data); + +// CleaupFunc that Zeros wipes a C string and unlocks and frees its memory. +void freeSecret(pam_handle_t *pamh, char *data, int error_status); + +#endif // FSCRYPT_PAM_H diff --git a/util/util.go b/util/util.go index 32e5c065..792b66c0 100644 --- a/util/util.go +++ b/util/util.go @@ -37,6 +37,18 @@ func Ptr(slice []byte) unsafe.Pointer { return unsafe.Pointer(&slice[0]) } +// ByteSlice takes a pointer to some data and views it as a slice of bytes. +// Note, indexing into this slice is unsafe. +func ByteSlice(ptr unsafe.Pointer) []byte { + return (*[1 << 30]byte)(ptr)[:] +} + +// PointerSlice takes a pointer to an array of pointers and views it as a slice +// of pointers. Note, indexing into this slice is unsafe. +func PointerSlice(ptr unsafe.Pointer) []unsafe.Pointer { + return (*[1 << 30]unsafe.Pointer)(ptr)[:] +} + // Index returns the first index i such that inVal == inArray[i]. // ok is true if we find a match, false otherwise. func Index(inVal int64, inArray []int64) (index int, ok bool) { diff --git a/util/util_test.go b/util/util_test.go index 33ce2ffd..7739edd0 100644 --- a/util/util_test.go +++ b/util/util_test.go @@ -20,19 +20,52 @@ package util import ( + "bytes" "testing" + "unsafe" ) const offset = 3 +var ( + byteArr = []byte{'a', 'b', 'c', 'd'} + ptrArr = []*int{&a, &b, &c, &d} + a = 1 + b = 2 + c = 3 + d = 4 +) + // Make sure the address behaves well under slicing func TestPtrOffset(t *testing.T) { - arr := []byte{'a', 'b', 'c', 'd'} - i1 := uintptr(Ptr(arr[offset:])) - i2 := uintptr(Ptr(arr)) + i1 := uintptr(Ptr(byteArr[offset:])) + i2 := uintptr(Ptr(byteArr)) if i1 != i2+offset { - t.Fatalf("pointers %v and %v do not have an offset of %v", i1, i2, offset) + t.Errorf("pointers %v and %v do not have an offset of %v", i1, i2, offset) + } +} + +// Tests that the ByteSlice method essentially reverses the Ptr method +func TestByteSlice(t *testing.T) { + ptr := Ptr(byteArr) + generatedArr := ByteSlice(ptr)[:len(byteArr)] + + if !bytes.Equal(byteArr, generatedArr) { + t.Errorf("generated array (%v) and original array (%v) do not agree", + generatedArr, byteArr) + } +} + +// Tests that the PointerSlice method correctly handles Go Pointers +func TestPointerSlice(t *testing.T) { + arrPtr := unsafe.Pointer(&ptrArr[0]) + + // Convert an array of unsafe pointers to int pointers. + for i, ptr := range PointerSlice(arrPtr)[:len(ptrArr)] { + if ptrArr[i] != (*int)(ptr) { + t.Errorf("generated array and original array disagree at %d", i) + } } }