diff --git a/crypto/ecc.c b/crypto/ecc.c index 8facafd67802..9d24e6522730 100644 --- a/crypto/ecc.c +++ b/crypto/ecc.c @@ -904,28 +904,41 @@ static inline void ecc_swap_digits(const u64 *in, u64 *out, out[i] = __swab64(in[ndigits - 1 - i]); } +static int __ecc_is_key_valid(const struct ecc_curve *curve, + const u64 *private_key, unsigned int ndigits) +{ + u64 one[ECC_MAX_DIGITS] = { 1, }; + u64 res[ECC_MAX_DIGITS]; + + if (!private_key) + return -EINVAL; + + if (curve->g.ndigits != ndigits) + return -EINVAL; + + /* Make sure the private key is in the range [2, n-3]. */ + if (vli_cmp(one, private_key, ndigits) != -1) + return -EINVAL; + vli_sub(res, curve->n, one, ndigits); + vli_sub(res, res, one, ndigits); + if (vli_cmp(res, private_key, ndigits) != 1) + return -EINVAL; + + return 0; +} + int ecc_is_key_valid(unsigned int curve_id, unsigned int ndigits, const u64 *private_key, unsigned int private_key_len) { int nbytes; const struct ecc_curve *curve = ecc_get_curve(curve_id); - if (!private_key) - return -EINVAL; - nbytes = ndigits << ECC_DIGITS_TO_BYTES_SHIFT; if (private_key_len != nbytes) return -EINVAL; - if (vli_is_zero(private_key, ndigits)) - return -EINVAL; - - /* Make sure the private key is in the range [1, n-1]. */ - if (vli_cmp(curve->n, private_key, ndigits) != 1) - return -EINVAL; - - return 0; + return __ecc_is_key_valid(curve, private_key, ndigits); } /* @@ -971,11 +984,8 @@ int ecc_gen_privkey(unsigned int curve_id, unsigned int ndigits, u64 *privkey) if (err) return err; - if (vli_is_zero(priv, ndigits)) - return -EINVAL; - - /* Make sure the private key is in the range [1, n-1]. */ - if (vli_cmp(curve->n, priv, ndigits) != 1) + /* Make sure the private key is in the valid range. */ + if (__ecc_is_key_valid(curve, priv, ndigits)) return -EINVAL; ecc_swap_digits(priv, privkey, ndigits);