/*
 * Copyright (c) 2023, 2025, Oracle and/or its affiliates. All rights reserved.
 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
 *
 * This code is free software; you can redistribute it and/or modify it
 * under the terms of the GNU General Public License version 2 only, as
 * published by the Free Software Foundation.
 *
 * This code is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 * version 2 for more details (a copy is included in the LICENSE file that
 * accompanied this code).
 *
 * You should have received a copy of the GNU General Public License version
 * 2 along with this work; if not, write to the Free Software Foundation,
 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 *
 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 * or visit www.oracle.com if you need additional information or have any
 * questions.
 */

import java.io.IOException;
import java.io.UncheckedIOException;
import java.net.DatagramSocket;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.ServerSocket;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpClient.Version;
import java.net.http.HttpRequest;
import java.net.http.HttpOption.Http3DiscoveryMode;
import java.net.http.HttpOption;
import java.net.http.HttpResponse;
import java.time.Duration;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.Future;
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Stream;

import jdk.httpclient.test.lib.common.HttpServerAdapters;
import jdk.internal.net.http.common.OperationTrackers.Tracker;
import jdk.test.lib.net.SimpleSSLContext;
import jdk.test.lib.net.URIBuilder;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;

import static java.net.http.HttpClient.Version.HTTP_1_1;
import static java.net.http.HttpClient.Version.HTTP_2;
import static java.net.http.HttpClient.Version.HTTP_3;
import static java.net.http.HttpOption.Http3DiscoveryMode.ANY;
import static java.net.http.HttpOption.Http3DiscoveryMode.HTTP_3_URI_ONLY;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;

/*
 * @test
 * @bug 8316580
 * @library /test/lib /test/jdk/java/net/httpclient/lib
 * @build HttpGetInCancelledFuture ReferenceTracker
 * @run junit/othervm -DuseReferenceTracker=false
 *                   HttpGetInCancelledFuture
 * @run junit/othervm -DuseReferenceTracker=true
 *                   HttpGetInCancelledFuture
 * @summary This test verifies that cancelling a future that
 * does an HTTP request using the HttpClient doesn't cause
 * HttpClient::close to block forever.
 */
public class HttpGetInCancelledFuture {

    static final boolean useTracker = Boolean.getBoolean("useReferenceTracker");

    static final class TestException extends RuntimeException {
        public TestException(String message, Throwable cause) {
            super(message, cause);
        }
    }

    static ReferenceTracker TRACKER = ReferenceTracker.INSTANCE;

    HttpClient makeClient(URI uri, Version version, Executor executor) {
        var builder = version == HTTP_3
                ? HttpServerAdapters.createClientBuilderForH3()
                : HttpClient.newBuilder();
        if (uri.getScheme().equalsIgnoreCase("https")) {
            builder.sslContext(SimpleSSLContext.findSSLContext());
        }
        return builder.connectTimeout(Duration.ofSeconds(1))
                .executor(executor)
                .version(version)
                .build();
    }

    record TestCase(String url, int reqCount, Version version, Http3DiscoveryMode config) {
        TestCase(String url, int reqCount, Version version) {
            this(url, reqCount, version, null);
        }
        TestCase(String url, int reqCount, Http3DiscoveryMode config) {
            this(url, reqCount, HTTP_3, null);
        }
    }


    // A server that doesn't accept
    static volatile ServerSocket   NOT_ACCEPTING;
    static volatile DatagramSocket NOT_RESPONDING;

    static List<TestCase> parameters() {
        ServerSocket ss = NOT_ACCEPTING;
        if (ss == null) {
            synchronized (HttpGetInCancelledFuture.class) {
                if ((ss = NOT_ACCEPTING) == null) {
                    try {
                        ss = new ServerSocket();
                        var loopback = InetAddress.getLoopbackAddress();
                        ss.bind(new InetSocketAddress(loopback, 0), 10);
                        NOT_ACCEPTING = ss;
                    } catch (IOException io) {
                        throw new UncheckedIOException(io);
                    }
                }
            }
        }

        DatagramSocket ds = NOT_RESPONDING;
        boolean sameport = false;
        if (ds == null) {
            synchronized (HttpGetInCancelledFuture.class) {
                if ((ds = NOT_RESPONDING) == null) {
                    try {
                        var loopback = InetAddress.getLoopbackAddress();
                        try {
                            ds = new DatagramSocket(new InetSocketAddress(loopback, ss.getLocalPort()));
                            sameport = true;
                        } catch (IOException io) {
                            ds = new DatagramSocket(new InetSocketAddress(loopback,0));
                        }
                        NOT_RESPONDING = ds;
                    } catch (IOException io) {
                        throw new UncheckedIOException(io);
                    }
                }
            }
        }

        URI http = URIBuilder.newBuilder()
                .loopback()
                .scheme("http")
                .port(ss.getLocalPort())
                .path("/not-accepting/")
                .buildUnchecked();
        URI https = URIBuilder.newBuilder()
                .loopback()
                .scheme("https")
                .port(ss.getLocalPort())
                .path("/not-accepting/")
                .buildUnchecked();
        URI https3 = URIBuilder.newBuilder()
                .loopback()
                .scheme("https")
                .port(ds.getLocalPort())
                .path("/not-responding/")
                .buildUnchecked();
        // use all HTTP versions, without and with TLS
        var def = Stream.of(
                new TestCase(https3.toString(), 200, HTTP_3_URI_ONLY),
                new TestCase(http.toString(), 200, HTTP_2),
                new TestCase(http.toString(), 200, HTTP_1_1),
                new TestCase(https.toString(), 200, HTTP_2),
                new TestCase(https.toString(), 200, HTTP_1_1)
                );
        var first = sameport
                ? Stream.of(new TestCase(https3.toString(), 200, ANY))
                : Stream.<TestCase>empty();
        var cases= Stream.concat(first, def);
        return cases.toList();
    }

