Given an n x n matrix where each of the rows and columns is sorted in ascending order, return the kth smallest element in the matrix.

Note that it is the kth smallest element in the sorted order, not the kth distinct element.


How to Solve

Algorithm

How to count number of elements less or equal to x efficiently?


Test Cases

Input:

(int[]) matrix = [[1,5,9],[10,11,13],[12,13,15]]
(int) k = 8

Output:

(int) 13

Explanation:

The elements in the matrix are [1,5,9,10,11,12,13,13,15], 
and the 8th smallest number is 13

Solution

class Solution {
public:
    int m, n;
    int kthSmallest(vector<vector<int>>& matrix, int k) {
        m = matrix.size(), n = matrix[0].size();
        int left = matrix[0][0], right = matrix[m-1][n-1], ans = -1;
        while (left <= right) {
            int mid = (left + right) >> 1;
            if (countLessOrEqual(matrix, mid) >= k) {
                ans = mid;
                right = mid - 1;
            } else left = mid + 1;
        }
        return ans;
    }
    int countLessOrEqual(vector<vector<int>>& matrix, int x) {
        int cnt = 0, c = n - 1;
        for (int r = 0; r < m; ++r) {
            while (c >= 0 && matrix[r][c] > x) --c;
            cnt += (c + 1);
        }
        return cnt;
    }
};
class Solution {
    int m, n;
    public int kthSmallest(int[][] matrix, int k) {
        m = matrix.length; n = matrix[0].length;
        int left = matrix[0][0], right = matrix[m-1][n-1], ans = -1;
        while (left <= right) {
            int mid = (left + right) >> 1;
            if (countLessOrEqual(matrix, mid) >= k) {
                ans = mid;
                right = mid - 1;
            } else left = mid + 1;
        }
        return ans;
    }
    int countLessOrEqual(int[][] matrix, int x) {
        int cnt = 0, c = n - 1;
        for (int r = 0; r < m; ++r) {
            while (c >= 0 && matrix[r][c] > x) --c;
            cnt += (c + 1);
        }
        return cnt;
    }
}
class Solution:
    def kthSmallest(self, matrix, k):
        m, n = len(matrix), len(matrix[0])

        def countLessOrEqual(x):
            cnt = 0
            c = n - 1
            for r in range(m):
                while c >= 0 and matrix[r][c] > x:
                    c -= 1
                cnt += (c + 1)
            return cnt

        left, right = matrix[0][0], matrix[-1][-1]
        ans = -1
        while left <= right:
            mid = (left + right) // 2
            if countLessOrEqual(mid) >= k:
                ans = mid
                right = mid - 1
            else:
                left = mid + 1

        return ans
Time Complexity: O(n+m)
Space Complexity: O(1)