Find All Unique Triplets in an Array That Sum to Zero

To solve the problem of finding all unique triplets in an array that sum to zero, we can use a two-pointer technique after sorting the array. This approach efficiently finds triplets with a time complexity of O(n2)O(n^2)O(n2).

  1. Sort the Array: Start by sorting the input array.

  2. Iterate through the Array: Use a loop to fix one element, and use two pointers (left and right) to find the other two elements.

  3. Two-Pointer Technique:

    • Set the left pointer to the next element after the fixed element and the right pointer to the end of the array.

    • Calculate the sum of the three elements.

    • If the sum is zero, add the triplet to the result list and move both pointers inward while skipping duplicates.

    • If the sum is less than zero, move the left pointer to the right to increase the sum.

    • If the sum is greater than zero, move the right pointer to the left to decrease the sum.

  4. Skip Duplicates: To ensure uniqueness, skip over duplicate elements.

Time and Space Complexity

  • Time Complexity: O(n2)O(n^2)O(n2), where n is the number of elements in the array. The outer loop runs n times, and for each iteration, the inner two-pointer search takes linear time.

  • Space Complexity: O(k)O(k)O(k), where k is the number of unique triplets found. The space used for storing the results can vary depending on the input.

Here’s how to implement this in C/C++, Java, and Python, along with the algorithm steps and complexity analysis.

Python Implementation

def three_sum(nums):
    nums.sort()
    result = []
    n = len(nums)

    for i in range(n):
        if i > 0 and nums[i] == nums[i - 1]:  # Skip duplicate elements
            continue
        left, right = i + 1, n - 1

        while left < right:
            total = nums[i] + nums[left] + nums[right]

            if total < 0:
                left += 1
            elif total > 0:
                right -= 1
            else:
                result.append((nums[i], nums[left], nums[right]))
                while left < right and nums[left] == nums[left + 1]:  # Skip duplicates
                    left += 1
                while left < right and nums[right] == nums[right - 1]:  # Skip duplicates
                    right -= 1
                left += 1
                right -= 1

    return result

# Example usage
print(three_sum([-1, 0, 1, 2, -1, -4]))

C++ Implementation

#include <iostream>
#include <vector>
#include <algorithm>

using namespace std;

vector<vector<int>> threeSum(vector<int>& nums) {
    sort(nums.begin(), nums.end());
    vector<vector<int>> result;
    int n = nums.size();

    for (int i = 0; i < n; i++) {
        if (i > 0 && nums[i] == nums[i - 1]) continue;  // Skip duplicates
        int left = i + 1, right = n - 1;

        while (left < right) {
            int total = nums[i] + nums[left] + nums[right];

            if (total < 0) {
                left++;
            } else if (total > 0) {
                right--;
            } else {
                result.push_back({nums[i], nums[left], nums[right]});
                while (left < right && nums[left] == nums[left + 1]) left++;  // Skip duplicates
                while (left < right && nums[right] == nums[right - 1]) right--;  // Skip duplicates
                left++;
                right--;
            }
        }
    }

    return result;
}

// Example usage
int main() {
    vector<int> nums = {-1, 0, 1, 2, -1, -4};
    vector<vector<int>> result = threeSum(nums);

    for (const auto& triplet : result) {
        cout << "[" << triplet[0] << ", " << triplet[1] << ", " << triplet[2] << "]\n";
    }

    return 0;
}

Java Implementation

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

public class ThreeSum {
    public List<List<Integer>> threeSum(int[] nums) {
        Arrays.sort(nums);
        List<List<Integer>> result = new ArrayList<>();
        int n = nums.length;

        for (int i = 0; i < n; i++) {
            if (i > 0 && nums[i] == nums[i - 1]) continue;  // Skip duplicates
            int left = i + 1, right = n - 1;

            while (left < right) {
                int total = nums[i] + nums[left] + nums[right];

                if (total < 0) {
                    left++;
                } else if (total > 0) {
                    right--;
                } else {
                    result.add(Arrays.asList(nums[i], nums[left], nums[right]));
                    while (left < right && nums[left] == nums[left + 1]) left++;  // Skip duplicates
                    while (left < right && nums[right] == nums[right - 1]) right--;  // Skip duplicates
                    left++;
                    right--;
                }
            }
        }

        return result;
    }

    public static void main(String[] args) {
        ThreeSum ts = new ThreeSum();
        int[] nums = {-1, 0, 1, 2, -1, -4};
        List<List<Integer>> result = ts.threeSum(nums);

        for (List<Integer> triplet : result) {
            System.out.println(triplet);
        }
    }
}