    @ParameterizedTest
    @MethodSource("parameters")
    void runTest(TestCase test) {
        System.out.println("Testing with: " + test);
        runTest(test.url, test.reqCount, test.version);
    }

    static class TestTaskScope implements AutoCloseable {
        final ExecutorService pool = new ForkJoinPool();
        final Map<Task<?>, Future<?>> tasks = new ConcurrentHashMap<>();
        final AtomicReference<ExecutionException> failed = new AtomicReference<>();

        class Task<T> implements Callable<T> {
            final Callable<T> task;
            final CompletableFuture<T> cf = new CompletableFuture<>();
            Task(Callable<T> task) {
                this.task = task;
            }
            @Override
            public T call() throws Exception {
                try {
                    var res = task.call();
                    cf.complete(res);
                    return res;
                } catch (Throwable t) {
                    cf.completeExceptionally(t);
                    throw t;
                }
            }
            CompletableFuture<T> cf() {
                return cf;
            }
        }


        static class ShutdownOnFailure extends TestTaskScope {
            public ShutdownOnFailure() {}

            @Override
            protected <T> void completed(Task<T> task, T result, Throwable throwable) {
                super.completed(task, result, throwable);
                if (throwable != null) {
                    if (failed.get() == null) {
                        ExecutionException ex = throwable instanceof ExecutionException x
                                ? x : new ExecutionException(throwable);
                        failed.compareAndSet(null, ex);
                    }
                    tasks.entrySet().forEach(this::cancel);
                }
            }

            void cancel(Map.Entry<Task<?>, Future<?>> entry) {
                entry.getValue().cancel(true);
                entry.getKey().cf().cancel(true);
                tasks.remove(entry.getKey(), entry.getValue());
            }

            @Override
            public <T> CompletableFuture<T> fork(Callable<T> callable) {
                var ex = failed.get();
                if (ex == null) {
                    return super.fork(callable);
                } // otherwise do nothing
                return CompletableFuture.failedFuture(new RejectedExecutionException());
            }
        }

        public <T> CompletableFuture<T> fork(Callable<T> callable) {
            var task = new Task<>(callable);
            var res = pool.submit(task);
            tasks.put(task, res);
            task.cf.whenComplete((r,t) -> completed(task, r, t));
            return task.cf;
        }

        protected <T> void completed(Task<T> task, T result, Throwable throwable) {
            tasks.remove(task);
        }

        public void join() throws InterruptedException {
            try {
                var cfs = tasks.keySet().stream()
                        .map(Task::cf).toArray(CompletableFuture[]::new);
                CompletableFuture.allOf(cfs).get();
            } catch (InterruptedException it) {
                throw it;
            } catch (ExecutionException ex) {
                failed.compareAndSet(null, ex);
            }
        }

        public void throwIfFailed() throws ExecutionException {
            ExecutionException x = failed.get();
            if (x != null) throw x;
        }

        public void close() {
            pool.close();
        }
    }

    ExecutorService testExecutor() {
        return Executors.newCachedThreadPool();
    }

