## 378. Kth Smallest Element in a Sorted Matrix

Given a n x n matrix where each of the rows and columns are sorted in ascending order, find the kth smallest element in the matrix.

Note that it is the kth smallest element in the sorted order, not the kth distinct element.

Example:

``````matrix = [
[ 1,  5,  9],
[10, 11, 13],
[12, 13, 15]
],
k = 8,

return 13.
``````

Note:
You may assume k is always valid, 1 ≤ k ≤ n^2.

Thoughts:

1. Heap:
1. Build a minHeap of elements from the first row.
2. Repeat for k - 1 times: poll out the current element from the queue, get its row and col number. if row is not at the end (n-1), push a new element at the same column, next row into the queue
3. return the kth smallest value by pop out of the queue
2. Binary Search:

1. binary search the value in matrix to get its rank on the matrix:

2. if rank < k: low = mid + 1

3. hi = mid

4. return lo(or hi) in the end

Code: Heap T: O(k*logn); S: O(n)

``````class Solution(object):
def kthSmallest(self, matrix, k):
"""
:type matrix: List[List[int]]
:type k: int
:rtype: int
"""
n = len(matrix)
pq = []
for j in range(n):
heapq.heappush(pq, Element(0, j, matrix[0][j]))

for i in range(k - 1):
e = heapq.heappop(pq)
if e.row == n - 1: continue
heapq.heappush(pq, Element(e.row + 1, e.col, matrix[e.row + 1][e.col]))

return heapq.heappop(pq).val

class Element(object):
def __init__(self, row, col, val):
self.row = row
self.col = col
self.val = val

def __lt__(self, other):
return self.val < other.val

def __eq__(self, other):
return self.val == other.val
``````

Code: Binary Search: T:O(n*log(max - min));S: O(1)

``````class Solution {
public int kthSmallest(int[][] matrix, int k) {
if (matrix == null || matrix.length == 0 || matrix[0].length == 0) return 0;
int m = matrix.length, n = matrix[0].length;
int l = matrix[0][0], h = matrix[m - 1][n-1];
while(l < h){
int mid = l + (h - l >> 1);
int cnt = 0, j = n - 1;
// rank mid value in matrix
for(int i = 0; i < m; i++){
while(j >=0 && matrix[i][j] > mid) j --; // since column is ascending
cnt+= (j + 1);                    // so matrix[i + 1][j] >= matrix[i][j] -> do not have to
// reset j here
}

if(cnt < k) l = mid + 1;
else h = mid;
}

return l; // or h
}
}
``````