    void runTest(String url, int reqCount, Version version) {
        final var dest = URI.create(url);
        try (final var executor = testExecutor()) {
            var httpClient = makeClient(dest, version, executor);
            TRACKER.track(httpClient);
            Tracker tracker = TRACKER.getTracker(httpClient);
            Throwable failed = null;
            try {
                try (final var scope = new TestTaskScope.ShutdownOnFailure()) {
                    launchAndProcessRequests(scope, httpClient, reqCount, version, dest);
                } finally {
                    System.out.printf("StructuredTaskScope closed: STARTED=%s, SUCCESS=%s, INTERRUPT=%s, FAILED=%s%n",
                            STARTED.get(), SUCCESS.get(), INTERRUPT.get(), FAILED.get());
                }
                System.out.println("ERROR: Expected TestException not thrown");
                throw new AssertionError("Expected TestException not thrown");
            } catch (TestException x) {
                System.out.println("Got expected exception: " + x);
            } catch (Throwable t) {
                System.out.println("ERROR: Unexpected exception: " + t);
                failed = t;
                throw t;
            } finally {
                // we can either use the tracker or call HttpClient::close
                if (useTracker) {
                    // using the tracker depends on GC but will give us some diagnostic
                    // if some operations are not properly cancelled and prevent the client
                    // from terminating
                    httpClient = null;
                    System.gc();
                    System.out.println(TRACKER.diagnose(tracker));
                    var error = TRACKER.check(tracker, 10000);
                    if (error != null) {
                        if (failed != null) error.addSuppressed(failed);
                        EXCEPTIONS.forEach(x -> {
                            System.out.println("FAILED: " + x);
                        });
                        EXCEPTIONS.forEach(x -> {
                            x.printStackTrace(System.out);
                        });
                        throw error;
                    }
                } else {
                    // if not all operations terminate, close() will block
                    // forever and the test will fail in jtreg timeout.
                    // there will be no diagnostic.
                    httpClient.close();
                }
                System.out.println("HttpClient closed");
            }
        } finally {
            System.out.println("ThreadExecutor closed");
        }
        // not all tasks may have been started before the scope was cancelled
        // due to the first connect/timeout exception, but all tasks that started
        // must have either succeeded, be interrupted, or failed
        assertTrue(STARTED.get() > 0);
        assertEquals(STARTED.get(), SUCCESS.get() + INTERRUPT.get() + FAILED.get());
        if (SUCCESS.get() > 0) {
            // we don't expect any server to be listening and responding
            System.out.println("WARNING: got some unexpected successful responses from "
                    + "\"" + NOT_ACCEPTING.getLocalSocketAddress() + "\": " + SUCCESS.get());
        }
    }

    private void launchAndProcessRequests(
            TestTaskScope.ShutdownOnFailure scope,
            HttpClient httpClient,
            int reqCount,
            Version version,
            URI dest) {
        for (int counter = 0; counter < reqCount; counter++) {
            scope.fork(() ->
                    getAndCheck(httpClient, dest, version)
            );
        }
        try {
            scope.join();
        } catch (InterruptedException e) {
            throw new AssertionError("scope.join() was interrupted", e);
        }
        try {
            scope.throwIfFailed();
        } catch (ExecutionException e) {
            throw new TestException("something threw an exception in StructuredTaskScope", e);
        }
    }

    final static AtomicLong ID = new AtomicLong();
    final AtomicLong SUCCESS = new AtomicLong();
    final AtomicLong INTERRUPT = new AtomicLong();
    final AtomicLong FAILED = new AtomicLong();
    final AtomicLong STARTED = new AtomicLong();
    final CopyOnWriteArrayList<Exception> EXCEPTIONS = new CopyOnWriteArrayList<>();
    private String getAndCheck(HttpClient httpClient, URI url, Version version) {
        STARTED.incrementAndGet();
        final var response = sendRequest(httpClient, url, version);
        String res = response.body();
        int statusCode = response.statusCode();
        assertEquals(200, statusCode);
        return res;
    }

    private HttpResponse<String> sendRequest(HttpClient httpClient, URI url, Version version) {
        var id = ID.incrementAndGet();
        try {
            var builder = HttpRequest.newBuilder(url).version(version).GET();
            if (version == HTTP_3) builder.setOption(HttpOption.H3_DISCOVERY, HTTP_3_URI_ONLY);
            var request = builder.build();
            var response = httpClient.send(request, HttpResponse.BodyHandlers.ofString());
            // System.out.println("Got response for " + id + ": " + response);
            SUCCESS.incrementAndGet();
            return response;
        } catch (InterruptedException e) {
            INTERRUPT.incrementAndGet();
            // System.out.println("Got interrupted for " + id + ": " + e);
            throw new RuntimeException(e);
        } catch (Exception e) {
            FAILED.incrementAndGet();
            EXCEPTIONS.add(e);
            //System.out.println("Got exception for " + id + ": " + e);
            throw new RuntimeException(e);
        }
    }

    @AfterAll
    static void tearDown() {
        try {
            System.gc();
            var error = TRACKER.check(5000);
            if (error != null) throw error;
        } finally {
            ServerSocket ss;
            DatagramSocket ds;
            synchronized (HttpGetInCancelledFuture.class) {
                ss = NOT_ACCEPTING;
                NOT_ACCEPTING = null;
                ds = NOT_RESPONDING;
                NOT_RESPONDING = null;
            }
            try (var ss1 = ss; var ds1 = ds;) {
                System.out.printf("Cleaning up: ss=%s, ds=%s%n",
                        ss1.getLocalSocketAddress(), ds1.getLocalSocketAddress());
            } catch (IOException io) {
                throw new UncheckedIOException(io);
            }
        }
    }
}

